You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

135 lines
6.0 KiB

import argparse
import logging
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from lib.networks import EMCADNet
from trainer import trainer_synapse
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
default='./data/synapse/train_npz', help='root dir for data')
parser.add_argument('--volume_path', type=str,
default='./data/synapse/test_vol_h5', help='root dir for validation volume data')
parser.add_argument('--dataset', type=str,
default='Synapse', help='experiment_name')
parser.add_argument('--list_dir', type=str,
default='./lists/lists_Synapse', help='list dir')
parser.add_argument('--num_classes', type=int,
default=9, help='output channel of network')
# network related parameters
parser.add_argument('--encoder', type=str,
default='pvt_v2_b2', help='Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...')
parser.add_argument('--expansion_factor', type=int,
default=2, help='expansion factor in MSCB block')
parser.add_argument('--kernel_sizes', type=int, nargs='+',
default=[1, 3, 5], help='multi-scale kernel sizes in MSDC block')
parser.add_argument('--lgag_ks', type=int,
default=3, help='Kernel size in LGAG')
parser.add_argument('--activation_mscb', type=str,
default='relu6', help='activation used in MSCB: relu6 or relu')
parser.add_argument('--no_dw_parallel', action='store_true',
default=False, help='use this flag to disable depth-wise parallel convolutions')
parser.add_argument('--concatenation', action='store_true',
default=False, help='use this flag to concatenate feature maps in MSDC block')
parser.add_argument('--no_pretrain', action='store_true',
default=False, help='use this flag to turn off loading pretrained enocder weights')
parser.add_argument('--pretrained_dir', type=str,
default='./pretrained_pth/pvt/', help='path to pretrained encoder dir')
parser.add_argument('--supervision', type=str,
default='mutation', help='loss supervision: mutation, deep_supervision or last_layer')
parser.add_argument('--max_iterations', type=int,
default=1000, help='maximum iteration number to train (适配8GB显存)')
parser.add_argument('--max_epochs', type=int,
default=10, help='maximum epoch number to train (适配8GB显存)')
parser.add_argument('--batch_size', type=int,
default=4, help='batch_size per gpu (适配8GB显存)')
parser.add_argument('--base_lr', type=float, default=0.0001,
help='segmentation network learning rate')
parser.add_argument('--img_size', type=int,
default=224, help='input patch size of network input (适配8GB显存)')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--deterministic', type=int, default=1,
help='whether use deterministic training')
parser.add_argument('--seed', type=int,
default=2222, help='random seed')
args = parser.parse_args()
if __name__ == "__main__":
if not args.deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
dataset_name = args.dataset
dataset_config = {
'Synapse': {
'root_path': args.root_path,
'volume_path': args.volume_path,
'list_dir': args.list_dir,
'num_classes': args.num_classes,
'z_spacing': 1,
},
}
args.num_classes = dataset_config[dataset_name]['num_classes']
args.root_path = dataset_config[dataset_name]['root_path']
args.volume_path = dataset_config[dataset_name]['volume_path']
args.z_spacing = dataset_config[dataset_name]['z_spacing']
args.list_dir = dataset_config[dataset_name]['list_dir']
if args.concatenation:
aggregation = 'concat'
else:
aggregation = 'add'
if args.no_dw_parallel:
dw_mode = 'series'
else:
dw_mode = 'parallel'
run = 1
# Simplified path names to avoid Windows 260 character path limit
ks_str = ''.join(map(str, args.kernel_sizes)) # e.g., "135" instead of "[1, 3, 5]"
exp_name = f"{args.encoder}_ks{ks_str}_{dw_mode}_{aggregation}_Run{run}_{dataset_name}{args.img_size}"
args.exp = exp_name
# Create a shorter snapshot path
snapshot_name = f"{args.encoder}_ks{ks_str}_ef{args.expansion_factor}"
snapshot_name += '_pt' if not args.no_pretrain else ''
snapshot_name += f'_{args.max_iterations//1000}k' if args.max_iterations != 50000 else ''
snapshot_name += f'_e{args.max_epochs}' if args.max_epochs != 300 else ''
snapshot_name += f'_bs{args.batch_size}'
snapshot_name += f'_lr{args.base_lr}' if args.base_lr != 0.0001 else ''
snapshot_name += f'_{args.img_size}'
snapshot_name += f'_s{args.seed}' if args.seed != 1234 else ''
snapshot_path = os.path.join("model_pth", exp_name, snapshot_name)
# Convert to absolute path and create directory
snapshot_path = os.path.abspath(snapshot_path)
from pathlib import Path
Path(snapshot_path).mkdir(parents=True, exist_ok=True)
model = EMCADNet(num_classes=args.num_classes, kernel_sizes=args.kernel_sizes, expansion_factor=args.expansion_factor, dw_parallel=not args.no_dw_parallel, add=not args.concatenation, lgag_ks=args.lgag_ks, activation=args.activation_mscb, encoder=args.encoder, pretrain= not args.no_pretrain, pretrained_dir=args.pretrained_dir)
model.cuda()
print('Model successfully created.')
trainer = {'Synapse': trainer_synapse,}
trainer[dataset_name](args, model, snapshot_path)