You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
182 lines
6.1 KiB
182 lines
6.1 KiB
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader
|
|
from torchvision.utils import make_grid
|
|
import os
|
|
import matplotlib.pyplot as plt
|
|
from data.dataset import LungXrayDataset
|
|
from models.autoencoder import Autoencoder
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
import time
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
def train_autoencoder(model, lr, train_loader, test_loader, num_epochs=100, device='cuda', output_dir='results_autoencoder'):
|
|
"""
|
|
训练自编码器, 使用TensorBoard进行可视化
|
|
|
|
Args:
|
|
model: 自编码器模型
|
|
train_loader: 训练数据加载器
|
|
test_loader: 测试数据加载器
|
|
num_epochs: 训练轮数
|
|
device: 使用的设备 ('cuda' 或 'cpu')
|
|
output_dir: 输出目录
|
|
"""
|
|
# 创建输出目录
|
|
checkpoint_dir = os.path.join(output_dir, 'checkpoints')
|
|
tensorboard_dir = os.path.join(output_dir, 'tensorboard')
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
os.makedirs(tensorboard_dir, exist_ok=True)
|
|
|
|
# 初始化TensorBoard writer
|
|
writer = SummaryWriter(tensorboard_dir)
|
|
|
|
# 将模型移至指定设备
|
|
model = model.to(device)
|
|
|
|
# 定义损失函数和优化器
|
|
criterion = nn.MSELoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=lr)# 记录最佳模型
|
|
best_test_loss = float('inf')
|
|
|
|
# 获取一批固定的测试数据用于可视化
|
|
fixed_test_data, _ = next(iter(test_loader))
|
|
fixed_test_data = fixed_test_data.to(device)
|
|
|
|
# 添加模型图到TensorBoard
|
|
writer.add_graph(model, fixed_test_data)
|
|
|
|
# 训练开始时间
|
|
start_time = time.time()
|
|
global_step = 0
|
|
|
|
# 训练循环
|
|
for epoch in range(num_epochs):
|
|
# 训练阶段
|
|
model.train()
|
|
train_loss = 0
|
|
train_pbar = tqdm(train_loader, desc=f'Epoch [{epoch+1}/{num_epochs}] Training')
|
|
|
|
for batch_idx, (data, _) in enumerate(train_pbar):
|
|
data = data.to(device)
|
|
|
|
# 前向传播
|
|
output = model(data)
|
|
loss = criterion(output, data)
|
|
|
|
# 反向传播和优化
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# 记录损失
|
|
train_loss += loss.item()
|
|
writer.add_scalar('Loss/train_step', loss.item(), global_step)
|
|
|
|
# 更新进度条
|
|
train_pbar.set_postfix({'loss': loss.item()})
|
|
global_step += 1
|
|
|
|
# 计算平均训练损失
|
|
train_loss /= len(train_loader)
|
|
|
|
# 测试阶段
|
|
model.eval()
|
|
test_loss = 0
|
|
test_pbar = tqdm(test_loader, desc=f'Epoch [{epoch+1}/{num_epochs}] Testing')
|
|
|
|
with torch.no_grad():
|
|
for data, _ in test_pbar:
|
|
data = data.to(device)
|
|
output = model(data)
|
|
loss = criterion(output, data)
|
|
test_loss += loss.item()
|
|
test_pbar.set_postfix({'loss': loss.item()})
|
|
|
|
test_loss /= len(test_loader)
|
|
|
|
# 记录每个epoch的损失
|
|
writer.add_scalars('Loss/epoch', {
|
|
'train': train_loss,
|
|
'test': test_loss
|
|
}, epoch)
|
|
|
|
# 可视化重建结果
|
|
with torch.no_grad():
|
|
reconstructed = model(fixed_test_data)
|
|
# 创建原始图像和重建图像的对比网格
|
|
comparison = torch.cat([fixed_test_data[:8], reconstructed[:8]])
|
|
grid = make_grid(comparison, nrow=8, normalize=True)
|
|
writer.add_image('Reconstruction', grid, epoch)
|
|
|
|
# 记录模型参数分布
|
|
for name, param in model.named_parameters():
|
|
writer.add_histogram(f'Parameters/{name}', param, epoch)
|
|
if param.grad is not None:
|
|
writer.add_histogram(f'Gradients/{name}', param.grad, epoch)
|
|
|
|
# 打印进度
|
|
elapsed_time = time.time() - start_time
|
|
print(f'Epoch [{epoch+1}/{num_epochs}], '
|
|
f'Train Loss: {train_loss:.6f}, '
|
|
f'Test Loss: {test_loss:.6f}, '
|
|
f'Time: {elapsed_time:.2f}s')
|
|
|
|
# 保存最佳模型
|
|
if test_loss < best_test_loss:
|
|
best_test_loss = test_loss
|
|
torch.save(model.state_dict(),
|
|
os.path.join(checkpoint_dir, 'best_model.pth'))
|
|
|
|
# 每10个epoch保存一次检查点
|
|
if (epoch + 1) % 10 == 0:
|
|
checkpoint = {
|
|
'epoch': epoch,
|
|
'model_state_dict': model.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'train_loss': train_loss,
|
|
'test_loss': test_loss,
|
|
}
|
|
torch.save(checkpoint,
|
|
os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth'))
|
|
|
|
# 保存最终模型
|
|
torch.save(model.state_dict(),
|
|
os.path.join(checkpoint_dir, 'final_model.pth'))
|
|
|
|
# 记录总训练时间
|
|
total_time = time.time() - start_time
|
|
print(f'Training completed in {total_time:.2f} seconds')
|
|
|
|
# 关闭TensorBoard writer
|
|
writer.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 设置随机种子
|
|
torch.manual_seed(42)
|
|
|
|
# 检查是否可以使用GPU
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
print(f'Using device: {device}')
|
|
|
|
# 创建数据加载器
|
|
train_dataset = LungXrayDataset(root_dir="/home/lgz/Code/class/ML/e1/covid19", is_train=True)
|
|
test_dataset = LungXrayDataset(root_dir="/home/lgz/Code/class/ML/e1/covid19", is_train=False)
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
|
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
|
|
|
# 创建模型
|
|
model = Autoencoder()# 训练模型
|
|
train_losses, test_losses = train_autoencoder(
|
|
model=model,
|
|
lr=1e-3,
|
|
train_loader=train_loader,
|
|
test_loader=test_loader,
|
|
num_epochs=100,
|
|
device=device
|
|
) |