"""Synapse 与 ACDC 训练入口。""" import argparse import os import random import numpy as np import torch import torch.backends.cudnn as cudnn from src.core.networks import EMCADNet from src.utils.trainer import trainer_ACDC, trainer_synapse def build_parser(): """构建训练参数解析器。""" parser = argparse.ArgumentParser() parser.add_argument( "--root_path", type=str, default="/data/ACDC/train", help="root dir for training data (ACDC: /data/ACDC/train)", ) parser.add_argument( "--volume_path", type=str, default="/data/ACDC/test", help="root dir for validation/test volume data", ) parser.add_argument( "--dataset", type=str, default="ACDC", choices=["Synapse", "ACDC"], help="experiment name", ) parser.add_argument( "--list_dir", type=str, default="/data/ACDC/lists_ACDC", help="list dir (ACDC: /data/ACDC/lists_ACDC)", ) parser.add_argument( "--num_classes", type=int, default=4, help="output channel of network (ACDC = 4)", ) parser.add_argument( "--encoder", type=str, default="pvt_v2_b2", help="Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...", ) parser.add_argument( "--expansion_factor", type=int, default=2, help="expansion factor in MSCB block", ) parser.add_argument( "--kernel_sizes", type=int, nargs="+", default=[1, 3, 5], help="multi-scale kernel sizes in MSDC block", ) parser.add_argument( "--lgag_ks", type=int, default=3, help="Kernel size in LGAG block" ) parser.add_argument( "--activation_mscb", type=str, default="relu6", help="activation used in MSCB: relu6 or relu", ) parser.add_argument( "--no_dw_parallel", action="store_true", default=False, help="use this flag to disable depth-wise parallel convolutions", ) parser.add_argument( "--concatenation", action="store_true", default=False, help="use this flag to concatenate feature maps in MSDC block", ) parser.add_argument( "--no_pretrain", action="store_true", default=False, help="use this flag to turn off loading pretrained encoder weights", ) parser.add_argument( "--pretrained_dir", type=str, default="./model_pth/", help="path to pretrained encoder dir, e.g. ./model_pth/", ) parser.add_argument( "--supervision", type=str, default="mutation", help="loss supervision: mutation, deep_supervision or last_layer", ) parser.add_argument( "--max_iterations", type=int, default=50000, help="maximum total iterations" ) parser.add_argument( "--max_epochs", type=int, default=300, help="maximum epoch number to train" ) parser.add_argument("--batch_size", type=int, default=6, help="batch_size per gpu") parser.add_argument( "--base_lr", type=float, default=0.0001, help="segmentation network learning rate", ) parser.add_argument( "--img_size", type=int, default=224, help="input patch size of network input", ) parser.add_argument("--n_gpu", type=int, default=1, help="total gpu") parser.add_argument( "--deterministic", type=int, default=1, help="whether use deterministic training", ) parser.add_argument("--seed", type=int, default=2222, help="random seed") return parser def set_deterministic(seed, deterministic): """配置随机种子与确定性行为。""" if not deterministic: cudnn.benchmark = True cudnn.deterministic = False else: cudnn.benchmark = False cudnn.deterministic = True random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) def build_snapshot_path(args, dataset_name): """根据参数生成输出目录。""" aggregation = "concat" if args.concatenation else "add" dw_mode = "series" if args.no_dw_parallel else "parallel" run = 1 exp = ( args.encoder + "_EMCAD_kernel_sizes_" + str(args.kernel_sizes) + "_dw_" + dw_mode + "_" + aggregation + "_lgag_ks_" + str(args.lgag_ks) + "_ef" + str(args.expansion_factor) + "_act_mscb_" + args.activation_mscb + "_loss_" + args.supervision + "_output_final_layer_Run" + str(run) + "_" + dataset_name + str(args.img_size) ) snapshot_path = "model_pth/{}/{}".format( exp, args.encoder + "_EMCAD_kernel_sizes_" + str(args.kernel_sizes) + "_dw_" + dw_mode + "_" + aggregation + "_lgag_ks_" + str(args.lgag_ks) + "_ef" + str(args.expansion_factor) + "_act_mscb_" + args.activation_mscb + "_loss_" + args.supervision + "_output_final_layer_Run" + str(run), ) snapshot_path = snapshot_path.replace("[", "").replace("]", "").replace(", ", "_") if not args.no_pretrain: snapshot_path = snapshot_path + "_pretrain" if args.max_iterations != 50000: snapshot_path = snapshot_path + "_" + str(args.max_iterations)[0:2] + "k" if args.max_epochs != 300: snapshot_path = snapshot_path + "_epo" + str(args.max_epochs) snapshot_path = snapshot_path + "_bs" + str(args.batch_size) if args.base_lr != 0.0001: snapshot_path = snapshot_path + "_lr" + str(args.base_lr) snapshot_path = snapshot_path + "_" + str(args.img_size) if args.seed != 1234: snapshot_path = snapshot_path + "_s" + str(args.seed) return exp, snapshot_path def main(): """主入口函数。""" parser = build_parser() args = parser.parse_args() set_deterministic(args.seed, args.deterministic) dataset_name = args.dataset acdc_root = args.root_path if dataset_name == "ACDC": tmp = args.root_path.rstrip("/") if os.path.basename(tmp) == "train": acdc_root = os.path.dirname(tmp) else: acdc_root = tmp dataset_config = { "Synapse": { "root_path": args.root_path, "volume_path": args.volume_path, "list_dir": args.list_dir, "num_classes": args.num_classes, "z_spacing": 1, }, "ACDC": { "root_path": acdc_root, "volume_path": args.volume_path, "list_dir": args.list_dir, "num_classes": args.num_classes, "z_spacing": 1, }, } cfg = dataset_config[dataset_name] args.num_classes = cfg["num_classes"] args.root_path = cfg["root_path"] args.volume_path = cfg["volume_path"] args.z_spacing = cfg["z_spacing"] args.list_dir = cfg["list_dir"] args.exp, snapshot_path = build_snapshot_path(args, dataset_name) if not os.path.exists(snapshot_path): os.makedirs(snapshot_path) model = EMCADNet( num_classes=args.num_classes, kernel_sizes=args.kernel_sizes, expansion_factor=args.expansion_factor, dw_parallel=not args.no_dw_parallel, add=not args.concatenation, lgag_ks=args.lgag_ks, activation=args.activation_mscb, encoder=args.encoder, pretrain=not args.no_pretrain, pretrained_dir=args.pretrained_dir, ) model.cuda() print("Model successfully created.") trainer_map = {"Synapse": trainer_synapse, "ACDC": trainer_ACDC} trainer_map[dataset_name](args, model, snapshot_path) if __name__ == "__main__": main()