You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
2278 lines
92 KiB
2278 lines
92 KiB
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
import math
|
|
|
|
from typing import Optional, Tuple, Type
|
|
|
|
from .common import LayerNorm2d, MLPBlock, Adapter
|
|
|
|
from abc import abstractmethod
|
|
import math
|
|
import numpy as np
|
|
import torch as th
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from collections import OrderedDict
|
|
from copy import deepcopy
|
|
from .utils import softmax_helper,sigmoid_helper
|
|
from .utils import InitWeights_He
|
|
from batchgenerators.augmentations.utils import pad_nd_image
|
|
from .utils import no_op
|
|
from .utils import to_cuda, maybe_to_torch
|
|
from scipy.ndimage.filters import gaussian_filter
|
|
from typing import Union, Tuple, List
|
|
from torch.cuda.amp import autocast
|
|
from .nn import (
|
|
checkpoint,
|
|
conv_nd,
|
|
linear,
|
|
avg_pool_nd,
|
|
zero_module,
|
|
normalization,
|
|
timestep_embedding,
|
|
layer_norm,
|
|
)
|
|
|
|
|
|
class OnePromptEncoderViT(nn.Module):
|
|
def __init__(
|
|
self,
|
|
args,
|
|
img_size: int = 1024,
|
|
patch_size: int = 16,
|
|
in_chans: int = 3,
|
|
embed_dim: int = 768,
|
|
depth: int = 12,
|
|
num_heads: int = 12,
|
|
mlp_ratio: float = 4.0,
|
|
qkv_bias: bool = True,
|
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
|
act_layer: Type[nn.Module] = nn.GELU,
|
|
use_abs_pos: bool = True,
|
|
use_rel_pos: bool = False,
|
|
rel_pos_zero_init: bool = True,
|
|
window_size: int = 0,
|
|
global_attn_indexes: Tuple[int, ...] = (),
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
img_size (int): Input image size.
|
|
patch_size (int): Patch size.
|
|
in_chans (int): Number of input image channels.
|
|
embed_dim (int): Patch embedding dimension.
|
|
depth (int): Depth of ViT.
|
|
num_heads (int): Number of attention heads in each ViT block.
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
|
norm_layer (nn.Module): Normalization layer.
|
|
act_layer (nn.Module): Activation layer.
|
|
use_abs_pos (bool): If True, use absolute positional embeddings.
|
|
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
|
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
|
window_size (int): Window size for window attention blocks.
|
|
global_attn_indexes (list): Indexes for blocks using global attention.
|
|
"""
|
|
super().__init__()
|
|
self.img_size = img_size
|
|
self.in_chans = in_chans
|
|
self.args = args
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
kernel_size=(patch_size, patch_size),
|
|
stride=(patch_size, patch_size),
|
|
in_chans=in_chans,
|
|
embed_dim=embed_dim,
|
|
)
|
|
|
|
self.pos_embed: Optional[nn.Parameter] = None
|
|
if use_abs_pos:
|
|
# Initialize absolute positional embedding with pretrain image size.
|
|
self.pos_embed = nn.Parameter(
|
|
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
|
|
)
|
|
|
|
self.blocks = nn.ModuleList()
|
|
for i in range(depth):
|
|
block = Block(
|
|
dim=embed_dim,
|
|
num_heads=num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=qkv_bias,
|
|
norm_layer=norm_layer,
|
|
act_layer=act_layer,
|
|
use_rel_pos=use_rel_pos,
|
|
rel_pos_zero_init=rel_pos_zero_init,
|
|
window_size=window_size if i not in global_attn_indexes else 0,
|
|
input_size=(img_size // patch_size, img_size // patch_size),
|
|
)
|
|
self.blocks.append(block)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
b = x.size(0)
|
|
skips = [[] for i in range(b)]
|
|
skips = []
|
|
x = self.patch_embed(x)
|
|
if self.pos_embed is not None:
|
|
# print("x size is", x.size())
|
|
# print("self.pos_embed size is",self.pos_embed.size())
|
|
x = x + self.pos_embed
|
|
|
|
for blk in self.blocks:
|
|
x = blk(x)
|
|
# for i in range(b):
|
|
# skips[i].append(x[i,...])
|
|
skips.append(x)
|
|
|
|
return x, skips
|
|
|
|
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
|
class ImageEncoderViT(nn.Module):
|
|
def __init__(
|
|
self,
|
|
args,
|
|
img_size: int = 1024,
|
|
patch_size: int = 16,
|
|
in_chans: int = 3,
|
|
embed_dim: int = 768,
|
|
depth: int = 12,
|
|
num_heads: int = 12,
|
|
mlp_ratio: float = 4.0,
|
|
qkv_bias: bool = True,
|
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
|
act_layer: Type[nn.Module] = nn.GELU,
|
|
use_abs_pos: bool = True,
|
|
use_rel_pos: bool = False,
|
|
rel_pos_zero_init: bool = True,
|
|
window_size: int = 0,
|
|
global_attn_indexes: Tuple[int, ...] = (),
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
img_size (int): Input image size.
|
|
patch_size (int): Patch size.
|
|
in_chans (int): Number of input image channels.
|
|
embed_dim (int): Patch embedding dimension.
|
|
depth (int): Depth of ViT.
|
|
num_heads (int): Number of attention heads in each ViT block.
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
|
norm_layer (nn.Module): Normalization layer.
|
|
act_layer (nn.Module): Activation layer.
|
|
use_abs_pos (bool): If True, use absolute positional embeddings.
|
|
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
|
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
|
window_size (int): Window size for window attention blocks.
|
|
global_attn_indexes (list): Indexes for blocks using global attention.
|
|
"""
|
|
super().__init__()
|
|
self.img_size = img_size
|
|
self.in_chans = in_chans
|
|
self.args = args
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
kernel_size=(patch_size, patch_size),
|
|
stride=(patch_size, patch_size),
|
|
in_chans=in_chans,
|
|
embed_dim=embed_dim,
|
|
)
|
|
|
|
self.pos_embed: Optional[nn.Parameter] = None
|
|
if use_abs_pos:
|
|
# Initialize absolute positional embedding with pretrain image size.
|
|
self.pos_embed = nn.Parameter(
|
|
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
|
|
)
|
|
|
|
self.blocks = nn.ModuleList()
|
|
for i in range(depth):
|
|
block = Block(
|
|
args= self.args,
|
|
dim=embed_dim,
|
|
num_heads=num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=qkv_bias,
|
|
norm_layer=norm_layer,
|
|
act_layer=act_layer,
|
|
use_rel_pos=use_rel_pos,
|
|
rel_pos_zero_init=rel_pos_zero_init,
|
|
window_size=window_size if i not in global_attn_indexes else 0,
|
|
input_size=(img_size // patch_size, img_size // patch_size),
|
|
)
|
|
self.blocks.append(block)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.patch_embed(x)
|
|
if self.pos_embed is not None:
|
|
x = x + self.pos_embed
|
|
|
|
for blk in self.blocks:
|
|
x = blk(x)
|
|
|
|
return x
|
|
|
|
|
|
class Block(nn.Module):
|
|
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_heads: int,
|
|
mlp_ratio: float = 4.0,
|
|
qkv_bias: bool = True,
|
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
|
act_layer: Type[nn.Module] = nn.GELU,
|
|
use_rel_pos: bool = False,
|
|
rel_pos_zero_init: bool = True,
|
|
window_size: int = 0,
|
|
input_size: Optional[Tuple[int, int]] = None,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
dim (int): Number of input channels.
|
|
num_heads (int): Number of attention heads in each ViT block.
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
|
norm_layer (nn.Module): Normalization layer.
|
|
act_layer (nn.Module): Activation layer.
|
|
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
|
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
|
window_size (int): Window size for window attention blocks. If it equals 0, then
|
|
use global attention.
|
|
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
|
positional parameter size.
|
|
"""
|
|
super().__init__()
|
|
self.norm1 = norm_layer(dim)
|
|
self.attn = Attention(
|
|
dim,
|
|
num_heads=num_heads,
|
|
qkv_bias=qkv_bias,
|
|
use_rel_pos=use_rel_pos,
|
|
rel_pos_zero_init=rel_pos_zero_init,
|
|
input_size=input_size if window_size == 0 else (window_size, window_size),
|
|
)
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
|
|
|
|
self.window_size = window_size
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
shortcut = x
|
|
x = self.norm1(x)
|
|
# Window partition
|
|
if self.window_size > 0:
|
|
H, W = x.shape[1], x.shape[2]
|
|
x, pad_hw = window_partition(x, self.window_size)
|
|
|
|
x = self.attn(x)
|
|
# Reverse window partition
|
|
if self.window_size > 0:
|
|
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
|
|
|
x = shortcut + x
|
|
x = x + self.mlp(self.norm2(x))
|
|
|
|
return x
|
|
|
|
|
|
class Attention(nn.Module):
|
|
"""Multi-head Attention block with relative position embeddings."""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_heads: int = 8,
|
|
qkv_bias: bool = True,
|
|
use_rel_pos: bool = False,
|
|
rel_pos_zero_init: bool = True,
|
|
input_size: Optional[Tuple[int, int]] = None,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
dim (int): Number of input channels.
|
|
num_heads (int): Number of attention heads.
|
|
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
|
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
|
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
|
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
|
positional parameter size.
|
|
"""
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
head_dim = dim // num_heads
|
|
self.scale = head_dim**-0.5
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
self.use_rel_pos = use_rel_pos
|
|
if self.use_rel_pos:
|
|
assert (
|
|
input_size is not None
|
|
), "Input size must be provided if using relative positional encoding."
|
|
# initialize relative positional embeddings
|
|
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
|
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
B, H, W, _ = x.shape
|
|
# qkv with shape (3, B, nHead, H * W, C)
|
|
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
# q, k, v with shape (B * nHead, H * W, C)
|
|
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
|
|
|
attn = (q * self.scale) @ k.transpose(-2, -1)
|
|
|
|
if self.use_rel_pos:
|
|
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
|
x = self.proj(x)
|
|
|
|
return x
|
|
|
|
|
|
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
|
"""
|
|
Partition into non-overlapping windows with padding if needed.
|
|
Args:
|
|
x (tensor): input tokens with [B, H, W, C].
|
|
window_size (int): window size.
|
|
|
|
Returns:
|
|
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
|
(Hp, Wp): padded height and width before partition
|
|
"""
|
|
B, H, W, C = x.shape
|
|
|
|
pad_h = (window_size - H % window_size) % window_size
|
|
pad_w = (window_size - W % window_size) % window_size
|
|
if pad_h > 0 or pad_w > 0:
|
|
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
|
Hp, Wp = H + pad_h, W + pad_w
|
|
|
|
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
|
return windows, (Hp, Wp)
|
|
|
|
|
|
def window_unpartition(
|
|
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
|
|
) -> torch.Tensor:
|
|
"""
|
|
Window unpartition into original sequences and removing padding.
|
|
Args:
|
|
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
|
window_size (int): window size.
|
|
pad_hw (Tuple): padded height and width (Hp, Wp).
|
|
hw (Tuple): original height and width (H, W) before padding.
|
|
|
|
Returns:
|
|
x: unpartitioned sequences with [B, H, W, C].
|
|
"""
|
|
Hp, Wp = pad_hw
|
|
H, W = hw
|
|
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
|
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
|
|
|
if Hp > H or Wp > W:
|
|
x = x[:, :H, :W, :].contiguous()
|
|
return x
|
|
|
|
|
|
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Get relative positional embeddings according to the relative positions of
|
|
query and key sizes.
|
|
Args:
|
|
q_size (int): size of query q.
|
|
k_size (int): size of key k.
|
|
rel_pos (Tensor): relative position embeddings (L, C).
|
|
|
|
Returns:
|
|
Extracted positional embeddings according to relative positions.
|
|
"""
|
|
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
|
# Interpolate rel pos if needed.
|
|
if rel_pos.shape[0] != max_rel_dist:
|
|
# Interpolate rel pos.
|
|
rel_pos_resized = F.interpolate(
|
|
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
|
size=max_rel_dist,
|
|
mode="linear",
|
|
)
|
|
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
|
else:
|
|
rel_pos_resized = rel_pos
|
|
|
|
# Scale the coords with short length if shapes for q and k are different.
|
|
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
|
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
|
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
|
|
|
return rel_pos_resized[relative_coords.long()]
|
|
|
|
|
|
def add_decomposed_rel_pos(
|
|
attn: torch.Tensor,
|
|
q: torch.Tensor,
|
|
rel_pos_h: torch.Tensor,
|
|
rel_pos_w: torch.Tensor,
|
|
q_size: Tuple[int, int],
|
|
k_size: Tuple[int, int],
|
|
) -> torch.Tensor:
|
|
"""
|
|
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
|
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
|
Args:
|
|
attn (Tensor): attention map.
|
|
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
|
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
|
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
|
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
|
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
|
|
|
Returns:
|
|
attn (Tensor): attention map with added relative positional embeddings.
|
|
"""
|
|
q_h, q_w = q_size
|
|
k_h, k_w = k_size
|
|
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
|
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
|
|
|
B, _, dim = q.shape
|
|
r_q = q.reshape(B, q_h, q_w, dim)
|
|
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
|
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
|
|
|
attn = (
|
|
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
|
).view(B, q_h * q_w, k_h * k_w)
|
|
|
|
return attn
|
|
|
|
def closest_numbers(target):
|
|
a = int(target ** 0.5)
|
|
b = a + 1
|
|
while True:
|
|
if a * b == target:
|
|
return (a, b)
|
|
elif a * b < target:
|
|
b += 1
|
|
else:
|
|
a -= 1
|
|
|
|
|
|
class PatchEmbed(nn.Module):
|
|
"""
|
|
Image to Patch Embedding.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
kernel_size: Tuple[int, int] = (16, 16),
|
|
stride: Tuple[int, int] = (16, 16),
|
|
padding: Tuple[int, int] = (0, 0),
|
|
in_chans: int = 3,
|
|
embed_dim: int = 768,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
kernel_size (Tuple): kernel size of the projection layer.
|
|
stride (Tuple): stride of the projection layer.
|
|
padding (Tuple): padding size of the projection layer.
|
|
in_chans (int): Number of input image channels.
|
|
embed_dim (int): Patch embedding dimension.
|
|
"""
|
|
super().__init__()
|
|
|
|
self.proj = nn.Conv2d(
|
|
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.proj(x)
|
|
# B C H W -> B H W C
|
|
x = x.permute(0, 2, 3, 1)
|
|
return x
|
|
|
|
|
|
class AttentionPool2d(nn.Module):
|
|
"""
|
|
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
spacial_dim: int,
|
|
embed_dim: int,
|
|
num_heads_channels: int,
|
|
output_dim: int = None,
|
|
):
|
|
super().__init__()
|
|
self.positional_embedding = nn.Parameter(
|
|
th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
|
|
)
|
|
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
|
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
|
self.num_heads = embed_dim // num_heads_channels
|
|
self.attention = QKVAttention(self.num_heads)
|
|
|
|
def forward(self, x):
|
|
b, c, *_spatial = x.shape
|
|
x = x.reshape(b, c, -1) # NC(HW)
|
|
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
|
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
|
x = self.qkv_proj(x)
|
|
x = self.attention(x)
|
|
x = self.c_proj(x)
|
|
return x[:, :, 0]
|
|
|
|
|
|
class TimestepBlock(nn.Module):
|
|
"""
|
|
Any module where forward() takes timestep embeddings as a second argument.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def forward(self, x, emb):
|
|
"""
|
|
Apply the module to `x` given `emb` timestep embeddings.
|
|
"""
|
|
|
|
|
|
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|
"""
|
|
A sequential module that passes timestep embeddings to the children that
|
|
support it as an extra input.
|
|
"""
|
|
|
|
def forward(self, x, emb):
|
|
for layer in self:
|
|
if isinstance(layer, TimestepBlock):
|
|
x = layer(x, emb)
|
|
else:
|
|
x = layer(x)
|
|
return x
|
|
|
|
|
|
class Upsample(nn.Module):
|
|
"""
|
|
An upsampling layer with an optional convolution.
|
|
|
|
:param channels: channels in the inputs and outputs.
|
|
:param use_conv: a bool determining if a convolution is applied.
|
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
|
upsampling occurs in the inner-two dimensions.
|
|
"""
|
|
|
|
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.dims = dims
|
|
if use_conv:
|
|
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
|
|
|
def forward(self, x):
|
|
assert x.shape[1] == self.channels
|
|
if self.dims == 3:
|
|
x = F.interpolate(
|
|
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
|
)
|
|
else:
|
|
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
|
if self.use_conv:
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class Downsample(nn.Module):
|
|
"""
|
|
A downsampling layer with an optional convolution.
|
|
|
|
:param channels: channels in the inputs and outputs.
|
|
:param use_conv: a bool determining if a convolution is applied.
|
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
|
downsampling occurs in the inner-two dimensions.
|
|
"""
|
|
|
|
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.dims = dims
|
|
stride = 2 if dims != 3 else (1, 2, 2)
|
|
if use_conv:
|
|
self.op = conv_nd(
|
|
dims, self.channels, self.out_channels, 3, stride=stride, padding=1
|
|
)
|
|
else:
|
|
assert self.channels == self.out_channels
|
|
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
|
|
|
def forward(self, x):
|
|
assert x.shape[1] == self.channels
|
|
return self.op(x)
|
|
|
|
def conv_bn(inp, oup, stride):
|
|
return nn.Sequential(
|
|
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
|
nn.BatchNorm2d(oup),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
def conv_dw(inp, oup, stride):
|
|
return nn.Sequential(
|
|
# dw
|
|
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
|
|
nn.BatchNorm2d(inp),
|
|
nn.ReLU(inplace=True),
|
|
|
|
# pw
|
|
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
|
nn.BatchNorm2d(oup),
|
|
nn.ReLU(inplace=True),
|
|
)
|
|
|
|
class MobBlock(nn.Module):
|
|
def __init__(self,ind):
|
|
super().__init__()
|
|
|
|
|
|
if ind == 0:
|
|
self.stage = nn.Sequential(
|
|
conv_bn(3, 32, 2),
|
|
conv_dw(32, 64, 1),
|
|
conv_dw(64, 128, 1),
|
|
conv_dw(128, 128, 1)
|
|
)
|
|
elif ind == 1:
|
|
self.stage = nn.Sequential(
|
|
conv_dw(128, 256, 2),
|
|
conv_dw(256, 256, 1)
|
|
)
|
|
elif ind == 2:
|
|
self.stage = nn.Sequential(
|
|
conv_dw(256, 256, 2),
|
|
conv_dw(256, 256, 1)
|
|
)
|
|
else:
|
|
self.stage = nn.Sequential(
|
|
conv_dw(256, 512, 2),
|
|
conv_dw(512, 512, 1),
|
|
conv_dw(512, 512, 1),
|
|
conv_dw(512, 512, 1),
|
|
conv_dw(512, 512, 1),
|
|
conv_dw(512, 512, 1)
|
|
)
|
|
|
|
def forward(self,x):
|
|
return self.stage(x)
|
|
|
|
|
|
|
|
class ResBlock(TimestepBlock):
|
|
"""
|
|
A residual block that can optionally change the number of channels.
|
|
|
|
:param channels: the number of input channels.
|
|
:param emb_channels: the number of timestep embedding channels.
|
|
:param dropout: the rate of dropout.
|
|
:param out_channels: if specified, the number of out channels.
|
|
:param use_conv: if True and out_channels is specified, use a spatial
|
|
convolution instead of a smaller 1x1 convolution to change the
|
|
channels in the skip connection.
|
|
:param dims: determines if the signal is 1D, 2D, or 3D.
|
|
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
|
:param up: if True, use this block for upsampling.
|
|
:param down: if True, use this block for downsampling.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channels,
|
|
emb_channels,
|
|
dropout,
|
|
out_channels=None,
|
|
use_conv=False,
|
|
use_scale_shift_norm=False,
|
|
dims=2,
|
|
use_checkpoint=False,
|
|
up=False,
|
|
down=False,
|
|
):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.emb_channels = emb_channels
|
|
self.dropout = dropout
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.use_checkpoint = use_checkpoint
|
|
self.use_scale_shift_norm = use_scale_shift_norm
|
|
|
|
self.in_layers = nn.Sequential(
|
|
normalization(channels),
|
|
nn.SiLU(),
|
|
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
|
)
|
|
|
|
self.updown = up or down
|
|
|
|
if up:
|
|
self.h_upd = Upsample(channels, False, dims)
|
|
self.x_upd = Upsample(channels, False, dims)
|
|
elif down:
|
|
self.h_upd = Downsample(channels, False, dims)
|
|
self.x_upd = Downsample(channels, False, dims)
|
|
else:
|
|
self.h_upd = self.x_upd = nn.Identity()
|
|
|
|
self.emb_layers = nn.Sequential(
|
|
nn.SiLU(),
|
|
linear(
|
|
emb_channels,
|
|
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
|
),
|
|
)
|
|
self.out_layers = nn.Sequential(
|
|
normalization(self.out_channels),
|
|
nn.SiLU(),
|
|
nn.Dropout(p=dropout),
|
|
zero_module(
|
|
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
|
),
|
|
)
|
|
|
|
if self.out_channels == channels:
|
|
self.skip_connection = nn.Identity()
|
|
elif use_conv:
|
|
self.skip_connection = conv_nd(
|
|
dims, channels, self.out_channels, 3, padding=1
|
|
)
|
|
else:
|
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
|
|
|
def forward(self, x, emb):
|
|
"""
|
|
Apply the block to a Tensor, conditioned on a timestep embedding.
|
|
|
|
:param x: an [N x C x ...] Tensor of features.
|
|
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
|
:return: an [N x C x ...] Tensor of outputs.
|
|
"""
|
|
return checkpoint(
|
|
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
|
)
|
|
|
|
def _forward(self, x, emb):
|
|
if self.updown:
|
|
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
|
h = in_rest(x)
|
|
h = self.h_upd(h)
|
|
x = self.x_upd(x)
|
|
h = in_conv(h)
|
|
else:
|
|
h = self.in_layers(x)
|
|
emb_out = self.emb_layers(emb).type(h.dtype)
|
|
while len(emb_out.shape) < len(h.shape):
|
|
emb_out = emb_out[..., None]
|
|
if self.use_scale_shift_norm:
|
|
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
|
scale, shift = th.chunk(emb_out, 2, dim=1)
|
|
h = out_norm(h) * (1 + scale) + shift
|
|
h = out_rest(h)
|
|
else:
|
|
h = h + emb_out
|
|
h = self.out_layers(h)
|
|
return self.skip_connection(x) + h
|
|
|
|
|
|
class AttentionBlock(nn.Module):
|
|
"""
|
|
An attention block that allows spatial positions to attend to each other.
|
|
|
|
Originally ported from here, but adapted to the N-d case.
|
|
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channels,
|
|
num_heads=1,
|
|
num_head_channels=-1,
|
|
use_checkpoint=False,
|
|
use_new_attention_order=False,
|
|
):
|
|
super().__init__()
|
|
self.channels = channels
|
|
if num_head_channels == -1:
|
|
self.num_heads = num_heads
|
|
else:
|
|
assert (
|
|
channels % num_head_channels == 0
|
|
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
|
self.num_heads = channels // num_head_channels
|
|
self.use_checkpoint = use_checkpoint
|
|
self.norm = normalization(channels)
|
|
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
|
if use_new_attention_order:
|
|
# split qkv before split heads
|
|
self.attention = QKVAttention(self.num_heads)
|
|
else:
|
|
# split heads before split qkv
|
|
self.attention = QKVAttentionLegacy(self.num_heads)
|
|
|
|
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
|
|
|
def forward(self, x):
|
|
return checkpoint(self._forward, (x,), self.parameters(), True)
|
|
|
|
def _forward(self, x):
|
|
b, c, *spatial = x.shape
|
|
x = x.reshape(b, c, -1)
|
|
qkv = self.qkv(self.norm(x))
|
|
h = self.attention(qkv)
|
|
h = self.proj_out(h)
|
|
return (x + h).reshape(b, c, *spatial)
|
|
|
|
|
|
def count_flops_attn(model, _x, y):
|
|
"""
|
|
A counter for the `thop` package to count the operations in an
|
|
attention operation.
|
|
Meant to be used like:
|
|
macs, params = thop.profile(
|
|
model,
|
|
inputs=(inputs, timestamps),
|
|
custom_ops={QKVAttention: QKVAttention.count_flops},
|
|
)
|
|
"""
|
|
b, c, *spatial = y[0].shape
|
|
num_spatial = int(np.prod(spatial))
|
|
# We perform two matmuls with the same number of ops.
|
|
# The first computes the weight matrix, the second computes
|
|
# the combination of the value vectors.
|
|
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
|
model.total_ops += th.DoubleTensor([matmul_ops])
|
|
|
|
|
|
class QKVAttentionLegacy(nn.Module):
|
|
"""
|
|
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
|
"""
|
|
|
|
def __init__(self, n_heads):
|
|
super().__init__()
|
|
self.n_heads = n_heads
|
|
|
|
def forward(self, qkv):
|
|
"""
|
|
Apply QKV attention.
|
|
|
|
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
|
:return: an [N x (H * C) x T] tensor after attention.
|
|
"""
|
|
bs, width, length = qkv.shape
|
|
assert width % (3 * self.n_heads) == 0
|
|
ch = width // (3 * self.n_heads)
|
|
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
weight = th.einsum(
|
|
"bct,bcs->bts", q * scale, k * scale
|
|
) # More stable with f16 than dividing afterwards
|
|
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
a = th.einsum("bts,bcs->bct", weight, v)
|
|
return a.reshape(bs, -1, length)
|
|
|
|
@staticmethod
|
|
def count_flops(model, _x, y):
|
|
return count_flops_attn(model, _x, y)
|
|
|
|
|
|
class QKVAttention(nn.Module):
|
|
"""
|
|
A module which performs QKV attention and splits in a different order.
|
|
"""
|
|
|
|
def __init__(self, n_heads):
|
|
super().__init__()
|
|
self.n_heads = n_heads
|
|
|
|
def forward(self, qkv):
|
|
"""
|
|
Apply QKV attention.
|
|
|
|
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
|
:return: an [N x (H * C) x T] tensor after attention.
|
|
"""
|
|
bs, width, length = qkv.shape
|
|
assert width % (3 * self.n_heads) == 0
|
|
ch = width // (3 * self.n_heads)
|
|
q, k, v = qkv.chunk(3, dim=1)
|
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
weight = th.einsum(
|
|
"bct,bcs->bts",
|
|
(q * scale).view(bs * self.n_heads, ch, length),
|
|
(k * scale).view(bs * self.n_heads, ch, length),
|
|
) # More stable with f16 than dividing afterwards
|
|
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
|
return a.reshape(bs, -1, length)
|
|
|
|
@staticmethod
|
|
def count_flops(model, _x, y):
|
|
return count_flops_attn(model, _x, y)
|
|
|
|
|
|
class EncoderUNetModel(nn.Module):
|
|
"""
|
|
The half UNet model with attention and timestep embedding.
|
|
|
|
For usage, see UNet.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
image_size,
|
|
in_channels,
|
|
model_channels,
|
|
out_channels,
|
|
num_res_blocks,
|
|
attention_resolutions,
|
|
dropout=0,
|
|
channel_mult=(1, 2, 4, 8),
|
|
conv_resample=True,
|
|
dims=2,
|
|
use_checkpoint=False,
|
|
use_fp16=False,
|
|
num_heads=1,
|
|
num_head_channels=-1,
|
|
num_heads_upsample=-1,
|
|
use_scale_shift_norm=False,
|
|
resblock_updown=False,
|
|
use_new_attention_order=False,
|
|
pool="adaptive",
|
|
):
|
|
super().__init__()
|
|
|
|
if num_heads_upsample == -1:
|
|
num_heads_upsample = num_heads
|
|
|
|
self.in_channels = in_channels
|
|
self.model_channels = model_channels
|
|
self.out_channels = out_channels
|
|
self.num_res_blocks = num_res_blocks
|
|
self.attention_resolutions = attention_resolutions
|
|
self.dropout = dropout
|
|
self.channel_mult = channel_mult
|
|
self.conv_resample = conv_resample
|
|
self.use_checkpoint = use_checkpoint
|
|
self.dtype = th.float16 if use_fp16 else th.float32
|
|
self.num_heads = num_heads
|
|
self.num_head_channels = num_head_channels
|
|
self.num_heads_upsample = num_heads_upsample
|
|
|
|
time_embed_dim = model_channels * 4
|
|
self.time_embed = nn.Sequential(
|
|
linear(model_channels, time_embed_dim),
|
|
nn.SiLU(),
|
|
linear(time_embed_dim, time_embed_dim),
|
|
)
|
|
|
|
self.input_blocks = nn.ModuleList(
|
|
[
|
|
TimestepEmbedSequential(
|
|
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
|
)
|
|
]
|
|
)
|
|
self._feature_size = model_channels
|
|
input_block_chans = [model_channels]
|
|
ch = model_channels
|
|
ds = 1
|
|
for level, mult in enumerate(channel_mult):
|
|
for _ in range(num_res_blocks):
|
|
layers = [
|
|
ResBlock(
|
|
ch,
|
|
time_embed_dim,
|
|
dropout,
|
|
out_channels=mult * model_channels,
|
|
dims=dims,
|
|
use_checkpoint=use_checkpoint,
|
|
use_scale_shift_norm=use_scale_shift_norm,
|
|
)
|
|
]
|
|
ch = mult * model_channels
|
|
if ds in attention_resolutions:
|
|
layers.append(
|
|
AttentionBlock(
|
|
ch,
|
|
use_checkpoint=use_checkpoint,
|
|
num_heads=num_heads,
|
|
num_head_channels=num_head_channels,
|
|
use_new_attention_order=use_new_attention_order,
|
|
)
|
|
)
|
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
|
self._feature_size += ch
|
|
input_block_chans.append(ch)
|
|
if level != len(channel_mult) - 1:
|
|
out_ch = ch
|
|
self.input_blocks.append(
|
|
TimestepEmbedSequential(
|
|
ResBlock(
|
|
ch,
|
|
time_embed_dim,
|
|
dropout,
|
|
out_channels=out_ch,
|
|
dims=dims,
|
|
use_checkpoint=use_checkpoint,
|
|
use_scale_shift_norm=use_scale_shift_norm,
|
|
down=True,
|
|
)
|
|
if resblock_updown
|
|
else Downsample(
|
|
ch, conv_resample, dims=dims, out_channels=out_ch
|
|
)
|
|
)
|
|
)
|
|
ch = out_ch
|
|
input_block_chans.append(ch)
|
|
ds *= 2
|
|
self._feature_size += ch
|
|
|
|
self.middle_block = TimestepEmbedSequential(
|
|
ResBlock(
|
|
ch,
|
|
time_embed_dim,
|
|
dropout,
|
|
dims=dims,
|
|
use_checkpoint=use_checkpoint,
|
|
use_scale_shift_norm=use_scale_shift_norm,
|
|
),
|
|
AttentionBlock(
|
|
ch,
|
|
use_checkpoint=use_checkpoint,
|
|
num_heads=num_heads,
|
|
num_head_channels=num_head_channels,
|
|
use_new_attention_order=use_new_attention_order,
|
|
),
|
|
ResBlock(
|
|
ch,
|
|
time_embed_dim,
|
|
dropout,
|
|
dims=dims,
|
|
use_checkpoint=use_checkpoint,
|
|
use_scale_shift_norm=use_scale_shift_norm,
|
|
),
|
|
)
|
|
self._feature_size += ch
|
|
self.pool = pool
|
|
self.gap = nn.AvgPool2d((8, 8)) #global average pooling
|
|
self.cam_feature_maps = None
|
|
print('pool', pool)
|
|
if pool == "adaptive":
|
|
self.out = nn.Sequential(
|
|
normalization(ch),
|
|
nn.SiLU(),
|
|
nn.AdaptiveAvgPool2d((1, 1)),
|
|
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
|
nn.Flatten(),
|
|
)
|
|
elif pool == "attention":
|
|
assert num_head_channels != -1
|
|
self.out = nn.Sequential(
|
|
normalization(ch),
|
|
nn.SiLU(),
|
|
AttentionPool2d(
|
|
(image_size // ds), ch, num_head_channels, out_channels
|
|
),
|
|
)
|
|
elif pool == "spatial":
|
|
self.out = nn.Linear(256, self.out_channels)
|
|
|
|
elif pool == "spatial_v2":
|
|
self.out = nn.Sequential(
|
|
nn.Linear(self._feature_size, 2048),
|
|
normalization(2048),
|
|
nn.SiLU(),
|
|
nn.Linear(2048, self.out_channels),
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Unexpected {pool} pooling")
|
|
|
|
|
|
def forward(self, x, timesteps):
|
|
"""
|
|
Apply the model to an input batch.
|
|
|
|
:param x: an [N x C x ...] Tensor of inputs.
|
|
:param timesteps: a 1-D batch of timesteps.
|
|
:return: an [N x K] Tensor of outputs.
|
|
"""
|
|
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
|
|
|
results = []
|
|
h = x.type(self.dtype)
|
|
for module in self.input_blocks:
|
|
h = module(h, emb)
|
|
if self.pool.startswith("spatial"):
|
|
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
|
h = self.middle_block(h, emb)
|
|
|
|
|
|
if self.pool.startswith("spatial"):
|
|
self.cam_feature_maps = h
|
|
h = self.gap(h)
|
|
N = h.shape[0]
|
|
h = h.reshape(N, -1)
|
|
print('h1', h.shape)
|
|
return self.out(h)
|
|
else:
|
|
h = h.type(x.dtype)
|
|
self.cam_feature_maps = h
|
|
return self.out(h)
|
|
|
|
class NeuralNetwork(nn.Module):
|
|
def __init__(self):
|
|
super(NeuralNetwork, self).__init__()
|
|
|
|
def get_device(self):
|
|
if next(self.parameters()).device.type == "cpu":
|
|
return "cpu"
|
|
else:
|
|
return next(self.parameters()).device.index
|
|
|
|
def set_device(self, device):
|
|
if device == "cpu":
|
|
self.cpu()
|
|
else:
|
|
self.cuda(device)
|
|
|
|
def forward(self, x):
|
|
raise NotImplementedError
|
|
|
|
|
|
class SegmentationNetwork(NeuralNetwork):
|
|
def __init__(self):
|
|
super(NeuralNetwork, self).__init__()
|
|
|
|
# if we have 5 pooling then our patch size must be divisible by 2**5
|
|
self.input_shape_must_be_divisible_by = None # for example in a 2d network that does 5 pool in x and 6 pool
|
|
# in y this would be (32, 64)
|
|
|
|
# we need to know this because we need to know if we are a 2d or a 3d netowrk
|
|
self.conv_op = None # nn.Conv2d or nn.Conv3d
|
|
|
|
# this tells us how many channels we have in the output. Important for preallocation in inference
|
|
self.num_classes = None # number of channels in the output
|
|
|
|
# depending on the loss, we do not hard code a nonlinearity into the architecture. To aggregate predictions
|
|
# during inference, we need to apply the nonlinearity, however. So it is important to let the newtork know what
|
|
# to apply in inference. For the most part this will be softmax
|
|
self.inference_apply_nonlin = lambda x: x # softmax_helper
|
|
|
|
# This is for saving a gaussian importance map for inference. It weights voxels higher that are closer to the
|
|
# center. Prediction at the borders are often less accurate and are thus downweighted. Creating these Gaussians
|
|
# can be expensive, so it makes sense to save and reuse them.
|
|
self._gaussian_3d = self._patch_size_for_gaussian_3d = None
|
|
self._gaussian_2d = self._patch_size_for_gaussian_2d = None
|
|
|
|
def predict_3D(self, x: np.ndarray, do_mirroring: bool, mirror_axes: Tuple[int, ...] = (0, 1, 2),
|
|
use_sliding_window: bool = False,
|
|
step_size: float = 0.5, patch_size: Tuple[int, ...] = None, regions_class_order: Tuple[int, ...] = None,
|
|
use_gaussian: bool = False, pad_border_mode: str = "constant",
|
|
pad_kwargs: dict = None, all_in_gpu: bool = False,
|
|
verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]:
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
assert step_size <= 1, 'step_size must be smaller than 1. Otherwise there will be a gap between consecutive ' \
|
|
'predictions'
|
|
|
|
if verbose: print("debug: mirroring", do_mirroring, "mirror_axes", mirror_axes)
|
|
|
|
if pad_kwargs is None:
|
|
pad_kwargs = {'constant_values': 0}
|
|
|
|
# A very long time ago the mirror axes were (2, 3, 4) for a 3d network. This is just to intercept any old
|
|
# code that uses this convention
|
|
if len(mirror_axes):
|
|
if self.conv_op == nn.Conv2d:
|
|
if max(mirror_axes) > 1:
|
|
raise ValueError("mirror axes. duh")
|
|
if self.conv_op == nn.Conv3d:
|
|
if max(mirror_axes) > 2:
|
|
raise ValueError("mirror axes. duh")
|
|
|
|
if self.training:
|
|
print('WARNING! Network is in train mode during inference. This may be intended, or not...')
|
|
|
|
assert len(x.shape) == 4, "data must have shape (c,x,y,z)"
|
|
|
|
if mixed_precision:
|
|
context = autocast
|
|
else:
|
|
context = no_op
|
|
|
|
with context():
|
|
with torch.no_grad():
|
|
if self.conv_op == nn.Conv3d:
|
|
if use_sliding_window:
|
|
res = self._internal_predict_3D_3Dconv_tiled(x, step_size, do_mirroring, mirror_axes, patch_size,
|
|
regions_class_order, use_gaussian, pad_border_mode,
|
|
pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu,
|
|
verbose=verbose)
|
|
else:
|
|
res = self._internal_predict_3D_3Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order,
|
|
pad_border_mode, pad_kwargs=pad_kwargs, verbose=verbose)
|
|
elif self.conv_op == nn.Conv2d:
|
|
if use_sliding_window:
|
|
res = self._internal_predict_3D_2Dconv_tiled(x, patch_size, do_mirroring, mirror_axes, step_size,
|
|
regions_class_order, use_gaussian, pad_border_mode,
|
|
pad_kwargs, all_in_gpu, False)
|
|
else:
|
|
res = self._internal_predict_3D_2Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order,
|
|
pad_border_mode, pad_kwargs, all_in_gpu, False)
|
|
else:
|
|
raise RuntimeError("Invalid conv op, cannot determine what dimensionality (2d/3d) the network is")
|
|
|
|
return res
|
|
|
|
def predict_2D(self, x, do_mirroring: bool, mirror_axes: tuple = (0, 1, 2), use_sliding_window: bool = False,
|
|
step_size: float = 0.5, patch_size: tuple = None, regions_class_order: tuple = None,
|
|
use_gaussian: bool = False, pad_border_mode: str = "constant",
|
|
pad_kwargs: dict = None, all_in_gpu: bool = False,
|
|
verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]:
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
assert step_size <= 1, 'step_size must be smaler than 1. Otherwise there will be a gap between consecutive ' \
|
|
'predictions'
|
|
|
|
if self.conv_op == nn.Conv3d:
|
|
raise RuntimeError("Cannot predict 2d if the network is 3d. Dummy.")
|
|
|
|
if verbose: print("debug: mirroring", do_mirroring, "mirror_axes", mirror_axes)
|
|
|
|
if pad_kwargs is None:
|
|
pad_kwargs = {'constant_values': 0}
|
|
|
|
# A very long time ago the mirror axes were (2, 3) for a 2d network. This is just to intercept any old
|
|
# code that uses this convention
|
|
if len(mirror_axes):
|
|
if max(mirror_axes) > 1:
|
|
raise ValueError("mirror axes. duh")
|
|
|
|
if self.training:
|
|
print('WARNING! Network is in train mode during inference. This may be intended, or not...')
|
|
|
|
assert len(x.shape) == 3, "data must have shape (c,x,y)"
|
|
|
|
if mixed_precision:
|
|
context = autocast
|
|
else:
|
|
context = no_op
|
|
|
|
with context():
|
|
with torch.no_grad():
|
|
if self.conv_op == nn.Conv2d:
|
|
if use_sliding_window:
|
|
res = self._internal_predict_2D_2Dconv_tiled(x, step_size, do_mirroring, mirror_axes, patch_size,
|
|
regions_class_order, use_gaussian, pad_border_mode,
|
|
pad_kwargs, all_in_gpu, verbose)
|
|
else:
|
|
res = self._internal_predict_2D_2Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order,
|
|
pad_border_mode, pad_kwargs, verbose)
|
|
else:
|
|
raise RuntimeError("Invalid conv op, cannot determine what dimensionality (2d/3d) the network is")
|
|
|
|
return res
|
|
|
|
@staticmethod
|
|
def _get_gaussian(patch_size, sigma_scale=1. / 8) -> np.ndarray:
|
|
tmp = np.zeros(patch_size)
|
|
center_coords = [i // 2 for i in patch_size]
|
|
sigmas = [i * sigma_scale for i in patch_size]
|
|
tmp[tuple(center_coords)] = 1
|
|
gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
|
|
gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * 1
|
|
gaussian_importance_map = gaussian_importance_map.astype(np.float32)
|
|
|
|
# gaussian_importance_map cannot be 0, otherwise we may end up with nans!
|
|
gaussian_importance_map[gaussian_importance_map == 0] = np.min(
|
|
gaussian_importance_map[gaussian_importance_map != 0])
|
|
|
|
return gaussian_importance_map
|
|
|
|
@staticmethod
|
|
def _compute_steps_for_sliding_window(patch_size: Tuple[int, ...], image_size: Tuple[int, ...], step_size: float) -> List[List[int]]:
|
|
assert [i >= j for i, j in zip(image_size, patch_size)], "image size must be as large or larger than patch_size"
|
|
assert 0 < step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1'
|
|
|
|
# our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of
|
|
# 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46
|
|
target_step_sizes_in_voxels = [i * step_size for i in patch_size]
|
|
|
|
num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, patch_size)]
|
|
|
|
steps = []
|
|
for dim in range(len(patch_size)):
|
|
# the highest step value for this dimension is
|
|
max_step_value = image_size[dim] - patch_size[dim]
|
|
if num_steps[dim] > 1:
|
|
actual_step_size = max_step_value / (num_steps[dim] - 1)
|
|
else:
|
|
actual_step_size = 99999999999 # does not matter because there is only one step at 0
|
|
|
|
steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])]
|
|
|
|
steps.append(steps_here)
|
|
|
|
return steps
|
|
|
|
def _internal_predict_3D_3Dconv_tiled(self, x: np.ndarray, step_size: float, do_mirroring: bool, mirror_axes: tuple,
|
|
patch_size: tuple, regions_class_order: tuple, use_gaussian: bool,
|
|
pad_border_mode: str, pad_kwargs: dict, all_in_gpu: bool,
|
|
verbose: bool) -> Tuple[np.ndarray, np.ndarray]:
|
|
# better safe than sorry
|
|
assert len(x.shape) == 4, "x must be (c, x, y, z)"
|
|
|
|
if verbose: print("step_size:", step_size)
|
|
if verbose: print("do mirror:", do_mirroring)
|
|
|
|
assert patch_size is not None, "patch_size cannot be None for tiled prediction"
|
|
|
|
# for sliding window inference the image must at least be as large as the patch size. It does not matter
|
|
# whether the shape is divisible by 2**num_pool as long as the patch size is
|
|
data, slicer = pad_nd_image(x, patch_size, pad_border_mode, pad_kwargs, True, None)
|
|
data_shape = data.shape # still c, x, y, z
|
|
|
|
# compute the steps for sliding window
|
|
steps = self._compute_steps_for_sliding_window(patch_size, data_shape[1:], step_size)
|
|
num_tiles = len(steps[0]) * len(steps[1]) * len(steps[2])
|
|
|
|
if verbose:
|
|
print("data shape:", data_shape)
|
|
print("patch size:", patch_size)
|
|
print("steps (x, y, and z):", steps)
|
|
print("number of tiles:", num_tiles)
|
|
|
|
# we only need to compute that once. It can take a while to compute this due to the large sigma in
|
|
# gaussian_filter
|
|
if use_gaussian and num_tiles > 1:
|
|
if self._gaussian_3d is None or not all(
|
|
[i == j for i, j in zip(patch_size, self._patch_size_for_gaussian_3d)]):
|
|
if verbose: print('computing Gaussian')
|
|
gaussian_importance_map = self._get_gaussian(patch_size, sigma_scale=1. / 8)
|
|
|
|
self._gaussian_3d = gaussian_importance_map
|
|
self._patch_size_for_gaussian_3d = patch_size
|
|
if verbose: print("done")
|
|
else:
|
|
if verbose: print("using precomputed Gaussian")
|
|
gaussian_importance_map = self._gaussian_3d
|
|
|
|
gaussian_importance_map = torch.from_numpy(gaussian_importance_map)
|
|
|
|
#predict on cpu if cuda not available
|
|
if torch.cuda.is_available():
|
|
gaussian_importance_map = gaussian_importance_map.cuda(self.get_device(), non_blocking=True)
|
|
|
|
else:
|
|
gaussian_importance_map = None
|
|
|
|
if all_in_gpu:
|
|
# If we run the inference in GPU only (meaning all tensors are allocated on the GPU, this reduces
|
|
# CPU-GPU communication but required more GPU memory) we need to preallocate a few things on GPU
|
|
|
|
if use_gaussian and num_tiles > 1:
|
|
# half precision for the outputs should be good enough. If the outputs here are half, the
|
|
# gaussian_importance_map should be as well
|
|
gaussian_importance_map = gaussian_importance_map.half()
|
|
|
|
# make sure we did not round anything to 0
|
|
gaussian_importance_map[gaussian_importance_map == 0] = gaussian_importance_map[
|
|
gaussian_importance_map != 0].min()
|
|
|
|
add_for_nb_of_preds = gaussian_importance_map
|
|
else:
|
|
add_for_nb_of_preds = torch.ones(patch_size, device=self.get_device())
|
|
|
|
if verbose: print("initializing result array (on GPU)")
|
|
aggregated_results = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,
|
|
device=self.get_device())
|
|
|
|
if verbose: print("moving data to GPU")
|
|
data = torch.from_numpy(data).cuda(self.get_device(), non_blocking=True)
|
|
|
|
if verbose: print("initializing result_numsamples (on GPU)")
|
|
aggregated_nb_of_predictions = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,
|
|
device=self.get_device())
|
|
|
|
else:
|
|
if use_gaussian and num_tiles > 1:
|
|
add_for_nb_of_preds = self._gaussian_3d
|
|
else:
|
|
add_for_nb_of_preds = np.ones(patch_size, dtype=np.float32)
|
|
aggregated_results = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
|
|
aggregated_nb_of_predictions = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
|
|
|
|
for x in steps[0]:
|
|
lb_x = x
|
|
ub_x = x + patch_size[0]
|
|
for y in steps[1]:
|
|
lb_y = y
|
|
ub_y = y + patch_size[1]
|
|
for z in steps[2]:
|
|
lb_z = z
|
|
ub_z = z + patch_size[2]
|
|
|
|
predicted_patch = self._internal_maybe_mirror_and_pred_3D(
|
|
data[None, :, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z], mirror_axes, do_mirroring,
|
|
gaussian_importance_map)[0]
|
|
|
|
if all_in_gpu:
|
|
predicted_patch = predicted_patch.half()
|
|
else:
|
|
predicted_patch = predicted_patch.cpu().numpy()
|
|
|
|
aggregated_results[:, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z] += predicted_patch
|
|
aggregated_nb_of_predictions[:, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z] += add_for_nb_of_preds
|
|
|
|
# we reverse the padding here (remeber that we padded the input to be at least as large as the patch size
|
|
slicer = tuple(
|
|
[slice(0, aggregated_results.shape[i]) for i in
|
|
range(len(aggregated_results.shape) - (len(slicer) - 1))] + slicer[1:])
|
|
aggregated_results = aggregated_results[slicer]
|
|
aggregated_nb_of_predictions = aggregated_nb_of_predictions[slicer]
|
|
|
|
# computing the class_probabilities by dividing the aggregated result with result_numsamples
|
|
aggregated_results /= aggregated_nb_of_predictions
|
|
del aggregated_nb_of_predictions
|
|
|
|
if regions_class_order is None:
|
|
predicted_segmentation = aggregated_results.argmax(0)
|
|
else:
|
|
if all_in_gpu:
|
|
class_probabilities_here = aggregated_results.detach().cpu().numpy()
|
|
else:
|
|
class_probabilities_here = aggregated_results
|
|
predicted_segmentation = np.zeros(class_probabilities_here.shape[1:], dtype=np.float32)
|
|
for i, c in enumerate(regions_class_order):
|
|
predicted_segmentation[class_probabilities_here[i] > 0.5] = c
|
|
|
|
if all_in_gpu:
|
|
if verbose: print("copying results to CPU")
|
|
|
|
if regions_class_order is None:
|
|
predicted_segmentation = predicted_segmentation.detach().cpu().numpy()
|
|
|
|
aggregated_results = aggregated_results.detach().cpu().numpy()
|
|
|
|
if verbose: print("prediction done")
|
|
return predicted_segmentation, aggregated_results
|
|
|
|
def _internal_predict_2D_2Dconv(self, x: np.ndarray, min_size: Tuple[int, int], do_mirroring: bool,
|
|
mirror_axes: tuple = (0, 1, 2), regions_class_order: tuple = None,
|
|
pad_border_mode: str = "constant", pad_kwargs: dict = None,
|
|
verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
This one does fully convolutional inference. No sliding window
|
|
"""
|
|
assert len(x.shape) == 3, "x must be (c, x, y)"
|
|
|
|
assert self.input_shape_must_be_divisible_by is not None, 'input_shape_must_be_divisible_by must be set to ' \
|
|
'run _internal_predict_2D_2Dconv'
|
|
if verbose: print("do mirror:", do_mirroring)
|
|
|
|
data, slicer = pad_nd_image(x, min_size, pad_border_mode, pad_kwargs, True,
|
|
self.input_shape_must_be_divisible_by)
|
|
|
|
predicted_probabilities = self._internal_maybe_mirror_and_pred_2D(data[None], mirror_axes, do_mirroring,
|
|
None)[0]
|
|
|
|
slicer = tuple(
|
|
[slice(0, predicted_probabilities.shape[i]) for i in range(len(predicted_probabilities.shape) -
|
|
(len(slicer) - 1))] + slicer[1:])
|
|
predicted_probabilities = predicted_probabilities[slicer]
|
|
|
|
if regions_class_order is None:
|
|
predicted_segmentation = predicted_probabilities.argmax(0)
|
|
predicted_segmentation = predicted_segmentation.detach().cpu().numpy()
|
|
predicted_probabilities = predicted_probabilities.detach().cpu().numpy()
|
|
else:
|
|
predicted_probabilities = predicted_probabilities.detach().cpu().numpy()
|
|
predicted_segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.float32)
|
|
for i, c in enumerate(regions_class_order):
|
|
predicted_segmentation[predicted_probabilities[i] > 0.5] = c
|
|
|
|
return predicted_segmentation, predicted_probabilities
|
|
|
|
def _internal_predict_3D_3Dconv(self, x: np.ndarray, min_size: Tuple[int, ...], do_mirroring: bool,
|
|
mirror_axes: tuple = (0, 1, 2), regions_class_order: tuple = None,
|
|
pad_border_mode: str = "constant", pad_kwargs: dict = None,
|
|
verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
This one does fully convolutional inference. No sliding window
|
|
"""
|
|
assert len(x.shape) == 4, "x must be (c, x, y, z)"
|
|
|
|
assert self.input_shape_must_be_divisible_by is not None, 'input_shape_must_be_divisible_by must be set to ' \
|
|
'run _internal_predict_3D_3Dconv'
|
|
if verbose: print("do mirror:", do_mirroring)
|
|
|
|
data, slicer = pad_nd_image(x, min_size, pad_border_mode, pad_kwargs, True,
|
|
self.input_shape_must_be_divisible_by)
|
|
|
|
predicted_probabilities = self._internal_maybe_mirror_and_pred_3D(data[None], mirror_axes, do_mirroring,
|
|
None)[0]
|
|
|
|
slicer = tuple(
|
|
[slice(0, predicted_probabilities.shape[i]) for i in range(len(predicted_probabilities.shape) -
|
|
(len(slicer) - 1))] + slicer[1:])
|
|
predicted_probabilities = predicted_probabilities[slicer]
|
|
|
|
if regions_class_order is None:
|
|
predicted_segmentation = predicted_probabilities.argmax(0)
|
|
predicted_segmentation = predicted_segmentation.detach().cpu().numpy()
|
|
predicted_probabilities = predicted_probabilities.detach().cpu().numpy()
|
|
else:
|
|
predicted_probabilities = predicted_probabilities.detach().cpu().numpy()
|
|
predicted_segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.float32)
|
|
for i, c in enumerate(regions_class_order):
|
|
predicted_segmentation[predicted_probabilities[i] > 0.5] = c
|
|
|
|
return predicted_segmentation, predicted_probabilities
|
|
|
|
def _internal_maybe_mirror_and_pred_3D(self, x: Union[np.ndarray, torch.tensor], mirror_axes: tuple,
|
|
do_mirroring: bool = True,
|
|
mult: np.ndarray or torch.tensor = None) -> torch.tensor:
|
|
assert len(x.shape) == 5, 'x must be (b, c, x, y, z)'
|
|
|
|
# if cuda available:
|
|
# everything in here takes place on the GPU. If x and mult are not yet on GPU this will be taken care of here
|
|
# we now return a cuda tensor! Not numpy array!
|
|
|
|
x = maybe_to_torch(x)
|
|
result_torch = torch.zeros([1, self.num_classes] + list(x.shape[2:]),
|
|
dtype=torch.float)
|
|
|
|
if torch.cuda.is_available():
|
|
x = to_cuda(x, gpu_id=self.get_device())
|
|
result_torch = result_torch.cuda(self.get_device(), non_blocking=True)
|
|
|
|
if mult is not None:
|
|
mult = maybe_to_torch(mult)
|
|
if torch.cuda.is_available():
|
|
mult = to_cuda(mult, gpu_id=self.get_device())
|
|
|
|
if do_mirroring:
|
|
mirror_idx = 8
|
|
num_results = 2 ** len(mirror_axes)
|
|
else:
|
|
mirror_idx = 1
|
|
num_results = 1
|
|
|
|
for m in range(mirror_idx):
|
|
if m == 0:
|
|
pred = self.inference_apply_nonlin(self(x))
|
|
result_torch += 1 / num_results * pred
|
|
|
|
if m == 1 and (2 in mirror_axes):
|
|
pred = self.inference_apply_nonlin(self(torch.flip(x, (4, ))))
|
|
result_torch += 1 / num_results * torch.flip(pred, (4,))
|
|
|
|
if m == 2 and (1 in mirror_axes):
|
|
pred = self.inference_apply_nonlin(self(torch.flip(x, (3, ))))
|
|
result_torch += 1 / num_results * torch.flip(pred, (3,))
|
|
|
|
if m == 3 and (2 in mirror_axes) and (1 in mirror_axes):
|
|
pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 3))))
|
|
result_torch += 1 / num_results * torch.flip(pred, (4, 3))
|
|
|
|
if m == 4 and (0 in mirror_axes):
|
|
pred = self.inference_apply_nonlin(self(torch.flip(x, (2, ))))
|
|
result_torch += 1 / num_results * torch.flip(pred, (2,))
|
|
|
|
if m == 5 and (0 in mirror_axes) and (2 in mirror_axes):
|
|
pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 2))))
|
|
result_torch += 1 / num_results * torch.flip(pred, (4, 2))
|
|
|
|
if m == 6 and (0 in mirror_axes) and (1 in mirror_axes):
|
|
pred = self.inference_apply_nonlin(self(torch.flip(x, (3, 2))))
|
|
result_torch += 1 / num_results * torch.flip(pred, (3, 2))
|
|
|
|
if m == 7 and (0 in mirror_axes) and (1 in mirror_axes) and (2 in mirror_axes):
|
|
pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 3, 2))))
|
|
result_torch += 1 / num_results * torch.flip(pred, (4, 3, 2))
|
|
|
|
if mult is not None:
|
|
result_torch[:, :] *= mult
|
|
|
|
return result_torch
|
|
|
|
def _internal_maybe_mirror_and_pred_2D(self, x: Union[np.ndarray, torch.tensor], mirror_axes: tuple,
|
|
do_mirroring: bool = True,
|
|
mult: np.ndarray or torch.tensor = None) -> torch.tensor:
|
|
# if cuda available:
|
|
# everything in here takes place on the GPU. If x and mult are not yet on GPU this will be taken care of here
|
|
# we now return a cuda tensor! Not numpy array!
|
|
|
|
assert len(x.shape) == 4, 'x must be (b, c, x, y)'
|
|
|
|
x = maybe_to_torch(x)
|
|
result_torch = torch.zeros([x.shape[0], self.num_classes] + list(x.shape[2:]), dtype=torch.float)
|
|
|
|
if torch.cuda.is_available():
|
|
x = to_cuda(x, gpu_id=self.get_device())
|
|
result_torch = result_torch.cuda(self.get_device(), non_blocking=True)
|
|
|
|
if mult is not None:
|
|
mult = maybe_to_torch(mult)
|
|
if torch.cuda.is_available():
|
|
mult = to_cuda(mult, gpu_id=self.get_device())
|
|
|
|
if do_mirroring:
|
|
mirror_idx = 4
|
|
num_results = 2 ** len(mirror_axes)
|
|
else:
|
|
mirror_idx = 1
|
|
num_results = 1
|
|
|
|
for m in range(mirror_idx):
|
|
if m == 0:
|
|
pred = self.inference_apply_nonlin(self(x))
|
|
result_torch += 1 / num_results * pred
|
|
|
|
if m == 1 and (1 in mirror_axes):
|
|
pred = self.inference_apply_nonlin(self(torch.flip(x, (3, ))))
|
|
result_torch += 1 / num_results * torch.flip(pred, (3, ))
|
|
|
|
if m == 2 and (0 in mirror_axes):
|
|
pred = self.inference_apply_nonlin(self(torch.flip(x, (2, ))))
|
|
result_torch += 1 / num_results * torch.flip(pred, (2, ))
|
|
|
|
if m == 3 and (0 in mirror_axes) and (1 in mirror_axes):
|
|
pred = self.inference_apply_nonlin(self(torch.flip(x, (3, 2))))
|
|
result_torch += 1 / num_results * torch.flip(pred, (3, 2))
|
|
|
|
if mult is not None:
|
|
result_torch[:, :] *= mult
|
|
|
|
return result_torch
|
|
|
|
def _internal_predict_2D_2Dconv_tiled(self, x: np.ndarray, step_size: float, do_mirroring: bool, mirror_axes: tuple,
|
|
patch_size: tuple, regions_class_order: tuple, use_gaussian: bool,
|
|
pad_border_mode: str, pad_kwargs: dict, all_in_gpu: bool,
|
|
verbose: bool) -> Tuple[np.ndarray, np.ndarray]:
|
|
# better safe than sorry
|
|
assert len(x.shape) == 3, "x must be (c, x, y)"
|
|
|
|
if verbose: print("step_size:", step_size)
|
|
if verbose: print("do mirror:", do_mirroring)
|
|
|
|
assert patch_size is not None, "patch_size cannot be None for tiled prediction"
|
|
|
|
# for sliding window inference the image must at least be as large as the patch size. It does not matter
|
|
# whether the shape is divisible by 2**num_pool as long as the patch size is
|
|
data, slicer = pad_nd_image(x, patch_size, pad_border_mode, pad_kwargs, True, None)
|
|
data_shape = data.shape # still c, x, y
|
|
|
|
# compute the steps for sliding window
|
|
steps = self._compute_steps_for_sliding_window(patch_size, data_shape[1:], step_size)
|
|
num_tiles = len(steps[0]) * len(steps[1])
|
|
|
|
if verbose:
|
|
print("data shape:", data_shape)
|
|
print("patch size:", patch_size)
|
|
print("steps (x, y, and z):", steps)
|
|
print("number of tiles:", num_tiles)
|
|
|
|
# we only need to compute that once. It can take a while to compute this due to the large sigma in
|
|
# gaussian_filter
|
|
if use_gaussian and num_tiles > 1:
|
|
if self._gaussian_2d is None or not all(
|
|
[i == j for i, j in zip(patch_size, self._patch_size_for_gaussian_2d)]):
|
|
if verbose: print('computing Gaussian')
|
|
gaussian_importance_map = self._get_gaussian(patch_size, sigma_scale=1. / 8)
|
|
|
|
self._gaussian_2d = gaussian_importance_map
|
|
self._patch_size_for_gaussian_2d = patch_size
|
|
else:
|
|
if verbose: print("using precomputed Gaussian")
|
|
gaussian_importance_map = self._gaussian_2d
|
|
|
|
gaussian_importance_map = torch.from_numpy(gaussian_importance_map)
|
|
if torch.cuda.is_available():
|
|
gaussian_importance_map = gaussian_importance_map.cuda(self.get_device(), non_blocking=True)
|
|
|
|
else:
|
|
gaussian_importance_map = None
|
|
|
|
if all_in_gpu:
|
|
# If we run the inference in GPU only (meaning all tensors are allocated on the GPU, this reduces
|
|
# CPU-GPU communication but required more GPU memory) we need to preallocate a few things on GPU
|
|
|
|
if use_gaussian and num_tiles > 1:
|
|
# half precision for the outputs should be good enough. If the outputs here are half, the
|
|
# gaussian_importance_map should be as well
|
|
gaussian_importance_map = gaussian_importance_map.half()
|
|
|
|
# make sure we did not round anything to 0
|
|
gaussian_importance_map[gaussian_importance_map == 0] = gaussian_importance_map[
|
|
gaussian_importance_map != 0].min()
|
|
|
|
add_for_nb_of_preds = gaussian_importance_map
|
|
else:
|
|
add_for_nb_of_preds = torch.ones(patch_size, device=self.get_device())
|
|
|
|
if verbose: print("initializing result array (on GPU)")
|
|
aggregated_results = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,
|
|
device=self.get_device())
|
|
|
|
if verbose: print("moving data to GPU")
|
|
data = torch.from_numpy(data).cuda(self.get_device(), non_blocking=True)
|
|
|
|
if verbose: print("initializing result_numsamples (on GPU)")
|
|
aggregated_nb_of_predictions = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,
|
|
device=self.get_device())
|
|
else:
|
|
if use_gaussian and num_tiles > 1:
|
|
add_for_nb_of_preds = self._gaussian_2d
|
|
else:
|
|
add_for_nb_of_preds = np.ones(patch_size, dtype=np.float32)
|
|
aggregated_results = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
|
|
aggregated_nb_of_predictions = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
|
|
|
|
for x in steps[0]:
|
|
lb_x = x
|
|
ub_x = x + patch_size[0]
|
|
for y in steps[1]:
|
|
lb_y = y
|
|
ub_y = y + patch_size[1]
|
|
|
|
predicted_patch = self._internal_maybe_mirror_and_pred_2D(
|
|
data[None, :, lb_x:ub_x, lb_y:ub_y], mirror_axes, do_mirroring,
|
|
gaussian_importance_map)[0]
|
|
|
|
if all_in_gpu:
|
|
predicted_patch = predicted_patch.half()
|
|
else:
|
|
predicted_patch = predicted_patch.cpu().numpy()
|
|
|
|
aggregated_results[:, lb_x:ub_x, lb_y:ub_y] += predicted_patch
|
|
aggregated_nb_of_predictions[:, lb_x:ub_x, lb_y:ub_y] += add_for_nb_of_preds
|
|
|
|
# we reverse the padding here (remeber that we padded the input to be at least as large as the patch size
|
|
slicer = tuple(
|
|
[slice(0, aggregated_results.shape[i]) for i in
|
|
range(len(aggregated_results.shape) - (len(slicer) - 1))] + slicer[1:])
|
|
aggregated_results = aggregated_results[slicer]
|
|
aggregated_nb_of_predictions = aggregated_nb_of_predictions[slicer]
|
|
|
|
# computing the class_probabilities by dividing the aggregated result with result_numsamples
|
|
class_probabilities = aggregated_results / aggregated_nb_of_predictions
|
|
|
|
if regions_class_order is None:
|
|
predicted_segmentation = class_probabilities.argmax(0)
|
|
else:
|
|
if all_in_gpu:
|
|
class_probabilities_here = class_probabilities.detach().cpu().numpy()
|
|
else:
|
|
class_probabilities_here = class_probabilities
|
|
predicted_segmentation = np.zeros(class_probabilities_here.shape[1:], dtype=np.float32)
|
|
for i, c in enumerate(regions_class_order):
|
|
predicted_segmentation[class_probabilities_here[i] > 0.5] = c
|
|
|
|
if all_in_gpu:
|
|
if verbose: print("copying results to CPU")
|
|
|
|
if regions_class_order is None:
|
|
predicted_segmentation = predicted_segmentation.detach().cpu().numpy()
|
|
|
|
class_probabilities = class_probabilities.detach().cpu().numpy()
|
|
|
|
if verbose: print("prediction done")
|
|
return predicted_segmentation, class_probabilities
|
|
|
|
def _internal_predict_3D_2Dconv(self, x: np.ndarray, min_size: Tuple[int, int], do_mirroring: bool,
|
|
mirror_axes: tuple = (0, 1), regions_class_order: tuple = None,
|
|
pad_border_mode: str = "constant", pad_kwargs: dict = None,
|
|
all_in_gpu: bool = False, verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
|
|
if all_in_gpu:
|
|
raise NotImplementedError
|
|
assert len(x.shape) == 4, "data must be c, x, y, z"
|
|
predicted_segmentation = []
|
|
softmax_pred = []
|
|
for s in range(x.shape[1]):
|
|
pred_seg, softmax_pres = self._internal_predict_2D_2Dconv(
|
|
x[:, s], min_size, do_mirroring, mirror_axes, regions_class_order, pad_border_mode, pad_kwargs, verbose)
|
|
predicted_segmentation.append(pred_seg[None])
|
|
softmax_pred.append(softmax_pres[None])
|
|
predicted_segmentation = np.vstack(predicted_segmentation)
|
|
softmax_pred = np.vstack(softmax_pred).transpose((1, 0, 2, 3))
|
|
return predicted_segmentation, softmax_pred
|
|
|
|
def predict_3D_pseudo3D_2Dconv(self, x: np.ndarray, min_size: Tuple[int, int], do_mirroring: bool,
|
|
mirror_axes: tuple = (0, 1), regions_class_order: tuple = None,
|
|
pseudo3D_slices: int = 5, all_in_gpu: bool = False,
|
|
pad_border_mode: str = "constant", pad_kwargs: dict = None,
|
|
verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
|
|
if all_in_gpu:
|
|
raise NotImplementedError
|
|
assert len(x.shape) == 4, "data must be c, x, y, z"
|
|
assert pseudo3D_slices % 2 == 1, "pseudo3D_slices must be odd"
|
|
extra_slices = (pseudo3D_slices - 1) // 2
|
|
|
|
shp_for_pad = np.array(x.shape)
|
|
shp_for_pad[1] = extra_slices
|
|
|
|
pad = np.zeros(shp_for_pad, dtype=np.float32)
|
|
data = np.concatenate((pad, x, pad), 1)
|
|
|
|
predicted_segmentation = []
|
|
softmax_pred = []
|
|
for s in range(extra_slices, data.shape[1] - extra_slices):
|
|
d = data[:, (s - extra_slices):(s + extra_slices + 1)]
|
|
d = d.reshape((-1, d.shape[-2], d.shape[-1]))
|
|
pred_seg, softmax_pres = \
|
|
self._internal_predict_2D_2Dconv(d, min_size, do_mirroring, mirror_axes,
|
|
regions_class_order, pad_border_mode, pad_kwargs, verbose)
|
|
predicted_segmentation.append(pred_seg[None])
|
|
softmax_pred.append(softmax_pres[None])
|
|
predicted_segmentation = np.vstack(predicted_segmentation)
|
|
softmax_pred = np.vstack(softmax_pred).transpose((1, 0, 2, 3))
|
|
|
|
return predicted_segmentation, softmax_pred
|
|
|
|
def _internal_predict_3D_2Dconv_tiled(self, x: np.ndarray, patch_size: Tuple[int, int], do_mirroring: bool,
|
|
mirror_axes: tuple = (0, 1), step_size: float = 0.5,
|
|
regions_class_order: tuple = None, use_gaussian: bool = False,
|
|
pad_border_mode: str = "edge", pad_kwargs: dict =None,
|
|
all_in_gpu: bool = False,
|
|
verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
|
|
if all_in_gpu:
|
|
raise NotImplementedError
|
|
|
|
assert len(x.shape) == 4, "data must be c, x, y, z"
|
|
|
|
predicted_segmentation = []
|
|
softmax_pred = []
|
|
|
|
for s in range(x.shape[1]):
|
|
pred_seg, softmax_pres = self._internal_predict_2D_2Dconv_tiled(
|
|
x[:, s], step_size, do_mirroring, mirror_axes, patch_size, regions_class_order, use_gaussian,
|
|
pad_border_mode, pad_kwargs, all_in_gpu, verbose)
|
|
|
|
predicted_segmentation.append(pred_seg[None])
|
|
softmax_pred.append(softmax_pres[None])
|
|
|
|
predicted_segmentation = np.vstack(predicted_segmentation)
|
|
softmax_pred = np.vstack(softmax_pred).transpose((1, 0, 2, 3))
|
|
|
|
return predicted_segmentation, softmax_pred
|
|
|
|
|
|
class ConvDropoutNormNonlin(nn.Module):
|
|
"""
|
|
fixes a bug in ConvDropoutNormNonlin where lrelu was used regardless of nonlin. Bad.
|
|
"""
|
|
|
|
def __init__(self, input_channels, output_channels,
|
|
conv_op=nn.Conv2d, conv_kwargs=None,
|
|
norm_op=nn.BatchNorm2d, norm_op_kwargs=None,
|
|
dropout_op=nn.Dropout2d, dropout_op_kwargs=None,
|
|
nonlin=nn.LeakyReLU, nonlin_kwargs=None):
|
|
super(ConvDropoutNormNonlin, self).__init__()
|
|
if nonlin_kwargs is None:
|
|
nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
|
|
if dropout_op_kwargs is None:
|
|
dropout_op_kwargs = {'p': 0.5, 'inplace': True}
|
|
if norm_op_kwargs is None:
|
|
norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
|
|
if conv_kwargs is None:
|
|
conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}
|
|
|
|
self.nonlin_kwargs = nonlin_kwargs
|
|
self.nonlin = nonlin
|
|
self.dropout_op = dropout_op
|
|
self.dropout_op_kwargs = dropout_op_kwargs
|
|
self.norm_op_kwargs = norm_op_kwargs
|
|
self.conv_kwargs = conv_kwargs
|
|
self.conv_op = conv_op
|
|
self.norm_op = norm_op
|
|
|
|
self.conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs)
|
|
if self.dropout_op is not None and self.dropout_op_kwargs['p'] is not None and self.dropout_op_kwargs[
|
|
'p'] > 0:
|
|
self.dropout = self.dropout_op(**self.dropout_op_kwargs)
|
|
else:
|
|
self.dropout = None
|
|
self.instnorm = self.norm_op(output_channels, **self.norm_op_kwargs)
|
|
self.lrelu = self.nonlin(**self.nonlin_kwargs)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.dropout is not None:
|
|
x = self.dropout(x)
|
|
return self.lrelu(self.instnorm(x))
|
|
|
|
|
|
class ConvDropoutNonlinNorm(ConvDropoutNormNonlin):
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
if self.dropout is not None:
|
|
x = self.dropout(x)
|
|
return self.instnorm(self.lrelu(x))
|
|
|
|
|
|
class StackedConvLayers(nn.Module):
|
|
def __init__(self, input_feature_channels, output_feature_channels, num_convs,
|
|
conv_op=nn.Conv2d, conv_kwargs=None,
|
|
norm_op=nn.BatchNorm2d, norm_op_kwargs=None,
|
|
dropout_op=nn.Dropout2d, dropout_op_kwargs=None,
|
|
nonlin=nn.LeakyReLU, nonlin_kwargs=None, first_stride=None, basic_block=ConvDropoutNormNonlin):
|
|
'''
|
|
stacks ConvDropoutNormLReLU layers. initial_stride will only be applied to first layer in the stack. The other parameters affect all layers
|
|
:param input_feature_channels:
|
|
:param output_feature_channels:
|
|
:param num_convs:
|
|
:param dilation:
|
|
:param kernel_size:
|
|
:param padding:
|
|
:param dropout:
|
|
:param initial_stride:
|
|
:param conv_op:
|
|
:param norm_op:
|
|
:param dropout_op:
|
|
:param inplace:
|
|
:param neg_slope:
|
|
:param norm_affine:
|
|
:param conv_bias:
|
|
'''
|
|
self.input_channels = input_feature_channels
|
|
self.output_channels = output_feature_channels
|
|
|
|
if nonlin_kwargs is None:
|
|
nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
|
|
if dropout_op_kwargs is None:
|
|
dropout_op_kwargs = {'p': 0.5, 'inplace': True}
|
|
if norm_op_kwargs is None:
|
|
norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
|
|
if conv_kwargs is None:
|
|
conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}
|
|
|
|
self.nonlin_kwargs = nonlin_kwargs
|
|
self.nonlin = nonlin
|
|
self.dropout_op = dropout_op
|
|
self.dropout_op_kwargs = dropout_op_kwargs
|
|
self.norm_op_kwargs = norm_op_kwargs
|
|
self.conv_kwargs = conv_kwargs
|
|
self.conv_op = conv_op
|
|
self.norm_op = norm_op
|
|
|
|
if first_stride is not None:
|
|
self.conv_kwargs_first_conv = deepcopy(conv_kwargs)
|
|
self.conv_kwargs_first_conv['stride'] = first_stride
|
|
else:
|
|
self.conv_kwargs_first_conv = conv_kwargs
|
|
|
|
super(StackedConvLayers, self).__init__()
|
|
self.blocks = nn.Sequential(
|
|
*([basic_block(input_feature_channels, output_feature_channels, self.conv_op,
|
|
self.conv_kwargs_first_conv,
|
|
self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs,
|
|
self.nonlin, self.nonlin_kwargs)] +
|
|
[basic_block(output_feature_channels, output_feature_channels, self.conv_op,
|
|
self.conv_kwargs,
|
|
self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs,
|
|
self.nonlin, self.nonlin_kwargs) for _ in range(num_convs - 1)]))
|
|
|
|
def forward(self, x):
|
|
return self.blocks(x)
|
|
|
|
|
|
def print_module_training_status(module):
|
|
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d) or isinstance(module, nn.Dropout3d) or \
|
|
isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout) or isinstance(module, nn.InstanceNorm3d) \
|
|
or isinstance(module, nn.InstanceNorm2d) or isinstance(module, nn.InstanceNorm1d) \
|
|
or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or isinstance(module,
|
|
nn.BatchNorm1d):
|
|
print(str(module), module.training)
|
|
|
|
|
|
class hwUpsample(nn.Module):
|
|
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=False):
|
|
super(hwUpsample, self).__init__()
|
|
self.align_corners = align_corners
|
|
self.mode = mode
|
|
self.scale_factor = scale_factor
|
|
self.size = size
|
|
|
|
def forward(self, x):
|
|
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode,
|
|
align_corners=self.align_corners)
|
|
|
|
|
|
class OnePromptEncoderUnet(SegmentationNetwork):
|
|
DEFAULT_BATCH_SIZE_3D = 2
|
|
DEFAULT_PATCH_SIZE_3D = (64, 192, 160)
|
|
SPACING_FACTOR_BETWEEN_STAGES = 2
|
|
BASE_NUM_FEATURES_3D = 30
|
|
MAX_NUMPOOL_3D = 999
|
|
MAX_NUM_FILTERS_3D = 320
|
|
|
|
DEFAULT_PATCH_SIZE_2D = (256, 256)
|
|
BASE_NUM_FEATURES_2D = 30
|
|
DEFAULT_BATCH_SIZE_2D = 50
|
|
MAX_NUMPOOL_2D = 999
|
|
MAX_FILTERS_2D = 480
|
|
|
|
use_this_for_batch_size_computation_2D = 19739648
|
|
use_this_for_batch_size_computation_3D = 520000000 # 505789440
|
|
|
|
def __init__(self, input_channels, base_num_features, final_num_features, fea_size, num_pool, num_conv_per_stage=2,
|
|
feat_map_mul_on_downscale=2, conv_op=nn.Conv2d,
|
|
norm_op=nn.BatchNorm2d, norm_op_kwargs=None,
|
|
dropout_op=nn.Dropout2d, dropout_op_kwargs=None,
|
|
nonlin=nn.LeakyReLU, nonlin_kwargs=None, highway = False, deep_supervision=False, anchor_out=False, dropout_in_localization=False,
|
|
final_nonlin=sigmoid_helper, weightInitializer=InitWeights_He(1e-2), pool_op_kernel_sizes=None,
|
|
conv_kernel_sizes=None,
|
|
upscale_logits=False, convolutional_pooling=False, convolutional_upsampling=False,
|
|
max_num_features=None, basic_block=ConvDropoutNormNonlin,
|
|
seg_output_use_bias=False):
|
|
"""
|
|
basically more flexible than v1, architecture is the same
|
|
|
|
Does this look complicated? Nah bro. Functionality > usability
|
|
|
|
This does everything you need, including world peace.
|
|
|
|
Questions? -> f.isensee@dkfz.de
|
|
"""
|
|
super(OnePromptEncoderUnet, self).__init__()
|
|
self.convolutional_upsampling = convolutional_upsampling
|
|
self.convolutional_pooling = convolutional_pooling
|
|
self.upscale_logits = upscale_logits
|
|
if nonlin_kwargs is None:
|
|
nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
|
|
if dropout_op_kwargs is None:
|
|
dropout_op_kwargs = {'p': 0.5, 'inplace': True}
|
|
if norm_op_kwargs is None:
|
|
norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
|
|
|
|
self.conv_kwargs = {'stride': 1, 'dilation': 1, 'bias': True}
|
|
|
|
self.nonlin = nonlin
|
|
self.nonlin_kwargs = nonlin_kwargs
|
|
self.dropout_op_kwargs = dropout_op_kwargs
|
|
self.norm_op_kwargs = norm_op_kwargs
|
|
self.weightInitializer = weightInitializer
|
|
self.conv_op = conv_op
|
|
self.norm_op = norm_op
|
|
self.dropout_op = dropout_op
|
|
self.final_nonlin = final_nonlin
|
|
self._deep_supervision = deep_supervision
|
|
self.do_ds = deep_supervision
|
|
self.anchor_out = anchor_out
|
|
|
|
if conv_op == nn.Conv2d:
|
|
pool_op = nn.MaxPool2d
|
|
if pool_op_kernel_sizes is None:
|
|
pool_op_kernel_sizes = [(2, 2)] * num_pool
|
|
if conv_kernel_sizes is None:
|
|
conv_kernel_sizes = [(3, 3)] * (num_pool + 1)
|
|
elif conv_op == nn.Conv3d:
|
|
pool_op = nn.MaxPool3d
|
|
if pool_op_kernel_sizes is None:
|
|
pool_op_kernel_sizes = [(2, 2, 2)] * num_pool
|
|
if conv_kernel_sizes is None:
|
|
conv_kernel_sizes = [(3, 3, 3)] * (num_pool + 1)
|
|
else:
|
|
raise ValueError("unknown convolution dimensionality, conv op: %s" % str(conv_op))
|
|
|
|
self.input_shape_must_be_divisible_by = np.prod(pool_op_kernel_sizes, 0, dtype=np.int64)
|
|
self.pool_op_kernel_sizes = pool_op_kernel_sizes
|
|
self.conv_kernel_sizes = conv_kernel_sizes
|
|
|
|
self.conv_pad_sizes = []
|
|
for krnl in self.conv_kernel_sizes:
|
|
self.conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl])
|
|
|
|
if max_num_features is None:
|
|
if self.conv_op == nn.Conv3d:
|
|
self.max_num_features = self.MAX_NUM_FILTERS_3D
|
|
else:
|
|
self.max_num_features = self.MAX_FILTERS_2D
|
|
else:
|
|
self.max_num_features = max_num_features
|
|
|
|
self.conv_blocks_context = []
|
|
self.conv_blocks_localization = []
|
|
self.td = []
|
|
self.al = []
|
|
|
|
output_features = base_num_features
|
|
input_features = input_channels
|
|
|
|
for d in range(num_pool):
|
|
# determine the first stride
|
|
if d != 0 and self.convolutional_pooling:
|
|
first_stride = pool_op_kernel_sizes[d - 1]
|
|
else:
|
|
first_stride = None
|
|
|
|
self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[d]
|
|
self.conv_kwargs['padding'] = self.conv_pad_sizes[d]
|
|
# add convolutions
|
|
self.conv_blocks_context.append(StackedConvLayers(input_features, output_features, num_conv_per_stage,
|
|
self.conv_op, self.conv_kwargs, self.norm_op,
|
|
self.norm_op_kwargs, self.dropout_op,
|
|
self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs,
|
|
first_stride, basic_block=basic_block))
|
|
self.al.append(nn.Linear(output_features, final_num_features))
|
|
|
|
if not self.convolutional_pooling:
|
|
self.td.append(pool_op(pool_op_kernel_sizes[d]))
|
|
input_features = output_features
|
|
output_features = int(np.round(output_features * feat_map_mul_on_downscale))
|
|
|
|
output_features = min(output_features, self.max_num_features)
|
|
|
|
# now the bottleneck.
|
|
# determine the first stride
|
|
if self.convolutional_pooling:
|
|
first_stride = pool_op_kernel_sizes[-1]
|
|
else:
|
|
first_stride = None
|
|
|
|
# the output of the last conv must match the number of features from the skip connection if we are not using
|
|
# convolutional upsampling. If we use convolutional upsampling then the reduction in feature maps will be
|
|
# done by the transposed conv
|
|
# if self.convolutional_upsampling:
|
|
# final_num_features = output_features
|
|
# else:
|
|
# final_num_features = self.conv_blocks_context[-1].output_channels
|
|
|
|
self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[num_pool]
|
|
self.conv_kwargs['padding'] = self.conv_pad_sizes[num_pool]
|
|
self.conv_blocks_context.append(nn.Sequential(
|
|
StackedConvLayers(input_features, output_features, num_conv_per_stage - 1, self.conv_op, self.conv_kwargs,
|
|
self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin,
|
|
self.nonlin_kwargs, first_stride, basic_block=basic_block),
|
|
StackedConvLayers(output_features, final_num_features, 1, self.conv_op, self.conv_kwargs,
|
|
self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin,
|
|
self.nonlin_kwargs, basic_block=basic_block)))
|
|
|
|
# if we don't want to do dropout in the localization pathway then we set the dropout prob to zero here
|
|
if not dropout_in_localization:
|
|
old_dropout_p = self.dropout_op_kwargs['p']
|
|
self.dropout_op_kwargs['p'] = 0.0
|
|
|
|
# # now lets build the localization pathway
|
|
# for u in range(num_pool):
|
|
# nfeatures_from_skip = self.conv_blocks_context[
|
|
# -(2 + u)].output_channels # self.conv_blocks_context[-1] is bottleneck, so start with -2
|
|
# n_features_after_tu_and_concat = nfeatures_from_skip * 2
|
|
|
|
|
|
# self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[- (u + 1)]
|
|
# self.conv_kwargs['padding'] = self.conv_pad_sizes[- (u + 1)]
|
|
# self.conv_blocks_localization.append(nn.Sequential(
|
|
# StackedConvLayers(n_features_after_tu_and_concat, nfeatures_from_skip, num_conv_per_stage - 1,
|
|
# self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op,
|
|
# self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block),
|
|
# StackedConvLayers(nfeatures_from_skip, final_num_features, 1, self.conv_op, self.conv_kwargs,
|
|
# self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs,
|
|
# self.nonlin, self.nonlin_kwargs, basic_block=basic_block)
|
|
# ))
|
|
self.up = []
|
|
for u in range(num_pool):
|
|
self.up.append(nn.Upsample(size=(fea_size, fea_size), mode='bilinear'))
|
|
|
|
|
|
if not dropout_in_localization:
|
|
self.dropout_op_kwargs['p'] = old_dropout_p
|
|
|
|
# register all modules properly
|
|
# self.conv_blocks_localization = nn.ModuleList(self.conv_blocks_localization)
|
|
self.conv_blocks_context = nn.ModuleList(self.conv_blocks_context)
|
|
|
|
self.td = nn.ModuleList(self.td)
|
|
self.up = nn.ModuleList(self.up)
|
|
self.al = nn.ModuleList(self.al)
|
|
|
|
if self.weightInitializer is not None:
|
|
self.apply(self.weightInitializer)
|
|
# self.apply(print_module_training_status)
|
|
|
|
def forward(self, raw):
|
|
skips_raw = []
|
|
for d in range(len(self.conv_blocks_context) - 1):
|
|
raw = self.conv_blocks_context[d](raw)
|
|
raw_arch = self.up[d](raw)
|
|
raw_arch = raw_arch.permute(0, 2, 3, 1)
|
|
raw_arch = self.al[d](raw_arch)
|
|
skips_raw.append(raw_arch)
|
|
if not self.convolutional_pooling:
|
|
raw = self.td[d](raw)
|
|
|
|
raw = self.conv_blocks_context[-1](raw)
|
|
raw = raw.permute(0, 2, 3, 1)
|
|
|
|
return raw, skips_raw
|
|
|
|
@staticmethod
|
|
def compute_approx_vram_consumption(patch_size, num_pool_per_axis, base_num_features, max_num_features,
|
|
num_modalities, num_classes, pool_op_kernel_sizes, deep_supervision=False,
|
|
conv_per_stage=2):
|
|
"""
|
|
This only applies for num_conv_per_stage and convolutional_upsampling=True
|
|
not real vram consumption. just a constant term to which the vram consumption will be approx proportional
|
|
(+ offset for parameter storage)
|
|
:param deep_supervision:
|
|
:param patch_size:
|
|
:param num_pool_per_axis:
|
|
:param base_num_features:
|
|
:param max_num_features:
|
|
:param num_modalities:
|
|
:param num_classes:
|
|
:param pool_op_kernel_sizes:
|
|
:return:
|
|
"""
|
|
if not isinstance(num_pool_per_axis, np.ndarray):
|
|
num_pool_per_axis = np.array(num_pool_per_axis)
|
|
|
|
npool = len(pool_op_kernel_sizes)
|
|
|
|
map_size = np.array(patch_size)
|
|
tmp = np.int64((conv_per_stage * 2 + 1) * np.prod(map_size, dtype=np.int64) * base_num_features +
|
|
num_modalities * np.prod(map_size, dtype=np.int64) +
|
|
num_classes * np.prod(map_size, dtype=np.int64))
|
|
|
|
num_feat = base_num_features
|
|
|
|
for p in range(npool):
|
|
for pi in range(len(num_pool_per_axis)):
|
|
map_size[pi] /= pool_op_kernel_sizes[p][pi]
|
|
num_feat = min(num_feat * 2, max_num_features)
|
|
num_blocks = (conv_per_stage * 2 + 1) if p < (npool - 1) else conv_per_stage # conv_per_stage + conv_per_stage for the convs of encode/decode and 1 for transposed conv
|
|
tmp += num_blocks * np.prod(map_size, dtype=np.int64) * num_feat
|
|
if deep_supervision and p < (npool - 2):
|
|
tmp += np.prod(map_size, dtype=np.int64) * num_classes
|
|
# print(p, map_size, num_feat, tmp)
|
|
return tmp |