forked from hnu202409060624/python
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
61 lines
1.9 KiB
61 lines
1.9 KiB
import json,torch,config
|
|
import torch.nn as nn
|
|
from chem_trans import Chem_trans
|
|
from date_build import Chem_Dataset
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
def train(model, train_loader, optimizer, epoch,criterion,L1=True,lambda_l1=1e-5):
|
|
x=[]
|
|
y=[]
|
|
model.train()
|
|
for i in range(epoch):
|
|
total_loss = 0.0
|
|
for batch_idx, (input_,edge_index,target,enc) in enumerate(train_loader):
|
|
|
|
|
|
model = model.to(config.device)
|
|
input_ = input_.to(config.device)
|
|
enc=enc.to(config.device)
|
|
optimizer.zero_grad()
|
|
target = target.to(config.device)
|
|
|
|
output = model(input_,edge_index.to(config.device),enc)
|
|
|
|
print(output)
|
|
print(target)
|
|
|
|
loss = criterion(output,target.float())
|
|
|
|
if L1:
|
|
# 计算 L1 正则化项
|
|
l1_norm = sum(p.abs().sum() for p in model.parameters())
|
|
# 将 L1 正则化项添加到损失中
|
|
loss += lambda_l1 * l1_norm
|
|
# 反向传播和优化
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# 累计损失
|
|
total_loss =total_loss+ loss.item()
|
|
torch.cuda.empty_cache()
|
|
print(f'epoch: {i}, loss: {loss.item()}')
|
|
if i % 50 == 0:
|
|
plt.figure()
|
|
plt.plot(x, y, label='Line 1')
|
|
plt.title('Simple Plot')
|
|
plt.xlabel('X epoch')
|
|
plt.ylabel('Y Loss')
|
|
torch.save(model,f'models\model3.1({i}).pt')
|
|
plt.savefig(f'training_progress/model3.1({i}).png')
|
|
x.append(i)
|
|
y.append(total_loss/len(train_loader))
|
|
|
|
model=Chem_trans().to(config.device)
|
|
|
|
|
|
criterion = nn.L1Loss().to(config.device)
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.000002,)
|
|
if __name__ == '__main__':
|
|
train(model,Chem_Dataset,optimizer,epoch=800,criterion=criterion,L1=False,lambda_l1=1e-5)
|