|
|
from functools import partial
|
|
|
# 定义全局的 worker_init_fn,接受 seed 和 worker_id
|
|
|
def worker_init_fn(worker_id, seed):
|
|
|
random.seed(seed + worker_id) # 确保每个worker的种子不同
|
|
|
import argparse
|
|
|
import logging
|
|
|
import os
|
|
|
import random
|
|
|
import sys
|
|
|
import time
|
|
|
import numpy as np
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
from tensorboardX import SummaryWriter
|
|
|
from torch.nn.modules.loss import CrossEntropyLoss
|
|
|
from torch.utils.data import DataLoader
|
|
|
from torchvision import transforms
|
|
|
from torch.cuda.amp import GradScaler, autocast
|
|
|
|
|
|
from utils.dataset_synapse import Synapse_dataset, RandomGenerator
|
|
|
from utils.utils import powerset, one_hot_encoder, DiceLoss, val_single_volume
|
|
|
|
|
|
|
|
|
def inference(args, model, best_performance):
|
|
|
db_test = Synapse_dataset(base_dir=args.volume_path,
|
|
|
split="test_vol",
|
|
|
list_dir=args.list_dir,
|
|
|
nclass=args.num_classes)
|
|
|
|
|
|
testloader = DataLoader(db_test,
|
|
|
batch_size=1,
|
|
|
shuffle=False,
|
|
|
num_workers=1,
|
|
|
pin_memory=True)
|
|
|
logging.info("{} test iterations per epoch".format(len(testloader)))
|
|
|
model.eval()
|
|
|
metric_list = 0.0
|
|
|
|
|
|
pred_save_path = os.path.join(args.exp, "test_predictions")
|
|
|
os.makedirs(pred_save_path, exist_ok=True)
|
|
|
model.eval()
|
|
|
with torch.no_grad():
|
|
|
for i_batch, sampled_batch in tqdm(enumerate(testloader), total=len(testloader)):
|
|
|
# 仅取图像数据,忽略占位符label
|
|
|
image_batch = sampled_batch['image']
|
|
|
case_name = sampled_batch['case_name'][0].replace('.npy.h5', '') # 获取case名称
|
|
|
|
|
|
# 关键修改1:获取模型所在设备(自动适配GPU/CPU)
|
|
|
device = next(model.parameters()).device
|
|
|
# 关键修改2:将数据移到模型所在设备(与模型设备一致)
|
|
|
image_batch = image_batch.to(device, dtype=torch.float32)
|
|
|
|
|
|
# 模型预测
|
|
|
output = model(image_batch)
|
|
|
# 取最终输出(多尺度输出时取最后一层)
|
|
|
if isinstance(output, (list, tuple)):
|
|
|
output = output[-1]
|
|
|
|
|
|
# 处理预测结果(转为语义分割掩码)
|
|
|
pred = torch.argmax(output, dim=1).cpu().numpy() # 按类别维度取最大值
|
|
|
pred = np.squeeze(pred) # 去除批次维度 (1, H, W) -> (H, W)
|
|
|
|
|
|
# 保存预测结果
|
|
|
save_file = os.path.join(pred_save_path, f"{case_name}_pred.npy")
|
|
|
np.save(save_file, pred)
|
|
|
logging.info(f"已保存预测结果:{save_file}")
|
|
|
logging.info("=" * 50)
|
|
|
logging.info("推理完成!预测结果已保存至:{}".format(pred_save_path))
|
|
|
logging.info("注:测试集无label字段,未计算评估指标")
|
|
|
logging.info("=" * 50)
|
|
|
|
|
|
return best_performance # 保持返回值兼容训练流程
|
|
|
|
|
|
def trainer_synapse(args, model, snapshot_path):
|
|
|
logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,
|
|
|
format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
|
|
|
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
|
|
|
logging.info(str(args))
|
|
|
base_lr = args.base_lr
|
|
|
num_classes = args.num_classes
|
|
|
batch_size = args.batch_size * args.n_gpu
|
|
|
|
|
|
db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train", nclass=args.num_classes,
|
|
|
transform=transforms.Compose(
|
|
|
[RandomGenerator(output_size=[args.img_size, args.img_size])]))
|
|
|
|
|
|
print("The length of train set is: {}".format(len(db_train)))
|
|
|
|
|
|
'''
|
|
|
def worker_init_fn(worker_id):
|
|
|
random.seed(args.seed + worker_id)
|
|
|
|
|
|
trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,
|
|
|
worker_init_fn=worker_init_fn)
|
|
|
'''
|
|
|
worker_init_fn_partial = partial(worker_init_fn, seed=args.seed)
|
|
|
|
|
|
# DataLoader 中使用绑定后的函数
|
|
|
trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,
|
|
|
worker_init_fn=worker_init_fn_partial) # 这里修改
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
if torch.cuda.device_count() > 1 and args.n_gpu > 1:
|
|
|
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
|
|
model = nn.DataParallel(model)
|
|
|
model.to(device)
|
|
|
|
|
|
model.train()
|
|
|
ce_loss = CrossEntropyLoss()
|
|
|
dice_loss = DiceLoss(num_classes)
|
|
|
|
|
|
#optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.0001)
|
|
|
writer = SummaryWriter(snapshot_path + '/log')
|
|
|
iter_num = 0
|
|
|
max_epoch = args.max_epochs
|
|
|
max_iterations = args.max_epochs * len(trainloader)
|
|
|
logging.info("{} iterations per epoch. {} max iterations ".format(len(trainloader), max_iterations))
|
|
|
best_performance = 0.0
|
|
|
iterator = tqdm(range(max_epoch), ncols=70)
|
|
|
|
|
|
for epoch_num in iterator:
|
|
|
|
|
|
for i_batch, sampled_batch in enumerate(trainloader):
|
|
|
image_batch, label_batch = sampled_batch['image'], sampled_batch['label']
|
|
|
image_batch, label_batch = image_batch.cuda(), label_batch.squeeze(1).cuda()
|
|
|
|
|
|
P = model(image_batch, mode='train')
|
|
|
|
|
|
if not isinstance(P, list):
|
|
|
P = [P]
|
|
|
if epoch_num == 0 and i_batch == 0:
|
|
|
n_outs = len(P)
|
|
|
out_idxs = list(np.arange(n_outs)) #[0, 1, 2, 3]#, 4, 5, 6, 7]
|
|
|
if args.supervision == 'mutation':
|
|
|
ss = [x for x in powerset(out_idxs)]
|
|
|
elif args.supervision == 'deep_supervision':
|
|
|
ss = [[x] for x in out_idxs]
|
|
|
else:
|
|
|
ss = [[-1]]
|
|
|
print(ss)
|
|
|
|
|
|
loss = 0.0
|
|
|
w_ce, w_dice = 0.3, 0.7
|
|
|
|
|
|
for s in ss:
|
|
|
iout = 0.0
|
|
|
if(s==[]):
|
|
|
continue
|
|
|
for idx in range(len(s)):
|
|
|
iout += P[s[idx]]
|
|
|
loss_ce = ce_loss(iout, label_batch[:].long())
|
|
|
loss_dice = dice_loss(iout, label_batch, softmax=True)
|
|
|
loss += (w_ce * loss_ce + w_dice * loss_dice)
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
#lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 # we did not use this
|
|
|
lr_ = base_lr
|
|
|
for param_group in optimizer.param_groups:
|
|
|
param_group['lr'] = lr_
|
|
|
|
|
|
iter_num = iter_num + 1
|
|
|
writer.add_scalar('info/lr', lr_, iter_num)
|
|
|
writer.add_scalar('info/total_loss', loss, iter_num)
|
|
|
|
|
|
|
|
|
if iter_num % 50 == 0:
|
|
|
logging.info('iteration %d, epoch %d : loss : %f, lr: %f' % (iter_num, epoch_num, loss.item(), lr_))
|
|
|
|
|
|
logging.info('iteration %d, epoch %d : loss : %f, lr: %f' % (iter_num, epoch_num, loss.item(), lr_))
|
|
|
|
|
|
save_mode_path = os.path.join(snapshot_path, 'last.pth')
|
|
|
torch.save(model.state_dict(), save_mode_path)
|
|
|
|
|
|
performance = inference(args, model, best_performance)
|
|
|
|
|
|
save_interval = 50
|
|
|
|
|
|
if(best_performance <= performance):
|
|
|
best_performance = performance
|
|
|
save_mode_path = os.path.join(snapshot_path, 'best.pth')
|
|
|
torch.save(model.state_dict(), save_mode_path)
|
|
|
logging.info("save model to {}".format(save_mode_path))
|
|
|
|
|
|
if (epoch_num + 1) % save_interval == 0:
|
|
|
save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')
|
|
|
torch.save(model.state_dict(), save_mode_path)
|
|
|
logging.info("save model to {}".format(save_mode_path))
|
|
|
|
|
|
if epoch_num >= max_epoch - 1:
|
|
|
save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')
|
|
|
torch.save(model.state_dict(), save_mode_path)
|
|
|
logging.info("save model to {}".format(save_mode_path))
|
|
|
iterator.close()
|
|
|
break
|
|
|
|
|
|
writer.close()
|
|
|
return "Training Finished!"
|