import json,torch,config import torch.nn as nn from chem_trans import Chem_trans from date_build import Chem_Dataset import matplotlib.pyplot as plt def train(model, train_loader, optimizer, epoch,criterion,L1=True,lambda_l1=1e-5): x=[] y=[] model.train() for i in range(epoch): total_loss = 0.0 for batch_idx, (input_,edge_index,target,enc) in enumerate(train_loader): model = model.to(config.device) input_ = input_.to(config.device) enc=enc.to(config.device) optimizer.zero_grad() target = target.to(config.device) output = model(input_,edge_index.to(config.device),enc) print(output) print(target) loss = criterion(output,target.float()) if L1: # 计算 L1 正则化项 l1_norm = sum(p.abs().sum() for p in model.parameters()) # 将 L1 正则化项添加到损失中 loss += lambda_l1 * l1_norm # 反向传播和优化 loss.backward() optimizer.step() # 累计损失 total_loss =total_loss+ loss.item() torch.cuda.empty_cache() print(f'epoch: {i}, loss: {loss.item()}') if i % 50 == 0: plt.figure() plt.plot(x, y, label='Line 1') plt.title('Simple Plot') plt.xlabel('X epoch') plt.ylabel('Y Loss') torch.save(model,f'models\model3.1({i}).pt') plt.savefig(f'training_progress/model3.1({i}).png') x.append(i) y.append(total_loss/len(train_loader)) model=Chem_trans().to(config.device) criterion = nn.L1Loss().to(config.device) optimizer = torch.optim.SGD(model.parameters(), lr=0.000002,) if __name__ == '__main__': train(model,Chem_Dataset,optimizer,epoch=800,criterion=criterion,L1=False,lambda_l1=1e-5)