You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

323 lines
12 KiB

import torch
import torch.nn as nn
import numpy as np
from medpy import metric
from scipy.ndimage import zoom
import seaborn as sns
from PIL import Image
import matplotlib.pyplot as plt
from segmentation_mask_overlay import overlay_masks
import matplotlib.colors as mcolors
import SimpleITK as sitk
import pandas as pd
from thop import profile
from thop import clever_format
from ptflops import get_model_complexity_info
def powerset(seq):
"""
Returns all the subsets of this set. This is a generator.
"""
if len(seq) <= 1:
yield seq
yield []
else:
for item in powerset(seq[1:]):
yield [seq[0]]+item
yield item
def clip_gradient(optimizer, grad_clip):
"""
For calibrating misalignment gradient via cliping gradient technique
:param optimizer:
:param grad_clip:
:return:
"""
for group in optimizer.param_groups:
for param in group['params']:
if param.grad is not None:
param.grad.data.clamp_(-grad_clip, grad_clip)
def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
decay = decay_rate ** (epoch // decay_epoch)
for param_group in optimizer.param_groups:
param_group['lr'] *= decay
class AvgMeter(object):
def __init__(self, num=40):
self.num = num
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.losses = []
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
self.losses.append(val)
def show(self):
return torch.mean(torch.stack(self.losses[np.maximum(len(self.losses)-self.num, 0):]))
def CalParams(model, input_tensor):
"""
Usage:
Calculate Params and FLOPs via [THOP](https://github.com/Lyken17/pytorch-OpCounter)
Necessarity:
from thop import profile
from thop import clever_format
:param model:
:param input_tensor:
:return:
"""
flops, params = profile(model, inputs=(input_tensor,))
flops, params = clever_format([flops, params], "%.3f")
print('[Statistics Information]\nFLOPs: {}\nParams: {}'.format(flops, params))
def one_hot_encoder(input_tensor,dataset,n_classes = None):
tensor_list = []
if dataset == 'MMWHS':
dict = [0,205,420,500,550,600,820,850]
for i in dict:
temp_prob = input_tensor == i
tensor_list.append(temp_prob.unsqueeze(1))
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()
else:
for i in range(n_classes):
temp_prob = input_tensor == i
tensor_list.append(temp_prob.unsqueeze(1))
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()
class DiceLoss(nn.Module):
def __init__(self, n_classes):
super(DiceLoss, self).__init__()
self.n_classes = n_classes
def _one_hot_encoder(self, input_tensor):
tensor_list = []
for i in range(self.n_classes):
temp_prob = input_tensor == i # * torch.ones_like(input_tensor)
tensor_list.append(temp_prob.unsqueeze(1))
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()
def _dice_loss(self, score, target):
target = target.float()
smooth = 1e-5
intersect = torch.sum(score * target)
y_sum = torch.sum(target * target)
z_sum = torch.sum(score * score)
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
loss = 1 - loss
return loss
def forward(self, inputs, target, weight=None, softmax=False):
if softmax:
inputs = torch.softmax(inputs, dim=1)
target = self._one_hot_encoder(target)
if weight is None:
weight = [1] * self.n_classes
assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
class_wise_dice = []
loss = 0.0
for i in range(0, self.n_classes):
dice = self._dice_loss(inputs[:, i], target[:, i])
class_wise_dice.append(1.0 - dice.item())
loss += dice * weight[i]
return loss / self.n_classes
def calculate_metric_percase(pred, gt):
pred[pred > 0] = 1
gt[gt > 0] = 1
if pred.sum() > 0 and gt.sum()>0:
dice = metric.binary.dc(pred, gt)
hd95 = metric.binary.hd95(pred, gt)
jaccard = metric.binary.jc(pred, gt)
asd = metric.binary.assd(pred, gt)
return dice, hd95, jaccard, asd
elif pred.sum() > 0 and gt.sum()==0:
return 1, 0, 1, 0
else:
return 0, 0, 0, 0
def calculate_dice_percase(pred, gt):
pred[pred > 0] = 1
gt[gt > 0] = 1
if pred.sum() > 0 and gt.sum()>0:
dice = metric.binary.dc(pred, gt)
return dice
elif pred.sum() > 0 and gt.sum()==0:
return 1
else:
return 0
def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1, class_names=None):
image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
if class_names==None:
mask_labels = np.arange(1,classes)
else:
mask_labels = class_names
cmaps = mcolors.CSS4_COLORS
my_colors=['red','darkorange','yellow','forestgreen','blue','purple','magenta','cyan','deeppink', 'chocolate', 'olive','deepskyblue','darkviolet']
cmap = {k: cmaps[k] for k in sorted(cmaps.keys()) if k in my_colors[:classes-1]}
if len(image.shape) == 3:
prediction = np.zeros_like(label)
for ind in range(image.shape[0]):
slice = image[ind, :, :]
x, y = slice.shape[0], slice.shape[1]
if x != patch_size[0] or y != patch_size[1]:
slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0
input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()
net.eval()
with torch.no_grad():
P = net(input)
outputs = P[-1]
out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)
out = out.cpu().detach().numpy()
if x != patch_size[0] or y != patch_size[1]:
pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
else:
pred = out
prediction[ind] = pred
# === 修复内存泄露的关键部分 START ===
lbl = label[ind, :, :]
masks = []
for i in range(1, classes):
masks.append(lbl==i)
preds_o = []
for i in range(1, classes):
preds_o.append(pred==i)
# 生成图片对象
fig_gt = overlay_masks(image[ind, :, :], masks, labels=mask_labels, colors=cmap, mask_alpha=0.5)
fig_pred = overlay_masks(image[ind, :, :], preds_o, labels=mask_labels, colors=cmap, mask_alpha=0.5)
# 保存图片
fig_gt.savefig(test_save_path + '/' + case + '_' +str(ind) + '_gt.png', bbox_inches="tight", dpi=300)
fig_pred.savefig(test_save_path + '/' + case + '_' +str(ind) + '_pred.png', bbox_inches="tight", dpi=300)
# !!! 必须手动关闭图片以释放内存 !!!
plt.close(fig_gt)
plt.close(fig_pred)
# === 修复内存泄露的关键部分 END ===
else:
input = torch.from_numpy(image).unsqueeze(
0).unsqueeze(0).float().cuda()
net.eval()
with torch.no_grad():
P = net(input)
outputs = P[-1]
out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)
prediction = out.cpu().detach().numpy()
metric_list = []
for i in range(1, classes):
metric_list.append(calculate_metric_percase(prediction == i, label == i))
if test_save_path is not None:
img_itk = sitk.GetImageFromArray(image.astype(np.float32))
prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32))
lab_itk = sitk.GetImageFromArray(label.astype(np.float32))
img_itk.SetSpacing((1, 1, z_spacing))
prd_itk.SetSpacing((1, 1, z_spacing))
lab_itk.SetSpacing((1, 1, z_spacing))
sitk.WriteImage(prd_itk, test_save_path + '/'+case + "_pred.nii.gz")
sitk.WriteImage(img_itk, test_save_path + '/'+ case + "_img.nii.gz")
sitk.WriteImage(lab_itk, test_save_path + '/'+ case + "_gt.nii.gz")
return metric_list
def val_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):
image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
if len(image.shape) == 3:
prediction = np.zeros_like(label)
for ind in range(image.shape[0]):
slice = image[ind, :, :]
x, y = slice.shape[0], slice.shape[1]
if x != patch_size[0] or y != patch_size[1]:
slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0
input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()
net.eval()
with torch.no_grad():
P = net(input)
outputs = 0.0
outputs = P[-1]
out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)
out = out.cpu().detach().numpy()
if x != patch_size[0] or y != patch_size[1]:
pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
else:
pred = out
prediction[ind] = pred
else:
input = torch.from_numpy(image).unsqueeze(
0).unsqueeze(0).float().cuda()
net.eval()
with torch.no_grad():
P = net(input)
outputs = P[-1]
out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)
prediction = out.cpu().detach().numpy()
metric_list = []
for i in range(1, classes):
metric_list.append(calculate_dice_percase(prediction == i, label == i))
return metric_list
def horizontal_flip(image):
image = image[:, ::-1, :]
return image
def vertical_flip(image):
image = image[::-1, :, :]
return image
def tta_model(model, image):
n_image = image
h_image = horizontal_flip(image)
v_image = vertical_flip(image)
n_mask = model.predict(np.expand_dims(n_image, axis=0))[0]
h_mask = model.predict(np.expand_dims(h_image, axis=0))[0]
v_mask = model.predict(np.expand_dims(v_image, axis=0))[0]
n_mask = n_mask
h_mask = horizontal_flip(h_mask)
v_mask = vertical_flip(v_mask)
mean_mask = (n_mask + h_mask + v_mask) / 3.0
return mean_mask
def cal_params_flops(model, size, logger):
input = torch.randn(1, 3, size, size).cuda()
flops, params = profile(model, inputs=(input,))
print('flops',flops/1e9) ## 打印计算量
print('params',params/1e6) ## 打印参数量
total = sum(p.numel() for p in model.parameters())
print("Total params: %.2fM" % (total/1e6))
logger.info(f'flops: {flops/1e9}, params: {params/1e6}, Total params: : {total/1e6:.4f}')
# Example function to calculate and print GMACs and parameter count for a given model
def print_model_stats(model, input_size=(3, 224, 224)):
# Print model parameter count
total_params = sum(p.numel() for p in model.parameters())
print(f'Model created, param count: {total_params}')
# Calculate GMACs using ptflops
macs, params = get_model_complexity_info(model, input_size, as_strings=True, print_per_layer_stat=True)
# Display GMACs and params
print(f'Model: {macs} GMACs, {params} parameters')