You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
242 lines
8.8 KiB
242 lines
8.8 KiB
import os
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
import argparse
|
|
from data.dataset import LungXrayDataset
|
|
from models.autoencoder import Autoencoder
|
|
from models.simplecnn import SimpleCNN
|
|
from utils import load_config
|
|
from train_autoencoder import train_autoencoder
|
|
from train_cnn import train_cnn
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='COVID-19 X-ray Classification Project')
|
|
|
|
# 基础参数
|
|
parser.add_argument('--config', type=str, default='config/config.yaml',
|
|
help='Path to config file')
|
|
parser.add_argument('--data_dir', type=str, default='data',
|
|
help='Path to data directory')
|
|
|
|
# 训练阶段选择
|
|
parser.add_argument('--train_autoencoder', action='store_true',
|
|
help='Train autoencoder model')
|
|
parser.add_argument('--train_cnn', action='store_true',
|
|
help='Train CNN model')
|
|
|
|
# 输出目录
|
|
parser.add_argument('--autoencoder_dir', type=str, default='results/autoencoder',
|
|
help='Output directory for autoencoder')
|
|
parser.add_argument('--cnn_dir', type=str, default='results/cnn',
|
|
help='Output directory for CNN')
|
|
|
|
# 自编码器训练参数
|
|
parser.add_argument('--ae_epochs', type=int, default=None,
|
|
help='Number of epochs for autoencoder')
|
|
parser.add_argument('--ae_batch_size', type=int, default=None,
|
|
help='Batch size for autoencoder')
|
|
parser.add_argument('--ae_lr', type=float, default=None,
|
|
help='Learning rate for autoencoder')
|
|
|
|
# CNN训练参数
|
|
parser.add_argument('--cnn_epochs', type=int, default=None,
|
|
help='Number of epochs for CNN')
|
|
parser.add_argument('--cnn_batch_size', type=int, default=None,
|
|
help='Batch size for CNN')
|
|
parser.add_argument('--cnn_lr', type=float, default=None,
|
|
help='Learning rate for CNN')
|
|
parser.add_argument('--noise_factor', type=float, default=0.3,
|
|
help='Noise factor for data augmentation')
|
|
|
|
# 设备选项
|
|
parser.add_argument('--device', type=str, default='cuda',
|
|
choices=['cuda', 'cpu'],
|
|
help='Device to use (cuda or cpu)')
|
|
parser.add_argument('--seed', type=int, default=42,
|
|
help='Random seed')
|
|
|
|
# 模型加载
|
|
parser.add_argument('--resume_autoencoder', type=str, default=None,
|
|
help='Path to autoencoder checkpoint to resume from')
|
|
parser.add_argument('--resume_cnn', type=str, default=None,
|
|
help='Path to CNN checkpoint to resume from')
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def train_phase_autoencoder(args, config, device, train_loader, test_loader):
|
|
"""自编码器训练阶段"""
|
|
print("=== Starting Autoencoder Training ===")
|
|
|
|
# 创建自编码器输出目录
|
|
os.makedirs(args.autoencoder_dir, exist_ok=True)
|
|
os.makedirs(os.path.join(args.autoencoder_dir, 'checkpoints'), exist_ok=True)
|
|
|
|
# 创建模型
|
|
autoencoder = Autoencoder()
|
|
|
|
# 如果指定了恢复训练的检查点
|
|
if args.resume_autoencoder:
|
|
print(f'Loading autoencoder checkpoint from {args.resume_autoencoder}')
|
|
autoencoder.load_state_dict(torch.load(args.resume_autoencoder, map_location=device))
|
|
|
|
# 训练自编码器
|
|
autoencoder_history = train_autoencoder(
|
|
model=autoencoder,
|
|
lr=config['training']['learning_rate'],
|
|
train_loader=train_loader,
|
|
test_loader=test_loader,
|
|
num_epochs=config['training']['num_epochs'],
|
|
device=device,
|
|
output_dir=args.autoencoder_dir
|
|
)
|
|
|
|
return autoencoder
|
|
|
|
def train_phase_cnn(args, config, device, train_loader, test_loader, autoencoder):
|
|
"""CNN训练阶段"""
|
|
print("=== Starting CNN Training ===")
|
|
|
|
# 创建CNN输出目录
|
|
os.makedirs(args.cnn_dir, exist_ok=True)
|
|
os.makedirs(os.path.join(args.cnn_dir, 'checkpoints'), exist_ok=True)
|
|
|
|
# 创建CNN模型
|
|
cnn_model = SimpleCNN()
|
|
|
|
# 如果指定了恢复训练的检查点
|
|
if args.resume_cnn:
|
|
print(f'Loading CNN checkpoint from {args.resume_cnn}')
|
|
cnn_model.load_state_dict(torch.load(args.resume_cnn, map_location=device))
|
|
|
|
# 训练CNN
|
|
cnn_history = train_cnn(
|
|
cnn_model=cnn_model,
|
|
autoencoder=autoencoder,
|
|
lr=config['training']['learning_rate'],
|
|
train_loader=train_loader,
|
|
test_loader=test_loader,
|
|
num_epochs=config['training']['num_epochs'],
|
|
device=device,
|
|
output_dir=args.cnn_dir,
|
|
noise_factor=args.noise_factor
|
|
)
|
|
|
|
return cnn_history
|
|
|
|
def main():
|
|
# 解析命令行参数
|
|
args = parse_args()
|
|
|
|
# 设置随机种子
|
|
torch.manual_seed(args.seed)
|
|
|
|
# 加载配置
|
|
config = load_config(args.config)
|
|
|
|
# 创建训练配置副本
|
|
ae_config = config.copy()
|
|
cnn_config = config.copy()
|
|
|
|
# 命令行参数覆盖配置文件 - 自编码器
|
|
if args.ae_epochs is not None:
|
|
ae_config['training']['num_epochs'] = args.ae_epochs
|
|
if args.ae_batch_size is not None:
|
|
ae_config['training']['batch_size'] = args.ae_batch_size
|
|
if args.ae_lr is not None:
|
|
ae_config['training']['learning_rate'] = args.ae_lr
|
|
|
|
# 命令行参数覆盖配置文件 - CNN
|
|
if args.cnn_epochs is not None:
|
|
cnn_config['training']['num_epochs'] = args.cnn_epochs
|
|
if args.cnn_batch_size is not None:
|
|
cnn_config['training']['batch_size'] = args.cnn_batch_size
|
|
if args.cnn_lr is not None:
|
|
cnn_config['training']['learning_rate'] = args.cnn_lr
|
|
|
|
# 设置设备
|
|
if args.device == 'cuda' and not torch.cuda.is_available():
|
|
print('Warning: CUDA is not available, using CPU instead')
|
|
device = 'cpu'
|
|
else:
|
|
device = args.device
|
|
device = torch.device(device)
|
|
print(f'Using device: {device}')
|
|
|
|
# 创建数据加载器 - 自编码器
|
|
if args.train_autoencoder:
|
|
train_dataset_ae = LungXrayDataset(
|
|
root_dir=args.data_dir,
|
|
is_train=True
|
|
)
|
|
test_dataset_ae = LungXrayDataset(
|
|
root_dir=args.data_dir,
|
|
is_train=False
|
|
)
|
|
|
|
train_loader_ae = DataLoader(
|
|
train_dataset_ae,
|
|
batch_size=ae_config['training']['batch_size'],
|
|
shuffle=True
|
|
)
|
|
test_loader_ae = DataLoader(
|
|
test_dataset_ae,
|
|
batch_size=ae_config['training']['batch_size'],
|
|
shuffle=False
|
|
)
|
|
|
|
# 创建数据加载器 - CNN
|
|
if args.train_cnn:
|
|
train_dataset_cnn = LungXrayDataset(
|
|
root_dir=args.data_dir,
|
|
is_train=True
|
|
)
|
|
test_dataset_cnn = LungXrayDataset(
|
|
root_dir=args.data_dir,
|
|
is_train=False
|
|
)
|
|
|
|
train_loader_cnn = DataLoader(
|
|
train_dataset_cnn,
|
|
batch_size=cnn_config['training']['batch_size'],
|
|
shuffle=True
|
|
)
|
|
test_loader_cnn = DataLoader(
|
|
test_dataset_cnn,
|
|
batch_size=cnn_config['training']['batch_size'],
|
|
shuffle=False
|
|
)
|
|
|
|
# 训练自编码器
|
|
if args.train_autoencoder:
|
|
print("\n=== Autoencoder Training Configuration ===")
|
|
print(f"Epochs: {ae_config['training']['num_epochs']}")
|
|
print(f"Batch Size: {ae_config['training']['batch_size']}")
|
|
print(f"Learning Rate: {ae_config['training']['learning_rate']}\n")
|
|
|
|
autoencoder = train_phase_autoencoder(args, ae_config, device,
|
|
train_loader_ae, test_loader_ae)
|
|
else:
|
|
# 如果不训练自编码器,则加载预训练的模型
|
|
autoencoder = Autoencoder()
|
|
autoencoder_path = args.autoencoder_dir
|
|
if os.path.exists(autoencoder_path):
|
|
print(f'Loading pretrained autoencoder from {autoencoder_path}')
|
|
autoencoder.load_state_dict(torch.load(autoencoder_path, map_location=device))
|
|
else:
|
|
raise FileNotFoundError(f"No pretrained autoencoder found at {autoencoder_path}")
|
|
|
|
# 训练CNN
|
|
if args.train_cnn:
|
|
print("\n=== CNN Training Configuration ===")
|
|
print(f"Epochs: {cnn_config['training']['num_epochs']}")
|
|
print(f"Batch Size: {cnn_config['training']['batch_size']}")
|
|
print(f"Learning Rate: {cnn_config['training']['learning_rate']}")
|
|
print(f"Noise Factor: {args.noise_factor}\n")
|
|
|
|
autoencoder.eval() # 设置自编码器为评估模式
|
|
train_phase_cnn(args, cnn_config, device, train_loader_cnn,
|
|
test_loader_cnn, autoencoder)
|
|
|
|
if __name__ == "__main__":
|
|
main() |