You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

203 lines
8.5 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from functools import partial
# 定义全局的 worker_init_fn接受 seed 和 worker_id
def worker_init_fn(worker_id, seed):
random.seed(seed + worker_id) # 确保每个worker的种子不同
import argparse
import logging
import os
import random
import sys
import time
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from torch.nn.modules.loss import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.cuda.amp import GradScaler, autocast
from utils.dataset_synapse import Synapse_dataset, RandomGenerator
from utils.utils import powerset, one_hot_encoder, DiceLoss, val_single_volume
def inference(args, model, best_performance):
db_test = Synapse_dataset(base_dir=args.volume_path,
split="test_vol",
list_dir=args.list_dir,
nclass=args.num_classes)
testloader = DataLoader(db_test,
batch_size=1,
shuffle=False,
num_workers=1,
pin_memory=True)
logging.info("{} test iterations per epoch".format(len(testloader)))
model.eval()
metric_list = 0.0
pred_save_path = os.path.join(args.exp, "test_predictions")
os.makedirs(pred_save_path, exist_ok=True)
model.eval()
with torch.no_grad():
for i_batch, sampled_batch in tqdm(enumerate(testloader), total=len(testloader)):
# 仅取图像数据忽略占位符label
image_batch = sampled_batch['image']
case_name = sampled_batch['case_name'][0].replace('.npy.h5', '') # 获取case名称
# 关键修改1获取模型所在设备自动适配GPU/CPU
device = next(model.parameters()).device
# 关键修改2将数据移到模型所在设备与模型设备一致
image_batch = image_batch.to(device, dtype=torch.float32)
# 模型预测
output = model(image_batch)
# 取最终输出(多尺度输出时取最后一层)
if isinstance(output, (list, tuple)):
output = output[-1]
# 处理预测结果(转为语义分割掩码)
pred = torch.argmax(output, dim=1).cpu().numpy() # 按类别维度取最大值
pred = np.squeeze(pred) # 去除批次维度 (1, H, W) -> (H, W)
# 保存预测结果
save_file = os.path.join(pred_save_path, f"{case_name}_pred.npy")
np.save(save_file, pred)
logging.info(f"已保存预测结果:{save_file}")
logging.info("=" * 50)
logging.info("推理完成!预测结果已保存至:{}".format(pred_save_path))
logging.info("测试集无label字段未计算评估指标")
logging.info("=" * 50)
return best_performance # 保持返回值兼容训练流程
def trainer_synapse(args, model, snapshot_path):
logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,
format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info(str(args))
base_lr = args.base_lr
num_classes = args.num_classes
batch_size = args.batch_size * args.n_gpu
db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train", nclass=args.num_classes,
transform=transforms.Compose(
[RandomGenerator(output_size=[args.img_size, args.img_size])]))
print("The length of train set is: {}".format(len(db_train)))
'''
def worker_init_fn(worker_id):
random.seed(args.seed + worker_id)
trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,
worker_init_fn=worker_init_fn)
'''
worker_init_fn_partial = partial(worker_init_fn, seed=args.seed)
# DataLoader 中使用绑定后的函数
trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,
worker_init_fn=worker_init_fn_partial) # 这里修改
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1 and args.n_gpu > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
model.to(device)
model.train()
ce_loss = CrossEntropyLoss()
dice_loss = DiceLoss(num_classes)
#optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.0001)
writer = SummaryWriter(snapshot_path + '/log')
iter_num = 0
max_epoch = args.max_epochs
max_iterations = args.max_epochs * len(trainloader)
logging.info("{} iterations per epoch. {} max iterations ".format(len(trainloader), max_iterations))
best_performance = 0.0
iterator = tqdm(range(max_epoch), ncols=70)
for epoch_num in iterator:
for i_batch, sampled_batch in enumerate(trainloader):
image_batch, label_batch = sampled_batch['image'], sampled_batch['label']
image_batch, label_batch = image_batch.cuda(), label_batch.squeeze(1).cuda()
P = model(image_batch, mode='train')
if not isinstance(P, list):
P = [P]
if epoch_num == 0 and i_batch == 0:
n_outs = len(P)
out_idxs = list(np.arange(n_outs)) #[0, 1, 2, 3]#, 4, 5, 6, 7]
if args.supervision == 'mutation':
ss = [x for x in powerset(out_idxs)]
elif args.supervision == 'deep_supervision':
ss = [[x] for x in out_idxs]
else:
ss = [[-1]]
print(ss)
loss = 0.0
w_ce, w_dice = 0.3, 0.7
for s in ss:
iout = 0.0
if(s==[]):
continue
for idx in range(len(s)):
iout += P[s[idx]]
loss_ce = ce_loss(iout, label_batch[:].long())
loss_dice = dice_loss(iout, label_batch, softmax=True)
loss += (w_ce * loss_ce + w_dice * loss_dice)
optimizer.zero_grad()
loss.backward()
optimizer.step()
#lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 # we did not use this
lr_ = base_lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr_
iter_num = iter_num + 1
writer.add_scalar('info/lr', lr_, iter_num)
writer.add_scalar('info/total_loss', loss, iter_num)
if iter_num % 50 == 0:
logging.info('iteration %d, epoch %d : loss : %f, lr: %f' % (iter_num, epoch_num, loss.item(), lr_))
logging.info('iteration %d, epoch %d : loss : %f, lr: %f' % (iter_num, epoch_num, loss.item(), lr_))
save_mode_path = os.path.join(snapshot_path, 'last.pth')
torch.save(model.state_dict(), save_mode_path)
performance = inference(args, model, best_performance)
save_interval = 50
if(best_performance <= performance):
best_performance = performance
save_mode_path = os.path.join(snapshot_path, 'best.pth')
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to {}".format(save_mode_path))
if (epoch_num + 1) % save_interval == 0:
save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to {}".format(save_mode_path))
if epoch_num >= max_epoch - 1:
save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to {}".format(save_mode_path))
iterator.close()
break
writer.close()
return "Training Finished!"