You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1155 lines
38 KiB
1155 lines
38 KiB
|
|
|
|
import sys
|
|
|
|
import numpy
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.autograd import Function
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
import torch.optim as optim
|
|
import torchvision.utils as vutils
|
|
from torch.utils.data import DataLoader
|
|
from torch.autograd import Variable
|
|
from torch import autograd
|
|
import random
|
|
import math
|
|
import PIL
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
|
|
import collections
|
|
import logging
|
|
import math
|
|
import os
|
|
import time
|
|
from datetime import datetime
|
|
|
|
import dateutil.tz
|
|
|
|
from typing import Union, Optional, List, Tuple, Text, BinaryIO
|
|
import pathlib
|
|
import warnings
|
|
import numpy as np
|
|
from PIL import Image, ImageDraw, ImageFont, ImageColor
|
|
from torchvision.models import vgg19
|
|
import torch.nn.functional as F
|
|
import cfg
|
|
|
|
import warnings
|
|
from collections import OrderedDict
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from PIL import Image
|
|
import torch
|
|
|
|
# from precpt import run_precpt
|
|
from models.discriminator import Discriminator
|
|
# from siren_pytorch import SirenNet, SirenWrapper
|
|
|
|
import shutil
|
|
import tempfile
|
|
|
|
import matplotlib.pyplot as plt
|
|
from tqdm import tqdm
|
|
|
|
from monai.losses import DiceCELoss
|
|
from monai.inferers import sliding_window_inference
|
|
from monai.transforms import (
|
|
AsDiscrete,
|
|
Compose,
|
|
CropForegroundd,
|
|
LoadImaged,
|
|
Orientationd,
|
|
RandFlipd,
|
|
RandCropByPosNegLabeld,
|
|
RandShiftIntensityd,
|
|
ScaleIntensityRanged,
|
|
Spacingd,
|
|
RandRotate90d,
|
|
EnsureTyped,
|
|
)
|
|
|
|
from monai.config import print_config
|
|
from monai.metrics import DiceMetric
|
|
from monai.networks.nets import SwinUNETR
|
|
|
|
from monai.data import (
|
|
ThreadDataLoader,
|
|
CacheDataset,
|
|
load_decathlon_datalist,
|
|
decollate_batch,
|
|
set_track_meta,
|
|
)
|
|
|
|
|
|
|
|
|
|
args = cfg.parse_args()
|
|
device = torch.device('cuda', args.gpu_device)
|
|
|
|
|
|
def get_network(args, net, use_gpu=True, gpu_device = 0, distribution = True):
|
|
""" return given network
|
|
"""
|
|
|
|
if net == 'oneprompt':
|
|
from models.oneprompt import OnePredictor, one_model_registry
|
|
from models.oneprompt.utils.transforms import ResizeLongestSide
|
|
net = one_model_registry[args.baseline](args).to(device)
|
|
else:
|
|
print('the network name you have entered is not supported yet')
|
|
sys.exit()
|
|
|
|
if use_gpu:
|
|
#net = net.cuda(device = gpu_device)
|
|
if distribution != 'none':
|
|
net = torch.nn.DataParallel(net,device_ids=[int(id) for id in args.distributed.split(',')])
|
|
net = net.to(device=gpu_device)
|
|
else:
|
|
net = net.to(device=gpu_device)
|
|
|
|
return net
|
|
|
|
|
|
def get_decath_loader(args):
|
|
|
|
train_transforms = Compose(
|
|
[
|
|
LoadImaged(keys=["image", "label"], ensure_channel_first=True),
|
|
ScaleIntensityRanged(
|
|
keys=["image"],
|
|
a_min=-175,
|
|
a_max=250,
|
|
b_min=0.0,
|
|
b_max=1.0,
|
|
clip=True,
|
|
),
|
|
CropForegroundd(keys=["image", "label"], source_key="image"),
|
|
Orientationd(keys=["image", "label"], axcodes="RAS"),
|
|
Spacingd(
|
|
keys=["image", "label"],
|
|
pixdim=(1.5, 1.5, 2.0),
|
|
mode=("bilinear", "nearest"),
|
|
),
|
|
EnsureTyped(keys=["image", "label"], device=device, track_meta=False),
|
|
RandCropByPosNegLabeld(
|
|
keys=["image", "label"],
|
|
label_key="label",
|
|
spatial_size=(args.roi_size, args.roi_size, args.chunk),
|
|
pos=1,
|
|
neg=1,
|
|
num_samples=args.num_sample,
|
|
image_key="image",
|
|
image_threshold=0,
|
|
),
|
|
RandFlipd(
|
|
keys=["image", "label"],
|
|
spatial_axis=[0],
|
|
prob=0.10,
|
|
),
|
|
RandFlipd(
|
|
keys=["image", "label"],
|
|
spatial_axis=[1],
|
|
prob=0.10,
|
|
),
|
|
RandFlipd(
|
|
keys=["image", "label"],
|
|
spatial_axis=[2],
|
|
prob=0.10,
|
|
),
|
|
RandRotate90d(
|
|
keys=["image", "label"],
|
|
prob=0.10,
|
|
max_k=3,
|
|
),
|
|
RandShiftIntensityd(
|
|
keys=["image"],
|
|
offsets=0.10,
|
|
prob=0.50,
|
|
),
|
|
]
|
|
)
|
|
val_transforms = Compose(
|
|
[
|
|
LoadImaged(keys=["image", "label"], ensure_channel_first=True),
|
|
ScaleIntensityRanged(
|
|
keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True
|
|
),
|
|
CropForegroundd(keys=["image", "label"], source_key="image"),
|
|
Orientationd(keys=["image", "label"], axcodes="RAS"),
|
|
Spacingd(
|
|
keys=["image", "label"],
|
|
pixdim=(1.5, 1.5, 2.0),
|
|
mode=("bilinear", "nearest"),
|
|
),
|
|
EnsureTyped(keys=["image", "label"], device=device, track_meta=True),
|
|
]
|
|
)
|
|
|
|
|
|
|
|
data_dir = args.data_path
|
|
split_JSON = "dataset_0.json"
|
|
|
|
datasets = os.path.join(data_dir, split_JSON)
|
|
datalist = load_decathlon_datalist(datasets, True, "training")
|
|
val_files = load_decathlon_datalist(datasets, True, "validation")
|
|
train_ds = CacheDataset(
|
|
data=datalist,
|
|
transform=train_transforms,
|
|
cache_num=24,
|
|
cache_rate=1.0,
|
|
num_workers=8,
|
|
)
|
|
train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=args.b, shuffle=True)
|
|
val_ds = CacheDataset(
|
|
data=val_files, transform=val_transforms, cache_num=2, cache_rate=1.0, num_workers=0
|
|
)
|
|
val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)
|
|
|
|
set_track_meta(False)
|
|
|
|
return train_loader, val_loader, train_transforms, val_transforms, datalist, val_files
|
|
|
|
|
|
def cka_loss(gram_featureA, gram_featureB):
|
|
|
|
scaled_hsic = torch.dot(torch.flatten(gram_featureA),torch.flatten(gram_featureB))
|
|
normalization_x = gram_featureA.norm()
|
|
normalization_y = gram_featureB.norm()
|
|
return scaled_hsic / (normalization_x * normalization_y)
|
|
|
|
|
|
class WarmUpLR(_LRScheduler):
|
|
"""warmup_training learning rate scheduler
|
|
Args:
|
|
optimizer: optimzier(e.g. SGD)
|
|
total_iters: totoal_iters of warmup phase
|
|
"""
|
|
def __init__(self, optimizer, total_iters, last_epoch=-1):
|
|
|
|
self.total_iters = total_iters
|
|
super().__init__(optimizer, last_epoch)
|
|
|
|
def get_lr(self):
|
|
"""we will use the first m batches, and set the learning
|
|
rate to base_lr * m / total_iters
|
|
"""
|
|
return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]
|
|
|
|
def gram_matrix(input):
|
|
a, b, c, d = input.size() # a=batch size(=1)
|
|
# b=number of feature maps
|
|
# (c,d)=dimensions of a f. map (N=c*d)
|
|
|
|
features = input.view(a * b, c * d) # resise F_XL into \hat F_XL
|
|
|
|
G = torch.mm(features, features.t()) # compute the gram product
|
|
|
|
# we 'normalize' the values of the gram matrix
|
|
# by dividing by the number of element in each feature maps.
|
|
return G.div(a * b * c * d)
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
def make_grid(
|
|
tensor: Union[torch.Tensor, List[torch.Tensor]],
|
|
nrow: int = 8,
|
|
padding: int = 2,
|
|
normalize: bool = False,
|
|
value_range: Optional[Tuple[int, int]] = None,
|
|
scale_each: bool = False,
|
|
pad_value: int = 0,
|
|
**kwargs
|
|
) -> torch.Tensor:
|
|
if not (torch.is_tensor(tensor) or
|
|
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
|
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
|
|
|
|
if "range" in kwargs.keys():
|
|
warning = "range will be deprecated, please use value_range instead."
|
|
warnings.warn(warning)
|
|
value_range = kwargs["range"]
|
|
|
|
# if list of tensors, convert to a 4D mini-batch Tensor
|
|
if isinstance(tensor, list):
|
|
tensor = torch.stack(tensor, dim=0)
|
|
|
|
if tensor.dim() == 2: # single image H x W
|
|
tensor = tensor.unsqueeze(0)
|
|
if tensor.dim() == 3: # single image
|
|
if tensor.size(0) == 1: # if single-channel, convert to 3-channel
|
|
tensor = torch.cat((tensor, tensor, tensor), 0)
|
|
tensor = tensor.unsqueeze(0)
|
|
|
|
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
|
|
tensor = torch.cat((tensor, tensor, tensor), 1)
|
|
|
|
if normalize is True:
|
|
tensor = tensor.clone() # avoid modifying tensor in-place
|
|
if value_range is not None:
|
|
assert isinstance(value_range, tuple), \
|
|
"value_range has to be a tuple (min, max) if specified. min and max are numbers"
|
|
|
|
def norm_ip(img, low, high):
|
|
img.clamp(min=low, max=high)
|
|
img.sub_(low).div_(max(high - low, 1e-5))
|
|
|
|
def norm_range(t, value_range):
|
|
if value_range is not None:
|
|
norm_ip(t, value_range[0], value_range[1])
|
|
else:
|
|
norm_ip(t, float(t.min()), float(t.max()))
|
|
|
|
if scale_each is True:
|
|
for t in tensor: # loop over mini-batch dimension
|
|
norm_range(t, value_range)
|
|
else:
|
|
norm_range(tensor, value_range)
|
|
|
|
if tensor.size(0) == 1:
|
|
return tensor.squeeze(0)
|
|
|
|
# make the mini-batch of images into a grid
|
|
nmaps = tensor.size(0)
|
|
xmaps = min(nrow, nmaps)
|
|
ymaps = int(math.ceil(float(nmaps) / xmaps))
|
|
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
|
|
num_channels = tensor.size(1)
|
|
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
|
|
k = 0
|
|
for y in range(ymaps):
|
|
for x in range(xmaps):
|
|
if k >= nmaps:
|
|
break
|
|
# Tensor.copy_() is a valid method but seems to be missing from the stubs
|
|
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
|
|
grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined]
|
|
2, x * width + padding, width - padding
|
|
).copy_(tensor[k])
|
|
k = k + 1
|
|
return grid
|
|
|
|
|
|
@torch.no_grad()
|
|
def save_image(
|
|
tensor: Union[torch.Tensor, List[torch.Tensor]],
|
|
fp: Union[Text, pathlib.Path, BinaryIO],
|
|
format: Optional[str] = None,
|
|
**kwargs
|
|
) -> None:
|
|
"""
|
|
Save a given Tensor into an image file.
|
|
Args:
|
|
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
|
|
saves the tensor as a grid of images by calling ``make_grid``.
|
|
fp (string or file object): A filename or a file object
|
|
format(Optional): If omitted, the format to use is determined from the filename extension.
|
|
If a file object was used instead of a filename, this parameter should always be used.
|
|
**kwargs: Other arguments are documented in ``make_grid``.
|
|
"""
|
|
|
|
grid = make_grid(tensor, **kwargs)
|
|
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
|
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
|
im = Image.fromarray(ndarr)
|
|
im.save(fp, format=format)
|
|
|
|
|
|
def create_logger(log_dir, phase='train'):
|
|
time_str = time.strftime('%Y-%m-%d-%H-%M')
|
|
log_file = '{}_{}.log'.format(time_str, phase)
|
|
final_log_file = os.path.join(log_dir, log_file)
|
|
head = '%(asctime)-15s %(message)s'
|
|
logging.basicConfig(filename=str(final_log_file),
|
|
format=head)
|
|
logger = logging.getLogger()
|
|
logger.setLevel(logging.INFO)
|
|
console = logging.StreamHandler()
|
|
logging.getLogger('').addHandler(console)
|
|
|
|
return logger
|
|
|
|
|
|
def set_log_dir(root_dir, exp_name):
|
|
path_dict = {}
|
|
os.makedirs(root_dir, exist_ok=True)
|
|
|
|
# set log path
|
|
exp_path = os.path.join(root_dir, exp_name)
|
|
now = datetime.now(dateutil.tz.tzlocal())
|
|
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
|
|
prefix = exp_path + '_' + timestamp
|
|
os.makedirs(prefix)
|
|
path_dict['prefix'] = prefix
|
|
|
|
# set checkpoint path
|
|
ckpt_path = os.path.join(prefix, 'Model')
|
|
os.makedirs(ckpt_path)
|
|
path_dict['ckpt_path'] = ckpt_path
|
|
|
|
log_path = os.path.join(prefix, 'Log')
|
|
os.makedirs(log_path)
|
|
path_dict['log_path'] = log_path
|
|
|
|
# set sample image path for fid calculation
|
|
sample_path = os.path.join(prefix, 'Samples')
|
|
os.makedirs(sample_path)
|
|
path_dict['sample_path'] = sample_path
|
|
|
|
return path_dict
|
|
|
|
|
|
def save_checkpoint(states, is_best, output_dir,
|
|
filename='checkpoint.pth'):
|
|
torch.save(states, os.path.join(output_dir, filename))
|
|
if is_best:
|
|
torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth'))
|
|
|
|
|
|
class RunningStats:
|
|
def __init__(self, WIN_SIZE):
|
|
self.mean = 0
|
|
self.run_var = 0
|
|
self.WIN_SIZE = WIN_SIZE
|
|
|
|
self.window = collections.deque(maxlen=WIN_SIZE)
|
|
|
|
def clear(self):
|
|
self.window.clear()
|
|
self.mean = 0
|
|
self.run_var = 0
|
|
|
|
def is_full(self):
|
|
return len(self.window) == self.WIN_SIZE
|
|
|
|
def push(self, x):
|
|
|
|
if len(self.window) == self.WIN_SIZE:
|
|
# Adjusting variance
|
|
x_removed = self.window.popleft()
|
|
self.window.append(x)
|
|
old_m = self.mean
|
|
self.mean += (x - x_removed) / self.WIN_SIZE
|
|
self.run_var += (x + x_removed - old_m - self.mean) * (x - x_removed)
|
|
else:
|
|
# Calculating first variance
|
|
self.window.append(x)
|
|
delta = x - self.mean
|
|
self.mean += delta / len(self.window)
|
|
self.run_var += delta * (x - self.mean)
|
|
|
|
def get_mean(self):
|
|
return self.mean if len(self.window) else 0.0
|
|
|
|
def get_var(self):
|
|
return self.run_var / len(self.window) if len(self.window) > 1 else 0.0
|
|
|
|
def get_std(self):
|
|
return math.sqrt(self.get_var())
|
|
|
|
def get_all(self):
|
|
return list(self.window)
|
|
|
|
def __str__(self):
|
|
return "Current window values: {}".format(list(self.window))
|
|
|
|
def iou(outputs: np.array, labels: np.array):
|
|
|
|
SMOOTH = 1e-6
|
|
intersection = (outputs & labels).sum((1, 2))
|
|
union = (outputs | labels).sum((1, 2))
|
|
|
|
iou = (intersection + SMOOTH) / (union + SMOOTH)
|
|
|
|
|
|
return iou.mean()
|
|
|
|
class DiceCoeff(Function):
|
|
"""Dice coeff for individual examples"""
|
|
|
|
def forward(self, input, target):
|
|
self.save_for_backward(input, target)
|
|
eps = 0.0001
|
|
self.inter = torch.dot(input.view(-1), target.view(-1))
|
|
self.union = torch.sum(input) + torch.sum(target) + eps
|
|
|
|
t = (2 * self.inter.float() + eps) / self.union.float()
|
|
return t
|
|
|
|
# This function has only a single output, so it gets only one gradient
|
|
def backward(self, grad_output):
|
|
|
|
input, target = self.saved_variables
|
|
grad_input = grad_target = None
|
|
|
|
if self.needs_input_grad[0]:
|
|
grad_input = grad_output * 2 * (target * self.union - self.inter) \
|
|
/ (self.union * self.union)
|
|
if self.needs_input_grad[1]:
|
|
grad_target = None
|
|
|
|
return grad_input, grad_target
|
|
|
|
|
|
def dice_coeff(input, target):
|
|
"""Dice coeff for batches"""
|
|
if input.is_cuda:
|
|
s = torch.FloatTensor(1).to(device = input.device).zero_()
|
|
else:
|
|
s = torch.FloatTensor(1).zero_()
|
|
|
|
for i, c in enumerate(zip(input, target)):
|
|
s = s + DiceCoeff().forward(c[0], c[1])
|
|
|
|
return s / (i + 1)
|
|
|
|
'''parameter'''
|
|
def para_image(w, h=None, img = None, mode = 'multi', seg = None, sd=None, batch=None,
|
|
fft = False, channels=None, init = None):
|
|
h = h or w
|
|
batch = batch or 1
|
|
ch = channels or 3
|
|
shape = [batch, ch, h, w]
|
|
param_f = fft_image if fft else pixel_image
|
|
if init is not None:
|
|
param_f = init_image
|
|
params, maps_f = param_f(init)
|
|
else:
|
|
params, maps_f = param_f(shape, sd=sd)
|
|
if mode == 'multi':
|
|
output = to_valid_out(maps_f,img,seg)
|
|
elif mode == 'seg':
|
|
output = gene_out(maps_f,img)
|
|
elif mode == 'raw':
|
|
output = raw_out(maps_f,img)
|
|
return params, output
|
|
|
|
def to_valid_out(maps_f,img,seg): #multi-rater
|
|
def inner():
|
|
maps = maps_f()
|
|
maps = maps.to(device = img.device)
|
|
maps = torch.nn.Softmax(dim = 1)(maps)
|
|
final_seg = torch.multiply(seg,maps).sum(dim = 1, keepdim = True)
|
|
return torch.cat((img,final_seg),1)
|
|
# return torch.cat((img,maps),1)
|
|
return inner
|
|
|
|
def gene_out(maps_f,img): #pure seg
|
|
def inner():
|
|
maps = maps_f()
|
|
maps = maps.to(device = img.device)
|
|
# maps = torch.nn.Sigmoid()(maps)
|
|
return torch.cat((img,maps),1)
|
|
# return torch.cat((img,maps),1)
|
|
return inner
|
|
|
|
def raw_out(maps_f,img): #raw
|
|
def inner():
|
|
maps = maps_f()
|
|
maps = maps.to(device = img.device)
|
|
# maps = torch.nn.Sigmoid()(maps)
|
|
return maps
|
|
# return torch.cat((img,maps),1)
|
|
return inner
|
|
|
|
|
|
class CompositeActivation(torch.nn.Module):
|
|
|
|
def forward(self, x):
|
|
x = torch.atan(x)
|
|
return torch.cat([x/0.67, (x*x)/0.6], 1)
|
|
# return x
|
|
|
|
|
|
def cppn(args, size, img = None, seg = None, batch=None, num_output_channels=1, num_hidden_channels=128, num_layers=8,
|
|
activation_fn=CompositeActivation, normalize=False, device = "cuda:0"):
|
|
|
|
r = 3 ** 0.5
|
|
|
|
coord_range = torch.linspace(-r, r, size)
|
|
x = coord_range.view(-1, 1).repeat(1, coord_range.size(0))
|
|
y = coord_range.view(1, -1).repeat(coord_range.size(0), 1)
|
|
|
|
input_tensor = torch.stack([x, y], dim=0).unsqueeze(0).repeat(batch,1,1,1).to(device)
|
|
|
|
layers = []
|
|
kernel_size = 1
|
|
for i in range(num_layers):
|
|
out_c = num_hidden_channels
|
|
in_c = out_c * 2 # * 2 for composite activation
|
|
if i == 0:
|
|
in_c = 2
|
|
if i == num_layers - 1:
|
|
out_c = num_output_channels
|
|
layers.append(('conv{}'.format(i), torch.nn.Conv2d(in_c, out_c, kernel_size)))
|
|
if normalize:
|
|
layers.append(('norm{}'.format(i), torch.nn.InstanceNorm2d(out_c)))
|
|
if i < num_layers - 1:
|
|
layers.append(('actv{}'.format(i), activation_fn()))
|
|
else:
|
|
layers.append(('output', torch.nn.Sigmoid()))
|
|
|
|
# Initialize model
|
|
net = torch.nn.Sequential(OrderedDict(layers)).to(device)
|
|
# Initialize weights
|
|
def weights_init(module):
|
|
if isinstance(module, torch.nn.Conv2d):
|
|
torch.nn.init.normal_(module.weight, 0, np.sqrt(1/module.in_channels))
|
|
if module.bias is not None:
|
|
torch.nn.init.zeros_(module.bias)
|
|
net.apply(weights_init)
|
|
# Set last conv2d layer's weights to 0
|
|
torch.nn.init.zeros_(dict(net.named_children())['conv{}'.format(num_layers - 1)].weight)
|
|
outimg = raw_out(lambda: net(input_tensor),img) if args.netype == 'raw' else to_valid_out(lambda: net(input_tensor),img,seg)
|
|
return net.parameters(), outimg
|
|
|
|
def get_siren(args):
|
|
wrapper = get_network(args, 'siren', use_gpu=args.gpu, gpu_device=torch.device('cuda', args.gpu_device), distribution = args.distributed)
|
|
'''load init weights'''
|
|
checkpoint = torch.load('./logs/siren_train_init_2022_08_19_21_00_16/Model/checkpoint_best.pth')
|
|
wrapper.load_state_dict(checkpoint['state_dict'],strict=False)
|
|
'''end'''
|
|
|
|
'''load prompt'''
|
|
checkpoint = torch.load('./logs/vae_standard_refuge1_2022_08_21_17_56_49/Model/checkpoint500')
|
|
vae = get_network(args, 'vae', use_gpu=args.gpu, gpu_device=torch.device('cuda', args.gpu_device), distribution = args.distributed)
|
|
vae.load_state_dict(checkpoint['state_dict'],strict=False)
|
|
'''end'''
|
|
|
|
return wrapper, vae
|
|
|
|
|
|
def siren(args, wrapper, vae, img = None, seg = None, batch=None, num_output_channels=1, num_hidden_channels=128, num_layers=8,
|
|
activation_fn=CompositeActivation, normalize=False, device = "cuda:0"):
|
|
vae_img = torchvision.transforms.Resize(64)(img)
|
|
latent = vae.encoder(vae_img).view(-1).detach()
|
|
outimg = raw_out(lambda: wrapper(latent = latent),img) if args.netype == 'raw' else to_valid_out(lambda: wrapper(latent = latent),img,seg)
|
|
# img = torch.randn(1, 3, 256, 256)
|
|
# loss = wrapper(img)
|
|
# loss.backward()
|
|
|
|
# # after much training ...
|
|
# # simply invoke the wrapper without passing in anything
|
|
|
|
# pred_img = wrapper() # (1, 3, 256, 256)
|
|
return wrapper.parameters(), outimg
|
|
|
|
|
|
'''adversary'''
|
|
def render_vis(
|
|
args,
|
|
model,
|
|
objective_f,
|
|
real_img,
|
|
param_f=None,
|
|
optimizer=None,
|
|
transforms=None,
|
|
thresholds=(256,),
|
|
verbose=True,
|
|
preprocess=True,
|
|
progress=True,
|
|
show_image=True,
|
|
save_image=False,
|
|
image_name=None,
|
|
show_inline=False,
|
|
fixed_image_size=None,
|
|
label = 1,
|
|
raw_img = None,
|
|
prompt = None
|
|
):
|
|
if label == 1:
|
|
sign = 1
|
|
elif label == 0:
|
|
sign = -1
|
|
else:
|
|
print('label is wrong, label is',label)
|
|
if args.reverse:
|
|
sign = -sign
|
|
if args.multilayer:
|
|
sign = 1
|
|
|
|
'''prepare'''
|
|
now = datetime.now()
|
|
date_time = now.strftime("%m-%d-%Y, %H:%M:%S")
|
|
|
|
netD, optD = pre_d()
|
|
'''end'''
|
|
|
|
if param_f is None:
|
|
param_f = lambda: param.image(128)
|
|
# param_f is a function that should return two things
|
|
# params - parameters to update, which we pass to the optimizer
|
|
# image_f - a function that returns an image as a tensor
|
|
params, image_f = param_f()
|
|
|
|
if optimizer is None:
|
|
optimizer = lambda params: torch.optim.Adam(params, lr=5e-1)
|
|
optimizer = optimizer(params)
|
|
|
|
if transforms is None:
|
|
transforms = []
|
|
transforms = transforms.copy()
|
|
|
|
# Upsample images smaller than 224
|
|
image_shape = image_f().shape
|
|
|
|
if fixed_image_size is not None:
|
|
new_size = fixed_image_size
|
|
elif image_shape[2] < 224 or image_shape[3] < 224:
|
|
new_size = 224
|
|
else:
|
|
new_size = None
|
|
if new_size:
|
|
transforms.append(
|
|
torch.nn.Upsample(size=new_size, mode="bilinear", align_corners=True)
|
|
)
|
|
|
|
transform_f = transform.compose(transforms)
|
|
|
|
hook = hook_model(model, image_f)
|
|
objective_f = objectives.as_objective(objective_f)
|
|
|
|
if verbose:
|
|
model(transform_f(image_f()))
|
|
print("Initial loss of ad: {:.3f}".format(objective_f(hook)))
|
|
|
|
images = []
|
|
try:
|
|
for i in tqdm(range(1, max(thresholds) + 1), disable=(not progress)):
|
|
optimizer.zero_grad()
|
|
try:
|
|
model(transform_f(image_f()))
|
|
except RuntimeError as ex:
|
|
if i == 1:
|
|
# Only display the warning message
|
|
# on the first iteration, no need to do that
|
|
# every iteration
|
|
warnings.warn(
|
|
"Some layers could not be computed because the size of the "
|
|
"image is not big enough. It is fine, as long as the non"
|
|
"computed layers are not used in the objective function"
|
|
f"(exception details: '{ex}')"
|
|
)
|
|
if args.disc:
|
|
'''dom loss part'''
|
|
# content_img = raw_img
|
|
# style_img = raw_img
|
|
# precpt_loss = run_precpt(cnn, cnn_normalization_mean, cnn_normalization_std, content_img, style_img, transform_f(image_f()))
|
|
for p in netD.parameters():
|
|
p.requires_grad = True
|
|
for _ in range(args.drec):
|
|
netD.zero_grad()
|
|
real = real_img
|
|
fake = image_f()
|
|
# for _ in range(6):
|
|
# errD, D_x, D_G_z1 = update_d(args, netD, optD, real, fake)
|
|
|
|
# label = torch.full((args.b,), 1., dtype=torch.float, device=device)
|
|
# label.fill_(1.)
|
|
# output = netD(fake).view(-1)
|
|
# errG = nn.BCELoss()(output, label)
|
|
# D_G_z2 = output.mean().item()
|
|
# dom_loss = err
|
|
one = torch.tensor(1, dtype=torch.float)
|
|
mone = one * -1
|
|
one = one.cuda(args.gpu_device)
|
|
mone = mone.cuda(args.gpu_device)
|
|
|
|
d_loss_real = netD(real)
|
|
d_loss_real = d_loss_real.mean()
|
|
d_loss_real.backward(mone)
|
|
|
|
d_loss_fake = netD(fake)
|
|
d_loss_fake = d_loss_fake.mean()
|
|
d_loss_fake.backward(one)
|
|
|
|
# Train with gradient penalty
|
|
gradient_penalty = calculate_gradient_penalty(netD, real.data, fake.data)
|
|
gradient_penalty.backward()
|
|
|
|
|
|
d_loss = d_loss_fake - d_loss_real + gradient_penalty
|
|
Wasserstein_D = d_loss_real - d_loss_fake
|
|
optD.step()
|
|
|
|
# Generator update
|
|
for p in netD.parameters():
|
|
p.requires_grad = False # to avoid computation
|
|
|
|
fake_images = image_f()
|
|
g_loss = netD(fake_images)
|
|
g_loss = -g_loss.mean()
|
|
dom_loss = g_loss
|
|
g_cost = -g_loss
|
|
|
|
if i% 5 == 0:
|
|
print(f' loss_fake: {d_loss_fake}, loss_real: {d_loss_real}')
|
|
print(f'Generator g_loss: {g_loss}')
|
|
'''end'''
|
|
|
|
|
|
|
|
'''ssim loss'''
|
|
|
|
'''end'''
|
|
|
|
if args.disc:
|
|
loss = sign * objective_f(hook) + args.pw * dom_loss
|
|
# loss = args.pw * dom_loss
|
|
else:
|
|
loss = sign * objective_f(hook)
|
|
# loss = args.pw * dom_loss
|
|
|
|
loss.backward()
|
|
|
|
# #video the images
|
|
# if i % 5 == 0:
|
|
# print('1')
|
|
# image_name = image_name[0].split('\\')[-1].split('.')[0] + '_' + str(i) + '.png'
|
|
# img_path = os.path.join(args.path_helper['sample_path'], str(image_name))
|
|
# export(image_f(), img_path)
|
|
# #end
|
|
# if i % 50 == 0:
|
|
# print('Loss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
|
|
# % (errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
|
|
|
|
optimizer.step()
|
|
if i in thresholds:
|
|
image = tensor_to_img_array(image_f())
|
|
# if verbose:
|
|
# print("Loss at step {}: {:.3f}".format(i, objective_f(hook)))
|
|
if save_image:
|
|
na = image_name[0].split('\\')[-1].split('.')[0] + '_' + str(i) + '.png'
|
|
na = date_time + na
|
|
outpath = args.quickcheck if args.quickcheck else args.path_helper['sample_path']
|
|
img_path = os.path.join(outpath, str(na))
|
|
export(image_f(), img_path)
|
|
|
|
images.append(image)
|
|
except KeyboardInterrupt:
|
|
print("Interrupted optimization at step {:d}.".format(i))
|
|
if verbose:
|
|
print("Loss at step {}: {:.3f}".format(i, objective_f(hook)))
|
|
images.append(tensor_to_img_array(image_f()))
|
|
|
|
if save_image:
|
|
na = image_name[0].split('\\')[-1].split('.')[0] + '.png'
|
|
na = date_time + na
|
|
outpath = args.quickcheck if args.quickcheck else args.path_helper['sample_path']
|
|
img_path = os.path.join(outpath, str(na))
|
|
export(image_f(), img_path)
|
|
if show_inline:
|
|
show(tensor_to_img_array(image_f()))
|
|
elif show_image:
|
|
view(image_f())
|
|
return image_f()
|
|
|
|
|
|
def tensor_to_img_array(tensor):
|
|
image = tensor.cpu().detach().numpy()
|
|
image = np.transpose(image, [0, 2, 3, 1])
|
|
return image
|
|
|
|
|
|
def view(tensor):
|
|
image = tensor_to_img_array(tensor)
|
|
assert len(image.shape) in [
|
|
3,
|
|
4,
|
|
], "Image should have 3 or 4 dimensions, invalid image shape {}".format(image.shape)
|
|
# Change dtype for PIL.Image
|
|
image = (image * 255).astype(np.uint8)
|
|
if len(image.shape) == 4:
|
|
image = np.concatenate(image, axis=1)
|
|
Image.fromarray(image).show()
|
|
|
|
|
|
def export(tensor, img_path=None):
|
|
# image_name = image_name or "image.jpg"
|
|
c = tensor.size(1)
|
|
# if c == 7:
|
|
# for i in range(c):
|
|
# w_map = tensor[:,i,:,:].unsqueeze(1)
|
|
# w_map = tensor_to_img_array(w_map).squeeze()
|
|
# w_map = (w_map * 255).astype(np.uint8)
|
|
# image_name = image_name[0].split('/')[-1].split('.')[0] + str(i)+ '.png'
|
|
# wheat = sns.heatmap(w_map,cmap='coolwarm')
|
|
# figure = wheat.get_figure()
|
|
# figure.savefig ('./fft_maps/weightheatmap/'+str(image_name), dpi=400)
|
|
# figure = 0
|
|
# else:
|
|
if c == 3:
|
|
vutils.save_image(tensor, fp = img_path)
|
|
else:
|
|
image = tensor[:,0:3,:,:]
|
|
w_map = tensor[:,-1,:,:].unsqueeze(1)
|
|
image = tensor_to_img_array(image)
|
|
w_map = 1 - tensor_to_img_array(w_map).squeeze()
|
|
# w_map[w_map==1] = 0
|
|
assert len(image.shape) in [
|
|
3,
|
|
4,
|
|
], "Image should have 3 or 4 dimensions, invalid image shape {}".format(image.shape)
|
|
# Change dtype for PIL.Image
|
|
image = (image * 255).astype(np.uint8)
|
|
w_map = (w_map * 255).astype(np.uint8)
|
|
|
|
Image.fromarray(w_map,'L').save(img_path)
|
|
|
|
|
|
class ModuleHook:
|
|
def __init__(self, module):
|
|
self.hook = module.register_forward_hook(self.hook_fn)
|
|
self.module = None
|
|
self.features = None
|
|
|
|
|
|
def hook_fn(self, module, input, output):
|
|
self.module = module
|
|
self.features = output
|
|
|
|
|
|
def close(self):
|
|
self.hook.remove()
|
|
|
|
|
|
def hook_model(model, image_f):
|
|
features = OrderedDict()
|
|
# recursive hooking function
|
|
def hook_layers(net, prefix=[]):
|
|
if hasattr(net, "_modules"):
|
|
for name, layer in net._modules.items():
|
|
if layer is None:
|
|
# e.g. GoogLeNet's aux1 and aux2 layers
|
|
continue
|
|
features["_".join(prefix + [name])] = ModuleHook(layer)
|
|
hook_layers(layer, prefix=prefix + [name])
|
|
|
|
hook_layers(model)
|
|
|
|
def hook(layer):
|
|
if layer == "input":
|
|
out = image_f()
|
|
elif layer == "labels":
|
|
out = list(features.values())[-1].features
|
|
else:
|
|
assert layer in features, f"Invalid layer {layer}. Retrieve the list of layers with `lucent.modelzoo.util.get_model_layers(model)`."
|
|
out = features[layer].features
|
|
assert out is not None, "There are no saved feature maps. Make sure to put the model in eval mode, like so: `model.to(device).eval()`. See README for example."
|
|
return out
|
|
|
|
return hook
|
|
|
|
def vis_image(imgs, pred_masks, gt_masks, save_path, reverse = False, points = None):
|
|
|
|
b,c,h,w = pred_masks.size()
|
|
dev = pred_masks.get_device()
|
|
row_num = min(b, 4)
|
|
|
|
if torch.max(pred_masks) > 1 or torch.min(pred_masks) < 0:
|
|
pred_masks = torch.sigmoid(pred_masks)
|
|
|
|
if reverse == True:
|
|
pred_masks = 1 - pred_masks
|
|
gt_masks = 1 - gt_masks
|
|
if c == 2:
|
|
pred_disc, pred_cup = pred_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w), pred_masks[:,1,:,:].unsqueeze(1).expand(b,3,h,w)
|
|
gt_disc, gt_cup = gt_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w), gt_masks[:,1,:,:].unsqueeze(1).expand(b,3,h,w)
|
|
tup = (imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:])
|
|
# compose = torch.cat((imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0)
|
|
compose = torch.cat((pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0)
|
|
vutils.save_image(compose, fp = save_path, nrow = row_num, padding = 10)
|
|
else:
|
|
imgs = torchvision.transforms.Resize((h,w))(imgs)
|
|
if imgs.size(1) == 1:
|
|
imgs = imgs[:,0,:,:].unsqueeze(1).expand(b,3,h,w)
|
|
pred_masks = pred_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w)
|
|
gt_masks = gt_masks[:,0,:,:].unsqueeze(1).expand(b,3,h,w)
|
|
if points != None:
|
|
for i in range(b):
|
|
if args.thd:
|
|
p = np.round(points.cpu()/args.roi_size * args.out_size).to(dtype = torch.int)
|
|
else:
|
|
p = np.round(points.cpu()/args.image_size * args.out_size).to(dtype = torch.int)
|
|
# gt_masks[i,:,points[i,0]-5:points[i,0]+5,points[i,1]-5:points[i,1]+5] = torch.Tensor([255, 0, 0]).to(dtype = torch.float32, device = torch.device('cuda:' + str(dev)))
|
|
gt_masks[i,0,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.5
|
|
gt_masks[i,1,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.1
|
|
gt_masks[i,2,p[i,0]-5:p[i,0]+5,p[i,1]-5:p[i,1]+5] = 0.4
|
|
tup = (imgs[:row_num,:,:,:],pred_masks[:row_num,:,:,:], gt_masks[:row_num,:,:,:])
|
|
# compose = torch.cat((imgs[:row_num,:,:,:],pred_disc[:row_num,:,:,:], pred_cup[:row_num,:,:,:], gt_disc[:row_num,:,:,:], gt_cup[:row_num,:,:,:]),0)
|
|
compose = torch.cat(tup,0)
|
|
vutils.save_image(compose, fp = save_path, nrow = row_num, padding = 10)
|
|
|
|
return
|
|
|
|
def eval_seg(pred,true_mask_p,threshold):
|
|
'''
|
|
threshold: a int or a tuple of int
|
|
masks: [b,2,h,w]
|
|
pred: [b,2,h,w]
|
|
'''
|
|
b, c, h, w = pred.size()
|
|
if c == 2:
|
|
iou_d, iou_c, disc_dice, cup_dice = 0,0,0,0
|
|
for th in threshold:
|
|
|
|
gt_vmask_p = (true_mask_p > th).float()
|
|
vpred = (pred > th).float()
|
|
vpred_cpu = vpred.cpu()
|
|
disc_pred = vpred_cpu[:,0,:,:].numpy().astype('int32')
|
|
cup_pred = vpred_cpu[:,1,:,:].numpy().astype('int32')
|
|
|
|
disc_mask = gt_vmask_p [:,0,:,:].squeeze(1).cpu().numpy().astype('int32')
|
|
cup_mask = gt_vmask_p [:, 1, :, :].squeeze(1).cpu().numpy().astype('int32')
|
|
|
|
'''iou for numpy'''
|
|
iou_d += iou(disc_pred,disc_mask)
|
|
iou_c += iou(cup_pred,cup_mask)
|
|
|
|
'''dice for torch'''
|
|
disc_dice += dice_coeff(vpred[:,0,:,:], gt_vmask_p[:,0,:,:]).item()
|
|
cup_dice += dice_coeff(vpred[:,1,:,:], gt_vmask_p[:,1,:,:]).item()
|
|
|
|
return iou_d / len(threshold), iou_c / len(threshold), disc_dice / len(threshold), cup_dice / len(threshold)
|
|
else:
|
|
eiou, edice = 0,0
|
|
for th in threshold:
|
|
|
|
gt_vmask_p = (true_mask_p > th).float()
|
|
vpred = (pred > th).float()
|
|
vpred_cpu = vpred.cpu()
|
|
disc_pred = vpred_cpu[:,0,:,:].numpy().astype('int32')
|
|
|
|
disc_mask = gt_vmask_p [:,0,:,:].squeeze(1).cpu().numpy().astype('int32')
|
|
|
|
'''iou for numpy'''
|
|
eiou += iou(disc_pred,disc_mask)
|
|
|
|
'''dice for torch'''
|
|
edice += dice_coeff(vpred[:,0,:,:], gt_vmask_p[:,0,:,:]).item()
|
|
|
|
return eiou / len(threshold), edice / len(threshold)
|
|
|
|
# @objectives.wrap_objective()
|
|
def dot_compare(layer, batch=1, cossim_pow=0):
|
|
def inner(T):
|
|
dot = (T(layer)[batch] * T(layer)[0]).sum()
|
|
mag = torch.sqrt(torch.sum(T(layer)[0]**2))
|
|
cossim = dot/(1e-6 + mag)
|
|
return -dot * cossim ** cossim_pow
|
|
return inner
|
|
|
|
def init_D(m):
|
|
classname = m.__class__.__name__
|
|
if classname.find('Conv') != -1:
|
|
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
|
elif classname.find('BatchNorm') != -1:
|
|
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
|
nn.init.constant_(m.bias.data, 0)
|
|
|
|
def pre_d():
|
|
netD = Discriminator(3).to(device)
|
|
# netD.apply(init_D)
|
|
beta1 = 0.5
|
|
dis_lr = 0.00002
|
|
optimizerD = optim.Adam(netD.parameters(), lr=dis_lr, betas=(beta1, 0.999))
|
|
return netD, optimizerD
|
|
|
|
def update_d(args, netD, optimizerD, real, fake):
|
|
criterion = nn.BCELoss()
|
|
|
|
label = torch.full((args.b,), 1., dtype=torch.float, device=device)
|
|
output = netD(real).view(-1)
|
|
# Calculate loss on all-real batch
|
|
errD_real = criterion(output, label)
|
|
# Calculate gradients for D in backward pass
|
|
errD_real.backward()
|
|
D_x = output.mean().item()
|
|
|
|
label.fill_(0.)
|
|
# Classify all fake batch with D
|
|
output = netD(fake.detach()).view(-1)
|
|
# Calculate D's loss on the all-fake batch
|
|
errD_fake = criterion(output, label)
|
|
# Calculate the gradients for this batch, accumulated (summed) with previous gradients
|
|
errD_fake.backward()
|
|
D_G_z1 = output.mean().item()
|
|
# Compute error of D as sum over the fake and the real batches
|
|
errD = errD_real + errD_fake
|
|
# Update D
|
|
optimizerD.step()
|
|
|
|
return errD, D_x, D_G_z1
|
|
|
|
def calculate_gradient_penalty(netD, real_images, fake_images):
|
|
eta = torch.FloatTensor(args.b,1,1,1).uniform_(0,1)
|
|
eta = eta.expand(args.b, real_images.size(1), real_images.size(2), real_images.size(3)).to(device = device)
|
|
|
|
interpolated = (eta * real_images + ((1 - eta) * fake_images)).to(device = device)
|
|
|
|
# define it to calculate gradient
|
|
interpolated = Variable(interpolated, requires_grad=True)
|
|
|
|
# calculate probability of interpolated examples
|
|
prob_interpolated = netD(interpolated)
|
|
|
|
# calculate gradients of probabilities with respect to examples
|
|
gradients = autograd.grad(outputs=prob_interpolated, inputs=interpolated,
|
|
grad_outputs=torch.ones(
|
|
prob_interpolated.size()).to(device = device),
|
|
create_graph=True, retain_graph=True)[0]
|
|
|
|
grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10
|
|
return grad_penalty
|
|
|
|
|
|
def random_click(mask, point_labels = 1, inout = 1):
|
|
indices = np.argwhere(mask == inout)
|
|
return indices[np.random.randint(len(indices))]
|
|
|
|
|
|
def generate_click_prompt(img, msk, pt_label = 1):
|
|
# return: prompt, prompt mask
|
|
pt_list = []
|
|
msk_list = []
|
|
b, c, h, w, d = msk.size()
|
|
msk = msk[:,0,:,:,:]
|
|
for i in range(d):
|
|
pt_list_s = []
|
|
msk_list_s = []
|
|
for j in range(b):
|
|
msk_s = msk[j,:,:,i]
|
|
indices = torch.nonzero(msk_s)
|
|
if indices.size(0) == 0:
|
|
# generate a random array between [0-h, 0-h]:
|
|
random_index = torch.randint(0, h, (2,)).to(device = msk.device)
|
|
new_s = msk_s
|
|
else:
|
|
random_index = random.choice(indices)
|
|
label = msk_s[random_index[0], random_index[1]]
|
|
new_s = torch.zeros_like(msk_s)
|
|
# convert bool tensor to int
|
|
new_s = (msk_s == label).to(dtype = torch.float)
|
|
# new_s[msk_s == label] = 1
|
|
pt_list_s.append(random_index)
|
|
msk_list_s.append(new_s)
|
|
pts = torch.stack(pt_list_s, dim=0)
|
|
msks = torch.stack(msk_list_s, dim=0)
|
|
pt_list.append(pts)
|
|
msk_list.append(msks)
|
|
pt = torch.stack(pt_list, dim=-1)
|
|
msk = torch.stack(msk_list, dim=-1)
|
|
|
|
msk = msk.unsqueeze(1)
|
|
|
|
return img, pt, msk #[b, 2, d], [b, c, h, w, d]
|
|
|
|
|
|
|