You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2278 lines
92 KiB

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import math
from typing import Optional, Tuple, Type
from .common import LayerNorm2d, MLPBlock, Adapter
from abc import abstractmethod
import math
import numpy as np
import torch as th
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from copy import deepcopy
from .utils import softmax_helper,sigmoid_helper
from .utils import InitWeights_He
from batchgenerators.augmentations.utils import pad_nd_image
from .utils import no_op
from .utils import to_cuda, maybe_to_torch
from scipy.ndimage.filters import gaussian_filter
from typing import Union, Tuple, List
from torch.cuda.amp import autocast
from .nn import (
checkpoint,
conv_nd,
linear,
avg_pool_nd,
zero_module,
normalization,
timestep_embedding,
layer_norm,
)
class OnePromptEncoderViT(nn.Module):
def __init__(
self,
args,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.in_chans = in_chans
self.args = args
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b = x.size(0)
skips = [[] for i in range(b)]
skips = []
x = self.patch_embed(x)
if self.pos_embed is not None:
# print("x size is", x.size())
# print("self.pos_embed size is",self.pos_embed.size())
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
# for i in range(b):
# skips[i].append(x[i,...])
skips.append(x)
return x, skips
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
args,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.in_chans = in_chans
self.args = args
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
args= self.args,
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
return x
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn
def closest_numbers(target):
a = int(target ** 0.5)
b = a + 1
while True:
if a * b == target:
return (a, b)
elif a * b < target:
b += 1
else:
a -= 1
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def __init__(
self,
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(
th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
self.attention = QKVAttention(self.num_heads)
def forward(self, x):
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1) # NC(HW)
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
return x[:, :, 0]
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=1
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True)
)
def conv_dw(inp, oup, stride):
return nn.Sequential(
# dw
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True),
# pw
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
)
class MobBlock(nn.Module):
def __init__(self,ind):
super().__init__()
if ind == 0:
self.stage = nn.Sequential(
conv_bn(3, 32, 2),
conv_dw(32, 64, 1),
conv_dw(64, 128, 1),
conv_dw(128, 128, 1)
)
elif ind == 1:
self.stage = nn.Sequential(
conv_dw(128, 256, 2),
conv_dw(256, 256, 1)
)
elif ind == 2:
self.stage = nn.Sequential(
conv_dw(256, 256, 2),
conv_dw(256, 256, 1)
)
else:
self.stage = nn.Sequential(
conv_dw(256, 512, 2),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1)
)
def forward(self,x):
return self.stage(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
def _forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
if use_new_attention_order:
# split qkv before split heads
self.attention = QKVAttention(self.num_heads)
else:
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True)
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
def count_flops_attn(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class EncoderUNetModel(nn.Module):
"""
The half UNet model with attention and timestep embedding.
For usage, see UNet.
"""
def __init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
pool="adaptive",
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.pool = pool
self.gap = nn.AvgPool2d((8, 8)) #global average pooling
self.cam_feature_maps = None
print('pool', pool)
if pool == "adaptive":
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
nn.AdaptiveAvgPool2d((1, 1)),
zero_module(conv_nd(dims, ch, out_channels, 1)),
nn.Flatten(),
)
elif pool == "attention":
assert num_head_channels != -1
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
AttentionPool2d(
(image_size // ds), ch, num_head_channels, out_channels
),
)
elif pool == "spatial":
self.out = nn.Linear(256, self.out_channels)
elif pool == "spatial_v2":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
normalization(2048),
nn.SiLU(),
nn.Linear(2048, self.out_channels),
)
else:
raise NotImplementedError(f"Unexpected {pool} pooling")
def forward(self, x, timesteps):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
results = []
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
if self.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb)
if self.pool.startswith("spatial"):
self.cam_feature_maps = h
h = self.gap(h)
N = h.shape[0]
h = h.reshape(N, -1)
print('h1', h.shape)
return self.out(h)
else:
h = h.type(x.dtype)
self.cam_feature_maps = h
return self.out(h)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
def get_device(self):
if next(self.parameters()).device.type == "cpu":
return "cpu"
else:
return next(self.parameters()).device.index
def set_device(self, device):
if device == "cpu":
self.cpu()
else:
self.cuda(device)
def forward(self, x):
raise NotImplementedError
class SegmentationNetwork(NeuralNetwork):
def __init__(self):
super(NeuralNetwork, self).__init__()
# if we have 5 pooling then our patch size must be divisible by 2**5
self.input_shape_must_be_divisible_by = None # for example in a 2d network that does 5 pool in x and 6 pool
# in y this would be (32, 64)
# we need to know this because we need to know if we are a 2d or a 3d netowrk
self.conv_op = None # nn.Conv2d or nn.Conv3d
# this tells us how many channels we have in the output. Important for preallocation in inference
self.num_classes = None # number of channels in the output
# depending on the loss, we do not hard code a nonlinearity into the architecture. To aggregate predictions
# during inference, we need to apply the nonlinearity, however. So it is important to let the newtork know what
# to apply in inference. For the most part this will be softmax
self.inference_apply_nonlin = lambda x: x # softmax_helper
# This is for saving a gaussian importance map for inference. It weights voxels higher that are closer to the
# center. Prediction at the borders are often less accurate and are thus downweighted. Creating these Gaussians
# can be expensive, so it makes sense to save and reuse them.
self._gaussian_3d = self._patch_size_for_gaussian_3d = None
self._gaussian_2d = self._patch_size_for_gaussian_2d = None
def predict_3D(self, x: np.ndarray, do_mirroring: bool, mirror_axes: Tuple[int, ...] = (0, 1, 2),
use_sliding_window: bool = False,
step_size: float = 0.5, patch_size: Tuple[int, ...] = None, regions_class_order: Tuple[int, ...] = None,
use_gaussian: bool = False, pad_border_mode: str = "constant",
pad_kwargs: dict = None, all_in_gpu: bool = False,
verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]:
torch.cuda.empty_cache()
assert step_size <= 1, 'step_size must be smaller than 1. Otherwise there will be a gap between consecutive ' \
'predictions'
if verbose: print("debug: mirroring", do_mirroring, "mirror_axes", mirror_axes)
if pad_kwargs is None:
pad_kwargs = {'constant_values': 0}
# A very long time ago the mirror axes were (2, 3, 4) for a 3d network. This is just to intercept any old
# code that uses this convention
if len(mirror_axes):
if self.conv_op == nn.Conv2d:
if max(mirror_axes) > 1:
raise ValueError("mirror axes. duh")
if self.conv_op == nn.Conv3d:
if max(mirror_axes) > 2:
raise ValueError("mirror axes. duh")
if self.training:
print('WARNING! Network is in train mode during inference. This may be intended, or not...')
assert len(x.shape) == 4, "data must have shape (c,x,y,z)"
if mixed_precision:
context = autocast
else:
context = no_op
with context():
with torch.no_grad():
if self.conv_op == nn.Conv3d:
if use_sliding_window:
res = self._internal_predict_3D_3Dconv_tiled(x, step_size, do_mirroring, mirror_axes, patch_size,
regions_class_order, use_gaussian, pad_border_mode,
pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu,
verbose=verbose)
else:
res = self._internal_predict_3D_3Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order,
pad_border_mode, pad_kwargs=pad_kwargs, verbose=verbose)
elif self.conv_op == nn.Conv2d:
if use_sliding_window:
res = self._internal_predict_3D_2Dconv_tiled(x, patch_size, do_mirroring, mirror_axes, step_size,
regions_class_order, use_gaussian, pad_border_mode,
pad_kwargs, all_in_gpu, False)
else:
res = self._internal_predict_3D_2Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order,
pad_border_mode, pad_kwargs, all_in_gpu, False)
else:
raise RuntimeError("Invalid conv op, cannot determine what dimensionality (2d/3d) the network is")
return res
def predict_2D(self, x, do_mirroring: bool, mirror_axes: tuple = (0, 1, 2), use_sliding_window: bool = False,
step_size: float = 0.5, patch_size: tuple = None, regions_class_order: tuple = None,
use_gaussian: bool = False, pad_border_mode: str = "constant",
pad_kwargs: dict = None, all_in_gpu: bool = False,
verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]:
torch.cuda.empty_cache()
assert step_size <= 1, 'step_size must be smaler than 1. Otherwise there will be a gap between consecutive ' \
'predictions'
if self.conv_op == nn.Conv3d:
raise RuntimeError("Cannot predict 2d if the network is 3d. Dummy.")
if verbose: print("debug: mirroring", do_mirroring, "mirror_axes", mirror_axes)
if pad_kwargs is None:
pad_kwargs = {'constant_values': 0}
# A very long time ago the mirror axes were (2, 3) for a 2d network. This is just to intercept any old
# code that uses this convention
if len(mirror_axes):
if max(mirror_axes) > 1:
raise ValueError("mirror axes. duh")
if self.training:
print('WARNING! Network is in train mode during inference. This may be intended, or not...')
assert len(x.shape) == 3, "data must have shape (c,x,y)"
if mixed_precision:
context = autocast
else:
context = no_op
with context():
with torch.no_grad():
if self.conv_op == nn.Conv2d:
if use_sliding_window:
res = self._internal_predict_2D_2Dconv_tiled(x, step_size, do_mirroring, mirror_axes, patch_size,
regions_class_order, use_gaussian, pad_border_mode,
pad_kwargs, all_in_gpu, verbose)
else:
res = self._internal_predict_2D_2Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order,
pad_border_mode, pad_kwargs, verbose)
else:
raise RuntimeError("Invalid conv op, cannot determine what dimensionality (2d/3d) the network is")
return res
@staticmethod
def _get_gaussian(patch_size, sigma_scale=1. / 8) -> np.ndarray:
tmp = np.zeros(patch_size)
center_coords = [i // 2 for i in patch_size]
sigmas = [i * sigma_scale for i in patch_size]
tmp[tuple(center_coords)] = 1
gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * 1
gaussian_importance_map = gaussian_importance_map.astype(np.float32)
# gaussian_importance_map cannot be 0, otherwise we may end up with nans!
gaussian_importance_map[gaussian_importance_map == 0] = np.min(
gaussian_importance_map[gaussian_importance_map != 0])
return gaussian_importance_map
@staticmethod
def _compute_steps_for_sliding_window(patch_size: Tuple[int, ...], image_size: Tuple[int, ...], step_size: float) -> List[List[int]]:
assert [i >= j for i, j in zip(image_size, patch_size)], "image size must be as large or larger than patch_size"
assert 0 < step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1'
# our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of
# 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46
target_step_sizes_in_voxels = [i * step_size for i in patch_size]
num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, patch_size)]
steps = []
for dim in range(len(patch_size)):
# the highest step value for this dimension is
max_step_value = image_size[dim] - patch_size[dim]
if num_steps[dim] > 1:
actual_step_size = max_step_value / (num_steps[dim] - 1)
else:
actual_step_size = 99999999999 # does not matter because there is only one step at 0
steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])]
steps.append(steps_here)
return steps
def _internal_predict_3D_3Dconv_tiled(self, x: np.ndarray, step_size: float, do_mirroring: bool, mirror_axes: tuple,
patch_size: tuple, regions_class_order: tuple, use_gaussian: bool,
pad_border_mode: str, pad_kwargs: dict, all_in_gpu: bool,
verbose: bool) -> Tuple[np.ndarray, np.ndarray]:
# better safe than sorry
assert len(x.shape) == 4, "x must be (c, x, y, z)"
if verbose: print("step_size:", step_size)
if verbose: print("do mirror:", do_mirroring)
assert patch_size is not None, "patch_size cannot be None for tiled prediction"
# for sliding window inference the image must at least be as large as the patch size. It does not matter
# whether the shape is divisible by 2**num_pool as long as the patch size is
data, slicer = pad_nd_image(x, patch_size, pad_border_mode, pad_kwargs, True, None)
data_shape = data.shape # still c, x, y, z
# compute the steps for sliding window
steps = self._compute_steps_for_sliding_window(patch_size, data_shape[1:], step_size)
num_tiles = len(steps[0]) * len(steps[1]) * len(steps[2])
if verbose:
print("data shape:", data_shape)
print("patch size:", patch_size)
print("steps (x, y, and z):", steps)
print("number of tiles:", num_tiles)
# we only need to compute that once. It can take a while to compute this due to the large sigma in
# gaussian_filter
if use_gaussian and num_tiles > 1:
if self._gaussian_3d is None or not all(
[i == j for i, j in zip(patch_size, self._patch_size_for_gaussian_3d)]):
if verbose: print('computing Gaussian')
gaussian_importance_map = self._get_gaussian(patch_size, sigma_scale=1. / 8)
self._gaussian_3d = gaussian_importance_map
self._patch_size_for_gaussian_3d = patch_size
if verbose: print("done")
else:
if verbose: print("using precomputed Gaussian")
gaussian_importance_map = self._gaussian_3d
gaussian_importance_map = torch.from_numpy(gaussian_importance_map)
#predict on cpu if cuda not available
if torch.cuda.is_available():
gaussian_importance_map = gaussian_importance_map.cuda(self.get_device(), non_blocking=True)
else:
gaussian_importance_map = None
if all_in_gpu:
# If we run the inference in GPU only (meaning all tensors are allocated on the GPU, this reduces
# CPU-GPU communication but required more GPU memory) we need to preallocate a few things on GPU
if use_gaussian and num_tiles > 1:
# half precision for the outputs should be good enough. If the outputs here are half, the
# gaussian_importance_map should be as well
gaussian_importance_map = gaussian_importance_map.half()
# make sure we did not round anything to 0
gaussian_importance_map[gaussian_importance_map == 0] = gaussian_importance_map[
gaussian_importance_map != 0].min()
add_for_nb_of_preds = gaussian_importance_map
else:
add_for_nb_of_preds = torch.ones(patch_size, device=self.get_device())
if verbose: print("initializing result array (on GPU)")
aggregated_results = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,
device=self.get_device())
if verbose: print("moving data to GPU")
data = torch.from_numpy(data).cuda(self.get_device(), non_blocking=True)
if verbose: print("initializing result_numsamples (on GPU)")
aggregated_nb_of_predictions = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,
device=self.get_device())
else:
if use_gaussian and num_tiles > 1:
add_for_nb_of_preds = self._gaussian_3d
else:
add_for_nb_of_preds = np.ones(patch_size, dtype=np.float32)
aggregated_results = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
aggregated_nb_of_predictions = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
for x in steps[0]:
lb_x = x
ub_x = x + patch_size[0]
for y in steps[1]:
lb_y = y
ub_y = y + patch_size[1]
for z in steps[2]:
lb_z = z
ub_z = z + patch_size[2]
predicted_patch = self._internal_maybe_mirror_and_pred_3D(
data[None, :, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z], mirror_axes, do_mirroring,
gaussian_importance_map)[0]
if all_in_gpu:
predicted_patch = predicted_patch.half()
else:
predicted_patch = predicted_patch.cpu().numpy()
aggregated_results[:, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z] += predicted_patch
aggregated_nb_of_predictions[:, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z] += add_for_nb_of_preds
# we reverse the padding here (remeber that we padded the input to be at least as large as the patch size
slicer = tuple(
[slice(0, aggregated_results.shape[i]) for i in
range(len(aggregated_results.shape) - (len(slicer) - 1))] + slicer[1:])
aggregated_results = aggregated_results[slicer]
aggregated_nb_of_predictions = aggregated_nb_of_predictions[slicer]
# computing the class_probabilities by dividing the aggregated result with result_numsamples
aggregated_results /= aggregated_nb_of_predictions
del aggregated_nb_of_predictions
if regions_class_order is None:
predicted_segmentation = aggregated_results.argmax(0)
else:
if all_in_gpu:
class_probabilities_here = aggregated_results.detach().cpu().numpy()
else:
class_probabilities_here = aggregated_results
predicted_segmentation = np.zeros(class_probabilities_here.shape[1:], dtype=np.float32)
for i, c in enumerate(regions_class_order):
predicted_segmentation[class_probabilities_here[i] > 0.5] = c
if all_in_gpu:
if verbose: print("copying results to CPU")
if regions_class_order is None:
predicted_segmentation = predicted_segmentation.detach().cpu().numpy()
aggregated_results = aggregated_results.detach().cpu().numpy()
if verbose: print("prediction done")
return predicted_segmentation, aggregated_results
def _internal_predict_2D_2Dconv(self, x: np.ndarray, min_size: Tuple[int, int], do_mirroring: bool,
mirror_axes: tuple = (0, 1, 2), regions_class_order: tuple = None,
pad_border_mode: str = "constant", pad_kwargs: dict = None,
verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""
This one does fully convolutional inference. No sliding window
"""
assert len(x.shape) == 3, "x must be (c, x, y)"
assert self.input_shape_must_be_divisible_by is not None, 'input_shape_must_be_divisible_by must be set to ' \
'run _internal_predict_2D_2Dconv'
if verbose: print("do mirror:", do_mirroring)
data, slicer = pad_nd_image(x, min_size, pad_border_mode, pad_kwargs, True,
self.input_shape_must_be_divisible_by)
predicted_probabilities = self._internal_maybe_mirror_and_pred_2D(data[None], mirror_axes, do_mirroring,
None)[0]
slicer = tuple(
[slice(0, predicted_probabilities.shape[i]) for i in range(len(predicted_probabilities.shape) -
(len(slicer) - 1))] + slicer[1:])
predicted_probabilities = predicted_probabilities[slicer]
if regions_class_order is None:
predicted_segmentation = predicted_probabilities.argmax(0)
predicted_segmentation = predicted_segmentation.detach().cpu().numpy()
predicted_probabilities = predicted_probabilities.detach().cpu().numpy()
else:
predicted_probabilities = predicted_probabilities.detach().cpu().numpy()
predicted_segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.float32)
for i, c in enumerate(regions_class_order):
predicted_segmentation[predicted_probabilities[i] > 0.5] = c
return predicted_segmentation, predicted_probabilities
def _internal_predict_3D_3Dconv(self, x: np.ndarray, min_size: Tuple[int, ...], do_mirroring: bool,
mirror_axes: tuple = (0, 1, 2), regions_class_order: tuple = None,
pad_border_mode: str = "constant", pad_kwargs: dict = None,
verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""
This one does fully convolutional inference. No sliding window
"""
assert len(x.shape) == 4, "x must be (c, x, y, z)"
assert self.input_shape_must_be_divisible_by is not None, 'input_shape_must_be_divisible_by must be set to ' \
'run _internal_predict_3D_3Dconv'
if verbose: print("do mirror:", do_mirroring)
data, slicer = pad_nd_image(x, min_size, pad_border_mode, pad_kwargs, True,
self.input_shape_must_be_divisible_by)
predicted_probabilities = self._internal_maybe_mirror_and_pred_3D(data[None], mirror_axes, do_mirroring,
None)[0]
slicer = tuple(
[slice(0, predicted_probabilities.shape[i]) for i in range(len(predicted_probabilities.shape) -
(len(slicer) - 1))] + slicer[1:])
predicted_probabilities = predicted_probabilities[slicer]
if regions_class_order is None:
predicted_segmentation = predicted_probabilities.argmax(0)
predicted_segmentation = predicted_segmentation.detach().cpu().numpy()
predicted_probabilities = predicted_probabilities.detach().cpu().numpy()
else:
predicted_probabilities = predicted_probabilities.detach().cpu().numpy()
predicted_segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.float32)
for i, c in enumerate(regions_class_order):
predicted_segmentation[predicted_probabilities[i] > 0.5] = c
return predicted_segmentation, predicted_probabilities
def _internal_maybe_mirror_and_pred_3D(self, x: Union[np.ndarray, torch.tensor], mirror_axes: tuple,
do_mirroring: bool = True,
mult: np.ndarray or torch.tensor = None) -> torch.tensor:
assert len(x.shape) == 5, 'x must be (b, c, x, y, z)'
# if cuda available:
# everything in here takes place on the GPU. If x and mult are not yet on GPU this will be taken care of here
# we now return a cuda tensor! Not numpy array!
x = maybe_to_torch(x)
result_torch = torch.zeros([1, self.num_classes] + list(x.shape[2:]),
dtype=torch.float)
if torch.cuda.is_available():
x = to_cuda(x, gpu_id=self.get_device())
result_torch = result_torch.cuda(self.get_device(), non_blocking=True)
if mult is not None:
mult = maybe_to_torch(mult)
if torch.cuda.is_available():
mult = to_cuda(mult, gpu_id=self.get_device())
if do_mirroring:
mirror_idx = 8
num_results = 2 ** len(mirror_axes)
else:
mirror_idx = 1
num_results = 1
for m in range(mirror_idx):
if m == 0:
pred = self.inference_apply_nonlin(self(x))
result_torch += 1 / num_results * pred
if m == 1 and (2 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (4, ))))
result_torch += 1 / num_results * torch.flip(pred, (4,))
if m == 2 and (1 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (3, ))))
result_torch += 1 / num_results * torch.flip(pred, (3,))
if m == 3 and (2 in mirror_axes) and (1 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 3))))
result_torch += 1 / num_results * torch.flip(pred, (4, 3))
if m == 4 and (0 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (2, ))))
result_torch += 1 / num_results * torch.flip(pred, (2,))
if m == 5 and (0 in mirror_axes) and (2 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 2))))
result_torch += 1 / num_results * torch.flip(pred, (4, 2))
if m == 6 and (0 in mirror_axes) and (1 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (3, 2))))
result_torch += 1 / num_results * torch.flip(pred, (3, 2))
if m == 7 and (0 in mirror_axes) and (1 in mirror_axes) and (2 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 3, 2))))
result_torch += 1 / num_results * torch.flip(pred, (4, 3, 2))
if mult is not None:
result_torch[:, :] *= mult
return result_torch
def _internal_maybe_mirror_and_pred_2D(self, x: Union[np.ndarray, torch.tensor], mirror_axes: tuple,
do_mirroring: bool = True,
mult: np.ndarray or torch.tensor = None) -> torch.tensor:
# if cuda available:
# everything in here takes place on the GPU. If x and mult are not yet on GPU this will be taken care of here
# we now return a cuda tensor! Not numpy array!
assert len(x.shape) == 4, 'x must be (b, c, x, y)'
x = maybe_to_torch(x)
result_torch = torch.zeros([x.shape[0], self.num_classes] + list(x.shape[2:]), dtype=torch.float)
if torch.cuda.is_available():
x = to_cuda(x, gpu_id=self.get_device())
result_torch = result_torch.cuda(self.get_device(), non_blocking=True)
if mult is not None:
mult = maybe_to_torch(mult)
if torch.cuda.is_available():
mult = to_cuda(mult, gpu_id=self.get_device())
if do_mirroring:
mirror_idx = 4
num_results = 2 ** len(mirror_axes)
else:
mirror_idx = 1
num_results = 1
for m in range(mirror_idx):
if m == 0:
pred = self.inference_apply_nonlin(self(x))
result_torch += 1 / num_results * pred
if m == 1 and (1 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (3, ))))
result_torch += 1 / num_results * torch.flip(pred, (3, ))
if m == 2 and (0 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (2, ))))
result_torch += 1 / num_results * torch.flip(pred, (2, ))
if m == 3 and (0 in mirror_axes) and (1 in mirror_axes):
pred = self.inference_apply_nonlin(self(torch.flip(x, (3, 2))))
result_torch += 1 / num_results * torch.flip(pred, (3, 2))
if mult is not None:
result_torch[:, :] *= mult
return result_torch
def _internal_predict_2D_2Dconv_tiled(self, x: np.ndarray, step_size: float, do_mirroring: bool, mirror_axes: tuple,
patch_size: tuple, regions_class_order: tuple, use_gaussian: bool,
pad_border_mode: str, pad_kwargs: dict, all_in_gpu: bool,
verbose: bool) -> Tuple[np.ndarray, np.ndarray]:
# better safe than sorry
assert len(x.shape) == 3, "x must be (c, x, y)"
if verbose: print("step_size:", step_size)
if verbose: print("do mirror:", do_mirroring)
assert patch_size is not None, "patch_size cannot be None for tiled prediction"
# for sliding window inference the image must at least be as large as the patch size. It does not matter
# whether the shape is divisible by 2**num_pool as long as the patch size is
data, slicer = pad_nd_image(x, patch_size, pad_border_mode, pad_kwargs, True, None)
data_shape = data.shape # still c, x, y
# compute the steps for sliding window
steps = self._compute_steps_for_sliding_window(patch_size, data_shape[1:], step_size)
num_tiles = len(steps[0]) * len(steps[1])
if verbose:
print("data shape:", data_shape)
print("patch size:", patch_size)
print("steps (x, y, and z):", steps)
print("number of tiles:", num_tiles)
# we only need to compute that once. It can take a while to compute this due to the large sigma in
# gaussian_filter
if use_gaussian and num_tiles > 1:
if self._gaussian_2d is None or not all(
[i == j for i, j in zip(patch_size, self._patch_size_for_gaussian_2d)]):
if verbose: print('computing Gaussian')
gaussian_importance_map = self._get_gaussian(patch_size, sigma_scale=1. / 8)
self._gaussian_2d = gaussian_importance_map
self._patch_size_for_gaussian_2d = patch_size
else:
if verbose: print("using precomputed Gaussian")
gaussian_importance_map = self._gaussian_2d
gaussian_importance_map = torch.from_numpy(gaussian_importance_map)
if torch.cuda.is_available():
gaussian_importance_map = gaussian_importance_map.cuda(self.get_device(), non_blocking=True)
else:
gaussian_importance_map = None
if all_in_gpu:
# If we run the inference in GPU only (meaning all tensors are allocated on the GPU, this reduces
# CPU-GPU communication but required more GPU memory) we need to preallocate a few things on GPU
if use_gaussian and num_tiles > 1:
# half precision for the outputs should be good enough. If the outputs here are half, the
# gaussian_importance_map should be as well
gaussian_importance_map = gaussian_importance_map.half()
# make sure we did not round anything to 0
gaussian_importance_map[gaussian_importance_map == 0] = gaussian_importance_map[
gaussian_importance_map != 0].min()
add_for_nb_of_preds = gaussian_importance_map
else:
add_for_nb_of_preds = torch.ones(patch_size, device=self.get_device())
if verbose: print("initializing result array (on GPU)")
aggregated_results = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,
device=self.get_device())
if verbose: print("moving data to GPU")
data = torch.from_numpy(data).cuda(self.get_device(), non_blocking=True)
if verbose: print("initializing result_numsamples (on GPU)")
aggregated_nb_of_predictions = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,
device=self.get_device())
else:
if use_gaussian and num_tiles > 1:
add_for_nb_of_preds = self._gaussian_2d
else:
add_for_nb_of_preds = np.ones(patch_size, dtype=np.float32)
aggregated_results = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
aggregated_nb_of_predictions = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
for x in steps[0]:
lb_x = x
ub_x = x + patch_size[0]
for y in steps[1]:
lb_y = y
ub_y = y + patch_size[1]
predicted_patch = self._internal_maybe_mirror_and_pred_2D(
data[None, :, lb_x:ub_x, lb_y:ub_y], mirror_axes, do_mirroring,
gaussian_importance_map)[0]
if all_in_gpu:
predicted_patch = predicted_patch.half()
else:
predicted_patch = predicted_patch.cpu().numpy()
aggregated_results[:, lb_x:ub_x, lb_y:ub_y] += predicted_patch
aggregated_nb_of_predictions[:, lb_x:ub_x, lb_y:ub_y] += add_for_nb_of_preds
# we reverse the padding here (remeber that we padded the input to be at least as large as the patch size
slicer = tuple(
[slice(0, aggregated_results.shape[i]) for i in
range(len(aggregated_results.shape) - (len(slicer) - 1))] + slicer[1:])
aggregated_results = aggregated_results[slicer]
aggregated_nb_of_predictions = aggregated_nb_of_predictions[slicer]
# computing the class_probabilities by dividing the aggregated result with result_numsamples
class_probabilities = aggregated_results / aggregated_nb_of_predictions
if regions_class_order is None:
predicted_segmentation = class_probabilities.argmax(0)
else:
if all_in_gpu:
class_probabilities_here = class_probabilities.detach().cpu().numpy()
else:
class_probabilities_here = class_probabilities
predicted_segmentation = np.zeros(class_probabilities_here.shape[1:], dtype=np.float32)
for i, c in enumerate(regions_class_order):
predicted_segmentation[class_probabilities_here[i] > 0.5] = c
if all_in_gpu:
if verbose: print("copying results to CPU")
if regions_class_order is None:
predicted_segmentation = predicted_segmentation.detach().cpu().numpy()
class_probabilities = class_probabilities.detach().cpu().numpy()
if verbose: print("prediction done")
return predicted_segmentation, class_probabilities
def _internal_predict_3D_2Dconv(self, x: np.ndarray, min_size: Tuple[int, int], do_mirroring: bool,
mirror_axes: tuple = (0, 1), regions_class_order: tuple = None,
pad_border_mode: str = "constant", pad_kwargs: dict = None,
all_in_gpu: bool = False, verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
if all_in_gpu:
raise NotImplementedError
assert len(x.shape) == 4, "data must be c, x, y, z"
predicted_segmentation = []
softmax_pred = []
for s in range(x.shape[1]):
pred_seg, softmax_pres = self._internal_predict_2D_2Dconv(
x[:, s], min_size, do_mirroring, mirror_axes, regions_class_order, pad_border_mode, pad_kwargs, verbose)
predicted_segmentation.append(pred_seg[None])
softmax_pred.append(softmax_pres[None])
predicted_segmentation = np.vstack(predicted_segmentation)
softmax_pred = np.vstack(softmax_pred).transpose((1, 0, 2, 3))
return predicted_segmentation, softmax_pred
def predict_3D_pseudo3D_2Dconv(self, x: np.ndarray, min_size: Tuple[int, int], do_mirroring: bool,
mirror_axes: tuple = (0, 1), regions_class_order: tuple = None,
pseudo3D_slices: int = 5, all_in_gpu: bool = False,
pad_border_mode: str = "constant", pad_kwargs: dict = None,
verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
if all_in_gpu:
raise NotImplementedError
assert len(x.shape) == 4, "data must be c, x, y, z"
assert pseudo3D_slices % 2 == 1, "pseudo3D_slices must be odd"
extra_slices = (pseudo3D_slices - 1) // 2
shp_for_pad = np.array(x.shape)
shp_for_pad[1] = extra_slices
pad = np.zeros(shp_for_pad, dtype=np.float32)
data = np.concatenate((pad, x, pad), 1)
predicted_segmentation = []
softmax_pred = []
for s in range(extra_slices, data.shape[1] - extra_slices):
d = data[:, (s - extra_slices):(s + extra_slices + 1)]
d = d.reshape((-1, d.shape[-2], d.shape[-1]))
pred_seg, softmax_pres = \
self._internal_predict_2D_2Dconv(d, min_size, do_mirroring, mirror_axes,
regions_class_order, pad_border_mode, pad_kwargs, verbose)
predicted_segmentation.append(pred_seg[None])
softmax_pred.append(softmax_pres[None])
predicted_segmentation = np.vstack(predicted_segmentation)
softmax_pred = np.vstack(softmax_pred).transpose((1, 0, 2, 3))
return predicted_segmentation, softmax_pred
def _internal_predict_3D_2Dconv_tiled(self, x: np.ndarray, patch_size: Tuple[int, int], do_mirroring: bool,
mirror_axes: tuple = (0, 1), step_size: float = 0.5,
regions_class_order: tuple = None, use_gaussian: bool = False,
pad_border_mode: str = "edge", pad_kwargs: dict =None,
all_in_gpu: bool = False,
verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
if all_in_gpu:
raise NotImplementedError
assert len(x.shape) == 4, "data must be c, x, y, z"
predicted_segmentation = []
softmax_pred = []
for s in range(x.shape[1]):
pred_seg, softmax_pres = self._internal_predict_2D_2Dconv_tiled(
x[:, s], step_size, do_mirroring, mirror_axes, patch_size, regions_class_order, use_gaussian,
pad_border_mode, pad_kwargs, all_in_gpu, verbose)
predicted_segmentation.append(pred_seg[None])
softmax_pred.append(softmax_pres[None])
predicted_segmentation = np.vstack(predicted_segmentation)
softmax_pred = np.vstack(softmax_pred).transpose((1, 0, 2, 3))
return predicted_segmentation, softmax_pred
class ConvDropoutNormNonlin(nn.Module):
"""
fixes a bug in ConvDropoutNormNonlin where lrelu was used regardless of nonlin. Bad.
"""
def __init__(self, input_channels, output_channels,
conv_op=nn.Conv2d, conv_kwargs=None,
norm_op=nn.BatchNorm2d, norm_op_kwargs=None,
dropout_op=nn.Dropout2d, dropout_op_kwargs=None,
nonlin=nn.LeakyReLU, nonlin_kwargs=None):
super(ConvDropoutNormNonlin, self).__init__()
if nonlin_kwargs is None:
nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
if dropout_op_kwargs is None:
dropout_op_kwargs = {'p': 0.5, 'inplace': True}
if norm_op_kwargs is None:
norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
if conv_kwargs is None:
conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}
self.nonlin_kwargs = nonlin_kwargs
self.nonlin = nonlin
self.dropout_op = dropout_op
self.dropout_op_kwargs = dropout_op_kwargs
self.norm_op_kwargs = norm_op_kwargs
self.conv_kwargs = conv_kwargs
self.conv_op = conv_op
self.norm_op = norm_op
self.conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs)
if self.dropout_op is not None and self.dropout_op_kwargs['p'] is not None and self.dropout_op_kwargs[
'p'] > 0:
self.dropout = self.dropout_op(**self.dropout_op_kwargs)
else:
self.dropout = None
self.instnorm = self.norm_op(output_channels, **self.norm_op_kwargs)
self.lrelu = self.nonlin(**self.nonlin_kwargs)
def forward(self, x):
x = self.conv(x)
if self.dropout is not None:
x = self.dropout(x)
return self.lrelu(self.instnorm(x))
class ConvDropoutNonlinNorm(ConvDropoutNormNonlin):
def forward(self, x):
x = self.conv(x)
if self.dropout is not None:
x = self.dropout(x)
return self.instnorm(self.lrelu(x))
class StackedConvLayers(nn.Module):
def __init__(self, input_feature_channels, output_feature_channels, num_convs,
conv_op=nn.Conv2d, conv_kwargs=None,
norm_op=nn.BatchNorm2d, norm_op_kwargs=None,
dropout_op=nn.Dropout2d, dropout_op_kwargs=None,
nonlin=nn.LeakyReLU, nonlin_kwargs=None, first_stride=None, basic_block=ConvDropoutNormNonlin):
'''
stacks ConvDropoutNormLReLU layers. initial_stride will only be applied to first layer in the stack. The other parameters affect all layers
:param input_feature_channels:
:param output_feature_channels:
:param num_convs:
:param dilation:
:param kernel_size:
:param padding:
:param dropout:
:param initial_stride:
:param conv_op:
:param norm_op:
:param dropout_op:
:param inplace:
:param neg_slope:
:param norm_affine:
:param conv_bias:
'''
self.input_channels = input_feature_channels
self.output_channels = output_feature_channels
if nonlin_kwargs is None:
nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
if dropout_op_kwargs is None:
dropout_op_kwargs = {'p': 0.5, 'inplace': True}
if norm_op_kwargs is None:
norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
if conv_kwargs is None:
conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}
self.nonlin_kwargs = nonlin_kwargs
self.nonlin = nonlin
self.dropout_op = dropout_op
self.dropout_op_kwargs = dropout_op_kwargs
self.norm_op_kwargs = norm_op_kwargs
self.conv_kwargs = conv_kwargs
self.conv_op = conv_op
self.norm_op = norm_op
if first_stride is not None:
self.conv_kwargs_first_conv = deepcopy(conv_kwargs)
self.conv_kwargs_first_conv['stride'] = first_stride
else:
self.conv_kwargs_first_conv = conv_kwargs
super(StackedConvLayers, self).__init__()
self.blocks = nn.Sequential(
*([basic_block(input_feature_channels, output_feature_channels, self.conv_op,
self.conv_kwargs_first_conv,
self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs,
self.nonlin, self.nonlin_kwargs)] +
[basic_block(output_feature_channels, output_feature_channels, self.conv_op,
self.conv_kwargs,
self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs,
self.nonlin, self.nonlin_kwargs) for _ in range(num_convs - 1)]))
def forward(self, x):
return self.blocks(x)
def print_module_training_status(module):
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d) or isinstance(module, nn.Dropout3d) or \
isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout) or isinstance(module, nn.InstanceNorm3d) \
or isinstance(module, nn.InstanceNorm2d) or isinstance(module, nn.InstanceNorm1d) \
or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or isinstance(module,
nn.BatchNorm1d):
print(str(module), module.training)
class hwUpsample(nn.Module):
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=False):
super(hwUpsample, self).__init__()
self.align_corners = align_corners
self.mode = mode
self.scale_factor = scale_factor
self.size = size
def forward(self, x):
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode,
align_corners=self.align_corners)
class OnePromptEncoderUnet(SegmentationNetwork):
DEFAULT_BATCH_SIZE_3D = 2
DEFAULT_PATCH_SIZE_3D = (64, 192, 160)
SPACING_FACTOR_BETWEEN_STAGES = 2
BASE_NUM_FEATURES_3D = 30
MAX_NUMPOOL_3D = 999
MAX_NUM_FILTERS_3D = 320
DEFAULT_PATCH_SIZE_2D = (256, 256)
BASE_NUM_FEATURES_2D = 30
DEFAULT_BATCH_SIZE_2D = 50
MAX_NUMPOOL_2D = 999
MAX_FILTERS_2D = 480
use_this_for_batch_size_computation_2D = 19739648
use_this_for_batch_size_computation_3D = 520000000 # 505789440
def __init__(self, input_channels, base_num_features, final_num_features, fea_size, num_pool, num_conv_per_stage=2,
feat_map_mul_on_downscale=2, conv_op=nn.Conv2d,
norm_op=nn.BatchNorm2d, norm_op_kwargs=None,
dropout_op=nn.Dropout2d, dropout_op_kwargs=None,
nonlin=nn.LeakyReLU, nonlin_kwargs=None, highway = False, deep_supervision=False, anchor_out=False, dropout_in_localization=False,
final_nonlin=sigmoid_helper, weightInitializer=InitWeights_He(1e-2), pool_op_kernel_sizes=None,
conv_kernel_sizes=None,
upscale_logits=False, convolutional_pooling=False, convolutional_upsampling=False,
max_num_features=None, basic_block=ConvDropoutNormNonlin,
seg_output_use_bias=False):
"""
basically more flexible than v1, architecture is the same
Does this look complicated? Nah bro. Functionality > usability
This does everything you need, including world peace.
Questions? -> f.isensee@dkfz.de
"""
super(OnePromptEncoderUnet, self).__init__()
self.convolutional_upsampling = convolutional_upsampling
self.convolutional_pooling = convolutional_pooling
self.upscale_logits = upscale_logits
if nonlin_kwargs is None:
nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
if dropout_op_kwargs is None:
dropout_op_kwargs = {'p': 0.5, 'inplace': True}
if norm_op_kwargs is None:
norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
self.conv_kwargs = {'stride': 1, 'dilation': 1, 'bias': True}
self.nonlin = nonlin
self.nonlin_kwargs = nonlin_kwargs
self.dropout_op_kwargs = dropout_op_kwargs
self.norm_op_kwargs = norm_op_kwargs
self.weightInitializer = weightInitializer
self.conv_op = conv_op
self.norm_op = norm_op
self.dropout_op = dropout_op
self.final_nonlin = final_nonlin
self._deep_supervision = deep_supervision
self.do_ds = deep_supervision
self.anchor_out = anchor_out
if conv_op == nn.Conv2d:
pool_op = nn.MaxPool2d
if pool_op_kernel_sizes is None:
pool_op_kernel_sizes = [(2, 2)] * num_pool
if conv_kernel_sizes is None:
conv_kernel_sizes = [(3, 3)] * (num_pool + 1)
elif conv_op == nn.Conv3d:
pool_op = nn.MaxPool3d
if pool_op_kernel_sizes is None:
pool_op_kernel_sizes = [(2, 2, 2)] * num_pool
if conv_kernel_sizes is None:
conv_kernel_sizes = [(3, 3, 3)] * (num_pool + 1)
else:
raise ValueError("unknown convolution dimensionality, conv op: %s" % str(conv_op))
self.input_shape_must_be_divisible_by = np.prod(pool_op_kernel_sizes, 0, dtype=np.int64)
self.pool_op_kernel_sizes = pool_op_kernel_sizes
self.conv_kernel_sizes = conv_kernel_sizes
self.conv_pad_sizes = []
for krnl in self.conv_kernel_sizes:
self.conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl])
if max_num_features is None:
if self.conv_op == nn.Conv3d:
self.max_num_features = self.MAX_NUM_FILTERS_3D
else:
self.max_num_features = self.MAX_FILTERS_2D
else:
self.max_num_features = max_num_features
self.conv_blocks_context = []
self.conv_blocks_localization = []
self.td = []
self.al = []
output_features = base_num_features
input_features = input_channels
for d in range(num_pool):
# determine the first stride
if d != 0 and self.convolutional_pooling:
first_stride = pool_op_kernel_sizes[d - 1]
else:
first_stride = None
self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[d]
self.conv_kwargs['padding'] = self.conv_pad_sizes[d]
# add convolutions
self.conv_blocks_context.append(StackedConvLayers(input_features, output_features, num_conv_per_stage,
self.conv_op, self.conv_kwargs, self.norm_op,
self.norm_op_kwargs, self.dropout_op,
self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs,
first_stride, basic_block=basic_block))
self.al.append(nn.Linear(output_features, final_num_features))
if not self.convolutional_pooling:
self.td.append(pool_op(pool_op_kernel_sizes[d]))
input_features = output_features
output_features = int(np.round(output_features * feat_map_mul_on_downscale))
output_features = min(output_features, self.max_num_features)
# now the bottleneck.
# determine the first stride
if self.convolutional_pooling:
first_stride = pool_op_kernel_sizes[-1]
else:
first_stride = None
# the output of the last conv must match the number of features from the skip connection if we are not using
# convolutional upsampling. If we use convolutional upsampling then the reduction in feature maps will be
# done by the transposed conv
# if self.convolutional_upsampling:
# final_num_features = output_features
# else:
# final_num_features = self.conv_blocks_context[-1].output_channels
self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[num_pool]
self.conv_kwargs['padding'] = self.conv_pad_sizes[num_pool]
self.conv_blocks_context.append(nn.Sequential(
StackedConvLayers(input_features, output_features, num_conv_per_stage - 1, self.conv_op, self.conv_kwargs,
self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin,
self.nonlin_kwargs, first_stride, basic_block=basic_block),
StackedConvLayers(output_features, final_num_features, 1, self.conv_op, self.conv_kwargs,
self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin,
self.nonlin_kwargs, basic_block=basic_block)))
# if we don't want to do dropout in the localization pathway then we set the dropout prob to zero here
if not dropout_in_localization:
old_dropout_p = self.dropout_op_kwargs['p']
self.dropout_op_kwargs['p'] = 0.0
# # now lets build the localization pathway
# for u in range(num_pool):
# nfeatures_from_skip = self.conv_blocks_context[
# -(2 + u)].output_channels # self.conv_blocks_context[-1] is bottleneck, so start with -2
# n_features_after_tu_and_concat = nfeatures_from_skip * 2
# self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[- (u + 1)]
# self.conv_kwargs['padding'] = self.conv_pad_sizes[- (u + 1)]
# self.conv_blocks_localization.append(nn.Sequential(
# StackedConvLayers(n_features_after_tu_and_concat, nfeatures_from_skip, num_conv_per_stage - 1,
# self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op,
# self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block),
# StackedConvLayers(nfeatures_from_skip, final_num_features, 1, self.conv_op, self.conv_kwargs,
# self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs,
# self.nonlin, self.nonlin_kwargs, basic_block=basic_block)
# ))
self.up = []
for u in range(num_pool):
self.up.append(nn.Upsample(size=(fea_size, fea_size), mode='bilinear'))
if not dropout_in_localization:
self.dropout_op_kwargs['p'] = old_dropout_p
# register all modules properly
# self.conv_blocks_localization = nn.ModuleList(self.conv_blocks_localization)
self.conv_blocks_context = nn.ModuleList(self.conv_blocks_context)
self.td = nn.ModuleList(self.td)
self.up = nn.ModuleList(self.up)
self.al = nn.ModuleList(self.al)
if self.weightInitializer is not None:
self.apply(self.weightInitializer)
# self.apply(print_module_training_status)
def forward(self, raw):
skips_raw = []
for d in range(len(self.conv_blocks_context) - 1):
raw = self.conv_blocks_context[d](raw)
raw_arch = self.up[d](raw)
raw_arch = raw_arch.permute(0, 2, 3, 1)
raw_arch = self.al[d](raw_arch)
skips_raw.append(raw_arch)
if not self.convolutional_pooling:
raw = self.td[d](raw)
raw = self.conv_blocks_context[-1](raw)
raw = raw.permute(0, 2, 3, 1)
return raw, skips_raw
@staticmethod
def compute_approx_vram_consumption(patch_size, num_pool_per_axis, base_num_features, max_num_features,
num_modalities, num_classes, pool_op_kernel_sizes, deep_supervision=False,
conv_per_stage=2):
"""
This only applies for num_conv_per_stage and convolutional_upsampling=True
not real vram consumption. just a constant term to which the vram consumption will be approx proportional
(+ offset for parameter storage)
:param deep_supervision:
:param patch_size:
:param num_pool_per_axis:
:param base_num_features:
:param max_num_features:
:param num_modalities:
:param num_classes:
:param pool_op_kernel_sizes:
:return:
"""
if not isinstance(num_pool_per_axis, np.ndarray):
num_pool_per_axis = np.array(num_pool_per_axis)
npool = len(pool_op_kernel_sizes)
map_size = np.array(patch_size)
tmp = np.int64((conv_per_stage * 2 + 1) * np.prod(map_size, dtype=np.int64) * base_num_features +
num_modalities * np.prod(map_size, dtype=np.int64) +
num_classes * np.prod(map_size, dtype=np.int64))
num_feat = base_num_features
for p in range(npool):
for pi in range(len(num_pool_per_axis)):
map_size[pi] /= pool_op_kernel_sizes[p][pi]
num_feat = min(num_feat * 2, max_num_features)
num_blocks = (conv_per_stage * 2 + 1) if p < (npool - 1) else conv_per_stage # conv_per_stage + conv_per_stage for the convs of encode/decode and 1 for transposed conv
tmp += num_blocks * np.prod(map_size, dtype=np.int64) * num_feat
if deep_supervision and p < (npool - 2):
tmp += np.prod(map_size, dtype=np.int64) * num_classes
# print(p, map_size, num_feat, tmp)
return tmp