You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
130 lines
4.4 KiB
130 lines
4.4 KiB
import os
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
import argparse
|
|
from data.dataset import LungXrayDataset
|
|
from models.autoencoder import Autoencoder
|
|
from utils import load_config
|
|
from train_cnn import train_cnn
|
|
from models.simplecnn import SimpleCNN
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='COVID-19 X-ray Classification Project')
|
|
|
|
# 基础参数
|
|
parser.add_argument('--config', type=str, default='configs/config.yaml',
|
|
help='Path to config file')
|
|
parser.add_argument('--data_dir', type=str, default='data',
|
|
help='Path to data directory')
|
|
parser.add_argument('--output_dir', type=str, default='results_cnn',
|
|
help='Path to output directory')
|
|
|
|
# 训练参数
|
|
parser.add_argument('--epochs', type=int, default=20,
|
|
help='Number of epochs (override config file)')
|
|
parser.add_argument('--batch_size', type=int, default=None,
|
|
help='Batch size (override config file)')
|
|
parser.add_argument('--lr', type=float, default=0.001,
|
|
help='Learning rate (override config file)')
|
|
|
|
# 自编码器相关参数
|
|
parser.add_argument('--autoencoder_path', type=str, required=True,
|
|
default="./results_autoencoder/checkpoints/best_model.pth",
|
|
help='Path to pretrained autoencoder model')
|
|
parser.add_argument('--noise_factor', type=float, default=0.3,
|
|
help='Noise factor for data augmentation')
|
|
|
|
# 设备选项
|
|
parser.add_argument('--device', type=str, default='cuda',
|
|
choices=['cuda', 'cpu'],
|
|
help='Device to use (cuda or cpu)')
|
|
parser.add_argument('--seed', type=int, default=42,
|
|
help='Random seed')
|
|
parser.add_argument('--resume', type=str, default=None,
|
|
help='Path to checkpoint to resume from')
|
|
|
|
return parser.parse_args()
|
|
|
|
def main():
|
|
# 解析命令行参数
|
|
args = parse_args()
|
|
|
|
# 设置随机种子
|
|
torch.manual_seed(args.seed)
|
|
|
|
# 加载配置
|
|
config = load_config(args.config)
|
|
|
|
# 命令行参数覆盖配置文件
|
|
if args.epochs is not None:
|
|
config['training']['num_epochs'] = args.epochs
|
|
if args.batch_size is not None:
|
|
config['training']['batch_size'] = args.batch_size
|
|
if args.lr is not None:
|
|
config['training']['learning_rate'] = args.lr
|
|
|
|
# 设置设备
|
|
if args.device == 'cuda' and not torch.cuda.is_available():
|
|
print('Warning: CUDA is not available, using CPU instead')
|
|
device = 'cpu'
|
|
else:
|
|
device = args.device
|
|
device = torch.device(device)
|
|
print(f'Using device: {device}')
|
|
|
|
# 创建输出目录
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True)
|
|
os.makedirs(os.path.join(args.output_dir, 'plots'), exist_ok=True)
|
|
|
|
# 创建数据加载器
|
|
train_dataset = LungXrayDataset(
|
|
root_dir=args.data_dir,
|
|
is_train=True
|
|
)
|
|
|
|
test_dataset = LungXrayDataset(
|
|
root_dir=args.data_dir,
|
|
is_train=False
|
|
)
|
|
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=config['training']['batch_size'],
|
|
shuffle=True
|
|
)
|
|
test_loader = DataLoader(
|
|
test_dataset,
|
|
batch_size=config['training']['batch_size'],
|
|
shuffle=False
|
|
)
|
|
|
|
# 加载预训练的自编码器
|
|
autoencoder = Autoencoder()
|
|
autoencoder.load_state_dict(torch.load(args.autoencoder_path, map_location=device))
|
|
autoencoder = autoencoder.to(device)
|
|
autoencoder.eval() # 设置为评估模式
|
|
|
|
# 创建CNN模型
|
|
cnn_model = SimpleCNN()
|
|
|
|
# 如果指定了恢复训练的检查点
|
|
if args.resume:
|
|
print(f'Loading checkpoint from {args.resume}')
|
|
cnn_model.load_state_dict(torch.load(args.resume, map_location=device))
|
|
|
|
# 训练模型
|
|
history = train_cnn(
|
|
cnn_model=cnn_model,
|
|
autoencoder=autoencoder,
|
|
lr=config['training']['learning_rate'],
|
|
train_loader=train_loader,
|
|
test_loader=test_loader,
|
|
num_epochs=config['training']['num_epochs'],
|
|
device=device,
|
|
output_dir=args.output_dir,
|
|
noise_factor=args.noise_factor
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
main() |