From 60af4f0345f3ec224abc4ccfdedd5043460b51b9 Mon Sep 17 00:00:00 2001 From: hnu202409060624 <2804411502@qq.com> Date: Mon, 30 Dec 2024 18:29:13 +0800 Subject: [PATCH] Delete 'main.py' --- main.py | 63 --------------------------------------------------------- 1 file changed, 63 deletions(-) delete mode 100644 main.py diff --git a/main.py b/main.py deleted file mode 100644 index 7be0943..0000000 --- a/main.py +++ /dev/null @@ -1,63 +0,0 @@ -import json,torch,config -import torch.nn as nn -from date_build import Chem_Dataset -from chem_trans import Chem_trans -from torch.utils.data import Dataset, DataLoader -from date_build import chem_collate_fn -from date_build import Chem_Dataset -import matplotlib.pyplot as plt -import os - -def train(model, train_loader, optimizer, epoch,criterion,L1=True,lambda_l1=1e-5): - x=[] - y=[] - model.train() - for i in range(epoch): - total_loss = 0.0 - for batch_idx, (input_,edge_index,target,enc) in enumerate(train_loader): - - - model = model.to(config.device) - input_ = input_.to(config.device) - enc=enc.to(config.device) - optimizer.zero_grad() - target = target.to(config.device) - - output = model(input_,edge_index.to(config.device),enc) - - print(output) - print(target) - - loss = criterion(output,target.float()) - - if L1: - # 计算 L1 正则化项 - l1_norm = sum(p.abs().sum() for p in model.parameters()) - # 将 L1 正则化项添加到损失中 - loss += lambda_l1 * l1_norm - # 反向传播和优化 - loss.backward() - optimizer.step() - - # 累计损失 - total_loss =total_loss+ loss.item() - torch.cuda.empty_cache() - print(f'epoch: {i}, loss: {loss.item()}') - if i % 50 == 0: - plt.figure() - plt.plot(x, y, label='Line 1') - plt.title('Simple Plot') - plt.xlabel('X epoch') - plt.ylabel('Y Loss') - torch.save(model,f'D:\models\model3.1({i}).pt') - plt.savefig(f'training_progress/model3.1({i}).png') - x.append(i) - y.append(total_loss/len(train_loader)) - -model=Chem_trans().to(config.device) - - -criterion = nn.L1Loss().to(config.device) -optimizer = torch.optim.SGD(model.parameters(), lr=0.000002,) -if __name__ == '__main__': - train(model,Chem_Dataset,optimizer,epoch=800,criterion=criterion,L1=False,lambda_l1=1e-5)