ADD file via upload

main
hnu202409060624 8 months ago
parent d32ec7f19f
commit d6b6d1be88

@ -0,0 +1,60 @@
import json,torch,config
import torch.nn as nn
from chem_trans import Chem_trans
from date_build import Chem_Dataset
import matplotlib.pyplot as plt
def train(model, train_loader, optimizer, epoch,criterion,L1=True,lambda_l1=1e-5):
x=[]
y=[]
model.train()
for i in range(epoch):
total_loss = 0.0
for batch_idx, (input_,edge_index,target,enc) in enumerate(train_loader):
model = model.to(config.device)
input_ = input_.to(config.device)
enc=enc.to(config.device)
optimizer.zero_grad()
target = target.to(config.device)
output = model(input_,edge_index.to(config.device),enc)
print(output)
print(target)
loss = criterion(output,target.float())
if L1:
# 计算 L1 正则化项
l1_norm = sum(p.abs().sum() for p in model.parameters())
# 将 L1 正则化项添加到损失中
loss += lambda_l1 * l1_norm
# 反向传播和优化
loss.backward()
optimizer.step()
# 累计损失
total_loss =total_loss+ loss.item()
torch.cuda.empty_cache()
print(f'epoch: {i}, loss: {loss.item()}')
if i % 50 == 0:
plt.figure()
plt.plot(x, y, label='Line 1')
plt.title('Simple Plot')
plt.xlabel('X epoch')
plt.ylabel('Y Loss')
torch.save(model,f'models\model3.1({i}).pt')
plt.savefig(f'training_progress/model3.1({i}).png')
x.append(i)
y.append(total_loss/len(train_loader))
model=Chem_trans().to(config.device)
criterion = nn.L1Loss().to(config.device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.000002,)
if __name__ == '__main__':
train(model,Chem_Dataset,optimizer,epoch=800,criterion=criterion,L1=False,lambda_l1=1e-5)
Loading…
Cancel
Save