You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pli7oqjkf 448d632675
Update README.md
1 year ago
README.md Update README.md 1 year ago

README.md

shenduxuexi

import torch import torch.nn asnn import numpy as np import matplotlib.pyplot as plt

生成训练数据均匀取一些x值

x = torch.linspace(0, 1, 100).unsqueeze(1).float() y = 2 * x ** 8 + 3 * x ** 6 + torch.sqrt(x)

定义神经网络模型

class MLP(nn.Module): def init(self): super(MLP, self).init() self.fc1 = nn.Linear(1, 64) self.fc2 = nn.Linear(64, 32) self.fc3 = nn.Linear(32, 16) self.fc4 = nn.Linear(16, 1)

def forward(self, x):
    x = torch.relu(self.fc1(x))
    x = torch.relu(self.fc2(x))
    x = torch.relu(self.fc3(x))
    x = self.fc4(x)
    return x

model = MLP()

定义损失函数和优化器

criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

训练模型

epochs = 500 losses = [] for epoch in range(epochs): optimizer.zero_grad() outputs = model(x) loss = criterion(outputs, y) loss.backward() optimizer.step() losses.append(loss.item()) if (epoch + 1) % 50 == 0: print(f'Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}')

绘制拟合结果

plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.plot(x.detach().numpy(), y.detach().numpy(), label='True Function') plt.plot(x.detach().numpy(), model(x).detach().numpy(), label='Fitted Function') plt.xlabel('x') plt.ylabel('y') plt.title('Function Fitting Result') plt.legend()

绘制损失函数走势

plt.subplot(1, 2, 2) plt.plot(losses) plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('Loss Curve') plt.show()