You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

290 lines
7.8 KiB

"""Synapse 与 ACDC 训练入口。"""
import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from src.core.networks import EMCADNet
from src.utils.trainer import trainer_ACDC, trainer_synapse
def build_parser():
"""构建训练参数解析器。"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--root_path",
type=str,
default="/data/ACDC/train",
help="root dir for training data (ACDC: /data/ACDC/train)",
)
parser.add_argument(
"--volume_path",
type=str,
default="/data/ACDC/test",
help="root dir for validation/test volume data",
)
parser.add_argument(
"--dataset",
type=str,
default="ACDC",
choices=["Synapse", "ACDC"],
help="experiment name",
)
parser.add_argument(
"--list_dir",
type=str,
default="/data/ACDC/lists_ACDC",
help="list dir (ACDC: /data/ACDC/lists_ACDC)",
)
parser.add_argument(
"--num_classes",
type=int,
default=4,
help="output channel of network (ACDC = 4)",
)
parser.add_argument(
"--encoder",
type=str,
default="pvt_v2_b2",
help="Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...",
)
parser.add_argument(
"--expansion_factor",
type=int,
default=2,
help="expansion factor in MSCB block",
)
parser.add_argument(
"--kernel_sizes",
type=int,
nargs="+",
default=[1, 3, 5],
help="multi-scale kernel sizes in MSDC block",
)
parser.add_argument(
"--lgag_ks", type=int, default=3, help="Kernel size in LGAG block"
)
parser.add_argument(
"--activation_mscb",
type=str,
default="relu6",
help="activation used in MSCB: relu6 or relu",
)
parser.add_argument(
"--no_dw_parallel",
action="store_true",
default=False,
help="use this flag to disable depth-wise parallel convolutions",
)
parser.add_argument(
"--concatenation",
action="store_true",
default=False,
help="use this flag to concatenate feature maps in MSDC block",
)
parser.add_argument(
"--no_pretrain",
action="store_true",
default=False,
help="use this flag to turn off loading pretrained encoder weights",
)
parser.add_argument(
"--pretrained_dir",
type=str,
default="./model_pth/",
help="path to pretrained encoder dir, e.g. ./model_pth/",
)
parser.add_argument(
"--supervision",
type=str,
default="mutation",
help="loss supervision: mutation, deep_supervision or last_layer",
)
parser.add_argument(
"--max_iterations", type=int, default=50000, help="maximum total iterations"
)
parser.add_argument(
"--max_epochs", type=int, default=300, help="maximum epoch number to train"
)
parser.add_argument("--batch_size", type=int, default=6, help="batch_size per gpu")
parser.add_argument(
"--base_lr",
type=float,
default=0.0001,
help="segmentation network learning rate",
)
parser.add_argument(
"--img_size",
type=int,
default=224,
help="input patch size of network input",
)
parser.add_argument("--n_gpu", type=int, default=1, help="total gpu")
parser.add_argument(
"--deterministic",
type=int,
default=1,
help="whether use deterministic training",
)
parser.add_argument("--seed", type=int, default=2222, help="random seed")
return parser
def set_deterministic(seed, deterministic):
"""配置随机种子与确定性行为。"""
if not deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def build_snapshot_path(args, dataset_name):
"""根据参数生成输出目录。"""
aggregation = "concat" if args.concatenation else "add"
dw_mode = "series" if args.no_dw_parallel else "parallel"
run = 1
exp = (
args.encoder
+ "_EMCAD_kernel_sizes_"
+ str(args.kernel_sizes)
+ "_dw_"
+ dw_mode
+ "_"
+ aggregation
+ "_lgag_ks_"
+ str(args.lgag_ks)
+ "_ef"
+ str(args.expansion_factor)
+ "_act_mscb_"
+ args.activation_mscb
+ "_loss_"
+ args.supervision
+ "_output_final_layer_Run"
+ str(run)
+ "_"
+ dataset_name
+ str(args.img_size)
)
snapshot_path = "model_pth/{}/{}".format(
exp,
args.encoder
+ "_EMCAD_kernel_sizes_"
+ str(args.kernel_sizes)
+ "_dw_"
+ dw_mode
+ "_"
+ aggregation
+ "_lgag_ks_"
+ str(args.lgag_ks)
+ "_ef"
+ str(args.expansion_factor)
+ "_act_mscb_"
+ args.activation_mscb
+ "_loss_"
+ args.supervision
+ "_output_final_layer_Run"
+ str(run),
)
snapshot_path = snapshot_path.replace("[", "").replace("]", "").replace(", ", "_")
if not args.no_pretrain:
snapshot_path = snapshot_path + "_pretrain"
if args.max_iterations != 50000:
snapshot_path = snapshot_path + "_" + str(args.max_iterations)[0:2] + "k"
if args.max_epochs != 300:
snapshot_path = snapshot_path + "_epo" + str(args.max_epochs)
snapshot_path = snapshot_path + "_bs" + str(args.batch_size)
if args.base_lr != 0.0001:
snapshot_path = snapshot_path + "_lr" + str(args.base_lr)
snapshot_path = snapshot_path + "_" + str(args.img_size)
if args.seed != 1234:
snapshot_path = snapshot_path + "_s" + str(args.seed)
return exp, snapshot_path
def main():
"""主入口函数。"""
parser = build_parser()
args = parser.parse_args()
set_deterministic(args.seed, args.deterministic)
dataset_name = args.dataset
acdc_root = args.root_path
if dataset_name == "ACDC":
tmp = args.root_path.rstrip("/")
if os.path.basename(tmp) == "train":
acdc_root = os.path.dirname(tmp)
else:
acdc_root = tmp
dataset_config = {
"Synapse": {
"root_path": args.root_path,
"volume_path": args.volume_path,
"list_dir": args.list_dir,
"num_classes": args.num_classes,
"z_spacing": 1,
},
"ACDC": {
"root_path": acdc_root,
"volume_path": args.volume_path,
"list_dir": args.list_dir,
"num_classes": args.num_classes,
"z_spacing": 1,
},
}
cfg = dataset_config[dataset_name]
args.num_classes = cfg["num_classes"]
args.root_path = cfg["root_path"]
args.volume_path = cfg["volume_path"]
args.z_spacing = cfg["z_spacing"]
args.list_dir = cfg["list_dir"]
args.exp, snapshot_path = build_snapshot_path(args, dataset_name)
if not os.path.exists(snapshot_path):
os.makedirs(snapshot_path)
model = EMCADNet(
num_classes=args.num_classes,
kernel_sizes=args.kernel_sizes,
expansion_factor=args.expansion_factor,
dw_parallel=not args.no_dw_parallel,
add=not args.concatenation,
lgag_ks=args.lgag_ks,
activation=args.activation_mscb,
encoder=args.encoder,
pretrain=not args.no_pretrain,
pretrained_dir=args.pretrained_dir,
)
model.cuda()
print("Model successfully created.")
trainer_map = {"Synapse": trainer_synapse, "ACDC": trainer_ACDC}
trainer_map[dataset_name](args, model, snapshot_path)
if __name__ == "__main__":
main()