parent
060ef84fec
commit
69d720a171
@ -0,0 +1,44 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Virtual environment
|
||||
.venv/
|
||||
env/
|
||||
venv/
|
||||
|
||||
# Logs and checkpoints
|
||||
*.log
|
||||
*.pth
|
||||
*.pt
|
||||
|
||||
# TensorBoard events
|
||||
events.out.tfevents.*
|
||||
|
||||
# Results
|
||||
results/**/checkpoints/
|
||||
results/**/tensorboard/
|
||||
results/**/plots/
|
||||
results/**/metrics*.png
|
||||
results/**/confusion_matrix*.png
|
||||
results/**/final_metrics.png
|
||||
|
||||
# Jupyter Notebook checkpoints
|
||||
.ipynb_checkpoints
|
||||
|
||||
# PyTorch Lightning cache (if used)
|
||||
lightning_logs/
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
|
||||
# System files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Config (optional: if you want to keep it private)
|
||||
# config/config.yaml
|
||||
|
||||
# Ignore model-specific large files optionally
|
||||
# results/**/checkpoints/*.pth
|
@ -0,0 +1,20 @@
|
||||
training:
|
||||
num_epochs: 100
|
||||
batch_size: 32
|
||||
learning_rate: 0.001
|
||||
save_interval: 10
|
||||
|
||||
model:
|
||||
input_channels: 1
|
||||
hidden_channels: 32
|
||||
latent_channels: 64
|
||||
|
||||
data:
|
||||
image_size: 256
|
||||
train_dir: "data/train"
|
||||
test_dir: "data/noisy_test"
|
||||
preprocess:
|
||||
resize_size: [256, 256] # 图像调整大小,需要与 dataset.py 中的 Resize 对应
|
||||
normalize: True # 是否进行标准化
|
||||
mean: [0.5] # 灰度图像的均值,需要与 dataset.py 中的 Normalize 对应
|
||||
std: [0.5] # 灰度图像的标准差,需要与 dataset.py 中的 Normalize 对应
|
@ -0,0 +1,66 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
class LungXrayDataset(Dataset):
|
||||
def __init__(self, root_dir, is_train=True, transform=None):
|
||||
"""
|
||||
参数:
|
||||
root_dir (str): 数据集根目录
|
||||
is_train (bool): 是否为训练集
|
||||
transform (callable, optional): 可选的图像预处理
|
||||
"""
|
||||
self.root_dir = root_dir
|
||||
self.is_train = is_train
|
||||
self.transform = transform
|
||||
self.classes = ['Covid', 'Normal', 'Viral Pneumonia']
|
||||
|
||||
# 设置基础图像变换
|
||||
if self.transform is None:
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((256, 256)), # 调整图像大小
|
||||
transforms.Grayscale(num_output_channels=1), # 转换为灰度图
|
||||
transforms.ToTensor(), # 转换为tensor
|
||||
transforms.Normalize(mean=[0.5], std=[0.5]) # 针对灰度图的标准化
|
||||
])
|
||||
|
||||
# 收集数据路径和标签
|
||||
self.data_info = []
|
||||
for class_idx, class_name in enumerate(self.classes):
|
||||
class_path = os.path.join(root_dir, 'train' if is_train else 'noisy_test',class_name)
|
||||
for img_name in os.listdir(class_path):
|
||||
if img_name.endswith(('.png', '.jpg', '.jpeg')):
|
||||
self.data_info.append({
|
||||
'path': os.path.join(class_path, img_name),
|
||||
'label': class_idx
|
||||
})
|
||||
|
||||
def add_gaussian_noise(self, image):
|
||||
"""添加高斯噪声"""
|
||||
noise = torch.randn_like(image) * 0.1 # 0.1是噪声强度,可以调整
|
||||
noisy_image = image + noise
|
||||
return torch.clamp(noisy_image, 0, 1)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_info)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.data_info[idx]['path']
|
||||
label = self.data_info[idx]['label']
|
||||
|
||||
# 加载图像
|
||||
image = Image.open(img_path).convert('L') # 直接以灰度图方式加载
|
||||
image = self.transform(image)
|
||||
|
||||
# 如果是训练集,添加高斯噪声
|
||||
if self.is_train:
|
||||
image = self.add_gaussian_noise(image)
|
||||
return image, label
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_dataset = LungXrayDataset(root_dir = "dataset", is_train = False)
|
||||
|
||||
print(train_dataset[0][0].shape)
|
@ -0,0 +1,106 @@
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import argparse
|
||||
import os
|
||||
from data.dataset import LungXrayDataset
|
||||
from models.autoencoder import Autoencoder
|
||||
from models.simplecnn import SimpleCNN
|
||||
from utils import load_config
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import warnings
|
||||
|
||||
# 禁止警告输出
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='COVID-19 X-ray Classification Inference')
|
||||
|
||||
# 模型路径
|
||||
parser.add_argument('--autoencoder_path', type=str, required=True,
|
||||
help='Path to the trained autoencoder model')
|
||||
parser.add_argument('--cnn_path', type=str, required=True,
|
||||
help='Path to the trained CNN model')
|
||||
|
||||
# 数据路径
|
||||
parser.add_argument('--image_path', type=str, required=True,
|
||||
help='Path to the input X-ray image')
|
||||
|
||||
# 配置路径
|
||||
parser.add_argument('--config', type=str, default='config/config.yaml',
|
||||
help='Path to config file')
|
||||
|
||||
# 设备选项
|
||||
parser.add_argument('--device', type=str, default='cuda',
|
||||
choices=['cuda', 'cpu'],
|
||||
help='Device to use for inference (cuda or cpu)')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def load_models(args, config, device):
|
||||
"""加载自编码器和 CNN 模型"""
|
||||
autoencoder = Autoencoder().to(device)
|
||||
cnn_model = SimpleCNN().to(device)
|
||||
|
||||
autoencoder.load_state_dict(torch.load(args.autoencoder_path, map_location=device))
|
||||
cnn_model.load_state_dict(torch.load(args.cnn_path, map_location=device))
|
||||
|
||||
autoencoder.eval()
|
||||
cnn_model.eval()
|
||||
|
||||
return autoencoder, cnn_model
|
||||
|
||||
def preprocess_image(image_path, config):
|
||||
"""预处理输入图像"""
|
||||
img = Image.open(image_path).convert('L') # 转换为灰度图像
|
||||
|
||||
preprocess_config = config['data']['preprocess']
|
||||
transform_list = [
|
||||
transforms.Resize(preprocess_config['resize_size']),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
if preprocess_config.get('normalize', False):
|
||||
transform_list.append(transforms.Normalize(
|
||||
mean=preprocess_config['mean'],
|
||||
std=preprocess_config['std']
|
||||
))
|
||||
transform = transforms.Compose(transform_list)
|
||||
|
||||
img_tensor = transform(img).unsqueeze(0) # 添加 batch 维度
|
||||
return img_tensor
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
|
||||
config = load_config(args.config)
|
||||
|
||||
# 加载模型
|
||||
autoencoder, cnn_model = load_models(args, config, device)
|
||||
|
||||
# 预处理图像
|
||||
input_tensor = preprocess_image(args.image_path, config).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
# 通过自编码器去噪
|
||||
denoised_image = autoencoder(input_tensor)
|
||||
|
||||
# 通过 CNN 进行分类
|
||||
output = cnn_model(denoised_image)
|
||||
probabilities = torch.softmax(output, dim=1)
|
||||
predicted_class = torch.argmax(probabilities, dim=1).item()
|
||||
|
||||
# 定义类别标签 (需要与你的训练数据集一致)
|
||||
class_names = ['Covid', 'Normal', 'Viral Pneumonia'] # 示例类别
|
||||
|
||||
# 将概率转换为百分比
|
||||
probabilities_percentage = probabilities.cpu().numpy()[0] * 100 # 转换为百分比
|
||||
|
||||
# 格式化输出
|
||||
print(f"Prediction Result:")
|
||||
print(f"--------------------------------------")
|
||||
print(f"Predicted Class: {class_names[predicted_class]}")
|
||||
print(f"Probabilities: Covid: {probabilities_percentage[0]:.2f}%, Normal: {probabilities_percentage[1]:.2f}%, Viral Pneumonia: {probabilities_percentage[2]:.2f}%")
|
||||
print(f"--------------------------------------")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,242 @@
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import argparse
|
||||
from data.dataset import LungXrayDataset
|
||||
from models.autoencoder import Autoencoder
|
||||
from models.simplecnn import SimpleCNN
|
||||
from utils import load_config
|
||||
from train_autoencoder import train_autoencoder
|
||||
from train_cnn import train_cnn
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='COVID-19 X-ray Classification Project')
|
||||
|
||||
# 基础参数
|
||||
parser.add_argument('--config', type=str, default='config/config.yaml',
|
||||
help='Path to config file')
|
||||
parser.add_argument('--data_dir', type=str, default='data',
|
||||
help='Path to data directory')
|
||||
|
||||
# 训练阶段选择
|
||||
parser.add_argument('--train_autoencoder', action='store_true',
|
||||
help='Train autoencoder model')
|
||||
parser.add_argument('--train_cnn', action='store_true',
|
||||
help='Train CNN model')
|
||||
|
||||
# 输出目录
|
||||
parser.add_argument('--autoencoder_dir', type=str, default='results/autoencoder',
|
||||
help='Output directory for autoencoder')
|
||||
parser.add_argument('--cnn_dir', type=str, default='results/cnn',
|
||||
help='Output directory for CNN')
|
||||
|
||||
# 自编码器训练参数
|
||||
parser.add_argument('--ae_epochs', type=int, default=None,
|
||||
help='Number of epochs for autoencoder')
|
||||
parser.add_argument('--ae_batch_size', type=int, default=None,
|
||||
help='Batch size for autoencoder')
|
||||
parser.add_argument('--ae_lr', type=float, default=None,
|
||||
help='Learning rate for autoencoder')
|
||||
|
||||
# CNN训练参数
|
||||
parser.add_argument('--cnn_epochs', type=int, default=None,
|
||||
help='Number of epochs for CNN')
|
||||
parser.add_argument('--cnn_batch_size', type=int, default=None,
|
||||
help='Batch size for CNN')
|
||||
parser.add_argument('--cnn_lr', type=float, default=None,
|
||||
help='Learning rate for CNN')
|
||||
parser.add_argument('--noise_factor', type=float, default=0.3,
|
||||
help='Noise factor for data augmentation')
|
||||
|
||||
# 设备选项
|
||||
parser.add_argument('--device', type=str, default='cuda',
|
||||
choices=['cuda', 'cpu'],
|
||||
help='Device to use (cuda or cpu)')
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help='Random seed')
|
||||
|
||||
# 模型加载
|
||||
parser.add_argument('--resume_autoencoder', type=str, default=None,
|
||||
help='Path to autoencoder checkpoint to resume from')
|
||||
parser.add_argument('--resume_cnn', type=str, default=None,
|
||||
help='Path to CNN checkpoint to resume from')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def train_phase_autoencoder(args, config, device, train_loader, test_loader):
|
||||
"""自编码器训练阶段"""
|
||||
print("=== Starting Autoencoder Training ===")
|
||||
|
||||
# 创建自编码器输出目录
|
||||
os.makedirs(args.autoencoder_dir, exist_ok=True)
|
||||
os.makedirs(os.path.join(args.autoencoder_dir, 'checkpoints'), exist_ok=True)
|
||||
|
||||
# 创建模型
|
||||
autoencoder = Autoencoder()
|
||||
|
||||
# 如果指定了恢复训练的检查点
|
||||
if args.resume_autoencoder:
|
||||
print(f'Loading autoencoder checkpoint from {args.resume_autoencoder}')
|
||||
autoencoder.load_state_dict(torch.load(args.resume_autoencoder, map_location=device))
|
||||
|
||||
# 训练自编码器
|
||||
autoencoder_history = train_autoencoder(
|
||||
model=autoencoder,
|
||||
lr=config['training']['learning_rate'],
|
||||
train_loader=train_loader,
|
||||
test_loader=test_loader,
|
||||
num_epochs=config['training']['num_epochs'],
|
||||
device=device,
|
||||
output_dir=args.autoencoder_dir
|
||||
)
|
||||
|
||||
return autoencoder
|
||||
|
||||
def train_phase_cnn(args, config, device, train_loader, test_loader, autoencoder):
|
||||
"""CNN训练阶段"""
|
||||
print("=== Starting CNN Training ===")
|
||||
|
||||
# 创建CNN输出目录
|
||||
os.makedirs(args.cnn_dir, exist_ok=True)
|
||||
os.makedirs(os.path.join(args.cnn_dir, 'checkpoints'), exist_ok=True)
|
||||
|
||||
# 创建CNN模型
|
||||
cnn_model = SimpleCNN()
|
||||
|
||||
# 如果指定了恢复训练的检查点
|
||||
if args.resume_cnn:
|
||||
print(f'Loading CNN checkpoint from {args.resume_cnn}')
|
||||
cnn_model.load_state_dict(torch.load(args.resume_cnn, map_location=device))
|
||||
|
||||
# 训练CNN
|
||||
cnn_history = train_cnn(
|
||||
cnn_model=cnn_model,
|
||||
autoencoder=autoencoder,
|
||||
lr=config['training']['learning_rate'],
|
||||
train_loader=train_loader,
|
||||
test_loader=test_loader,
|
||||
num_epochs=config['training']['num_epochs'],
|
||||
device=device,
|
||||
output_dir=args.cnn_dir,
|
||||
noise_factor=args.noise_factor
|
||||
)
|
||||
|
||||
return cnn_history
|
||||
|
||||
def main():
|
||||
# 解析命令行参数
|
||||
args = parse_args()
|
||||
|
||||
# 设置随机种子
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
# 加载配置
|
||||
config = load_config(args.config)
|
||||
|
||||
# 创建训练配置副本
|
||||
ae_config = config.copy()
|
||||
cnn_config = config.copy()
|
||||
|
||||
# 命令行参数覆盖配置文件 - 自编码器
|
||||
if args.ae_epochs is not None:
|
||||
ae_config['training']['num_epochs'] = args.ae_epochs
|
||||
if args.ae_batch_size is not None:
|
||||
ae_config['training']['batch_size'] = args.ae_batch_size
|
||||
if args.ae_lr is not None:
|
||||
ae_config['training']['learning_rate'] = args.ae_lr
|
||||
|
||||
# 命令行参数覆盖配置文件 - CNN
|
||||
if args.cnn_epochs is not None:
|
||||
cnn_config['training']['num_epochs'] = args.cnn_epochs
|
||||
if args.cnn_batch_size is not None:
|
||||
cnn_config['training']['batch_size'] = args.cnn_batch_size
|
||||
if args.cnn_lr is not None:
|
||||
cnn_config['training']['learning_rate'] = args.cnn_lr
|
||||
|
||||
# 设置设备
|
||||
if args.device == 'cuda' and not torch.cuda.is_available():
|
||||
print('Warning: CUDA is not available, using CPU instead')
|
||||
device = 'cpu'
|
||||
else:
|
||||
device = args.device
|
||||
device = torch.device(device)
|
||||
print(f'Using device: {device}')
|
||||
|
||||
# 创建数据加载器 - 自编码器
|
||||
if args.train_autoencoder:
|
||||
train_dataset_ae = LungXrayDataset(
|
||||
root_dir=args.data_dir,
|
||||
is_train=True
|
||||
)
|
||||
test_dataset_ae = LungXrayDataset(
|
||||
root_dir=args.data_dir,
|
||||
is_train=False
|
||||
)
|
||||
|
||||
train_loader_ae = DataLoader(
|
||||
train_dataset_ae,
|
||||
batch_size=ae_config['training']['batch_size'],
|
||||
shuffle=True
|
||||
)
|
||||
test_loader_ae = DataLoader(
|
||||
test_dataset_ae,
|
||||
batch_size=ae_config['training']['batch_size'],
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
# 创建数据加载器 - CNN
|
||||
if args.train_cnn:
|
||||
train_dataset_cnn = LungXrayDataset(
|
||||
root_dir=args.data_dir,
|
||||
is_train=True
|
||||
)
|
||||
test_dataset_cnn = LungXrayDataset(
|
||||
root_dir=args.data_dir,
|
||||
is_train=False
|
||||
)
|
||||
|
||||
train_loader_cnn = DataLoader(
|
||||
train_dataset_cnn,
|
||||
batch_size=cnn_config['training']['batch_size'],
|
||||
shuffle=True
|
||||
)
|
||||
test_loader_cnn = DataLoader(
|
||||
test_dataset_cnn,
|
||||
batch_size=cnn_config['training']['batch_size'],
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
# 训练自编码器
|
||||
if args.train_autoencoder:
|
||||
print("\n=== Autoencoder Training Configuration ===")
|
||||
print(f"Epochs: {ae_config['training']['num_epochs']}")
|
||||
print(f"Batch Size: {ae_config['training']['batch_size']}")
|
||||
print(f"Learning Rate: {ae_config['training']['learning_rate']}\n")
|
||||
|
||||
autoencoder = train_phase_autoencoder(args, ae_config, device,
|
||||
train_loader_ae, test_loader_ae)
|
||||
else:
|
||||
# 如果不训练自编码器,则加载预训练的模型
|
||||
autoencoder = Autoencoder()
|
||||
autoencoder_path = args.autoencoder_dir
|
||||
if os.path.exists(autoencoder_path):
|
||||
print(f'Loading pretrained autoencoder from {autoencoder_path}')
|
||||
autoencoder.load_state_dict(torch.load(autoencoder_path, map_location=device))
|
||||
else:
|
||||
raise FileNotFoundError(f"No pretrained autoencoder found at {autoencoder_path}")
|
||||
|
||||
# 训练CNN
|
||||
if args.train_cnn:
|
||||
print("\n=== CNN Training Configuration ===")
|
||||
print(f"Epochs: {cnn_config['training']['num_epochs']}")
|
||||
print(f"Batch Size: {cnn_config['training']['batch_size']}")
|
||||
print(f"Learning Rate: {cnn_config['training']['learning_rate']}")
|
||||
print(f"Noise Factor: {args.noise_factor}\n")
|
||||
|
||||
autoencoder.eval() # 设置自编码器为评估模式
|
||||
train_phase_cnn(args, cnn_config, device, train_loader_cnn,
|
||||
test_loader_cnn, autoencoder)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,110 @@
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import argparse
|
||||
from data.dataset import LungXrayDataset
|
||||
from models.autoencoder import Autoencoder
|
||||
from utils import load_config
|
||||
from train_autoencoder import train_autoencoder
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='COVID-19 X-ray Denoising Project')
|
||||
|
||||
# 基础参数
|
||||
parser.add_argument('--config', type=str, default='configs/config.yaml',
|
||||
help='Path to config file')
|
||||
parser.add_argument('--data_dir', type=str, default='data',
|
||||
help='Path to data directory')
|
||||
parser.add_argument('--output_dir', type=str, default='results',
|
||||
help='Path to output directory')
|
||||
|
||||
# 训练参数
|
||||
parser.add_argument('--epochs', type=int, default=20,
|
||||
help='Number of epochs (override config file)')
|
||||
parser.add_argument('--batch_size', type=int, default=None,
|
||||
help='Batch size (override config file)')
|
||||
parser.add_argument('--lr', type=float, default=0.001,
|
||||
help='Learning rate (override config file)')
|
||||
|
||||
# 设备选项
|
||||
parser.add_argument('--device', type=str, default='cuda',
|
||||
choices=['cuda', 'cpu'],
|
||||
help='Device to use (cuda or cpu)')
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help='Random seed')
|
||||
parser.add_argument('--resume', type=str, default=None,
|
||||
help='Path to checkpoint to resume from')
|
||||
|
||||
return parser.parse_args()
|
||||
def main():
|
||||
# 解析命令行参数
|
||||
args = parse_args()
|
||||
|
||||
# 设置随机种子
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
# 加载配置
|
||||
config = load_config(args.config)
|
||||
|
||||
# 命令行参数覆盖配置文件
|
||||
if args.epochs is not None:
|
||||
config['training']['num_epochs'] = args.epochs
|
||||
if args.batch_size is not None:
|
||||
config['training']['batch_size'] = args.batch_size
|
||||
if args.lr is not None:
|
||||
config['training']['learning_rate'] = args.lr
|
||||
|
||||
# 设置设备
|
||||
if args.device == 'cuda' and not torch.cuda.is_available():
|
||||
print('Warning: CUDA is not available, using CPU instead')
|
||||
device = 'cpu'
|
||||
else:
|
||||
device = args.device
|
||||
device = torch.device(device)
|
||||
print(f'Using device: {device}')
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True)
|
||||
os.makedirs(os.path.join(args.output_dir, 'visualizations'), exist_ok=True)
|
||||
|
||||
# 创建数据加载器
|
||||
train_dataset = LungXrayDataset(
|
||||
root_dir=args.data_dir,
|
||||
is_train=True
|
||||
)
|
||||
|
||||
test_dataset = LungXrayDataset(
|
||||
root_dir=args.data_dir,
|
||||
is_train=False
|
||||
)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config['training']['batch_size'],
|
||||
shuffle=True
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=config['training']['batch_size'],
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
# 创建模型
|
||||
model = Autoencoder()
|
||||
|
||||
# 如果指定了恢复训练的检查点
|
||||
if args.resume:
|
||||
print(f'Loading checkpoint from {args.resume}')
|
||||
model.load_state_dict(torch.load(args.resume, map_location=device))
|
||||
|
||||
# 训练模型
|
||||
train_losses, test_losses = train_autoencoder(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
test_loader=test_loader,
|
||||
num_epochs=config['training']['num_epochs'],
|
||||
device=device,
|
||||
output_dir=args.output_dir
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,130 @@
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import argparse
|
||||
from data.dataset import LungXrayDataset
|
||||
from models.autoencoder import Autoencoder
|
||||
from utils import load_config
|
||||
from train_cnn import train_cnn
|
||||
from models.simplecnn import SimpleCNN
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='COVID-19 X-ray Classification Project')
|
||||
|
||||
# 基础参数
|
||||
parser.add_argument('--config', type=str, default='configs/config.yaml',
|
||||
help='Path to config file')
|
||||
parser.add_argument('--data_dir', type=str, default='data',
|
||||
help='Path to data directory')
|
||||
parser.add_argument('--output_dir', type=str, default='results_cnn',
|
||||
help='Path to output directory')
|
||||
|
||||
# 训练参数
|
||||
parser.add_argument('--epochs', type=int, default=20,
|
||||
help='Number of epochs (override config file)')
|
||||
parser.add_argument('--batch_size', type=int, default=None,
|
||||
help='Batch size (override config file)')
|
||||
parser.add_argument('--lr', type=float, default=0.001,
|
||||
help='Learning rate (override config file)')
|
||||
|
||||
# 自编码器相关参数
|
||||
parser.add_argument('--autoencoder_path', type=str, required=True,
|
||||
default="./results_autoencoder/checkpoints/best_model.pth",
|
||||
help='Path to pretrained autoencoder model')
|
||||
parser.add_argument('--noise_factor', type=float, default=0.3,
|
||||
help='Noise factor for data augmentation')
|
||||
|
||||
# 设备选项
|
||||
parser.add_argument('--device', type=str, default='cuda',
|
||||
choices=['cuda', 'cpu'],
|
||||
help='Device to use (cuda or cpu)')
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help='Random seed')
|
||||
parser.add_argument('--resume', type=str, default=None,
|
||||
help='Path to checkpoint to resume from')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def main():
|
||||
# 解析命令行参数
|
||||
args = parse_args()
|
||||
|
||||
# 设置随机种子
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
# 加载配置
|
||||
config = load_config(args.config)
|
||||
|
||||
# 命令行参数覆盖配置文件
|
||||
if args.epochs is not None:
|
||||
config['training']['num_epochs'] = args.epochs
|
||||
if args.batch_size is not None:
|
||||
config['training']['batch_size'] = args.batch_size
|
||||
if args.lr is not None:
|
||||
config['training']['learning_rate'] = args.lr
|
||||
|
||||
# 设置设备
|
||||
if args.device == 'cuda' and not torch.cuda.is_available():
|
||||
print('Warning: CUDA is not available, using CPU instead')
|
||||
device = 'cpu'
|
||||
else:
|
||||
device = args.device
|
||||
device = torch.device(device)
|
||||
print(f'Using device: {device}')
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True)
|
||||
os.makedirs(os.path.join(args.output_dir, 'plots'), exist_ok=True)
|
||||
|
||||
# 创建数据加载器
|
||||
train_dataset = LungXrayDataset(
|
||||
root_dir=args.data_dir,
|
||||
is_train=True
|
||||
)
|
||||
|
||||
test_dataset = LungXrayDataset(
|
||||
root_dir=args.data_dir,
|
||||
is_train=False
|
||||
)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config['training']['batch_size'],
|
||||
shuffle=True
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=config['training']['batch_size'],
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
# 加载预训练的自编码器
|
||||
autoencoder = Autoencoder()
|
||||
autoencoder.load_state_dict(torch.load(args.autoencoder_path, map_location=device))
|
||||
autoencoder = autoencoder.to(device)
|
||||
autoencoder.eval() # 设置为评估模式
|
||||
|
||||
# 创建CNN模型
|
||||
cnn_model = SimpleCNN()
|
||||
|
||||
# 如果指定了恢复训练的检查点
|
||||
if args.resume:
|
||||
print(f'Loading checkpoint from {args.resume}')
|
||||
cnn_model.load_state_dict(torch.load(args.resume, map_location=device))
|
||||
|
||||
# 训练模型
|
||||
history = train_cnn(
|
||||
cnn_model=cnn_model,
|
||||
autoencoder=autoencoder,
|
||||
lr=config['training']['learning_rate'],
|
||||
train_loader=train_loader,
|
||||
test_loader=test_loader,
|
||||
num_epochs=config['training']['num_epochs'],
|
||||
device=device,
|
||||
output_dir=args.output_dir,
|
||||
noise_factor=args.noise_factor
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Autoencoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(Autoencoder, self).__init__()
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2, stride=2),
|
||||
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2, stride=2)
|
||||
)
|
||||
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.UpsamplingNearest2d(scale_factor=2),
|
||||
nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.UpsamplingNearest2d(scale_factor=2),
|
||||
nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
@ -0,0 +1,33 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class SimpleCNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(SimpleCNN, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
||||
|
||||
# Assuming input size is 64x64, after two pooling layers:
|
||||
# Input -> Conv1 -> Pool -> Conv2 -> Pool
|
||||
# Size reduces from 64x64 -> 32x32 -> 16x16 -> 32*16*16 features for fc1
|
||||
self.fc1 = nn.Linear(32 * 64 * 64, 128) # Adjust based on input dimensions
|
||||
self.fc2 = nn.Linear(128, 32)
|
||||
self.fc3 = nn.Linear(32, 3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv2(x)
|
||||
x = torch.relu(x)
|
||||
x = self.pool(x)
|
||||
|
||||
# Flatten the tensor for the fully connected layers
|
||||
x = x.view(x.size(0), -1) # Flatten to (batch_size, feature_size)
|
||||
x = self.fc1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = torch.relu(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
@ -0,0 +1,10 @@
|
||||
gradio==5.9.1
|
||||
matplotlib==3.10.0
|
||||
numpy==2.2.1
|
||||
Pillow==11.0.0
|
||||
PyYAML==6.0.2
|
||||
scikit_learn==1.6.0
|
||||
seaborn==0.13.2
|
||||
torch==2.5.1
|
||||
torchvision==0.20.1
|
||||
tqdm==4.67.1
|
@ -0,0 +1,182 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.utils import make_grid
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
from data.dataset import LungXrayDataset
|
||||
from models.autoencoder import Autoencoder
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import time
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
def train_autoencoder(model, lr, train_loader, test_loader, num_epochs=100, device='cuda', output_dir='results_autoencoder'):
|
||||
"""
|
||||
训练自编码器, 使用TensorBoard进行可视化
|
||||
|
||||
Args:
|
||||
model: 自编码器模型
|
||||
train_loader: 训练数据加载器
|
||||
test_loader: 测试数据加载器
|
||||
num_epochs: 训练轮数
|
||||
device: 使用的设备 ('cuda' 或 'cpu')
|
||||
output_dir: 输出目录
|
||||
"""
|
||||
# 创建输出目录
|
||||
checkpoint_dir = os.path.join(output_dir, 'checkpoints')
|
||||
tensorboard_dir = os.path.join(output_dir, 'tensorboard')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(tensorboard_dir, exist_ok=True)
|
||||
|
||||
# 初始化TensorBoard writer
|
||||
writer = SummaryWriter(tensorboard_dir)
|
||||
|
||||
# 将模型移至指定设备
|
||||
model = model.to(device)
|
||||
|
||||
# 定义损失函数和优化器
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)# 记录最佳模型
|
||||
best_test_loss = float('inf')
|
||||
|
||||
# 获取一批固定的测试数据用于可视化
|
||||
fixed_test_data, _ = next(iter(test_loader))
|
||||
fixed_test_data = fixed_test_data.to(device)
|
||||
|
||||
# 添加模型图到TensorBoard
|
||||
writer.add_graph(model, fixed_test_data)
|
||||
|
||||
# 训练开始时间
|
||||
start_time = time.time()
|
||||
global_step = 0
|
||||
|
||||
# 训练循环
|
||||
for epoch in range(num_epochs):
|
||||
# 训练阶段
|
||||
model.train()
|
||||
train_loss = 0
|
||||
train_pbar = tqdm(train_loader, desc=f'Epoch [{epoch+1}/{num_epochs}] Training')
|
||||
|
||||
for batch_idx, (data, _) in enumerate(train_pbar):
|
||||
data = data.to(device)
|
||||
|
||||
# 前向传播
|
||||
output = model(data)
|
||||
loss = criterion(output, data)
|
||||
|
||||
# 反向传播和优化
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 记录损失
|
||||
train_loss += loss.item()
|
||||
writer.add_scalar('Loss/train_step', loss.item(), global_step)
|
||||
|
||||
# 更新进度条
|
||||
train_pbar.set_postfix({'loss': loss.item()})
|
||||
global_step += 1
|
||||
|
||||
# 计算平均训练损失
|
||||
train_loss /= len(train_loader)
|
||||
|
||||
# 测试阶段
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
test_pbar = tqdm(test_loader, desc=f'Epoch [{epoch+1}/{num_epochs}] Testing')
|
||||
|
||||
with torch.no_grad():
|
||||
for data, _ in test_pbar:
|
||||
data = data.to(device)
|
||||
output = model(data)
|
||||
loss = criterion(output, data)
|
||||
test_loss += loss.item()
|
||||
test_pbar.set_postfix({'loss': loss.item()})
|
||||
|
||||
test_loss /= len(test_loader)
|
||||
|
||||
# 记录每个epoch的损失
|
||||
writer.add_scalars('Loss/epoch', {
|
||||
'train': train_loss,
|
||||
'test': test_loss
|
||||
}, epoch)
|
||||
|
||||
# 可视化重建结果
|
||||
with torch.no_grad():
|
||||
reconstructed = model(fixed_test_data)
|
||||
# 创建原始图像和重建图像的对比网格
|
||||
comparison = torch.cat([fixed_test_data[:8], reconstructed[:8]])
|
||||
grid = make_grid(comparison, nrow=8, normalize=True)
|
||||
writer.add_image('Reconstruction', grid, epoch)
|
||||
|
||||
# 记录模型参数分布
|
||||
for name, param in model.named_parameters():
|
||||
writer.add_histogram(f'Parameters/{name}', param, epoch)
|
||||
if param.grad is not None:
|
||||
writer.add_histogram(f'Gradients/{name}', param.grad, epoch)
|
||||
|
||||
# 打印进度
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f'Epoch [{epoch+1}/{num_epochs}], '
|
||||
f'Train Loss: {train_loss:.6f}, '
|
||||
f'Test Loss: {test_loss:.6f}, '
|
||||
f'Time: {elapsed_time:.2f}s')
|
||||
|
||||
# 保存最佳模型
|
||||
if test_loss < best_test_loss:
|
||||
best_test_loss = test_loss
|
||||
torch.save(model.state_dict(),
|
||||
os.path.join(checkpoint_dir, 'best_model.pth'))
|
||||
|
||||
# 每10个epoch保存一次检查点
|
||||
if (epoch + 1) % 10 == 0:
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'train_loss': train_loss,
|
||||
'test_loss': test_loss,
|
||||
}
|
||||
torch.save(checkpoint,
|
||||
os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth'))
|
||||
|
||||
# 保存最终模型
|
||||
torch.save(model.state_dict(),
|
||||
os.path.join(checkpoint_dir, 'final_model.pth'))
|
||||
|
||||
# 记录总训练时间
|
||||
total_time = time.time() - start_time
|
||||
print(f'Training completed in {total_time:.2f} seconds')
|
||||
|
||||
# 关闭TensorBoard writer
|
||||
writer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 设置随机种子
|
||||
torch.manual_seed(42)
|
||||
|
||||
# 检查是否可以使用GPU
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print(f'Using device: {device}')
|
||||
|
||||
# 创建数据加载器
|
||||
train_dataset = LungXrayDataset(root_dir="/home/lgz/Code/class/ML/e1/covid19", is_train=True)
|
||||
test_dataset = LungXrayDataset(root_dir="/home/lgz/Code/class/ML/e1/covid19", is_train=False)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
||||
|
||||
# 创建模型
|
||||
model = Autoencoder()# 训练模型
|
||||
train_losses, test_losses = train_autoencoder(
|
||||
model=model,
|
||||
lr=1e-3,
|
||||
train_loader=train_loader,
|
||||
test_loader=test_loader,
|
||||
num_epochs=100,
|
||||
device=device
|
||||
)
|
@ -0,0 +1,325 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from sklearn.metrics import confusion_matrix, classification_report
|
||||
import seaborn as sns
|
||||
import numpy as np
|
||||
from models.autoencoder import Autoencoder
|
||||
from models.simplecnn import SimpleCNN
|
||||
from data.dataset import LungXrayDataset
|
||||
|
||||
def add_noise(images, noise_factor=0.3):
|
||||
"""添加高斯噪声"""
|
||||
noisy_images = images + noise_factor * torch.randn_like(images)
|
||||
return torch.clamp(noisy_images, 0., 1.)
|
||||
|
||||
def plot_confusion_matrix(cm, classes, output_path):
|
||||
"""绘制混淆矩阵"""
|
||||
plt.figure(figsize=(10, 8))
|
||||
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
||||
xticklabels=classes, yticklabels=classes)
|
||||
plt.title('Confusion Matrix')
|
||||
plt.ylabel('True Label')
|
||||
plt.xlabel('Predicted Label')
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path)
|
||||
plt.close()
|
||||
|
||||
def train_cnn(cnn_model, autoencoder, lr, train_loader, test_loader, num_epochs=100,
|
||||
device='cuda', output_dir='results_cnn', noise_factor=0.3):
|
||||
"""
|
||||
训练CNN模型
|
||||
Args:
|
||||
cnn_model: CNN模型
|
||||
autoencoder: 预训练的autoencoder模型
|
||||
lr: 学习率
|
||||
train_loader: 训练数据加载器
|
||||
test_loader: 测试数据加载器
|
||||
num_epochs: 训练轮数
|
||||
device: 使用的设备
|
||||
output_dir: 输出目录
|
||||
noise_factor: 噪声因子
|
||||
"""
|
||||
# 创建输出目录
|
||||
checkpoint_dir = os.path.join(output_dir, 'checkpoints')
|
||||
tensorboard_dir = os.path.join(output_dir, 'tensorboard')
|
||||
plot_dir = os.path.join(output_dir, 'plots')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(tensorboard_dir, exist_ok=True)
|
||||
os.makedirs(plot_dir, exist_ok=True)
|
||||
|
||||
# 初始化TensorBoard writer
|
||||
writer = SummaryWriter(tensorboard_dir)
|
||||
|
||||
# 将模型移至指定设备
|
||||
cnn_model = cnn_model.to(device)
|
||||
autoencoder = autoencoder.to(device)
|
||||
autoencoder.eval() # 设置autoencoder为评估模式
|
||||
|
||||
# 定义损失函数和优化器
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(cnn_model.parameters(), lr=lr)
|
||||
|
||||
# 记录最佳模型
|
||||
best_test_acc = 0.0
|
||||
|
||||
# 用于记录训练历史
|
||||
history = {
|
||||
'train_loss': [],
|
||||
'test_loss': [],
|
||||
'train_acc': [],
|
||||
'test_acc': []
|
||||
}
|
||||
|
||||
# 训练开始时间
|
||||
start_time = time.time()
|
||||
global_step = 0
|
||||
|
||||
# 类别名称
|
||||
classes = ['Covid', 'Normal', 'Viral Pneumonia']
|
||||
|
||||
# 训练循环
|
||||
for epoch in range(num_epochs):
|
||||
# 训练阶段
|
||||
cnn_model.train()
|
||||
train_loss = 0
|
||||
train_correct = 0
|
||||
train_total = 0
|
||||
train_pbar = tqdm(train_loader, desc=f'Epoch [{epoch+1}/{num_epochs}] Training')
|
||||
|
||||
for batch_idx, (data, targets) in enumerate(train_pbar):
|
||||
data, targets = data.to(device), targets.to(device)
|
||||
|
||||
# 添加噪声
|
||||
noisy_data = add_noise(data, noise_factor)
|
||||
|
||||
# 通过autoencoder降噪
|
||||
with torch.no_grad():
|
||||
denoised_data = autoencoder(noisy_data)
|
||||
|
||||
# 前向传播
|
||||
outputs = cnn_model(denoised_data)
|
||||
loss = criterion(outputs, targets)
|
||||
|
||||
# 计算准确率
|
||||
_, predicted = outputs.max(1)
|
||||
train_total += targets.size(0)
|
||||
train_correct += predicted.eq(targets).sum().item()
|
||||
|
||||
# 反向传播和优化
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 记录损失
|
||||
train_loss += loss.item()
|
||||
|
||||
# 记录到TensorBoard
|
||||
writer.add_scalar('Loss/train_step', loss.item(), global_step)
|
||||
|
||||
# 更新进度条
|
||||
train_pbar.set_postfix({
|
||||
'loss': loss.item(),
|
||||
'acc': 100. * train_correct / train_total
|
||||
})
|
||||
global_step += 1
|
||||
|
||||
# 计算平均训练指标
|
||||
train_loss = train_loss / len(train_loader)
|
||||
train_acc = 100. * train_correct / train_total
|
||||
|
||||
# 测试阶段
|
||||
cnn_model.eval()
|
||||
test_loss = 0
|
||||
test_correct = 0
|
||||
test_total = 0
|
||||
all_predictions = []
|
||||
all_targets = []
|
||||
|
||||
with torch.no_grad():
|
||||
test_pbar = tqdm(test_loader, desc=f'Epoch [{epoch+1}/{num_epochs}] Testing')
|
||||
for data, targets in test_pbar:
|
||||
data, targets = data.to(device), targets.to(device)
|
||||
|
||||
# 添加噪声并通过autoencoder降噪
|
||||
noisy_data = add_noise(data, noise_factor)
|
||||
denoised_data = autoencoder(noisy_data)
|
||||
|
||||
outputs = cnn_model(denoised_data)
|
||||
loss = criterion(outputs, targets)
|
||||
|
||||
test_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
test_total += targets.size(0)
|
||||
test_correct += predicted.eq(targets).sum().item()
|
||||
|
||||
# 收集预测结果用于混淆矩阵
|
||||
all_predictions.extend(predicted.cpu().numpy())
|
||||
all_targets.extend(targets.cpu().numpy())
|
||||
|
||||
test_pbar.set_postfix({
|
||||
'loss': loss.item(),
|
||||
'acc': 100. * test_correct / test_total
|
||||
})
|
||||
|
||||
# 计算平均测试指标
|
||||
test_loss = test_loss / len(test_loader)
|
||||
test_acc = 100. * test_correct / test_total
|
||||
|
||||
# 记录历史
|
||||
history['train_loss'].append(train_loss)
|
||||
history['test_loss'].append(test_loss)
|
||||
history['train_acc'].append(train_acc)
|
||||
history['test_acc'].append(test_acc)
|
||||
|
||||
# 记录到TensorBoard
|
||||
writer.add_scalars('Loss/epoch', {
|
||||
'train': train_loss,
|
||||
'test': test_loss
|
||||
}, epoch)
|
||||
|
||||
writer.add_scalars('Accuracy/epoch', {
|
||||
'train': train_acc,
|
||||
'test': test_acc
|
||||
}, epoch)
|
||||
|
||||
# 每个epoch结束时绘制混淆矩阵
|
||||
cm = confusion_matrix(all_targets, all_predictions)
|
||||
plot_confusion_matrix(cm, classes,
|
||||
os.path.join(plot_dir, f'confusion_matrix_epoch_{epoch+1}.png'))
|
||||
|
||||
# 打印分类报告
|
||||
report = classification_report(all_targets, all_predictions, target_names=classes)
|
||||
print(f"\nClassification Report - Epoch {epoch+1}:")
|
||||
print(report)
|
||||
|
||||
# 打印进度
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f'Epoch [{epoch+1}/{num_epochs}], '
|
||||
f'Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, '
|
||||
f'Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%, '
|
||||
f'Time: {elapsed_time:.2f}s')
|
||||
|
||||
# 保存最佳模型
|
||||
if test_acc > best_test_acc:
|
||||
best_test_acc = test_acc
|
||||
torch.save(cnn_model.state_dict(),
|
||||
os.path.join(checkpoint_dir, 'best_model.pth'))
|
||||
|
||||
# 每10个epoch保存检查点和绘制图表
|
||||
if (epoch + 1) % 10 == 0:
|
||||
# 保存检查点
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': cnn_model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'train_loss': train_loss,
|
||||
'test_loss': test_loss,
|
||||
'train_acc': train_acc,
|
||||
'test_acc': test_acc
|
||||
}
|
||||
torch.save(checkpoint,
|
||||
os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth'))
|
||||
|
||||
# 绘制损失和准确率曲线
|
||||
plt.figure(figsize=(12, 5))
|
||||
|
||||
# 损失曲线
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.plot(history['train_loss'], label='Train Loss')
|
||||
plt.plot(history['test_loss'], label='Test Loss')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.title('Training and Testing Losses')
|
||||
|
||||
# 准确率曲线
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.plot(history['train_acc'], label='Train Acc')
|
||||
plt.plot(history['test_acc'], label='Test Acc')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Accuracy (%)')
|
||||
plt.legend()
|
||||
plt.title('Training and Testing Accuracies')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(plot_dir, f'metrics_epoch_{epoch+1}.png'))
|
||||
plt.close()
|
||||
|
||||
# 保存最终模型
|
||||
torch.save(cnn_model.state_dict(),
|
||||
os.path.join(checkpoint_dir, 'final_model.pth'))
|
||||
|
||||
# 绘制最终的损失和准确率曲线
|
||||
plt.figure(figsize=(12, 5))
|
||||
|
||||
# 损失曲线
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.plot(history['train_loss'], label='Train Loss')
|
||||
plt.plot(history['test_loss'], label='Test Loss')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.title('Training and Testing Losses')
|
||||
|
||||
# 准确率曲线
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.plot(history['train_acc'], label='Train Acc')
|
||||
plt.plot(history['test_acc'], label='Test Acc')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Accuracy (%)')
|
||||
plt.legend()
|
||||
plt.title('Training and Testing Accuracies')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(plot_dir, 'final_metrics.png'))
|
||||
plt.close()
|
||||
|
||||
# 记录总训练时间
|
||||
total_time = time.time() - start_time
|
||||
print(f'Training completed in {total_time:.2f} seconds')
|
||||
|
||||
# 关闭TensorBoard writer
|
||||
writer.close()
|
||||
|
||||
return history
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 设置随机种子
|
||||
torch.manual_seed(42)
|
||||
|
||||
# 检查是否可以使用GPU
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print(f'Using device: {device}')
|
||||
|
||||
# 创建数据加载器
|
||||
train_dataset = LungXrayDataset(root_dir="/home/lgz/Code/class/ML/e1/covid19", is_train=True)
|
||||
test_dataset = LungXrayDataset(root_dir="/home/lgz/Code/class/ML/e1/covid19", is_train=False)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
||||
|
||||
# 加载预训练的autoencoder
|
||||
autoencoder = Autoencoder()
|
||||
autoencoder.load_state_dict(torch.load('results/checkpoints/best_model.pth'))
|
||||
|
||||
# 创建CNN模型
|
||||
cnn_model = SimpleCNN()
|
||||
|
||||
# 训练CNN模型
|
||||
history = train_cnn(
|
||||
cnn_model=cnn_model,
|
||||
autoencoder=autoencoder,
|
||||
lr=1e-3,
|
||||
train_loader=train_loader,
|
||||
test_loader=test_loader,
|
||||
num_epochs=100,
|
||||
device=device,
|
||||
noise_factor=0.3
|
||||
)
|
@ -0,0 +1,26 @@
|
||||
import yaml
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision.utils import save_image
|
||||
import os
|
||||
|
||||
def load_config(config_path):
|
||||
"""加载配置文件"""
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
def save_reconstructed_images(original, reconstructed, path, nrow=8):
|
||||
"""保存重建图像对比"""
|
||||
comparison = torch.cat([original[:nrow], reconstructed[:nrow]])
|
||||
save_image(comparison.cpu(), path, nrow=nrow)
|
||||
|
||||
def plot_losses(train_losses, test_losses, save_path):
|
||||
"""绘制损失曲线"""
|
||||
plt.figure(figsize=(10, 5))
|
||||
plt.plot(train_losses, label='Train Loss')
|
||||
plt.plot(test_losses, label='Test Loss')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.savefig(save_path)
|
||||
plt.close()
|
Loading…
Reference in new issue