ADD file via upload

main
p8cpi3xy4 4 months ago
parent 2323a106a9
commit 6ebc1820c3

@ -0,0 +1,133 @@
import sys
import os
# 获取当前文件train_synapse.py所在的目录EMCAD目录
current_dir = os.path.dirname(os.path.abspath(__file__))
# 将EMCAD目录添加到Python搜索路径这样就能找到同级的lib目录
sys.path.append(current_dir)
import timm.models
import argparse
import logging
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from lib.networks import EMCADNet # 现在应该能找到了
from trainer import trainer_synapse
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
default='C:\Users\21608\PycharmProjects\pythonProject5\EMCAD\pvt_v2_b2_EMCAD_kernel_sizes_[1, 3, 5]_dw_parallel_add_lgag_ks_3_ef2_act_mscb_relu6_loss_mutation_output_final_layer_Run1_Synapse224\test_predictions', help='root dir for data')
parser.add_argument('--volume_path', type=str,
default='C:/Users/21608/PycharmProjects/pythonProject5/data/synapse/test_vol_h5_new', help='root dir for validation volume data')
parser.add_argument('--dataset', type=str,
default='Synapse', help='experiment_name')
parser.add_argument('--list_dir', type=str,
default='./lists/lists_Synapse', help='list dir')
parser.add_argument('--num_classes', type=int,
default=9, help='output channel of network')
# network related parameters
parser.add_argument('--encoder', type=str,
default='pvt_v2_b2', help='Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...')
parser.add_argument('--expansion_factor', type=int,
default=2, help='expansion factor in MSCB block')
parser.add_argument('--kernel_sizes', type=int, nargs='+',
default=[1, 3, 5], help='multi-scale kernel sizes in MSDC block')
parser.add_argument('--lgag_ks', type=int,
default=3, help='Kernel size in LGAG')
parser.add_argument('--activation_mscb', type=str,
default='relu6', help='activation used in MSCB: relu6 or relu')
parser.add_argument('--no_dw_parallel', action='store_true',
default=False, help='use this flag to disable depth-wise parallel convolutions')
parser.add_argument('--concatenation', action='store_true',
default=False, help='use this flag to concatenate feature maps in MSDC block')
parser.add_argument('--no_pretrain', action='store_true',
default=False, help='use this flag to turn off loading pretrained enocder weights')
parser.add_argument('--pretrained_dir', type=str,
default='./pretrained_pth/pvt/', help='path to pretrained encoder dir')
parser.add_argument('--supervision', type=str,
default='mutation', help='loss supervision: mutation, deep_supervision or last_layer')
parser.add_argument('--max_iterations', type=int,
default=50000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int,
default=300, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int,
default=6, help='batch_size per gpu')
parser.add_argument('--base_lr', type=float, default=0.0001,
help='segmentation network learning rate')
parser.add_argument('--img_size', type=int,
default=224, help='input patch size of network input')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--deterministic', type=int, default=1,
help='whether use deterministic training')
parser.add_argument('--seed', type=int,
default=2222, help='random seed')
args = parser.parse_args()
if __name__ == "__main__":
if not args.deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
dataset_name = args.dataset
dataset_config = {
'Synapse': {
'root_path': args.root_path,
'volume_path': args.volume_path,
'list_dir': args.list_dir,
'num_classes': args.num_classes,
'z_spacing': 1,
},
}
args.num_classes = dataset_config[dataset_name]['num_classes']
args.root_path = dataset_config[dataset_name]['root_path']
args.volume_path = dataset_config[dataset_name]['volume_path']
args.z_spacing = dataset_config[dataset_name]['z_spacing']
args.list_dir = dataset_config[dataset_name]['list_dir']
if args.concatenation:
aggregation = 'concat'
else:
aggregation = 'add'
if args.no_dw_parallel:
dw_mode = 'series'
else:
dw_mode = 'parallel'
run = 1
args.exp = args.encoder + '_EMCAD_kernel_sizes_' + str(args.kernel_sizes) + '_dw_' + dw_mode + '_' + aggregation + '_lgag_ks_' + str(args.lgag_ks) + '_ef' + str(args.expansion_factor) + '_act_mscb_' + args.activation_mscb + '_loss_' + args.supervision + '_output_final_layer_Run'+str(run)+'_' + dataset_name + str(args.img_size)
snapshot_path = "model_pth/{}/{}".format(args.exp, args.encoder + '_EMCAD_kernel_sizes_' + str(args.kernel_sizes) + '_dw_' + dw_mode + '_' + aggregation + '_lgag_ks_' + str(args.lgag_ks) + '_ef' + str(args.expansion_factor) + '_act_mscb_' + args.activation_mscb + '_loss_' + args.supervision + '_output_final_layer_Run'+str(run))
snapshot_path = snapshot_path.replace('[', '').replace(']', '').replace(', ', '_')
snapshot_path = snapshot_path + '_pretrain' if not args.no_pretrain else snapshot_path
snapshot_path = snapshot_path+'_'+str(args.max_iterations)[0:2]+'k' if args.max_iterations != 50000 else snapshot_path
snapshot_path = snapshot_path + '_epo' +str(args.max_epochs) if args.max_epochs != 300 else snapshot_path
snapshot_path = snapshot_path+'_bs'+str(args.batch_size)
snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.0001 else snapshot_path
snapshot_path = snapshot_path + '_'+str(args.img_size)
snapshot_path = snapshot_path + '_s'+str(args.seed) if args.seed!=1234 else snapshot_path
if not os.path.exists(snapshot_path):
os.makedirs(snapshot_path)
model = EMCADNet(num_classes=args.num_classes, kernel_sizes=args.kernel_sizes, expansion_factor=args.expansion_factor, dw_parallel=not args.no_dw_parallel, add=not args.concatenation, lgag_ks=args.lgag_ks, activation=args.activation_mscb, encoder=args.encoder, pretrain= not args.no_pretrain, pretrained_dir=args.pretrained_dir)
model.cuda()
print('Model successfully created.')
trainer = {'Synapse': trainer_synapse,}
trainer[dataset_name](args, model, snapshot_path)
Loading…
Cancel
Save