You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

175 lines
7.3 KiB

import argparse
import logging
import os
import random
import sys
import time
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from torch.nn.modules.loss import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.cuda.amp import GradScaler, autocast
from utils.dataset_synapse import Synapse_dataset, RandomGenerator
from utils.utils import powerset, one_hot_encoder, DiceLoss, val_single_volume
import platform
# Global seed for worker initialization (needed for Windows multiprocessing)
_global_seed = 0
def worker_init_fn(worker_id):
"""Worker init function - must be at module level for Windows pickle support"""
random.seed(_global_seed + worker_id)
np.random.seed(_global_seed + worker_id)
def inference(args, model, best_performance):
db_test = Synapse_dataset(base_dir=args.volume_path, split="test_vol", list_dir=args.list_dir, nclass=args.num_classes)
# Use num_workers=0 on Windows to avoid multiprocessing hang on exit
num_workers = 0 if platform.system() == 'Windows' else 2
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True)
logging.info("{} test iterations per epoch".format(len(testloader)))
model.eval()
metric_list = 0.0
with torch.no_grad():
for i_batch, sampled_batch in tqdm(enumerate(testloader)):
h, w = sampled_batch["image"].size()[2:]
image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0]
metric_i = val_single_volume(image, label, model, classes=args.num_classes, patch_size=[args.img_size, args.img_size],
case=case_name, z_spacing=args.z_spacing)
metric_list += np.array(metric_i)
metric_list = metric_list / len(db_test)
performance = np.mean(metric_list, axis=0)
logging.info('Testing performance in val model: mean_dice : %f, best_dice : %f' % (performance, best_performance))
return performance
def trainer_synapse(args, model, snapshot_path):
logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,
format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info(str(args))
base_lr = args.base_lr
num_classes = args.num_classes
batch_size = args.batch_size * args.n_gpu
db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train", nclass=args.num_classes,
transform=transforms.Compose(
[RandomGenerator(output_size=[args.img_size, args.img_size])]))
print("The length of train set is: {}".format(len(db_train)))
# Set global seed for worker init function
global _global_seed
_global_seed = args.seed
# Use num_workers=0 on Windows to avoid multiprocessing pickle issues
num_workers = 0 if platform.system() == 'Windows' else 8
trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True,
worker_init_fn=worker_init_fn)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1 and args.n_gpu > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
model.to(device)
model.train()
ce_loss = CrossEntropyLoss()
dice_loss = DiceLoss(num_classes)
#optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.0001)
writer = SummaryWriter(snapshot_path + '/log')
iter_num = 0
max_epoch = args.max_epochs
max_iterations = args.max_epochs * len(trainloader)
logging.info("{} iterations per epoch. {} max iterations ".format(len(trainloader), max_iterations))
best_performance = 0.0
iterator = tqdm(range(max_epoch), ncols=70)
for epoch_num in iterator:
for i_batch, sampled_batch in enumerate(trainloader):
image_batch, label_batch = sampled_batch['image'], sampled_batch['label']
image_batch, label_batch = image_batch.cuda(), label_batch.squeeze(1).cuda()
P = model(image_batch, mode='train')
if not isinstance(P, list):
P = [P]
if epoch_num == 0 and i_batch == 0:
n_outs = len(P)
out_idxs = list(np.arange(n_outs)) #[0, 1, 2, 3]#, 4, 5, 6, 7]
if args.supervision == 'mutation':
ss = [x for x in powerset(out_idxs)]
elif args.supervision == 'deep_supervision':
ss = [[x] for x in out_idxs]
else:
ss = [[-1]]
print(ss)
loss = 0.0
w_ce, w_dice = 0.3, 0.7
for s in ss:
iout = 0.0
if(s==[]):
continue
for idx in range(len(s)):
iout += P[s[idx]]
loss_ce = ce_loss(iout, label_batch[:].long())
loss_dice = dice_loss(iout, label_batch, softmax=True)
loss += (w_ce * loss_ce + w_dice * loss_dice)
optimizer.zero_grad()
loss.backward()
optimizer.step()
#lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 # we did not use this
lr_ = base_lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr_
iter_num = iter_num + 1
writer.add_scalar('info/lr', lr_, iter_num)
writer.add_scalar('info/total_loss', loss, iter_num)
if iter_num % 50 == 0:
logging.info('iteration %d, epoch %d : loss : %f, lr: %f' % (iter_num, epoch_num, loss.item(), lr_))
logging.info('iteration %d, epoch %d : loss : %f, lr: %f' % (iter_num, epoch_num, loss.item(), lr_))
save_mode_path = os.path.join(snapshot_path, 'last.pth')
torch.save(model.state_dict(), save_mode_path)
performance = inference(args, model, best_performance)
save_interval = 50
if(best_performance <= performance):
best_performance = performance
save_mode_path = os.path.join(snapshot_path, 'best.pth')
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to {}".format(save_mode_path))
if (epoch_num + 1) % save_interval == 0:
save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to {}".format(save_mode_path))
if epoch_num >= max_epoch - 1:
save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to {}".format(save_mode_path))
iterator.close()
break
writer.close()
return "Training Finished!"