You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

95 lines
2.5 KiB

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
softmax_helper = lambda x: F.softmax(x, 1)
sigmoid_helper = lambda x: F.sigmoid(x)
class InitWeights_He(object):
def __init__(self, neg_slope=1e-2):
self.neg_slope = neg_slope
def __call__(self, module):
if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d):
module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope)
if module.bias is not None:
module.bias = nn.init.constant_(module.bias, 0)
def maybe_to_torch(d):
if isinstance(d, list):
d = [maybe_to_torch(i) if not isinstance(i, torch.Tensor) else i for i in d]
elif not isinstance(d, torch.Tensor):
d = torch.from_numpy(d).float()
return d
def to_cuda(data, non_blocking=True, gpu_id=0):
if isinstance(data, list):
data = [i.cuda(gpu_id, non_blocking=non_blocking) for i in data]
else:
data = data.cuda(gpu_id, non_blocking=non_blocking)
return data
class no_op(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
def staple(a):
# a: n,c,h,w detach tensor
mvres = mv(a)
gap = 0.4
if gap > 0.02:
for i, s in enumerate(a):
r = s * mvres
res = r if i == 0 else torch.cat((res,r),0)
nres = mv(res)
gap = torch.mean(torch.abs(mvres - nres))
mvres = nres
a = res
return mvres
def allone(disc,cup):
disc = np.array(disc) / 255
cup = np.array(cup) / 255
res = np.clip(disc * 0.5 + cup,0,1) * 255
res = 255 - res
res = Image.fromarray(np.uint8(res))
return res
def dice_score(pred, targs):
pred = (pred>0).float()
return 2. * (pred*targs).sum() / (pred+targs).sum()
def mv(a):
# res = Image.fromarray(np.uint8(img_list[0] / 2 + img_list[1] / 2 ))
# res.show()
b = a.size(0)
return torch.sum(a, 0, keepdim=True) / b
def tensor_to_img_array(tensor):
image = tensor.cpu().detach().numpy()
image = np.transpose(image, [0, 2, 3, 1])
return image
def export(tar, img_path=None):
# image_name = image_name or "image.jpg"
c = tar.size(1)
if c == 3:
vutils.save_image(tar, fp = img_path)
else:
s = th.tensor(tar)[:,-1,:,:].unsqueeze(1)
s = th.cat((s,s,s),1)
vutils.save_image(s, fp = img_path)
def norm(t):
m, s, v = torch.mean(t), torch.std(t), torch.var(t)
return (t - m) / s