forked from hnu202409060624/python
parent
9ef2e7ccde
commit
60af4f0345
@ -1,63 +0,0 @@
|
|||||||
import json,torch,config
|
|
||||||
import torch.nn as nn
|
|
||||||
from date_build import Chem_Dataset
|
|
||||||
from chem_trans import Chem_trans
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
from date_build import chem_collate_fn
|
|
||||||
from date_build import Chem_Dataset
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import os
|
|
||||||
|
|
||||||
def train(model, train_loader, optimizer, epoch,criterion,L1=True,lambda_l1=1e-5):
|
|
||||||
x=[]
|
|
||||||
y=[]
|
|
||||||
model.train()
|
|
||||||
for i in range(epoch):
|
|
||||||
total_loss = 0.0
|
|
||||||
for batch_idx, (input_,edge_index,target,enc) in enumerate(train_loader):
|
|
||||||
|
|
||||||
|
|
||||||
model = model.to(config.device)
|
|
||||||
input_ = input_.to(config.device)
|
|
||||||
enc=enc.to(config.device)
|
|
||||||
optimizer.zero_grad()
|
|
||||||
target = target.to(config.device)
|
|
||||||
|
|
||||||
output = model(input_,edge_index.to(config.device),enc)
|
|
||||||
|
|
||||||
print(output)
|
|
||||||
print(target)
|
|
||||||
|
|
||||||
loss = criterion(output,target.float())
|
|
||||||
|
|
||||||
if L1:
|
|
||||||
# 计算 L1 正则化项
|
|
||||||
l1_norm = sum(p.abs().sum() for p in model.parameters())
|
|
||||||
# 将 L1 正则化项添加到损失中
|
|
||||||
loss += lambda_l1 * l1_norm
|
|
||||||
# 反向传播和优化
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# 累计损失
|
|
||||||
total_loss =total_loss+ loss.item()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
print(f'epoch: {i}, loss: {loss.item()}')
|
|
||||||
if i % 50 == 0:
|
|
||||||
plt.figure()
|
|
||||||
plt.plot(x, y, label='Line 1')
|
|
||||||
plt.title('Simple Plot')
|
|
||||||
plt.xlabel('X epoch')
|
|
||||||
plt.ylabel('Y Loss')
|
|
||||||
torch.save(model,f'D:\models\model3.1({i}).pt')
|
|
||||||
plt.savefig(f'training_progress/model3.1({i}).png')
|
|
||||||
x.append(i)
|
|
||||||
y.append(total_loss/len(train_loader))
|
|
||||||
|
|
||||||
model=Chem_trans().to(config.device)
|
|
||||||
|
|
||||||
|
|
||||||
criterion = nn.L1Loss().to(config.device)
|
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.000002,)
|
|
||||||
if __name__ == '__main__':
|
|
||||||
train(model,Chem_Dataset,optimizer,epoch=800,criterion=criterion,L1=False,lambda_l1=1e-5)
|
|
Loading…
Reference in new issue