You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
163 lines
6.8 KiB
163 lines
6.8 KiB
import argparse
|
|
import logging
|
|
import os
|
|
import random
|
|
import sys
|
|
import time
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from tensorboardX import SummaryWriter
|
|
from torch.nn.modules.loss import CrossEntropyLoss
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import transforms
|
|
from torch.cuda.amp import GradScaler, autocast
|
|
|
|
from utils.dataset_synapse import Synapse_dataset, RandomGenerator
|
|
from utils.utils import powerset, one_hot_encoder, DiceLoss, val_single_volume
|
|
|
|
def inference(args, model, best_performance):
|
|
db_test = Synapse_dataset(base_dir=args.volume_path, split="test_vol", list_dir=args.list_dir, nclass=args.num_classes)
|
|
|
|
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
|
|
logging.info("{} test iterations per epoch".format(len(testloader)))
|
|
model.eval()
|
|
metric_list = 0.0
|
|
for i_batch, sampled_batch in tqdm(enumerate(testloader)):
|
|
h, w = sampled_batch["image"].size()[2:]
|
|
image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0]
|
|
metric_i = val_single_volume(image, label, model, classes=args.num_classes, patch_size=[args.img_size, args.img_size],
|
|
case=case_name, z_spacing=args.z_spacing)
|
|
metric_list += np.array(metric_i)
|
|
metric_list = metric_list / len(db_test)
|
|
performance = np.mean(metric_list, axis=0)
|
|
logging.info('Testing performance in val model: mean_dice : %f, best_dice : %f' % (performance, best_performance))
|
|
return performance
|
|
|
|
def trainer_synapse(args, model, snapshot_path):
|
|
logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,
|
|
format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
|
|
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
|
|
logging.info(str(args))
|
|
base_lr = args.base_lr
|
|
num_classes = args.num_classes
|
|
batch_size = args.batch_size * args.n_gpu
|
|
|
|
db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train", nclass=args.num_classes,
|
|
transform=transforms.Compose(
|
|
[RandomGenerator(output_size=[args.img_size, args.img_size])]))
|
|
|
|
print("The length of train set is: {}".format(len(db_train)))
|
|
|
|
def worker_init_fn(worker_id):
|
|
random.seed(args.seed + worker_id)
|
|
|
|
trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,
|
|
worker_init_fn=worker_init_fn)
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
if torch.cuda.device_count() > 1 and args.n_gpu > 1:
|
|
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
|
model = nn.DataParallel(model)
|
|
model.to(device)
|
|
|
|
model.train()
|
|
ce_loss = CrossEntropyLoss()
|
|
dice_loss = DiceLoss(num_classes)
|
|
|
|
#optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
|
|
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.0001)
|
|
writer = SummaryWriter(snapshot_path + '/log')
|
|
iter_num = 0
|
|
max_epoch = args.max_epochs
|
|
max_iterations = args.max_epochs * len(trainloader)
|
|
logging.info("{} iterations per epoch. {} max iterations ".format(len(trainloader), max_iterations))
|
|
best_performance = 0.0
|
|
iterator = tqdm(range(max_epoch), ncols=70)
|
|
|
|
for epoch_num in iterator:
|
|
|
|
for i_batch, sampled_batch in enumerate(trainloader):
|
|
image_batch, label_batch = sampled_batch['image'], sampled_batch['label']
|
|
image_batch, label_batch = image_batch.cuda(), label_batch.squeeze(1).cuda()
|
|
|
|
P = model(image_batch, mode='train')
|
|
|
|
if not isinstance(P, list):
|
|
P = [P]
|
|
if epoch_num == 0 and i_batch == 0:
|
|
n_outs = len(P)
|
|
out_idxs = list(np.arange(n_outs)) #[0, 1, 2, 3]#, 4, 5, 6, 7]
|
|
if args.supervision == 'mutation':
|
|
ss = [x for x in powerset(out_idxs)]
|
|
elif args.supervision == 'deep_supervision':
|
|
ss = [[x] for x in out_idxs]
|
|
else:
|
|
ss = [[-1]]
|
|
print(ss)
|
|
|
|
loss = 0.0
|
|
w_ce, w_dice = 0.3, 0.7
|
|
|
|
for s in ss:
|
|
iout = 0.0
|
|
if(s==[]):
|
|
continue
|
|
for idx in range(len(s)):
|
|
iout += P[s[idx]]
|
|
loss_ce = ce_loss(iout, label_batch[:].long())
|
|
loss_dice = dice_loss(iout, label_batch, softmax=True)
|
|
loss += (w_ce * loss_ce + w_dice * loss_dice)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
#lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 # we did not use this
|
|
lr_ = base_lr
|
|
for param_group in optimizer.param_groups:
|
|
param_group['lr'] = lr_
|
|
|
|
iter_num = iter_num + 1
|
|
writer.add_scalar('info/lr', lr_, iter_num)
|
|
writer.add_scalar('info/total_loss', loss, iter_num)
|
|
|
|
|
|
if iter_num % 50 == 0:
|
|
logging.info('iteration %d, epoch %d : loss : %f, lr: %f' % (iter_num, epoch_num, loss.item(), lr_))
|
|
|
|
logging.info('iteration %d, epoch %d : loss : %f, lr: %f' % (iter_num, epoch_num, loss.item(), lr_))
|
|
|
|
save_mode_path = os.path.join(snapshot_path, 'last.pth')
|
|
torch.save(model.state_dict(), save_mode_path)
|
|
|
|
performance = inference(args, model, best_performance)
|
|
|
|
writer.add_scalar('info/val_mean_dice', performance, epoch_num)
|
|
writer.add_scalar('info/val_best_dice', best_performance, epoch_num)
|
|
|
|
save_interval = 50
|
|
|
|
if(best_performance <= performance):
|
|
best_performance = performance
|
|
save_mode_path = os.path.join(snapshot_path, 'best.pth')
|
|
torch.save(model.state_dict(), save_mode_path)
|
|
logging.info("save model to {}".format(save_mode_path))
|
|
|
|
if (epoch_num + 1) % save_interval == 0:
|
|
save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')
|
|
torch.save(model.state_dict(), save_mode_path)
|
|
logging.info("save model to {}".format(save_mode_path))
|
|
|
|
if epoch_num >= max_epoch - 1:
|
|
save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')
|
|
torch.save(model.state_dict(), save_mode_path)
|
|
logging.info("save model to {}".format(save_mode_path))
|
|
iterator.close()
|
|
break
|
|
|
|
writer.close()
|
|
return "Training Finished!"
|