You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

97 lines
3.0 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import torch
import torch.nn.functional as F
import config
import json
import numpy as np
from date_build import Chem_Dataset
model=torch.load('models\model3.1(150).pt')
def padding( chemic, lens=config.seq_len):
while len(chemic) < lens:
chemic.append(0)
return chemic
def target_padding( chemical, lens=config.seq_len):
chemical.append(1)
while len(chemical)<lens:
chemical.append(0)
return chemical
class Chemical_Trans:
def __init__(self,len=config.seq_len):
with open('vocab/vocabulary.json', 'r', encoding='utf-8') as f:
f = json.load(f)
self.vocab = f
self.len = len
def CCO(self,adj_matrix):
source_nodes = []
target_nodes = []
# 遍历邻接矩阵只记录i < j的边
for i in range(len(adj_matrix)):
for j in range(i + 1, len(adj_matrix[i])): # 注意这里从i + 1开始
if adj_matrix[i][j] != 0: # 存在边
source_nodes.append(i)
target_nodes.append(j)
return [source_nodes, target_nodes]
def chemical_trans(self, chemical):
chemical = chemical.split('-')
sequence = []
for i in chemical:
sequence.append(self.vocab[i])
return sequence
def generate_adjacency_matrix(self,num_vertices, edges, directed=False):
# 初始化邻接矩阵为0
adjacency_matrix = [[0] * num_vertices for _ in range(num_vertices)]
print(edges)
# 遍历边的列表,更新邻接矩阵
for u, v in edges:
# 确保顶点索引在有效范围内
if 0 <= u-1 < num_vertices and 0 <= v-1 < num_vertices:
adjacency_matrix[u-1][v-1] = 1
if not directed:
adjacency_matrix[v-1][u-1] = 1
return adjacency_matrix
def adjacency_to_degree_matrix(self,adjacency_matrix):
# 确保输入是NumPy数组
adjacency_matrix = np.array(adjacency_matrix)
# 计算每个顶点的度数(即每行的和)
degrees = np.sum(adjacency_matrix, axis=1)
# 构造度矩阵(对角矩阵,对角线元素为度数)
degree_matrix = np.zeros_like(adjacency_matrix)
np.fill_diagonal(degree_matrix, degrees)
return degree_matrix
def __len__(self):
return len(self.vocab)
def __getitem__(self, key):
return self.vocab.keys()
Chemical_Trans=Chemical_Trans()
model.eval()
model.to(config.device)
with torch.no_grad():
losses=[]
for batch_idx, (input_,edge_index,target,enc) in enumerate(Chem_Dataset):
model = model.to(config.device)
input_ = input_.to(config.device)
output = model(input_,edge_index.to(config.device),enc.to(config.device))
losses.append(F.l1_loss(output.to('cpu'),target.to('cpu')).item())
print(output)
print(target)
print(sum(losses)/len(losses))