forked from hnu202409060624/python
parent
d32ec7f19f
commit
d6b6d1be88
@ -0,0 +1,60 @@
|
||||
import json,torch,config
|
||||
import torch.nn as nn
|
||||
from chem_trans import Chem_trans
|
||||
from date_build import Chem_Dataset
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def train(model, train_loader, optimizer, epoch,criterion,L1=True,lambda_l1=1e-5):
|
||||
x=[]
|
||||
y=[]
|
||||
model.train()
|
||||
for i in range(epoch):
|
||||
total_loss = 0.0
|
||||
for batch_idx, (input_,edge_index,target,enc) in enumerate(train_loader):
|
||||
|
||||
|
||||
model = model.to(config.device)
|
||||
input_ = input_.to(config.device)
|
||||
enc=enc.to(config.device)
|
||||
optimizer.zero_grad()
|
||||
target = target.to(config.device)
|
||||
|
||||
output = model(input_,edge_index.to(config.device),enc)
|
||||
|
||||
print(output)
|
||||
print(target)
|
||||
|
||||
loss = criterion(output,target.float())
|
||||
|
||||
if L1:
|
||||
# 计算 L1 正则化项
|
||||
l1_norm = sum(p.abs().sum() for p in model.parameters())
|
||||
# 将 L1 正则化项添加到损失中
|
||||
loss += lambda_l1 * l1_norm
|
||||
# 反向传播和优化
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 累计损失
|
||||
total_loss =total_loss+ loss.item()
|
||||
torch.cuda.empty_cache()
|
||||
print(f'epoch: {i}, loss: {loss.item()}')
|
||||
if i % 50 == 0:
|
||||
plt.figure()
|
||||
plt.plot(x, y, label='Line 1')
|
||||
plt.title('Simple Plot')
|
||||
plt.xlabel('X epoch')
|
||||
plt.ylabel('Y Loss')
|
||||
torch.save(model,f'models\model3.1({i}).pt')
|
||||
plt.savefig(f'training_progress/model3.1({i}).png')
|
||||
x.append(i)
|
||||
y.append(total_loss/len(train_loader))
|
||||
|
||||
model=Chem_trans().to(config.device)
|
||||
|
||||
|
||||
criterion = nn.L1Loss().to(config.device)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.000002,)
|
||||
if __name__ == '__main__':
|
||||
train(model,Chem_Dataset,optimizer,epoch=800,criterion=criterion,L1=False,lambda_l1=1e-5)
|
Loading…
Reference in new issue