You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
135 lines
6.0 KiB
135 lines
6.0 KiB
import argparse
|
|
import logging
|
|
import os
|
|
import random
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.backends.cudnn as cudnn
|
|
|
|
from lib.networks import EMCADNet
|
|
from trainer import trainer_synapse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('--root_path', type=str,
|
|
default='./data/synapse/train_npz', help='root dir for data')
|
|
parser.add_argument('--volume_path', type=str,
|
|
default='./data/synapse/test_vol_h5', help='root dir for validation volume data')
|
|
parser.add_argument('--dataset', type=str,
|
|
default='Synapse', help='experiment_name')
|
|
parser.add_argument('--list_dir', type=str,
|
|
default='./lists/lists_Synapse', help='list dir')
|
|
parser.add_argument('--num_classes', type=int,
|
|
default=9, help='output channel of network')
|
|
# network related parameters
|
|
parser.add_argument('--encoder', type=str,
|
|
default='pvt_v2_b2', help='Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...')
|
|
parser.add_argument('--expansion_factor', type=int,
|
|
default=2, help='expansion factor in MSCB block')
|
|
parser.add_argument('--kernel_sizes', type=int, nargs='+',
|
|
default=[1, 3, 5], help='multi-scale kernel sizes in MSDC block')
|
|
parser.add_argument('--lgag_ks', type=int,
|
|
default=3, help='Kernel size in LGAG')
|
|
parser.add_argument('--activation_mscb', type=str,
|
|
default='relu6', help='activation used in MSCB: relu6 or relu')
|
|
parser.add_argument('--no_dw_parallel', action='store_true',
|
|
default=False, help='use this flag to disable depth-wise parallel convolutions')
|
|
parser.add_argument('--concatenation', action='store_true',
|
|
default=False, help='use this flag to concatenate feature maps in MSDC block')
|
|
parser.add_argument('--no_pretrain', action='store_true',
|
|
default=False, help='use this flag to turn off loading pretrained enocder weights')
|
|
parser.add_argument('--pretrained_dir', type=str,
|
|
default='./pretrained_pth/pvt/', help='path to pretrained encoder dir')
|
|
parser.add_argument('--supervision', type=str,
|
|
default='mutation', help='loss supervision: mutation, deep_supervision or last_layer')
|
|
|
|
parser.add_argument('--max_iterations', type=int,
|
|
default=1000, help='maximum iteration number to train (适配8GB显存)')
|
|
parser.add_argument('--max_epochs', type=int,
|
|
default=10, help='maximum epoch number to train (适配8GB显存)')
|
|
parser.add_argument('--batch_size', type=int,
|
|
default=4, help='batch_size per gpu (适配8GB显存)')
|
|
parser.add_argument('--base_lr', type=float, default=0.0001,
|
|
help='segmentation network learning rate')
|
|
parser.add_argument('--img_size', type=int,
|
|
default=224, help='input patch size of network input (适配8GB显存)')
|
|
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
|
|
parser.add_argument('--deterministic', type=int, default=1,
|
|
help='whether use deterministic training')
|
|
parser.add_argument('--seed', type=int,
|
|
default=2222, help='random seed')
|
|
|
|
args = parser.parse_args()
|
|
|
|
if __name__ == "__main__":
|
|
if not args.deterministic:
|
|
cudnn.benchmark = True
|
|
cudnn.deterministic = False
|
|
else:
|
|
cudnn.benchmark = False
|
|
cudnn.deterministic = True
|
|
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
torch.cuda.manual_seed(args.seed)
|
|
|
|
dataset_name = args.dataset
|
|
dataset_config = {
|
|
'Synapse': {
|
|
'root_path': args.root_path,
|
|
'volume_path': args.volume_path,
|
|
'list_dir': args.list_dir,
|
|
'num_classes': args.num_classes,
|
|
'z_spacing': 1,
|
|
},
|
|
}
|
|
args.num_classes = dataset_config[dataset_name]['num_classes']
|
|
args.root_path = dataset_config[dataset_name]['root_path']
|
|
args.volume_path = dataset_config[dataset_name]['volume_path']
|
|
args.z_spacing = dataset_config[dataset_name]['z_spacing']
|
|
args.list_dir = dataset_config[dataset_name]['list_dir']
|
|
|
|
if args.concatenation:
|
|
aggregation = 'concat'
|
|
else:
|
|
aggregation = 'add'
|
|
|
|
if args.no_dw_parallel:
|
|
dw_mode = 'series'
|
|
else:
|
|
dw_mode = 'parallel'
|
|
|
|
run = 1
|
|
# Simplified path names to avoid Windows 260 character path limit
|
|
ks_str = ''.join(map(str, args.kernel_sizes)) # e.g., "135" instead of "[1, 3, 5]"
|
|
exp_name = f"{args.encoder}_ks{ks_str}_{dw_mode}_{aggregation}_Run{run}_{dataset_name}{args.img_size}"
|
|
args.exp = exp_name
|
|
|
|
# Create a shorter snapshot path
|
|
snapshot_name = f"{args.encoder}_ks{ks_str}_ef{args.expansion_factor}"
|
|
snapshot_name += '_pt' if not args.no_pretrain else ''
|
|
snapshot_name += f'_{args.max_iterations//1000}k' if args.max_iterations != 50000 else ''
|
|
snapshot_name += f'_e{args.max_epochs}' if args.max_epochs != 300 else ''
|
|
snapshot_name += f'_bs{args.batch_size}'
|
|
snapshot_name += f'_lr{args.base_lr}' if args.base_lr != 0.0001 else ''
|
|
snapshot_name += f'_{args.img_size}'
|
|
snapshot_name += f'_s{args.seed}' if args.seed != 1234 else ''
|
|
|
|
snapshot_path = os.path.join("model_pth", exp_name, snapshot_name)
|
|
|
|
# Convert to absolute path and create directory
|
|
snapshot_path = os.path.abspath(snapshot_path)
|
|
from pathlib import Path
|
|
Path(snapshot_path).mkdir(parents=True, exist_ok=True)
|
|
|
|
model = EMCADNet(num_classes=args.num_classes, kernel_sizes=args.kernel_sizes, expansion_factor=args.expansion_factor, dw_parallel=not args.no_dw_parallel, add=not args.concatenation, lgag_ks=args.lgag_ks, activation=args.activation_mscb, encoder=args.encoder, pretrain= not args.no_pretrain, pretrained_dir=args.pretrained_dir)
|
|
|
|
model.cuda()
|
|
|
|
print('Model successfully created.')
|
|
|
|
trainer = {'Synapse': trainer_synapse,}
|
|
trainer[dataset_name](args, model, snapshot_path)
|