You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

242 lines
8.8 KiB

import os
import torch
from torch.utils.data import DataLoader
import argparse
from data.dataset import LungXrayDataset
from models.autoencoder import Autoencoder
from models.simplecnn import SimpleCNN
from utils import load_config
from train_autoencoder import train_autoencoder
from train_cnn import train_cnn
def parse_args():
parser = argparse.ArgumentParser(description='COVID-19 X-ray Classification Project')
# 基础参数
parser.add_argument('--config', type=str, default='config/config.yaml',
help='Path to config file')
parser.add_argument('--data_dir', type=str, default='data',
help='Path to data directory')
# 训练阶段选择
parser.add_argument('--train_autoencoder', action='store_true',
help='Train autoencoder model')
parser.add_argument('--train_cnn', action='store_true',
help='Train CNN model')
# 输出目录
parser.add_argument('--autoencoder_dir', type=str, default='results/autoencoder',
help='Output directory for autoencoder')
parser.add_argument('--cnn_dir', type=str, default='results/cnn',
help='Output directory for CNN')
# 自编码器训练参数
parser.add_argument('--ae_epochs', type=int, default=None,
help='Number of epochs for autoencoder')
parser.add_argument('--ae_batch_size', type=int, default=None,
help='Batch size for autoencoder')
parser.add_argument('--ae_lr', type=float, default=None,
help='Learning rate for autoencoder')
# CNN训练参数
parser.add_argument('--cnn_epochs', type=int, default=None,
help='Number of epochs for CNN')
parser.add_argument('--cnn_batch_size', type=int, default=None,
help='Batch size for CNN')
parser.add_argument('--cnn_lr', type=float, default=None,
help='Learning rate for CNN')
parser.add_argument('--noise_factor', type=float, default=0.3,
help='Noise factor for data augmentation')
# 设备选项
parser.add_argument('--device', type=str, default='cuda',
choices=['cuda', 'cpu'],
help='Device to use (cuda or cpu)')
parser.add_argument('--seed', type=int, default=42,
help='Random seed')
# 模型加载
parser.add_argument('--resume_autoencoder', type=str, default=None,
help='Path to autoencoder checkpoint to resume from')
parser.add_argument('--resume_cnn', type=str, default=None,
help='Path to CNN checkpoint to resume from')
return parser.parse_args()
def train_phase_autoencoder(args, config, device, train_loader, test_loader):
"""自编码器训练阶段"""
print("=== Starting Autoencoder Training ===")
# 创建自编码器输出目录
os.makedirs(args.autoencoder_dir, exist_ok=True)
os.makedirs(os.path.join(args.autoencoder_dir, 'checkpoints'), exist_ok=True)
# 创建模型
autoencoder = Autoencoder()
# 如果指定了恢复训练的检查点
if args.resume_autoencoder:
print(f'Loading autoencoder checkpoint from {args.resume_autoencoder}')
autoencoder.load_state_dict(torch.load(args.resume_autoencoder, map_location=device))
# 训练自编码器
autoencoder_history = train_autoencoder(
model=autoencoder,
lr=config['training']['learning_rate'],
train_loader=train_loader,
test_loader=test_loader,
num_epochs=config['training']['num_epochs'],
device=device,
output_dir=args.autoencoder_dir
)
return autoencoder
def train_phase_cnn(args, config, device, train_loader, test_loader, autoencoder):
"""CNN训练阶段"""
print("=== Starting CNN Training ===")
# 创建CNN输出目录
os.makedirs(args.cnn_dir, exist_ok=True)
os.makedirs(os.path.join(args.cnn_dir, 'checkpoints'), exist_ok=True)
# 创建CNN模型
cnn_model = SimpleCNN()
# 如果指定了恢复训练的检查点
if args.resume_cnn:
print(f'Loading CNN checkpoint from {args.resume_cnn}')
cnn_model.load_state_dict(torch.load(args.resume_cnn, map_location=device))
# 训练CNN
cnn_history = train_cnn(
cnn_model=cnn_model,
autoencoder=autoencoder,
lr=config['training']['learning_rate'],
train_loader=train_loader,
test_loader=test_loader,
num_epochs=config['training']['num_epochs'],
device=device,
output_dir=args.cnn_dir,
noise_factor=args.noise_factor
)
return cnn_history
def main():
# 解析命令行参数
args = parse_args()
# 设置随机种子
torch.manual_seed(args.seed)
# 加载配置
config = load_config(args.config)
# 创建训练配置副本
ae_config = config.copy()
cnn_config = config.copy()
# 命令行参数覆盖配置文件 - 自编码器
if args.ae_epochs is not None:
ae_config['training']['num_epochs'] = args.ae_epochs
if args.ae_batch_size is not None:
ae_config['training']['batch_size'] = args.ae_batch_size
if args.ae_lr is not None:
ae_config['training']['learning_rate'] = args.ae_lr
# 命令行参数覆盖配置文件 - CNN
if args.cnn_epochs is not None:
cnn_config['training']['num_epochs'] = args.cnn_epochs
if args.cnn_batch_size is not None:
cnn_config['training']['batch_size'] = args.cnn_batch_size
if args.cnn_lr is not None:
cnn_config['training']['learning_rate'] = args.cnn_lr
# 设置设备
if args.device == 'cuda' and not torch.cuda.is_available():
print('Warning: CUDA is not available, using CPU instead')
device = 'cpu'
else:
device = args.device
device = torch.device(device)
print(f'Using device: {device}')
# 创建数据加载器 - 自编码器
if args.train_autoencoder:
train_dataset_ae = LungXrayDataset(
root_dir=args.data_dir,
is_train=True
)
test_dataset_ae = LungXrayDataset(
root_dir=args.data_dir,
is_train=False
)
train_loader_ae = DataLoader(
train_dataset_ae,
batch_size=ae_config['training']['batch_size'],
shuffle=True
)
test_loader_ae = DataLoader(
test_dataset_ae,
batch_size=ae_config['training']['batch_size'],
shuffle=False
)
# 创建数据加载器 - CNN
if args.train_cnn:
train_dataset_cnn = LungXrayDataset(
root_dir=args.data_dir,
is_train=True
)
test_dataset_cnn = LungXrayDataset(
root_dir=args.data_dir,
is_train=False
)
train_loader_cnn = DataLoader(
train_dataset_cnn,
batch_size=cnn_config['training']['batch_size'],
shuffle=True
)
test_loader_cnn = DataLoader(
test_dataset_cnn,
batch_size=cnn_config['training']['batch_size'],
shuffle=False
)
# 训练自编码器
if args.train_autoencoder:
print("\n=== Autoencoder Training Configuration ===")
print(f"Epochs: {ae_config['training']['num_epochs']}")
print(f"Batch Size: {ae_config['training']['batch_size']}")
print(f"Learning Rate: {ae_config['training']['learning_rate']}\n")
autoencoder = train_phase_autoencoder(args, ae_config, device,
train_loader_ae, test_loader_ae)
else:
# 如果不训练自编码器,则加载预训练的模型
autoencoder = Autoencoder()
autoencoder_path = args.autoencoder_dir
if os.path.exists(autoencoder_path):
print(f'Loading pretrained autoencoder from {autoencoder_path}')
autoencoder.load_state_dict(torch.load(autoencoder_path, map_location=device))
else:
raise FileNotFoundError(f"No pretrained autoencoder found at {autoencoder_path}")
# 训练CNN
if args.train_cnn:
print("\n=== CNN Training Configuration ===")
print(f"Epochs: {cnn_config['training']['num_epochs']}")
print(f"Batch Size: {cnn_config['training']['batch_size']}")
print(f"Learning Rate: {cnn_config['training']['learning_rate']}")
print(f"Noise Factor: {args.noise_factor}\n")
autoencoder.eval() # 设置自编码器为评估模式
train_phase_cnn(args, cnn_config, device, train_loader_cnn,
test_loader_cnn, autoencoder)
if __name__ == "__main__":
main()