from functools import partial # 定义全局的 worker_init_fn,接受 seed 和 worker_id def worker_init_fn(worker_id, seed): random.seed(seed + worker_id) # 确保每个worker的种子不同 import argparse import logging import os import random import sys import time import numpy as np from tqdm import tqdm import torch import torch.nn as nn import torch.optim as optim from tensorboardX import SummaryWriter from torch.nn.modules.loss import CrossEntropyLoss from torch.utils.data import DataLoader from torchvision import transforms from torch.cuda.amp import GradScaler, autocast from utils.dataset_synapse import Synapse_dataset, RandomGenerator from utils.utils import powerset, one_hot_encoder, DiceLoss, val_single_volume def inference(args, model, best_performance): db_test = Synapse_dataset(base_dir=args.volume_path, split="test_vol", list_dir=args.list_dir, nclass=args.num_classes) testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) logging.info("{} test iterations per epoch".format(len(testloader))) model.eval() metric_list = 0.0 pred_save_path = os.path.join(args.exp, "test_predictions") os.makedirs(pred_save_path, exist_ok=True) model.eval() with torch.no_grad(): for i_batch, sampled_batch in tqdm(enumerate(testloader), total=len(testloader)): # 仅取图像数据,忽略占位符label image_batch = sampled_batch['image'] case_name = sampled_batch['case_name'][0].replace('.npy.h5', '') # 获取case名称 # 关键修改1:获取模型所在设备(自动适配GPU/CPU) device = next(model.parameters()).device # 关键修改2:将数据移到模型所在设备(与模型设备一致) image_batch = image_batch.to(device, dtype=torch.float32) # 模型预测 output = model(image_batch) # 取最终输出(多尺度输出时取最后一层) if isinstance(output, (list, tuple)): output = output[-1] # 处理预测结果(转为语义分割掩码) pred = torch.argmax(output, dim=1).cpu().numpy() # 按类别维度取最大值 pred = np.squeeze(pred) # 去除批次维度 (1, H, W) -> (H, W) # 保存预测结果 save_file = os.path.join(pred_save_path, f"{case_name}_pred.npy") np.save(save_file, pred) logging.info(f"已保存预测结果:{save_file}") logging.info("=" * 50) logging.info("推理完成!预测结果已保存至:{}".format(pred_save_path)) logging.info("注:测试集无label字段,未计算评估指标") logging.info("=" * 50) return best_performance # 保持返回值兼容训练流程 def trainer_synapse(args, model, snapshot_path): logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging.info(str(args)) base_lr = args.base_lr num_classes = args.num_classes batch_size = args.batch_size * args.n_gpu db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train", nclass=args.num_classes, transform=transforms.Compose( [RandomGenerator(output_size=[args.img_size, args.img_size])])) print("The length of train set is: {}".format(len(db_train))) ''' def worker_init_fn(worker_id): random.seed(args.seed + worker_id) trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, worker_init_fn=worker_init_fn) ''' worker_init_fn_partial = partial(worker_init_fn, seed=args.seed) # DataLoader 中使用绑定后的函数 trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, worker_init_fn=worker_init_fn_partial) # 这里修改 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.device_count() > 1 and args.n_gpu > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") model = nn.DataParallel(model) model.to(device) model.train() ce_loss = CrossEntropyLoss() dice_loss = DiceLoss(num_classes) #optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.0001) writer = SummaryWriter(snapshot_path + '/log') iter_num = 0 max_epoch = args.max_epochs max_iterations = args.max_epochs * len(trainloader) logging.info("{} iterations per epoch. {} max iterations ".format(len(trainloader), max_iterations)) best_performance = 0.0 iterator = tqdm(range(max_epoch), ncols=70) for epoch_num in iterator: for i_batch, sampled_batch in enumerate(trainloader): image_batch, label_batch = sampled_batch['image'], sampled_batch['label'] image_batch, label_batch = image_batch.cuda(), label_batch.squeeze(1).cuda() P = model(image_batch, mode='train') if not isinstance(P, list): P = [P] if epoch_num == 0 and i_batch == 0: n_outs = len(P) out_idxs = list(np.arange(n_outs)) #[0, 1, 2, 3]#, 4, 5, 6, 7] if args.supervision == 'mutation': ss = [x for x in powerset(out_idxs)] elif args.supervision == 'deep_supervision': ss = [[x] for x in out_idxs] else: ss = [[-1]] print(ss) loss = 0.0 w_ce, w_dice = 0.3, 0.7 for s in ss: iout = 0.0 if(s==[]): continue for idx in range(len(s)): iout += P[s[idx]] loss_ce = ce_loss(iout, label_batch[:].long()) loss_dice = dice_loss(iout, label_batch, softmax=True) loss += (w_ce * loss_ce + w_dice * loss_dice) optimizer.zero_grad() loss.backward() optimizer.step() #lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 # we did not use this lr_ = base_lr for param_group in optimizer.param_groups: param_group['lr'] = lr_ iter_num = iter_num + 1 writer.add_scalar('info/lr', lr_, iter_num) writer.add_scalar('info/total_loss', loss, iter_num) if iter_num % 50 == 0: logging.info('iteration %d, epoch %d : loss : %f, lr: %f' % (iter_num, epoch_num, loss.item(), lr_)) logging.info('iteration %d, epoch %d : loss : %f, lr: %f' % (iter_num, epoch_num, loss.item(), lr_)) save_mode_path = os.path.join(snapshot_path, 'last.pth') torch.save(model.state_dict(), save_mode_path) performance = inference(args, model, best_performance) save_interval = 50 if(best_performance <= performance): best_performance = performance save_mode_path = os.path.join(snapshot_path, 'best.pth') torch.save(model.state_dict(), save_mode_path) logging.info("save model to {}".format(save_mode_path)) if (epoch_num + 1) % save_interval == 0: save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth') torch.save(model.state_dict(), save_mode_path) logging.info("save model to {}".format(save_mode_path)) if epoch_num >= max_epoch - 1: save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth') torch.save(model.state_dict(), save_mode_path) logging.info("save model to {}".format(save_mode_path)) iterator.close() break writer.close() return "Training Finished!"