You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
329 lines
11 KiB
329 lines
11 KiB
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from typing import List, Tuple, Type
|
|
|
|
from .common import LayerNorm2d
|
|
from .modules import CrossAttentionBlock, OnePromptFormer, TwoWayTransformer
|
|
from einops import rearrange
|
|
import math
|
|
from .image_encoder import PatchEmbed
|
|
|
|
class OnePromptDecoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
depth: int = 4,
|
|
prompt_embed_dim: int = 256,
|
|
embed_dim: int = 768,
|
|
out_chans: int = 256,
|
|
token_num: int,
|
|
patch_size: int,
|
|
mlp_dim: int = 1024,
|
|
) -> None:
|
|
super().__init__()
|
|
self.depth = depth
|
|
self.of = nn.ModuleList()
|
|
self.deals = nn.ModuleList()
|
|
|
|
# nlist = [4096, 4096, 4096, 4096]
|
|
# embed_dim_list = [768, 256, 256, 256]
|
|
|
|
self.updecode = MaskDecoder(
|
|
transformer_dim = prompt_embed_dim,
|
|
num_multimask_outputs=3,
|
|
transformer= TwoWayTransformer(
|
|
depth=2,
|
|
embedding_dim=prompt_embed_dim,
|
|
# mlp_dim=2048,
|
|
# num_heads=8,
|
|
mlp_dim=256,
|
|
num_heads=2,
|
|
)
|
|
)
|
|
|
|
self.neck = nn.Sequential(
|
|
nn.Conv2d(
|
|
embed_dim,
|
|
out_chans,
|
|
kernel_size=1,
|
|
bias=False,
|
|
),
|
|
LayerNorm2d(out_chans),
|
|
nn.Conv2d(
|
|
out_chans,
|
|
out_chans,
|
|
kernel_size=3,
|
|
padding=1,
|
|
bias=False,
|
|
),
|
|
LayerNorm2d(out_chans),
|
|
)
|
|
|
|
for i in range(depth):
|
|
self.of.append(
|
|
OnePromptFormer(
|
|
embedding_dim = prompt_embed_dim,
|
|
prompt_embed_dim = prompt_embed_dim,
|
|
token_num = token_num,
|
|
num_heads = 2,
|
|
mlp_dim = mlp_dim
|
|
)
|
|
)
|
|
|
|
self.deals.append(
|
|
Decode_Align(embed_dim=embed_dim, transformer_dim=prompt_embed_dim, stages=token_num-1)
|
|
)
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
kernel_size=(patch_size, patch_size),
|
|
stride=(patch_size, patch_size),
|
|
in_chans=prompt_embed_dim,
|
|
embed_dim=out_chans,
|
|
)
|
|
|
|
|
|
|
|
def forward(
|
|
self,
|
|
skips_raw: list,
|
|
skips_tmp: list,
|
|
raw_emb: torch.Tensor,
|
|
tmp_emb: torch.Tensor,
|
|
pt1: torch.Tensor,
|
|
pt2: torch.Tensor,
|
|
image_pe: torch.Tensor,
|
|
sparse_prompt_embeddings: torch.Tensor,
|
|
dense_prompt_embeddings: torch.Tensor,
|
|
multimask_output: bool,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
x = raw_emb + tmp_emb
|
|
x = self.neck(x.permute(0, 3, 1, 2))
|
|
x = x.permute(0, 2, 3, 1)
|
|
|
|
raw_emb = self.neck(raw_emb.permute(0, 3, 1, 2))
|
|
# raw_emb = raw_emb.permute(0, 2, 3, 1)
|
|
|
|
for u in range(self.depth):
|
|
if u == 0:
|
|
x, img_embed, tmp_embed, temp_pos, p1, p2= self.deals[u](x, skips_raw[-(u + 1)], skips_tmp[-(u + 1)], image_pe, pt1, pt2, dense_prompt_embeddings)
|
|
p1 = p1 + temp_pos.flatten(2).permute(0, 2, 1)
|
|
p2 = p2 + temp_pos.flatten(2).permute(0, 2, 1)
|
|
img_embed = img_embed.flatten(2).permute(0, 2, 1)
|
|
tmp_embed = tmp_embed.flatten(2).permute(0, 2, 1)
|
|
x = x.flatten(2).permute(0, 2, 1)
|
|
# print('tmp_embed size', tmp_embed.size())
|
|
# print('temp_pos size', temp_pos.size())
|
|
# print('p1 size', p1.size())
|
|
# print('p2 size', p2.size())
|
|
x = self.of[u](x,img_embed, tmp_embed, p1, p2)
|
|
# print(x.size())
|
|
x = rearrange(x,'b (c1 c2) d -> b d c1 c2', c1 = int(math.sqrt(x.size(1))))
|
|
x = self.patch_embed(x)
|
|
x = rearrange(x,'b c1 c2 d-> b (c1 c2) d')
|
|
|
|
# 调整 image_pe 的尺寸以匹配 raw_emb 的新尺寸
|
|
image_pe = F.interpolate(image_pe, size=(raw_emb.shape[2], raw_emb.shape[3]), mode='bilinear', align_corners=False)
|
|
|
|
# Select the correct mask or masks for output
|
|
low_res_masks, iou_predictions = self.updecode(
|
|
image_embeddings=raw_emb,
|
|
image_pe=image_pe,
|
|
mix_embeddings=x,
|
|
multimask_output=multimask_output,
|
|
)
|
|
|
|
return low_res_masks, iou_predictions
|
|
|
|
|
|
|
|
class MaskDecoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
transformer_dim: int,
|
|
transformer: nn.Module,
|
|
num_multimask_outputs: int = 3,
|
|
activation: Type[nn.Module] = nn.GELU,
|
|
iou_head_depth: int = 3,
|
|
iou_head_hidden_dim: int = 256,
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
self.transformer_dim = transformer_dim
|
|
self.transformer = transformer
|
|
|
|
self.num_multimask_outputs = num_multimask_outputs
|
|
|
|
self.iou_token = nn.Embedding(1, transformer_dim)
|
|
self.num_mask_tokens = num_multimask_outputs + 1
|
|
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
|
|
|
self.output_upscaling = nn.Sequential(
|
|
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
|
LayerNorm2d(transformer_dim // 4),
|
|
activation(),
|
|
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
|
activation(),
|
|
)
|
|
self.output_hypernetworks_mlps = nn.ModuleList(
|
|
[
|
|
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
|
for i in range(self.num_mask_tokens)
|
|
]
|
|
)
|
|
|
|
self.iou_prediction_head = MLP(
|
|
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
image_embeddings: torch.Tensor,
|
|
image_pe: torch.Tensor,
|
|
mix_embeddings: torch.Tensor,
|
|
multimask_output: bool,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
masks, iou_pred = self.predict_masks(
|
|
image_embeddings=image_embeddings,
|
|
image_pe=image_pe,
|
|
mix_embeddings=mix_embeddings,
|
|
)
|
|
|
|
# Select the correct mask or masks for output
|
|
if multimask_output:
|
|
mask_slice = slice(1, None)
|
|
else:
|
|
mask_slice = slice(0, 1)
|
|
masks = masks[:, mask_slice, :, :]
|
|
iou_pred = iou_pred[:, mask_slice]
|
|
|
|
# Prepare output
|
|
return masks, iou_pred
|
|
|
|
def predict_masks(
|
|
self,
|
|
image_embeddings: torch.Tensor,
|
|
image_pe: torch.Tensor,
|
|
mix_embeddings: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Predicts masks. See 'forward' for more details."""
|
|
# Concatenate output tokens
|
|
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
|
output_tokens = output_tokens.unsqueeze(0).expand(image_embeddings.size(0), -1, -1)
|
|
# print("output_tokens", output_tokens.size())
|
|
# print("mix_embeddings", mix_embeddings.size())
|
|
tokens = torch.cat((output_tokens, mix_embeddings), dim=1)
|
|
|
|
# Expand per-image data in batch direction to be per-mask
|
|
if image_embeddings.shape[0] != tokens.shape[0]:
|
|
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
|
else:
|
|
src = image_embeddings
|
|
# print("src size is", src.size())
|
|
# print("dense_prompt_embeddings size is", dense_prompt_embeddings.size())
|
|
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
|
b, c, h, w = src.shape
|
|
|
|
# Run the transformer
|
|
hs, src = self.transformer(src, pos_src, tokens)
|
|
iou_token_out = hs[:, 0, :]
|
|
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
|
|
|
# Upscale mask embeddings and predict masks using the mask tokens
|
|
src = src.transpose(1, 2).view(b, c, h, w)
|
|
upscaled_embedding = self.output_upscaling(src)
|
|
hyper_in_list: List[torch.Tensor] = []
|
|
for i in range(self.num_mask_tokens):
|
|
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
|
|
hyper_in = torch.stack(hyper_in_list, dim=1)
|
|
b, c, h, w = upscaled_embedding.shape
|
|
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
|
|
|
# Generate mask quality predictions
|
|
iou_pred = self.iou_prediction_head(iou_token_out)
|
|
|
|
return masks, iou_pred
|
|
|
|
|
|
|
|
# Lightly adapted from
|
|
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
|
class MLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
hidden_dim: int,
|
|
output_dim: int,
|
|
num_layers: int,
|
|
sigmoid_output: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
self.num_layers = num_layers
|
|
h = [hidden_dim] * (num_layers - 1)
|
|
self.layers = nn.ModuleList(
|
|
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
|
)
|
|
self.sigmoid_output = sigmoid_output
|
|
|
|
def forward(self, x):
|
|
for i, layer in enumerate(self.layers):
|
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
|
if self.sigmoid_output:
|
|
x = F.sigmoid(x)
|
|
return x
|
|
|
|
|
|
class Decode_Align(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
embed_dim: int,
|
|
transformer_dim: int,
|
|
stages: int = 4096,
|
|
) -> None:
|
|
super().__init__()
|
|
self.transformer_dim = transformer_dim
|
|
|
|
self.num_mask_tokens = stages
|
|
self.p1_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
|
self.p2_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
|
self.layer = nn.Linear(embed_dim, transformer_dim)
|
|
|
|
def forward(
|
|
self,
|
|
x:torch.Tensor,
|
|
src_embeddings:torch.Tensor,
|
|
image_embeddings: torch.Tensor,
|
|
image_pe: torch.Tensor,
|
|
pt1: torch.Tensor,
|
|
pt2: torch.Tensor,
|
|
dense_prompt_embeddings: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
image_embeddings = self.layer(image_embeddings)
|
|
src_embeddings = self.layer(src_embeddings)
|
|
# x = self.layer(x)
|
|
|
|
p1 = self.p1_tokens.weight.unsqueeze(0).expand(pt1.size(0), -1, -1)
|
|
p2 = self.p2_tokens.weight.unsqueeze(0).expand(pt1.size(0), -1, -1)
|
|
|
|
p1_tokens = torch.cat((p1, pt1), dim=1)
|
|
p2_tokens = torch.cat((p2, pt2), dim=1)
|
|
|
|
if image_embeddings.shape[0] != p1_tokens.shape[0]:
|
|
src = torch.repeat_interleave(image_embeddings, p1_tokens.shape[0], dim=0)
|
|
else:
|
|
src = image_embeddings
|
|
src = src.permute(0, 3, 1 ,2)
|
|
img = src_embeddings.permute(0, 3, 1 ,2)
|
|
x = x.permute(0, 3, 1 ,2)
|
|
src = src + dense_prompt_embeddings
|
|
pos_src = torch.repeat_interleave(image_pe, p1_tokens.shape[0], dim=0)
|
|
b, c, h, w = src.shape
|
|
|
|
return x, img, src, pos_src, p1_tokens, p2_tokens
|