From 1980f0b1f136dd013189620f3658e8d6a56eff8f Mon Sep 17 00:00:00 2001 From: hnu202409060624 <2804411502@qq.com> Date: Mon, 30 Dec 2024 18:26:06 +0800 Subject: [PATCH] ADD file via upload --- main.py | 63 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 main.py diff --git a/main.py b/main.py new file mode 100644 index 0000000..7be0943 --- /dev/null +++ b/main.py @@ -0,0 +1,63 @@ +import json,torch,config +import torch.nn as nn +from date_build import Chem_Dataset +from chem_trans import Chem_trans +from torch.utils.data import Dataset, DataLoader +from date_build import chem_collate_fn +from date_build import Chem_Dataset +import matplotlib.pyplot as plt +import os + +def train(model, train_loader, optimizer, epoch,criterion,L1=True,lambda_l1=1e-5): + x=[] + y=[] + model.train() + for i in range(epoch): + total_loss = 0.0 + for batch_idx, (input_,edge_index,target,enc) in enumerate(train_loader): + + + model = model.to(config.device) + input_ = input_.to(config.device) + enc=enc.to(config.device) + optimizer.zero_grad() + target = target.to(config.device) + + output = model(input_,edge_index.to(config.device),enc) + + print(output) + print(target) + + loss = criterion(output,target.float()) + + if L1: + # 计算 L1 正则化项 + l1_norm = sum(p.abs().sum() for p in model.parameters()) + # 将 L1 正则化项添加到损失中 + loss += lambda_l1 * l1_norm + # 反向传播和优化 + loss.backward() + optimizer.step() + + # 累计损失 + total_loss =total_loss+ loss.item() + torch.cuda.empty_cache() + print(f'epoch: {i}, loss: {loss.item()}') + if i % 50 == 0: + plt.figure() + plt.plot(x, y, label='Line 1') + plt.title('Simple Plot') + plt.xlabel('X epoch') + plt.ylabel('Y Loss') + torch.save(model,f'D:\models\model3.1({i}).pt') + plt.savefig(f'training_progress/model3.1({i}).png') + x.append(i) + y.append(total_loss/len(train_loader)) + +model=Chem_trans().to(config.device) + + +criterion = nn.L1Loss().to(config.device) +optimizer = torch.optim.SGD(model.parameters(), lr=0.000002,) +if __name__ == '__main__': + train(model,Chem_Dataset,optimizer,epoch=800,criterion=criterion,L1=False,lambda_l1=1e-5)