ADD file via upload

main
hnu202409060624 8 months ago
parent 219bdc8381
commit 451825159f

@ -0,0 +1,97 @@
import torch
import torch.nn.functional as F
import config
import json
import numpy as np
from date_build import Chem_Dataset
model=torch.load('D:\models\model3.1(150).pt')
def padding( chemic, lens=config.seq_len):
while len(chemic) < lens:
chemic.append(0)
return chemic
def target_padding( chemical, lens=config.seq_len):
chemical.append(1)
while len(chemical)<lens:
chemical.append(0)
return chemical
class Chemical_Trans:
def __init__(self,len=config.seq_len):
with open('vocab/vocabulary.json', 'r', encoding='utf-8') as f:
f = json.load(f)
self.vocab = f
self.len = len
def CCO(self,adj_matrix):
source_nodes = []
target_nodes = []
# 遍历邻接矩阵只记录i < j的边
for i in range(len(adj_matrix)):
for j in range(i + 1, len(adj_matrix[i])): # 注意这里从i + 1开始
if adj_matrix[i][j] != 0: # 存在边
source_nodes.append(i)
target_nodes.append(j)
return [source_nodes, target_nodes]
def chemical_trans(self, chemical):
chemical = chemical.split('-')
sequence = []
for i in chemical:
sequence.append(self.vocab[i])
return sequence
def generate_adjacency_matrix(self,num_vertices, edges, directed=False):
# 初始化邻接矩阵为0
adjacency_matrix = [[0] * num_vertices for _ in range(num_vertices)]
print(edges)
# 遍历边的列表,更新邻接矩阵
for u, v in edges:
# 确保顶点索引在有效范围内
if 0 <= u-1 < num_vertices and 0 <= v-1 < num_vertices:
adjacency_matrix[u-1][v-1] = 1
if not directed:
adjacency_matrix[v-1][u-1] = 1
return adjacency_matrix
def adjacency_to_degree_matrix(self,adjacency_matrix):
# 确保输入是NumPy数组
adjacency_matrix = np.array(adjacency_matrix)
# 计算每个顶点的度数(即每行的和)
degrees = np.sum(adjacency_matrix, axis=1)
# 构造度矩阵(对角矩阵,对角线元素为度数)
degree_matrix = np.zeros_like(adjacency_matrix)
np.fill_diagonal(degree_matrix, degrees)
return degree_matrix
def __len__(self):
return len(self.vocab)
def __getitem__(self, key):
return self.vocab.keys()
Chemical_Trans=Chemical_Trans()
model.eval()
model.to(config.device)
with torch.no_grad():
losses=[]
for batch_idx, (input_,edge_index,target,enc) in enumerate(Chem_Dataset):
model = model.to(config.device)
input_ = input_.to(config.device)
output = model(input_,edge_index.to(config.device),enc.to(config.device))
losses.append(F.l1_loss(output.to('cpu'),target.to('cpu')).item())
print(output)
print(target)
print(sum(losses)/len(losses))
Loading…
Cancel
Save