You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

163 lines
6.8 KiB

import argparse
import logging
import os
import random
import sys
import time
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from torch.nn.modules.loss import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.cuda.amp import GradScaler, autocast
from utils.dataset_synapse import Synapse_dataset, RandomGenerator
from utils.utils import powerset, one_hot_encoder, DiceLoss, val_single_volume
def inference(args, model, best_performance):
db_test = Synapse_dataset(base_dir=args.volume_path, split="test_vol", list_dir=args.list_dir, nclass=args.num_classes)
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
logging.info("{} test iterations per epoch".format(len(testloader)))
model.eval()
metric_list = 0.0
for i_batch, sampled_batch in tqdm(enumerate(testloader)):
h, w = sampled_batch["image"].size()[2:]
image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0]
metric_i = val_single_volume(image, label, model, classes=args.num_classes, patch_size=[args.img_size, args.img_size],
case=case_name, z_spacing=args.z_spacing)
metric_list += np.array(metric_i)
metric_list = metric_list / len(db_test)
performance = np.mean(metric_list, axis=0)
logging.info('Testing performance in val model: mean_dice : %f, best_dice : %f' % (performance, best_performance))
return performance
def trainer_synapse(args, model, snapshot_path):
logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,
format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info(str(args))
base_lr = args.base_lr
num_classes = args.num_classes
batch_size = args.batch_size * args.n_gpu
db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train", nclass=args.num_classes,
transform=transforms.Compose(
[RandomGenerator(output_size=[args.img_size, args.img_size])]))
print("The length of train set is: {}".format(len(db_train)))
def worker_init_fn(worker_id):
random.seed(args.seed + worker_id)
trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,
worker_init_fn=worker_init_fn)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1 and args.n_gpu > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
model.to(device)
model.train()
ce_loss = CrossEntropyLoss()
dice_loss = DiceLoss(num_classes)
#optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.0001)
writer = SummaryWriter(snapshot_path + '/log')
iter_num = 0
max_epoch = args.max_epochs
max_iterations = args.max_epochs * len(trainloader)
logging.info("{} iterations per epoch. {} max iterations ".format(len(trainloader), max_iterations))
best_performance = 0.0
iterator = tqdm(range(max_epoch), ncols=70)
for epoch_num in iterator:
for i_batch, sampled_batch in enumerate(trainloader):
image_batch, label_batch = sampled_batch['image'], sampled_batch['label']
image_batch, label_batch = image_batch.cuda(), label_batch.squeeze(1).cuda()
P = model(image_batch, mode='train')
if not isinstance(P, list):
P = [P]
if epoch_num == 0 and i_batch == 0:
n_outs = len(P)
out_idxs = list(np.arange(n_outs)) #[0, 1, 2, 3]#, 4, 5, 6, 7]
if args.supervision == 'mutation':
ss = [x for x in powerset(out_idxs)]
elif args.supervision == 'deep_supervision':
ss = [[x] for x in out_idxs]
else:
ss = [[-1]]
print(ss)
loss = 0.0
w_ce, w_dice = 0.3, 0.7
for s in ss:
iout = 0.0
if(s==[]):
continue
for idx in range(len(s)):
iout += P[s[idx]]
loss_ce = ce_loss(iout, label_batch[:].long())
loss_dice = dice_loss(iout, label_batch, softmax=True)
loss += (w_ce * loss_ce + w_dice * loss_dice)
optimizer.zero_grad()
loss.backward()
optimizer.step()
#lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 # we did not use this
lr_ = base_lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr_
iter_num = iter_num + 1
writer.add_scalar('info/lr', lr_, iter_num)
writer.add_scalar('info/total_loss', loss, iter_num)
if iter_num % 50 == 0:
logging.info('iteration %d, epoch %d : loss : %f, lr: %f' % (iter_num, epoch_num, loss.item(), lr_))
logging.info('iteration %d, epoch %d : loss : %f, lr: %f' % (iter_num, epoch_num, loss.item(), lr_))
save_mode_path = os.path.join(snapshot_path, 'last.pth')
torch.save(model.state_dict(), save_mode_path)
performance = inference(args, model, best_performance)
writer.add_scalar('info/val_mean_dice', performance, epoch_num)
writer.add_scalar('info/val_best_dice', best_performance, epoch_num)
save_interval = 50
if(best_performance <= performance):
best_performance = performance
save_mode_path = os.path.join(snapshot_path, 'best.pth')
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to {}".format(save_mode_path))
if (epoch_num + 1) % save_interval == 0:
save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to {}".format(save_mode_path))
if epoch_num >= max_epoch - 1:
save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to {}".format(save_mode_path))
iterator.close()
break
writer.close()
return "Training Finished!"