You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

325 lines
11 KiB

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import numpy as np
from models.autoencoder import Autoencoder
from models.simplecnn import SimpleCNN
from data.dataset import LungXrayDataset
def add_noise(images, noise_factor=0.3):
"""添加高斯噪声"""
noisy_images = images + noise_factor * torch.randn_like(images)
return torch.clamp(noisy_images, 0., 1.)
def plot_confusion_matrix(cm, classes, output_path):
"""绘制混淆矩阵"""
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=classes, yticklabels=classes)
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig(output_path)
plt.close()
def train_cnn(cnn_model, autoencoder, lr, train_loader, test_loader, num_epochs=100,
device='cuda', output_dir='results_cnn', noise_factor=0.3):
"""
训练CNN模型
Args:
cnn_model: CNN模型
autoencoder: 预训练的autoencoder模型
lr: 学习率
train_loader: 训练数据加载器
test_loader: 测试数据加载器
num_epochs: 训练轮数
device: 使用的设备
output_dir: 输出目录
noise_factor: 噪声因子
"""
# 创建输出目录
checkpoint_dir = os.path.join(output_dir, 'checkpoints')
tensorboard_dir = os.path.join(output_dir, 'tensorboard')
plot_dir = os.path.join(output_dir, 'plots')
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(tensorboard_dir, exist_ok=True)
os.makedirs(plot_dir, exist_ok=True)
# 初始化TensorBoard writer
writer = SummaryWriter(tensorboard_dir)
# 将模型移至指定设备
cnn_model = cnn_model.to(device)
autoencoder = autoencoder.to(device)
autoencoder.eval() # 设置autoencoder为评估模式
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn_model.parameters(), lr=lr)
# 记录最佳模型
best_test_acc = 0.0
# 用于记录训练历史
history = {
'train_loss': [],
'test_loss': [],
'train_acc': [],
'test_acc': []
}
# 训练开始时间
start_time = time.time()
global_step = 0
# 类别名称
classes = ['Covid', 'Normal', 'Viral Pneumonia']
# 训练循环
for epoch in range(num_epochs):
# 训练阶段
cnn_model.train()
train_loss = 0
train_correct = 0
train_total = 0
train_pbar = tqdm(train_loader, desc=f'Epoch [{epoch+1}/{num_epochs}] Training')
for batch_idx, (data, targets) in enumerate(train_pbar):
data, targets = data.to(device), targets.to(device)
# 添加噪声
noisy_data = add_noise(data, noise_factor)
# 通过autoencoder降噪
with torch.no_grad():
denoised_data = autoencoder(noisy_data)
# 前向传播
outputs = cnn_model(denoised_data)
loss = criterion(outputs, targets)
# 计算准确率
_, predicted = outputs.max(1)
train_total += targets.size(0)
train_correct += predicted.eq(targets).sum().item()
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 记录损失
train_loss += loss.item()
# 记录到TensorBoard
writer.add_scalar('Loss/train_step', loss.item(), global_step)
# 更新进度条
train_pbar.set_postfix({
'loss': loss.item(),
'acc': 100. * train_correct / train_total
})
global_step += 1
# 计算平均训练指标
train_loss = train_loss / len(train_loader)
train_acc = 100. * train_correct / train_total
# 测试阶段
cnn_model.eval()
test_loss = 0
test_correct = 0
test_total = 0
all_predictions = []
all_targets = []
with torch.no_grad():
test_pbar = tqdm(test_loader, desc=f'Epoch [{epoch+1}/{num_epochs}] Testing')
for data, targets in test_pbar:
data, targets = data.to(device), targets.to(device)
# 添加噪声并通过autoencoder降噪
noisy_data = add_noise(data, noise_factor)
denoised_data = autoencoder(noisy_data)
outputs = cnn_model(denoised_data)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
test_total += targets.size(0)
test_correct += predicted.eq(targets).sum().item()
# 收集预测结果用于混淆矩阵
all_predictions.extend(predicted.cpu().numpy())
all_targets.extend(targets.cpu().numpy())
test_pbar.set_postfix({
'loss': loss.item(),
'acc': 100. * test_correct / test_total
})
# 计算平均测试指标
test_loss = test_loss / len(test_loader)
test_acc = 100. * test_correct / test_total
# 记录历史
history['train_loss'].append(train_loss)
history['test_loss'].append(test_loss)
history['train_acc'].append(train_acc)
history['test_acc'].append(test_acc)
# 记录到TensorBoard
writer.add_scalars('Loss/epoch', {
'train': train_loss,
'test': test_loss
}, epoch)
writer.add_scalars('Accuracy/epoch', {
'train': train_acc,
'test': test_acc
}, epoch)
# 每个epoch结束时绘制混淆矩阵
cm = confusion_matrix(all_targets, all_predictions)
plot_confusion_matrix(cm, classes,
os.path.join(plot_dir, f'confusion_matrix_epoch_{epoch+1}.png'))
# 打印分类报告
report = classification_report(all_targets, all_predictions, target_names=classes)
print(f"\nClassification Report - Epoch {epoch+1}:")
print(report)
# 打印进度
elapsed_time = time.time() - start_time
print(f'Epoch [{epoch+1}/{num_epochs}], '
f'Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, '
f'Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%, '
f'Time: {elapsed_time:.2f}s')
# 保存最佳模型
if test_acc > best_test_acc:
best_test_acc = test_acc
torch.save(cnn_model.state_dict(),
os.path.join(checkpoint_dir, 'best_model.pth'))
# 每10个epoch保存检查点和绘制图表
if (epoch + 1) % 10 == 0:
# 保存检查点
checkpoint = {
'epoch': epoch,
'model_state_dict': cnn_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'test_loss': test_loss,
'train_acc': train_acc,
'test_acc': test_acc
}
torch.save(checkpoint,
os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth'))
# 绘制损失和准确率曲线
plt.figure(figsize=(12, 5))
# 损失曲线
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['test_loss'], label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Testing Losses')
# 准确率曲线
plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train Acc')
plt.plot(history['test_acc'], label='Test Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Training and Testing Accuracies')
plt.tight_layout()
plt.savefig(os.path.join(plot_dir, f'metrics_epoch_{epoch+1}.png'))
plt.close()
# 保存最终模型
torch.save(cnn_model.state_dict(),
os.path.join(checkpoint_dir, 'final_model.pth'))
# 绘制最终的损失和准确率曲线
plt.figure(figsize=(12, 5))
# 损失曲线
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['test_loss'], label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Testing Losses')
# 准确率曲线
plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train Acc')
plt.plot(history['test_acc'], label='Test Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Training and Testing Accuracies')
plt.tight_layout()
plt.savefig(os.path.join(plot_dir, 'final_metrics.png'))
plt.close()
# 记录总训练时间
total_time = time.time() - start_time
print(f'Training completed in {total_time:.2f} seconds')
# 关闭TensorBoard writer
writer.close()
return history
if __name__ == "__main__":
# 设置随机种子
torch.manual_seed(42)
# 检查是否可以使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
# 创建数据加载器
train_dataset = LungXrayDataset(root_dir="/home/lgz/Code/class/ML/e1/covid19", is_train=True)
test_dataset = LungXrayDataset(root_dir="/home/lgz/Code/class/ML/e1/covid19", is_train=False)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# 加载预训练的autoencoder
autoencoder = Autoencoder()
autoencoder.load_state_dict(torch.load('results/checkpoints/best_model.pth'))
# 创建CNN模型
cnn_model = SimpleCNN()
# 训练CNN模型
history = train_cnn(
cnn_model=cnn_model,
autoencoder=autoencoder,
lr=1e-3,
train_loader=train_loader,
test_loader=test_loader,
num_epochs=100,
device=device,
noise_factor=0.3
)