You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
515 lines
17 KiB
515 lines
17 KiB
|
|
import torch
|
|
from torch import Tensor, nn
|
|
|
|
import math
|
|
from typing import Tuple, Type
|
|
|
|
from .common import MLPBlock
|
|
|
|
from torch import nn
|
|
# from functools import partial
|
|
from einops.layers.torch import Rearrange, Reduce
|
|
|
|
import numpy as np
|
|
|
|
import torch.nn.functional as F
|
|
|
|
pair = lambda x: x if isinstance(x, tuple) else (x, x)
|
|
|
|
def gaussian_kernel(size, mean, std):
|
|
"""Generates a 2D Gaussian kernel."""
|
|
d = torch.distributions.Normal(mean, std)
|
|
vals = d.log_prob(torch.arange(size).float())
|
|
grid = torch.exp(vals[:, None] + vals[None, :])
|
|
grid /= grid.sum()
|
|
return grid
|
|
|
|
class GaussianConv2d(nn.Module):
|
|
def __init__(self, in_channels = 1, out_channels = 1, kernel_size = 3, stride=1, padding=1, mean=0.0, std=1.0):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.kernel_size = kernel_size
|
|
self.stride = stride
|
|
self.padding = padding
|
|
self.mean = nn.Parameter(torch.tensor(mean), requires_grad=True)
|
|
self.std = nn.Parameter(torch.tensor(std), requires_grad=True)
|
|
self.weights = nn.Parameter(gaussian_kernel(kernel_size, self.mean, self.std), requires_grad=True)
|
|
self.bias = nn.Parameter(torch.zeros(out_channels), requires_grad=True)
|
|
|
|
def forward(self, x):
|
|
return F.conv2d(x, self.weights.unsqueeze(0).unsqueeze(0).repeat(self.out_channels, self.in_channels, 1, 1),
|
|
bias=self.bias, stride=self.stride, padding=self.padding)
|
|
|
|
|
|
def PromptMLP(dim = 3, expansion_factor = 4, dropout = 0., dense = nn.Linear):
|
|
inner_dim = int(dim * expansion_factor)
|
|
return nn.Sequential(
|
|
dense(dim, inner_dim),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
dense(inner_dim, 1),
|
|
nn.Dropout(dropout)
|
|
)
|
|
|
|
class PromptMixer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int = 3,
|
|
depth: int = 1,
|
|
expansion_factor: int = 4,
|
|
dropout: float = 0.,
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
self.depth = depth
|
|
self.dim = dim
|
|
self.expansion_factor = expansion_factor
|
|
self.dropout = dropout
|
|
self.layers = nn.Sequential(
|
|
Rearrange('k b n d -> b n d k'),
|
|
*[nn.Sequential(
|
|
PromptMLP(dim, expansion_factor, dropout),
|
|
) for _ in range(depth)],
|
|
# nn.LayerNorm(dim) # b n d
|
|
)
|
|
|
|
def forward(self, q, k, v):
|
|
qk = torch.stack([q, k, v]) # 3 b n d
|
|
res = self.layers(qk)
|
|
# print("res size is", res.size())
|
|
return res.squeeze(-1) # b n d
|
|
|
|
|
|
class PromptParser(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embedding_dim: int,
|
|
token_num: int,
|
|
) -> None:
|
|
super().__init__()
|
|
self.embedding_dim = embedding_dim
|
|
|
|
self.pt_mix = PromptMixer()
|
|
self.gauss = GaussianConv2d(in_channels = token_num)
|
|
|
|
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
|
|
|
def forward(
|
|
self,
|
|
image_embedding: Tensor,
|
|
tmp_embedding: Tensor,
|
|
prompt_embedding1: Tensor,
|
|
prompt_embedding2: Tensor,
|
|
) -> Tuple[Tensor, Tensor]:
|
|
|
|
# Ensure prompt embeddings have the same batch size as tmp_embedding
|
|
if prompt_embedding1.size(0) != tmp_embedding.size(0):
|
|
prompt_embedding1 = torch.repeat_interleave(prompt_embedding1, tmp_embedding.size(0), dim=0)
|
|
if prompt_embedding2.size(0) != tmp_embedding.size(0):
|
|
prompt_embedding2 = torch.repeat_interleave(prompt_embedding2, tmp_embedding.size(0), dim=0)
|
|
|
|
pt_pe = prompt_embedding1 + prompt_embedding2
|
|
etpp = self.pt_mix(tmp_embedding, prompt_embedding1, prompt_embedding2)
|
|
att_m = torch.einsum ('bncd, bndx -> bncx', etpp.unsqueeze(-1), image_embedding.unsqueeze(-2))
|
|
att_m = self.gauss(att_m)
|
|
etq = torch.einsum ('bncd, bndx -> bncx', image_embedding.unsqueeze(-1), (tmp_embedding + pt_pe).unsqueeze(-2))
|
|
eg = torch.max(att_m * etq, etq)
|
|
res = torch.einsum ('bncx, bnx -> bnc', eg, tmp_embedding + pt_pe)
|
|
return image_embedding, res
|
|
|
|
class OnePromptFormer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embedding_dim: int,
|
|
prompt_embed_dim: int,
|
|
token_num: int,
|
|
num_heads: int,
|
|
mlp_dim: int,
|
|
activation: Type[nn.Module] = nn.ReLU,
|
|
) -> None:
|
|
super().__init__()
|
|
self.embedding_dim = embedding_dim
|
|
self.num_heads = num_heads
|
|
self.mlp_dim = mlp_dim
|
|
|
|
self.layers = nn.ModuleList()
|
|
|
|
self.nn = nn.Linear(embedding_dim, prompt_embed_dim)
|
|
|
|
self.attns1 = Attention(prompt_embed_dim, num_heads)
|
|
self.attns2 = Attention(prompt_embed_dim, num_heads)
|
|
self.mlps1 = MLPBlock(prompt_embed_dim, mlp_dim, activation)
|
|
self.norms1 = nn.LayerNorm(prompt_embed_dim)
|
|
self.norms2 = nn.LayerNorm(prompt_embed_dim)
|
|
|
|
|
|
self.parser = PromptParser(embedding_dim = prompt_embed_dim, token_num = token_num)
|
|
self.attnt1 = Attention(prompt_embed_dim, num_heads)
|
|
self.mlpt1 = MLPBlock(prompt_embed_dim, mlp_dim, activation)
|
|
self.normt1 = nn.LayerNorm(prompt_embed_dim)
|
|
self.normt2 = nn.LayerNorm(prompt_embed_dim)
|
|
|
|
self.attnm1 = Attention(prompt_embed_dim, num_heads)
|
|
self.attnm2 = Attention(prompt_embed_dim, num_heads)
|
|
|
|
self.final = nn.Sequential(
|
|
MLPBlock(prompt_embed_dim, mlp_dim, activation),
|
|
nn.LayerNorm(prompt_embed_dim)
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
emb: Tensor,
|
|
image_embedding: Tensor,
|
|
tmp_embedding: Tensor,
|
|
prompt_embedding1: Tensor,
|
|
prompt_embedding2: Tensor,
|
|
) -> Tuple[Tensor, Tensor]:
|
|
|
|
image_embedding, et = self.parser(image_embedding,tmp_embedding, prompt_embedding1, prompt_embedding2)
|
|
es = self.attns1(q=image_embedding, k= emb, v= emb)
|
|
es_bk = es
|
|
es = self.attns2(q=et, k= es, v= es)
|
|
es = self.norms1(es + et)
|
|
es = self.norms2(self.mlps1(es) + es)
|
|
|
|
et = self.attnt1(q = es_bk, k = et, v = et)
|
|
et = self.normt1(es_bk + et)
|
|
et = self.norms2(self.mlps1(et) + et)
|
|
|
|
e = self.attnm1(q = et, k = es, v = es)
|
|
e = self.attnm2(q = e, k = e, v = e)
|
|
e = self.final(e)
|
|
|
|
return e
|
|
|
|
|
|
class MixedUpScale(nn.Module):
|
|
def __init__(
|
|
self,
|
|
depth: int,
|
|
embedding_dim: int,
|
|
num_heads: int,
|
|
mlp_dim: int,
|
|
activation: Type[nn.Module] = nn.ReLU,
|
|
attention_downsample_rate: int = 2,
|
|
) -> None:
|
|
super().__init__()
|
|
self.depth = depth
|
|
self.embedding_dim = embedding_dim
|
|
self.num_heads = num_heads
|
|
self.mlp_dim = mlp_dim
|
|
self.layers = nn.ModuleList()
|
|
|
|
for i in range(depth):
|
|
self.layers.append(
|
|
CrossAttentionBlock(
|
|
embedding_dim=embedding_dim,
|
|
num_heads=num_heads,
|
|
mlp_dim=mlp_dim,
|
|
activation=activation,
|
|
attention_downsample_rate=attention_downsample_rate,
|
|
skip_first_layer_pe=(i == 0),
|
|
)
|
|
)
|
|
|
|
self.final_attn = Attention(
|
|
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
|
)
|
|
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
|
|
|
def forward(
|
|
self,
|
|
image_embedding: Tensor,
|
|
image_pe: Tensor,
|
|
point_embedding: Tensor,
|
|
) -> Tuple[Tensor, Tensor]:
|
|
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
|
bs, c, h, w = image_embedding.shape
|
|
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
|
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
|
|
|
# Prepare queries
|
|
queries = point_embedding
|
|
keys = image_embedding
|
|
|
|
# Apply transformer blocks and final layernorm
|
|
for layer in self.layers:
|
|
queries, keys = layer(
|
|
queries=queries,
|
|
keys=keys,
|
|
query_pe=point_embedding,
|
|
key_pe=image_pe,
|
|
)
|
|
|
|
# Apply the final attention layer from the points to the image
|
|
q = queries + point_embedding
|
|
k = keys + image_pe
|
|
attn_out = self.final_attn(q=q, k=k, v=keys)
|
|
queries = queries + attn_out
|
|
queries = self.norm_final_attn(queries)
|
|
|
|
return queries, keys
|
|
|
|
class TwoWayTransformer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
depth: int,
|
|
embedding_dim: int,
|
|
num_heads: int,
|
|
mlp_dim: int,
|
|
activation: Type[nn.Module] = nn.ReLU,
|
|
attention_downsample_rate: int = 2,
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
self.depth = depth
|
|
self.embedding_dim = embedding_dim
|
|
self.num_heads = num_heads
|
|
self.mlp_dim = mlp_dim
|
|
self.layers = nn.ModuleList()
|
|
|
|
for i in range(depth):
|
|
self.layers.append(
|
|
TwoWayAttentionBlock(
|
|
embedding_dim=embedding_dim,
|
|
num_heads=num_heads,
|
|
mlp_dim=mlp_dim,
|
|
activation=activation,
|
|
attention_downsample_rate=attention_downsample_rate,
|
|
skip_first_layer_pe=(i == 0),
|
|
)
|
|
)
|
|
|
|
self.final_attn_token_to_image = Attention(
|
|
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
|
)
|
|
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
|
|
|
def forward(
|
|
self,
|
|
image_embedding: Tensor,
|
|
image_pe: Tensor,
|
|
point_embedding: Tensor,
|
|
) -> Tuple[Tensor, Tensor]:
|
|
|
|
bs, c, h, w = image_embedding.shape
|
|
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
|
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
|
|
|
# Prepare queries
|
|
queries = point_embedding
|
|
keys = image_embedding
|
|
|
|
# Apply transformer blocks and final layernorm
|
|
for layer in self.layers:
|
|
queries, keys = layer(
|
|
queries=queries,
|
|
keys=keys,
|
|
query_pe=point_embedding,
|
|
key_pe=image_pe,
|
|
)
|
|
|
|
# Apply the final attention layer from the points to the image
|
|
q = queries + point_embedding
|
|
k = keys + image_pe
|
|
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
|
queries = queries + attn_out
|
|
queries = self.norm_final_attn(queries)
|
|
|
|
return queries, keys
|
|
|
|
|
|
class TwoWayAttentionBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embedding_dim: int,
|
|
num_heads: int,
|
|
mlp_dim: int = 2048,
|
|
activation: Type[nn.Module] = nn.ReLU,
|
|
attention_downsample_rate: int = 2,
|
|
skip_first_layer_pe: bool = False,
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
self.self_attn = Attention(embedding_dim, num_heads)
|
|
self.norm1 = nn.LayerNorm(embedding_dim)
|
|
|
|
self.cross_attn_token_to_image = Attention(
|
|
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
|
)
|
|
self.norm2 = nn.LayerNorm(embedding_dim)
|
|
|
|
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
|
self.norm3 = nn.LayerNorm(embedding_dim)
|
|
|
|
self.norm4 = nn.LayerNorm(embedding_dim)
|
|
self.cross_attn_image_to_token = Attention(
|
|
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
|
)
|
|
|
|
self.skip_first_layer_pe = skip_first_layer_pe
|
|
|
|
def forward(
|
|
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
|
) -> Tuple[Tensor, Tensor]:
|
|
# Self attention block
|
|
if self.skip_first_layer_pe:
|
|
queries = self.self_attn(q=queries, k=queries, v=queries)
|
|
else:
|
|
q = queries + query_pe
|
|
attn_out = self.self_attn(q=q, k=q, v=queries)
|
|
queries = queries + attn_out
|
|
queries = self.norm1(queries)
|
|
|
|
# Cross attention block, tokens attending to image embedding
|
|
q = queries + query_pe
|
|
# print("key size is", keys.size())
|
|
# print("image_pe size is", key_pe.size())
|
|
k = keys + key_pe
|
|
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
|
queries = queries + attn_out
|
|
queries = self.norm2(queries)
|
|
|
|
# MLP block
|
|
mlp_out = self.mlp(queries)
|
|
queries = queries + mlp_out
|
|
queries = self.norm3(queries)
|
|
|
|
# Cross attention block, image embedding attending to tokens
|
|
q = queries + query_pe
|
|
k = keys + key_pe
|
|
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
|
keys = keys + attn_out
|
|
keys = self.norm4(keys)
|
|
|
|
return queries, keys
|
|
|
|
|
|
class CrossAttentionBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
depth: int,
|
|
embedding_dim: int,
|
|
num_heads: int,
|
|
mlp_dim: int = 2048,
|
|
activation: Type[nn.Module] = nn.ReLU,
|
|
attention_downsample_rate: int = 2,
|
|
skip_first_layer_pe: bool = False,
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
self.self_attn = Attention(embedding_dim, num_heads)
|
|
self.norm1 = nn.LayerNorm(embedding_dim)
|
|
|
|
self.cross_attn_token_to_image = Attention(
|
|
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
|
)
|
|
self.norm2 = nn.LayerNorm(embedding_dim)
|
|
|
|
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
|
self.norm3 = nn.LayerNorm(embedding_dim)
|
|
|
|
self.norm4 = nn.LayerNorm(embedding_dim)
|
|
self.cross_attn_image_to_token = Attention(
|
|
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
|
)
|
|
|
|
self.skip_first_layer_pe = skip_first_layer_pe
|
|
|
|
def forward(
|
|
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
|
) -> Tuple[Tensor, Tensor]:
|
|
# Self attention block
|
|
if self.skip_first_layer_pe:
|
|
queries = self.self_attn(q=queries, k=queries, v=queries)
|
|
else:
|
|
q = queries + query_pe
|
|
attn_out = self.self_attn(q=q, k=q, v=queries)
|
|
queries = queries + attn_out
|
|
queries = self.norm1(queries)
|
|
|
|
# Cross attention block, tokens attending to image embedding
|
|
q = queries + query_pe
|
|
k = keys + key_pe
|
|
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
|
queries = queries + attn_out
|
|
queries = self.norm2(queries)
|
|
|
|
# MLP block
|
|
mlp_out = self.mlp(queries)
|
|
queries = queries + mlp_out
|
|
queries = self.norm3(queries)
|
|
|
|
# Cross attention block, image embedding attending to tokens
|
|
q = queries + query_pe
|
|
k = keys + key_pe
|
|
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
|
keys = keys + attn_out
|
|
keys = self.norm4(keys)
|
|
|
|
return queries, keys
|
|
|
|
|
|
class Attention(nn.Module):
|
|
"""
|
|
An attention layer that allows for downscaling the size of the embedding
|
|
after projection to queries, keys, and values.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding_dim: int,
|
|
num_heads: int,
|
|
downsample_rate: int = 1,
|
|
) -> None:
|
|
super().__init__()
|
|
self.embedding_dim = embedding_dim
|
|
self.internal_dim = embedding_dim // downsample_rate
|
|
self.num_heads = num_heads
|
|
# print("self.embedding_dim is", self.embedding_dim)
|
|
# print("self.internal_dim is", self.internal_dim)
|
|
# print("num_heads is", num_heads)
|
|
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
|
|
|
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
|
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
|
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
|
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
|
|
|
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
|
b, n, c = x.shape
|
|
x = x.reshape(b, n, num_heads, c // num_heads)
|
|
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
|
|
|
def _recombine_heads(self, x: Tensor) -> Tensor:
|
|
b, n_heads, n_tokens, c_per_head = x.shape
|
|
x = x.transpose(1, 2)
|
|
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
|
|
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
|
# Input projections
|
|
q = self.q_proj(q)
|
|
k = self.k_proj(k)
|
|
v = self.v_proj(v)
|
|
|
|
# Separate into heads
|
|
q = self._separate_heads(q, self.num_heads)
|
|
k = self._separate_heads(k, self.num_heads)
|
|
v = self._separate_heads(v, self.num_heads)
|
|
|
|
# Attention
|
|
_, _, _, c_per_head = q.shape
|
|
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
|
attn = attn / math.sqrt(c_per_head)
|
|
attn = torch.softmax(attn, dim=-1)
|
|
|
|
# Get output
|
|
out = attn @ v
|
|
out = self._recombine_heads(out)
|
|
out = self.out_proj(out)
|
|
|
|
return out
|