- Add core training scripts (train.py, val.py, function.py, utils.py) - Add model implementations (oneprompt, unet, tag, etc.) - Add configuration files (configs/, conf/) - Add utility scripts and dependencies - Add .gitignore to exclude logs, checkpoints and cache files 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>main
parent
91a9803e8a
commit
60b9c98523
@ -0,0 +1,7 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(git add:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,54 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
*.egg-info/
|
||||
*.egg
|
||||
.eggs/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Virtual environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# Logs and outputs
|
||||
logs/
|
||||
runs/
|
||||
|
||||
# Model checkpoints (large files)
|
||||
checkpoint/
|
||||
*.pth
|
||||
*.pt
|
||||
*.ckpt
|
||||
*.bin
|
||||
*.safetensors
|
||||
|
||||
# Data files (usually large)
|
||||
data/
|
||||
datasets/
|
||||
|
||||
# OS generated files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
desktop.ini
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.temp
|
||||
*.bak
|
||||
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Medicine Token
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@ -0,0 +1,111 @@
|
||||
# One-Prompt to Segment All Meical Image
|
||||
|
||||
One-Prompt to Segment All Medical Images, or say One-Prompt, combines the strengths of one-shot and interactive methods. In the inference stage, with just one prompted sample, it can adeptly handle the unseen task in a single forward pass.
|
||||
|
||||
This method is elaborated in the paper [One-Prompt to Segment All Medical Images](https://arxiv.org/abs/2305.10300).
|
||||
|
||||
|
||||
## A Quick Overview
|
||||
<img width="800" height="580" src="https://github.com/KidsWithTokens/one-prompt/blob/main/figs/oneprompt.png">
|
||||
|
||||
|
||||
## Requirement
|
||||
|
||||
Install the environment:
|
||||
|
||||
``conda env create -f environment.yml``
|
||||
|
||||
``conda activate oneprompt``
|
||||
|
||||
## Dataset
|
||||
### Download the open-source datasets
|
||||
We collected 78 **open-source** datasets for training and testing the model. The datasets and their download links are in [here](https://drive.google.com/file/d/1iXFm9M1ocrWNkEIthWUWnZYY2-1l-qya/view?usp=share_link).
|
||||
|
||||
### Download the prompts
|
||||
The prompts corresponding to the datasets can be downloaded [here](https://drive.google.com/file/d/1cNv2WW_Cv2NYzpt90vvELaweM5ltIe8n/view?usp=share_link). Each prompt is saved a json message with the format ``{DATASET_NAME, SAMPLE_INDEX, PROMPT_TYPE, PROMPT_CONTENT}``
|
||||
|
||||
## Train
|
||||
run ``python train.py -net oneprompt -mod one_adpt -exp_name basic_exp -b 64 -dataset oneprompt -data_path *../data* -baseline 'unet'``
|
||||
|
||||
## Test Examples
|
||||
|
||||
### Melanoma Segmentation from Skin Images (2D)
|
||||
|
||||
1. Download ISIC dataset part 1 from https://challenge.isic-archive.com/data/. Then put the csv files in "./data/isic" under your data path. Your dataset folder under "your_data_path" should be like:
|
||||
|
||||
ISIC/
|
||||
|
||||
ISBI2016_ISIC_Part1_Test_Data/...
|
||||
|
||||
ISBI2016_ISIC_Part1_Training_Data/...
|
||||
|
||||
ISBI2016_ISIC_Part1_Test_GroundTruth.csv
|
||||
|
||||
ISBI2016_ISIC_Part1_Training_GroundTruth.csv
|
||||
|
||||
2. run: ``python val.py -net oneprompt -mod one_adpt -exp_name One-ISIC -weights *weight_path* -b 1 -dataset isic -data_path ../dataset/isic -vis 10 -baseline 'unet'``
|
||||
change "data_path" and "exp_name" for your own useage. you can change "exp_name" to anything you want.
|
||||
|
||||
You can descrease the ``image size`` or batch size ``b`` if out of memory.
|
||||
|
||||
3. Evaluation: The code can automatically evaluate the model on the test set during traing, set "--val_freq" to control how many epoches you want to evaluate once. You can also run val.py for the independent evaluation.
|
||||
|
||||
4. Result Visualization: You can set "--vis" parameter to control how many epoches you want to see the results in the training or evaluation process.
|
||||
|
||||
In default, everything will be saved at `` ./logs/``
|
||||
|
||||
### REFUGE: Optic-disc Segmentation from Fundus Images (2D)
|
||||
[REFUGE](https://refuge.grand-challenge.org/) dataset contains 1200 fundus images with optic disc/cup segmentations and clinical glaucoma labels.
|
||||
|
||||
1. Dowaload the dataset manually from [here](https://huggingface.co/datasets/realslimman/REFUGE-MultiRater/tree/main), or using command lines:
|
||||
|
||||
``git lfs install``
|
||||
|
||||
``git clone git@hf.co:datasets/realslimman/REFUGE-MultiRater``
|
||||
|
||||
unzip and put the dataset to the target folder
|
||||
|
||||
``unzip ./REFUGE-MultiRater.zip``
|
||||
|
||||
``mv REFUGE-MultiRater ./data``
|
||||
|
||||
2. For training the adapter, run: ``python val.py -net oneprompt -mod one_adpt -exp_name One-REFUGE -weights *weight_path* -b 1 -baseline 'unet' -dataset REFUGE -data_path ./data/REFUGE-MultiRater``
|
||||
you can change "exp_name" to anything you want.
|
||||
|
||||
You can descrease the ``image size`` or batch size ``b`` if out of memory.
|
||||
|
||||
## Run on your own dataset
|
||||
It is simple to run omeprompt on the other datasets. Just write another dataset class following which in `` ./dataset.py``. You only need to make sure you return a dict with
|
||||
|
||||
|
||||
{
|
||||
'image': A tensor saving images with size [C,H,W] for 2D image, size [C, H, W, D] for 3D data.
|
||||
D is the depth of 3D volume, C is the channel of a scan/frame, which is commonly 1 for CT, MRI, US data.
|
||||
If processing, say like a colorful surgical video, D could the number of time frames, and C will be 3 for a RGB frame.
|
||||
|
||||
'label': The target masks. Same size with the images except the resolutions (H and W).
|
||||
|
||||
'p_label': The prompt label to decide positive/negative prompt. To simplify, you can always set 1 if don't need the negative prompt function.
|
||||
|
||||
'pt': The prompt. e.g., a click prompt should be [x of click, y of click], one click for each scan/frame if using 3d data.
|
||||
|
||||
'image_meta_dict': Optional. if you want save/visulize the result, you should put the name of the image in it with the key ['filename_or_obj'].
|
||||
|
||||
...(others as you want)
|
||||
}
|
||||
|
||||
## Cite
|
||||
```
|
||||
@InProceedings{Wu_2024_CVPR,
|
||||
author = {Wu, Junde and Xu, Min},
|
||||
title = {One-Prompt to Segment All Medical Images},
|
||||
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
month = {June},
|
||||
year = {2024},
|
||||
pages = {11302-11312}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,14 @@
|
||||
""" dynamically load settings
|
||||
|
||||
author baiyu
|
||||
"""
|
||||
import conf.global_settings as settings
|
||||
|
||||
class Settings:
|
||||
def __init__(self, settings):
|
||||
|
||||
for attr in dir(settings):
|
||||
if attr.isupper():
|
||||
setattr(self, attr, getattr(settings, attr))
|
||||
|
||||
settings = Settings(settings)
|
||||
@ -0,0 +1,51 @@
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
#CIFAR100 dataset path (python version)
|
||||
#CIFAR100_PATH = '/nfs/private/cifar100/cifar-100-python'
|
||||
|
||||
#mean and std of cifar100 dataset
|
||||
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
|
||||
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
|
||||
|
||||
GLAUCOMA_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
|
||||
GLAUCOMA_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
|
||||
|
||||
MASK_TRAIN_MEAN = (2.654204690220496/255)
|
||||
MASK_TRAIN_STD = (21.46473779720519/255)
|
||||
|
||||
#CIFAR100_TEST_MEAN = (0.5088964127604166, 0.48739301317401956, 0.44194221124387256)
|
||||
#CIFAR100_TEST_STD = (0.2682515741720801, 0.2573637364478126, 0.2770957707973042)
|
||||
|
||||
#directory to save weights file
|
||||
CHECKPOINT_PATH = 'checkpoint'
|
||||
|
||||
#total training epoches
|
||||
EPOCH = 30000
|
||||
step_size = 10
|
||||
i = 1
|
||||
MILESTONES = []
|
||||
while i * 5 <= EPOCH:
|
||||
MILESTONES.append(i* step_size)
|
||||
i += 1
|
||||
|
||||
#initial learning rate
|
||||
#INIT_LR = 0.1
|
||||
|
||||
#time of we run the script
|
||||
TIME_NOW = datetime.now().isoformat()
|
||||
|
||||
#tensorboard log dir
|
||||
LOG_DIR = 'runs'
|
||||
|
||||
#save weights file per SAVE_EPOCH epoch
|
||||
SAVE_EPOCH = 10
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,54 @@
|
||||
# One-Prompt Medical Image Segmentation 配置文件
|
||||
# 项目: One-Prompt to Segment All Medical Images (CVPR 2024)
|
||||
|
||||
project:
|
||||
name: "one-prompt-segmentation"
|
||||
version: "1.0.0"
|
||||
description: "一提示分割所有医学图像"
|
||||
paper: "https://arxiv.org/abs/2305.10300"
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
dataset: "polyp" # 数据集类型: polyp, isic, refuge
|
||||
data_path: "/root/wangtao/paper_reapppearence/data/TestDataset"
|
||||
train_ratio: 0.8
|
||||
batch_size: 1
|
||||
num_workers: 4
|
||||
|
||||
# 模型配置
|
||||
model:
|
||||
net: "oneprompt" # 网络类型
|
||||
baseline: "unet" # 基线模型: unet, resnet
|
||||
mod: "one_adpt" # 模块类型
|
||||
image_size: 256 # 输入图像大小
|
||||
out_size: 256 # 输出大小
|
||||
patch_size: 16 # Patch大小 (需要等于 2^num_pool)
|
||||
dim: 256 # 嵌入维度
|
||||
depth: 1 # Transformer深度
|
||||
heads: 16 # 注意力头数
|
||||
mlp_dim: 1024 # MLP维度
|
||||
|
||||
# 训练配置
|
||||
training:
|
||||
epochs: 100 # 训练轮数
|
||||
learning_rate: 0.0001 # 学习率
|
||||
optimizer: "adam" # 优化器
|
||||
weight_decay: 0.0 # 权重衰减
|
||||
scheduler:
|
||||
name: "step" # 学习率调度器
|
||||
step_size: 10 # 步长
|
||||
gamma: 0.5 # 衰减因子
|
||||
early_stopping_patience: 20 # 早停耐心值
|
||||
gradient_clip: 1.0 # 梯度裁剪
|
||||
|
||||
# 验证配置
|
||||
validation:
|
||||
val_freq: 5 # 验证频率
|
||||
vis_freq: 50 # 可视化频率
|
||||
|
||||
# 日志配置
|
||||
logging:
|
||||
log_dir: "logs"
|
||||
tensorboard: true
|
||||
save_best: true
|
||||
checkpoint_freq: 10
|
||||
@ -0,0 +1,34 @@
|
||||
name: oneprompt
|
||||
channels:
|
||||
- pytorch
|
||||
- nvidia
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.11
|
||||
- pytorch=2.1.1
|
||||
- torchvision=0.16.1
|
||||
- torchaudio=2.1.1
|
||||
- pytorch-cuda=12.1
|
||||
- numpy=1.26.0
|
||||
- pandas=2.1.1
|
||||
- pillow=10.0.1
|
||||
- matplotlib=3.8.0
|
||||
- scikit-image=0.22.0
|
||||
- scikit-learn=1.2.2
|
||||
- scipy=1.11.4
|
||||
- tqdm
|
||||
- pyyaml
|
||||
- tensorboardx
|
||||
- transformers=4.32.1
|
||||
- timm=0.9.12
|
||||
- accelerate=0.24.1
|
||||
- huggingface_hub
|
||||
- pip
|
||||
- pip:
|
||||
- monai==1.3.0
|
||||
- einops==0.7.0
|
||||
- opencv-python==4.8.1.78
|
||||
- kornia==0.4.1
|
||||
- nibabel
|
||||
- batchgenerators
|
||||
|
After Width: | Height: | Size: 1.0 MiB |
@ -0,0 +1,334 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from skimage import io
|
||||
from torch.utils.data import DataLoader
|
||||
#from dataset import *
|
||||
from torch.autograd import Variable
|
||||
from PIL import Image
|
||||
from tensorboardX import SummaryWriter
|
||||
#from models.discriminatorlayer import discriminator
|
||||
from conf import settings
|
||||
import time
|
||||
import cfg
|
||||
from conf import settings
|
||||
from tqdm import tqdm
|
||||
from utils import *
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
from einops import rearrange
|
||||
import pytorch_ssim
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
|
||||
from monai.losses import DiceCELoss
|
||||
from monai.inferers import sliding_window_inference
|
||||
from monai.transforms import (
|
||||
AsDiscrete,
|
||||
)
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
args = cfg.parse_args()
|
||||
|
||||
GPUdevice = torch.device('cuda', args.gpu_device)
|
||||
pos_weight = torch.ones([1]).cuda(device=GPUdevice)*2
|
||||
criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
||||
seed = torch.randint(1,11,(args.b,7))
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
max_iterations = settings.EPOCH
|
||||
post_label = AsDiscrete(to_onehot=14)
|
||||
post_pred = AsDiscrete(argmax=True, to_onehot=14)
|
||||
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
|
||||
dice_val_best = 0.0
|
||||
global_step_best = 0
|
||||
epoch_loss_values = []
|
||||
metric_values = []
|
||||
|
||||
def train_one(args, net: nn.Module, optimizer, train_loader,
|
||||
epoch, writer, schedulers=None, vis = 50):
|
||||
hard = 0
|
||||
epoch_loss = 0
|
||||
ind = 0
|
||||
# train mode
|
||||
net.train()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 处理 DataParallel 包装
|
||||
model = net.module if hasattr(net, 'module') else net
|
||||
|
||||
epoch_loss = 0
|
||||
GPUdevice = torch.device('cuda:' + str(args.gpu_device))
|
||||
|
||||
if args.thd:
|
||||
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
|
||||
else:
|
||||
lossfunc = criterion_G
|
||||
|
||||
with tqdm(total=len(train_loader), desc=f'Epoch {epoch}', unit='img') as pbar:
|
||||
|
||||
for pack in train_loader:
|
||||
# 获取当前batch的实际大小
|
||||
current_b = pack['image'].size(0)
|
||||
|
||||
if ind == 0:
|
||||
tmp_img = pack['image'].to(dtype = torch.float32, device = GPUdevice)[0,:,:,:].unsqueeze(0).repeat(current_b, 1, 1, 1)
|
||||
tmp_mask = pack['label'].to(dtype = torch.float32, device = GPUdevice)[0,:,:,:].unsqueeze(0).repeat(current_b, 1, 1, 1)
|
||||
if 'pt' not in pack:
|
||||
tmp_img, pt, tmp_mask = generate_click_prompt(tmp_img, tmp_mask)
|
||||
else:
|
||||
pt = pack['pt']
|
||||
point_labels = pack['p_label']
|
||||
|
||||
if point_labels[0] != -1:
|
||||
point_coords = pt
|
||||
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
|
||||
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
|
||||
# 只取第一个点并重复到当前batch大小
|
||||
coords_torch = coords_torch[0:1, :].repeat(current_b, 1)
|
||||
labels_torch = labels_torch[0:1].repeat(current_b)
|
||||
coords_torch, labels_torch = coords_torch[:, None, :], labels_torch[:, None]
|
||||
tmp_pt = (coords_torch, labels_torch)
|
||||
else:
|
||||
# 更新模板图片的batch大小以匹配当前batch
|
||||
if tmp_img.size(0) != current_b:
|
||||
tmp_img = tmp_img[0:1].repeat(current_b, 1, 1, 1)
|
||||
tmp_mask = tmp_mask[0:1].repeat(current_b, 1, 1, 1)
|
||||
if 'tmp_pt' in dir():
|
||||
coords_torch = tmp_pt[0][0:1].repeat(current_b, 1, 1)
|
||||
labels_torch = tmp_pt[1][0:1].repeat(current_b, 1)
|
||||
tmp_pt = (coords_torch, labels_torch)
|
||||
|
||||
imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice)
|
||||
masks = pack['label'].to(dtype = torch.float32, device = GPUdevice)
|
||||
|
||||
name = pack['image_meta_dict']['filename_or_obj']
|
||||
|
||||
# 处理当前batch的点击提示
|
||||
if 'pt' in pack:
|
||||
pt = pack['pt']
|
||||
point_labels = pack['p_label']
|
||||
if point_labels[0] != -1:
|
||||
point_coords = pt
|
||||
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
|
||||
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
|
||||
coords_torch, labels_torch = coords_torch[:, None, :], labels_torch[:, None]
|
||||
pt = (coords_torch, labels_torch)
|
||||
|
||||
if args.thd:
|
||||
pt = rearrange(pt, 'b n d -> (b d) n')
|
||||
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
|
||||
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
|
||||
|
||||
imgs = imgs.repeat(1,3,1,1)
|
||||
point_labels = torch.ones(imgs.size(0))
|
||||
|
||||
imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs)
|
||||
masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks)
|
||||
|
||||
showp = pt
|
||||
|
||||
mask_type = torch.float32
|
||||
ind += 1
|
||||
b_size,c,w,h = imgs.size()
|
||||
longsize = w if w >=h else h
|
||||
|
||||
'''init'''
|
||||
if hard:
|
||||
true_mask_ave = (true_mask_ave > 0.5).float()
|
||||
imgs = imgs.to(dtype = mask_type,device = GPUdevice)
|
||||
|
||||
# 使用混合精度训练
|
||||
with torch.amp.autocast('cuda'):
|
||||
with torch.no_grad():
|
||||
# 使用梯度检查点节省显存
|
||||
imge, skips= model.image_encoder(imgs)
|
||||
timge, tskips = model.image_encoder(tmp_img)
|
||||
|
||||
# imge= net.image_encoder(imgs)
|
||||
p1, p2, se, de = model.prompt_encoder(
|
||||
points=pt,
|
||||
boxes=None,
|
||||
doodles= None,
|
||||
masks=None,
|
||||
)
|
||||
|
||||
# 清理不需要的中间变量
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
pred, _ = model.mask_decoder(
|
||||
skips_raw = skips,
|
||||
skips_tmp = tskips,
|
||||
raw_emb = imge,
|
||||
tmp_emb = timge,
|
||||
pt1 = p1,
|
||||
pt2 = p2,
|
||||
image_pe=model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=se,
|
||||
dense_prompt_embeddings=de,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
# 调整预测大小以匹配目标
|
||||
if pred.shape[-2:] != masks.shape[-2:]:
|
||||
pred = F.interpolate(pred, size=masks.shape[-2:], mode='bilinear', align_corners=False)
|
||||
|
||||
loss = lossfunc(pred, masks)
|
||||
|
||||
# 检查 nan 并跳过
|
||||
if torch.isnan(loss) or torch.isinf(loss):
|
||||
optimizer.zero_grad()
|
||||
pbar.set_postfix(**{'loss (batch)': 'nan/inf skipped'})
|
||||
pbar.update()
|
||||
ind += 1
|
||||
continue
|
||||
|
||||
pbar.set_postfix(**{'loss (batch)': loss.item()})
|
||||
epoch_loss += loss.item()
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
# 梯度裁剪
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
'''vis images'''
|
||||
if vis:
|
||||
if ind % vis == 0:
|
||||
namecat = 'Train'
|
||||
for na in name:
|
||||
namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
|
||||
vis_image(imgs,pred,masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False)
|
||||
|
||||
pbar.update()
|
||||
|
||||
return loss
|
||||
|
||||
def validation_one(args, val_loader, epoch, net: nn.Module, clean_dir=True):
|
||||
# eval mode
|
||||
net.eval()
|
||||
|
||||
# 处理 DataParallel 包装
|
||||
model = net.module if hasattr(net, 'module') else net
|
||||
|
||||
mask_type = torch.float32
|
||||
n_val = len(val_loader) # the number of batch
|
||||
ave_res, mix_res = (0,0,0,0), (0,0,0,0)
|
||||
rater_res = [(0,0,0,0) for _ in range(6)]
|
||||
tot = 0
|
||||
hard = 0
|
||||
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
|
||||
GPUdevice = torch.device('cuda:' + str(args.gpu_device))
|
||||
device = GPUdevice
|
||||
|
||||
if args.thd:
|
||||
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
|
||||
else:
|
||||
lossfunc = criterion_G
|
||||
|
||||
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
|
||||
|
||||
for ind, pack in enumerate(val_loader):
|
||||
if ind == 0:
|
||||
tmp_img = pack['image'].to(dtype = torch.float32, device = GPUdevice)[0,:,:,:].unsqueeze(0).repeat(args.b, 1, 1, 1)
|
||||
tmp_mask = pack['label'].to(dtype = torch.float32, device = GPUdevice)[0,:,:,:].unsqueeze(0).repeat(args.b, 1, 1, 1)
|
||||
if 'pt' not in pack:
|
||||
tmp_img, pt, tmp_mask = generate_click_prompt(tmp_img, tmp_mask)
|
||||
else:
|
||||
pt = pack['pt']
|
||||
point_labels = pack['p_label']
|
||||
|
||||
if point_labels[0] != -1:
|
||||
# point_coords = onetrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w))
|
||||
point_coords = pt
|
||||
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
|
||||
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
|
||||
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
|
||||
pt = (coords_torch, labels_torch)
|
||||
|
||||
|
||||
imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice)
|
||||
masks = pack['label'].to(dtype = torch.float32, device = GPUdevice)
|
||||
|
||||
name = pack['image_meta_dict']['filename_or_obj']
|
||||
|
||||
showp = pt
|
||||
|
||||
mask_type = torch.float32
|
||||
ind += 1
|
||||
b_size,c,w,h = imgs.size()
|
||||
longsize = w if w >=h else h
|
||||
|
||||
'''init'''
|
||||
if hard:
|
||||
true_mask_ave = (true_mask_ave > 0.5).float()
|
||||
#true_mask_ave = cons_tensor(true_mask_ave)
|
||||
imgs = imgs.to(dtype = mask_type,device = GPUdevice)
|
||||
|
||||
'''test'''
|
||||
with torch.no_grad():
|
||||
imge, skips= model.image_encoder(imgs)
|
||||
timge, tskips = model.image_encoder(tmp_img)
|
||||
|
||||
p1, p2, se, de = model.prompt_encoder(
|
||||
points=pt,
|
||||
boxes=None,
|
||||
doodles= None,
|
||||
masks=None,
|
||||
)
|
||||
pred, _ = model.mask_decoder(
|
||||
skips_raw = skips,
|
||||
skips_tmp = tskips,
|
||||
raw_emb = imge,
|
||||
tmp_emb = timge,
|
||||
pt1 = p1,
|
||||
pt2 = p2,
|
||||
image_pe=model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=se,
|
||||
dense_prompt_embeddings=de,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
# 调整预测大小以匹配目标
|
||||
if pred.shape[-2:] != masks.shape[-2:]:
|
||||
pred = F.interpolate(pred, size=masks.shape[-2:], mode='bilinear', align_corners=False)
|
||||
|
||||
tot += lossfunc(pred, masks)
|
||||
|
||||
'''vis images'''
|
||||
if args.vis and ind % args.vis == 0:
|
||||
namecat = 'Test'
|
||||
for na in name:
|
||||
img_name = na.split('/')[-1].split('.')[0]
|
||||
namecat = namecat + img_name + '+'
|
||||
vis_image(imgs,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False)
|
||||
|
||||
|
||||
temp = eval_seg(pred, masks, threshold)
|
||||
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
|
||||
|
||||
pbar.update()
|
||||
|
||||
|
||||
return tot/ n_val , tuple([a/n_val for a in mix_res])
|
||||
@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
git config --global user.name "Wu Junde"
|
||||
git config --global user.email "izzy843794947@gmail.com"
|
||||
# Add changes to the staging area
|
||||
git add .
|
||||
|
||||
# Commit changes with a default message
|
||||
git commit -m "update"
|
||||
|
||||
# Push changes to the remote repository
|
||||
git push origin master
|
||||
@ -0,0 +1,84 @@
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.utils as vutils
|
||||
import numpy as np
|
||||
|
||||
# class Discriminator(nn.Module):
|
||||
# def __init__(self, ngpu, nc = 3, ndf = 64):
|
||||
# super(Discriminator, self).__init__()
|
||||
# self.ngpu = ngpu
|
||||
# self.main = nn.Sequential(
|
||||
# # input is (nc) x 64 x 64
|
||||
# nn.Conv2d(nc, ndf, 4, 4, 1, bias=False),
|
||||
# nn.LeakyReLU(0.2, inplace=True),
|
||||
# # state size. (ndf) x 32 x 32
|
||||
# nn.Conv2d(ndf, ndf * 2, 4, 4, 1, bias=False),
|
||||
# nn.BatchNorm2d(ndf * 2),
|
||||
# nn.LeakyReLU(0.2, inplace=True),
|
||||
# # state size. (ndf*2) x 16 x 16
|
||||
# nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
|
||||
# nn.BatchNorm2d(ndf * 4),
|
||||
# nn.LeakyReLU(0.2, inplace=True),
|
||||
# # state size. (ndf*4) x 8 x 8
|
||||
# nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
|
||||
# nn.BatchNorm2d(ndf * 8),
|
||||
# nn.LeakyReLU(0.2, inplace=True),
|
||||
# # state size. (ndf*8) x 4 x 4
|
||||
# nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
|
||||
# nn.Sigmoid()
|
||||
# )
|
||||
|
||||
# def forward(self, input):
|
||||
# return self.main(input)
|
||||
|
||||
|
||||
|
||||
class Discriminator(torch.nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
# Filters [256, 512, 1024]
|
||||
# Input_dim = channels (Cx64x64)
|
||||
# Output_dim = 1
|
||||
self.main_module = nn.Sequential(
|
||||
# Omitting batch normalization in critic because our new penalized training objective (WGAN with gradient penalty) is no longer valid
|
||||
# in this setting, since we penalize the norm of the critic's gradient with respect to each input independently and not the enitre batch.
|
||||
# There is not good & fast implementation of layer normalization --> using per instance normalization nn.InstanceNorm2d()
|
||||
# Image (Cx32x32)
|
||||
nn.Conv2d(in_channels=channels, out_channels=256, kernel_size=4, stride=2, padding=1),
|
||||
nn.InstanceNorm2d(256, affine=True),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
|
||||
# State (256x16x16)
|
||||
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
|
||||
nn.InstanceNorm2d(512, affine=True),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
|
||||
# State (512x8x8)
|
||||
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1),
|
||||
nn.InstanceNorm2d(1024, affine=True),
|
||||
nn.LeakyReLU(0.2, inplace=True))
|
||||
# output of main module --> State (1024x4x4)
|
||||
|
||||
self.output = nn.Sequential(
|
||||
# The output of D is no longer a probability, we do not apply sigmoid at the output of D.
|
||||
nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0))
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.main_module(x)
|
||||
return self.output(x)
|
||||
|
||||
def feature_extraction(self, x):
|
||||
# Use discriminator for feature extraction then flatten to vector of 16384
|
||||
x = self.main_module(x)
|
||||
return x.view(-1, 1024*4*4)
|
||||
|
||||
|
||||
@ -0,0 +1,360 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
__version__ = "0.5.1"
|
||||
from .utils import (
|
||||
GlobalParams,
|
||||
BlockArgs,
|
||||
BlockDecoder,
|
||||
efficientnet,
|
||||
get_model_params,
|
||||
)
|
||||
|
||||
|
||||
from .utils import (
|
||||
round_filters,
|
||||
round_repeats,
|
||||
drop_connect,
|
||||
get_same_padding_conv2d,
|
||||
get_same_padding_conv2d_freeze,
|
||||
get_model_params,
|
||||
efficientnet_params,
|
||||
load_pretrained_weights,
|
||||
Swish,
|
||||
MemoryEfficientSwish,
|
||||
gram_matrix,
|
||||
)
|
||||
|
||||
|
||||
class MBConvBlock(nn.Module):
|
||||
"""
|
||||
Mobile Inverted Residual Bottleneck Block
|
||||
|
||||
Args:
|
||||
block_args (namedtuple): BlockArgs, see above
|
||||
global_params (namedtuple): GlobalParam, see above
|
||||
|
||||
Attributes:
|
||||
has_se (bool): Whether the block contains a Squeeze and Excitation layer.
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, global_params):
|
||||
super().__init__()
|
||||
self._block_args = block_args
|
||||
self._bn_mom = 1 - global_params.batch_norm_momentum
|
||||
self._bn_eps = global_params.batch_norm_epsilon
|
||||
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
|
||||
self.id_skip = block_args.id_skip # skip connection and drop connect
|
||||
|
||||
# Get static or dynamic convolution depending on image size
|
||||
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
|
||||
|
||||
# Expansion phase
|
||||
inp = self._block_args.input_filters # number of input channels
|
||||
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
|
||||
if self._block_args.expand_ratio != 1:
|
||||
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
|
||||
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
||||
|
||||
# Depthwise convolution phase
|
||||
k = self._block_args.kernel_size
|
||||
s = self._block_args.stride
|
||||
self._depthwise_conv = Conv2d(
|
||||
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
|
||||
kernel_size=k, stride=s, bias=False)
|
||||
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
||||
|
||||
# Squeeze and Excitation layer, if desired
|
||||
if self.has_se:
|
||||
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
|
||||
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
|
||||
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
|
||||
|
||||
# Output phase
|
||||
final_oup = self._block_args.output_filters
|
||||
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
|
||||
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
|
||||
self._swish = MemoryEfficientSwish()
|
||||
|
||||
def forward(self, inputs, drop_connect_rate=None):
|
||||
"""
|
||||
:param inputs: input tensor
|
||||
:param drop_connect_rate: drop connect rate (float, between 0 and 1)
|
||||
:return: output of block
|
||||
"""
|
||||
|
||||
# Expansion and Depthwise Convolution
|
||||
x = inputs
|
||||
if self._block_args.expand_ratio != 1:
|
||||
x = self._swish(self._bn0(self._expand_conv(inputs)))
|
||||
x = self._swish(self._bn1(self._depthwise_conv(x)))
|
||||
|
||||
# Squeeze and Excitation
|
||||
if self.has_se:
|
||||
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
||||
x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
|
||||
x = torch.sigmoid(x_squeezed) * x
|
||||
|
||||
x = self._bn2(self._project_conv(x))
|
||||
|
||||
# Skip connection and drop connect
|
||||
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
|
||||
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
|
||||
if drop_connect_rate:
|
||||
x = drop_connect(x, p=drop_connect_rate, training=self.training)
|
||||
x = x + inputs # skip connection
|
||||
return x
|
||||
|
||||
def set_swish(self, memory_efficient=True):
|
||||
"""Sets swish function as memory efficient (for training) or standard (for export)"""
|
||||
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
||||
|
||||
class MBConvBlock_freeze(nn.Module):
|
||||
"""
|
||||
Mobile Inverted Residual Bottleneck Block
|
||||
|
||||
Args:
|
||||
block_args (namedtuple): BlockArgs, see above
|
||||
global_params (namedtuple): GlobalParam, see above
|
||||
|
||||
Attributes:
|
||||
has_se (bool): Whether the block contains a Squeeze and Excitation layer.
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, index, device, global_params):
|
||||
super().__init__()
|
||||
self._block_args = block_args
|
||||
self._bn_mom = 1 - global_params.batch_norm_momentum
|
||||
self._bn_eps = global_params.batch_norm_epsilon
|
||||
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
|
||||
self.id_skip = block_args.id_skip # skip connection and drop connect
|
||||
|
||||
self.Conv2d = get_same_padding_conv2d_freeze(image_size=global_params.image_size)
|
||||
|
||||
s = self._block_args.stride
|
||||
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
|
||||
# Output phase
|
||||
final_oup = self._block_args.output_filters
|
||||
self._swish = MemoryEfficientSwish()
|
||||
self.oup = oup
|
||||
self.s = s
|
||||
self.block_name = '_blocks.{:d}'.format(index)
|
||||
self.device = device
|
||||
|
||||
def forward(self, inputs, weights, drop_connect_rate=None):
|
||||
"""
|
||||
:param inputs: input tensor
|
||||
:param drop_connect_rate: drop connect rate (float, between 0 and 1)
|
||||
:return: output of block
|
||||
"""
|
||||
|
||||
# Expansion and Depthwise Convolution
|
||||
# for (name,para) in weights.items():
|
||||
# print(name) if name.find('_expand_conv') else None
|
||||
|
||||
x = inputs
|
||||
if self._block_args.expand_ratio != 1:
|
||||
x = self.Conv2d(x, weights[self.block_name + '._expand_conv.weight'])
|
||||
x = F.batch_norm(x, torch.zeros(x.data.size()[1]).to(self.device),
|
||||
torch.ones(x.data.size()[1]).to(self.device),
|
||||
weights[self.block_name + '._bn0.weight'],
|
||||
weights[self.block_name + '._bn0.bias'],
|
||||
training=True)
|
||||
x = self.Conv2d(x, weights[self.block_name + '._depthwise_conv.weight'], groups = self.oup, stride=self.s)
|
||||
x = F.batch_norm(x, torch.zeros(x.data.size()[1]).to(self.device),
|
||||
torch.ones(x.data.size()[1]).to(self.device),
|
||||
weights[self.block_name + '._bn1.weight'],
|
||||
weights[self.block_name + '._bn1.bias'],
|
||||
training=True)
|
||||
|
||||
# Squeeze and Excitation
|
||||
if self.has_se:
|
||||
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
||||
x = self.Conv2d(x, weights[self.block_name + '._se_reduce.weight'],weights[self.block_name + '._se_reduce.bias'])
|
||||
x = self.Conv2d(x, weights[self.block_name + '._se_expand.weight'],
|
||||
weights[self.block_name + '._se_expand.bias'])
|
||||
x = torch.sigmoid(x_squeezed) * x
|
||||
|
||||
x = self.Conv2d(x, weights[self.block_name + '._project_conv.weight'])
|
||||
x = F.batch_norm(x, torch.zeros(x.data.size()[1]).to(self.device),
|
||||
torch.ones(x.data.size()[1]).to(self.device),
|
||||
weights[self.block_name + '._bn2.weight'],
|
||||
weights[self.block_name + '._bn2.bias'],
|
||||
training=True)
|
||||
|
||||
# Skip connection and drop connect
|
||||
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
|
||||
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
|
||||
if drop_connect_rate:
|
||||
x = drop_connect(x, p=drop_connect_rate, training=self.training)
|
||||
x = x + inputs # skip connection
|
||||
return x
|
||||
|
||||
def set_swish(self, memory_efficient=True):
|
||||
"""Sets swish function as memory efficient (for training) or standard (for export)"""
|
||||
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
||||
|
||||
|
||||
class EfficientNet(nn.Module):
|
||||
"""
|
||||
An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
|
||||
|
||||
Args:
|
||||
blocks_args (list): A list of BlockArgs to construct blocks
|
||||
global_params (namedtuple): A set of GlobalParams shared between blocks
|
||||
|
||||
Example:
|
||||
model = EfficientNet.from_pretrained('efficientnet-b0')
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, device , blocks_args=None, global_params=None):
|
||||
super().__init__()
|
||||
assert isinstance(blocks_args, list), 'blocks_args should be a list'
|
||||
assert len(blocks_args) > 0, 'block args must be greater than 0'
|
||||
self._global_params = global_params
|
||||
self._blocks_args = blocks_args
|
||||
self.type = type
|
||||
# Get static or dynamic convolution depending on image size
|
||||
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
|
||||
|
||||
# Batch norm parameters
|
||||
bn_mom = 1 - self._global_params.batch_norm_momentum
|
||||
bn_eps = self._global_params.batch_norm_epsilon
|
||||
|
||||
# Stem
|
||||
in_channels = 4 # rgb
|
||||
out_channels = round_filters(32, self._global_params) # number of output channels
|
||||
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
||||
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
||||
|
||||
# Build blocks
|
||||
self._blocks = nn.ModuleList([])
|
||||
for block_args in self._blocks_args:
|
||||
|
||||
# Update block input and output filters based on depth multiplier.
|
||||
block_args = block_args._replace(
|
||||
input_filters=round_filters(block_args.input_filters, self._global_params),
|
||||
output_filters=round_filters(block_args.output_filters, self._global_params),
|
||||
num_repeat=round_repeats(block_args.num_repeat, self._global_params)
|
||||
)
|
||||
|
||||
# The first block needs to take care of stride and filter size increase.
|
||||
self._blocks.append(MBConvBlock(block_args, self._global_params))
|
||||
if block_args.num_repeat > 1:
|
||||
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
|
||||
for _ in range(block_args.num_repeat - 1):
|
||||
self._blocks.append(MBConvBlock(block_args, self._global_params))
|
||||
|
||||
# Head
|
||||
in_channels = block_args.output_filters # output of final block
|
||||
out_channels = round_filters(1280, self._global_params)
|
||||
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
||||
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
||||
|
||||
# Final linear layer
|
||||
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self._dropout = nn.Dropout(self._global_params.dropout_rate)
|
||||
self.conv_reg = nn.Conv2d(1792, 1, 1)
|
||||
if self.type == 'big_map' or self.type == 'img':
|
||||
self.conv_transe1 = nn.Conv2d(1792, 448, 1)
|
||||
self.bn_transe1 = nn.BatchNorm2d(num_features=448, momentum=bn_mom, eps=bn_eps)
|
||||
self.conv_transe2 = nn.Conv2d(448, 112, 1)
|
||||
self.bn_transe2 = nn.BatchNorm2d(num_features=112, momentum=bn_mom, eps=bn_eps)
|
||||
if self.type == 'big_map':
|
||||
self.conv_transe_mask = nn.Conv2d(112, 1, 1)
|
||||
self.deconv_big = nn.ConvTranspose2d(1792, 1, 5, stride=4) ##transpose
|
||||
if self.type == 'img':
|
||||
self.conv_transe3 = nn.Conv2d(112, 3, 1)
|
||||
self.deconv_img = nn.ConvTranspose2d(1792, 3, 5, stride=4) ##transpose
|
||||
elif self.type == 'deconv_map' or self.type == 'deconv_img':
|
||||
self.conv_big_reg = nn.ConvTranspose2d(1792, 1, 5, stride=4) ##transpose
|
||||
self.conv_img = nn.ConvTranspose2d(1792, 3, 5, stride=4) ##transpose
|
||||
else:
|
||||
self.conv_reg = nn.Conv2d(1792, 1, 1)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.up_double = nn.Upsample(scale_factor=2, mode='bilinear')
|
||||
self._fc = nn.Linear(out_channels, 1)
|
||||
self._swish = MemoryEfficientSwish()
|
||||
self.sig = nn.Sigmoid()
|
||||
self.device = device
|
||||
|
||||
def set_swish(self, memory_efficient=True):
|
||||
"""Sets swish function as memory efficient (for training) or standard (for export)"""
|
||||
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
||||
for block in self._blocks:
|
||||
block.set_swish(memory_efficient)
|
||||
|
||||
def extract_features(self, inputs):
|
||||
""" Returns output of the final convolution layer """
|
||||
|
||||
# Stem
|
||||
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
||||
|
||||
# Blocks
|
||||
for idx, block in enumerate(self._blocks):
|
||||
drop_connect_rate = self._global_params.drop_connect_rate
|
||||
if drop_connect_rate:
|
||||
drop_connect_rate *= float(idx) / len(self._blocks)
|
||||
x = block(x, drop_connect_rate=drop_connect_rate)
|
||||
|
||||
# Head
|
||||
x = self._swish(self._bn1(self._conv_head(x)))
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, inputs, weights=None):
|
||||
""" Calls extract_features to extract features, applies final linear layer, and returns logits. """
|
||||
bs = inputs.size(0)
|
||||
# Convolution layers
|
||||
x = self.extract_features(inputs)
|
||||
# Pooling and final linear layer
|
||||
x = self._avg_pooling(x)
|
||||
x = x.view(bs, -1)
|
||||
x = self._dropout(x)
|
||||
x = self._fc(x)
|
||||
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, model_name, device, override_params=None):
|
||||
cls._check_model_name_is_valid(model_name)
|
||||
blocks_args, global_params = get_model_params(model_name, override_params)
|
||||
return cls(device, blocks_args, global_params)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name, num_classes=1000, in_channels=3):
|
||||
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
|
||||
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
|
||||
if in_channels != 3:
|
||||
Conv2d = get_same_padding_conv2d(image_size=model._global_params.image_size)
|
||||
out_channels = round_filters(32, model._global_params)
|
||||
model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name, num_classes=1000):
|
||||
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
|
||||
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def get_image_size(cls, model_name):
|
||||
cls._check_model_name_is_valid(model_name)
|
||||
_, _, res, _ = efficientnet_params(model_name)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False):
|
||||
""" Validates model name. None that pretrained weights are only available for
|
||||
the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """
|
||||
num_models = 4 if also_need_pretrained_weights else 8
|
||||
valid_models = ['efficientnet-b' + str(i) for i in range(num_models)]
|
||||
if model_name not in valid_models:
|
||||
raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,307 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
__version__ = "0.5.1"
|
||||
from .utils import (
|
||||
GlobalParams,
|
||||
BlockArgs,
|
||||
BlockDecoder,
|
||||
efficientnet,
|
||||
get_model_params,
|
||||
)
|
||||
|
||||
|
||||
from .utils import (
|
||||
round_filters,
|
||||
round_repeats,
|
||||
drop_connect,
|
||||
get_same_padding_conv2d,
|
||||
get_model_params,
|
||||
efficientnet_params,
|
||||
load_pretrained_weights,
|
||||
Swish,
|
||||
MemoryEfficientSwish,
|
||||
gram_matrix,
|
||||
)
|
||||
|
||||
|
||||
class MBConvBlock(nn.Module):
|
||||
"""
|
||||
Mobile Inverted Residual Bottleneck Block
|
||||
|
||||
Args:
|
||||
block_args (namedtuple): BlockArgs, see above
|
||||
global_params (namedtuple): GlobalParam, see above
|
||||
|
||||
Attributes:
|
||||
has_se (bool): Whether the block contains a Squeeze and Excitation layer.
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, global_params):
|
||||
super().__init__()
|
||||
self._block_args = block_args
|
||||
self._bn_mom = 1 - global_params.batch_norm_momentum
|
||||
self._bn_eps = global_params.batch_norm_epsilon
|
||||
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
|
||||
self.id_skip = block_args.id_skip # skip connection and drop connect
|
||||
|
||||
# Get static or dynamic convolution depending on image size
|
||||
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
|
||||
|
||||
# Expansion phase
|
||||
inp = self._block_args.input_filters # number of input channels
|
||||
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
|
||||
if self._block_args.expand_ratio != 1:
|
||||
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
|
||||
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
||||
|
||||
# Depthwise convolution phase
|
||||
k = self._block_args.kernel_size
|
||||
s = self._block_args.stride
|
||||
self._depthwise_conv = Conv2d(
|
||||
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
|
||||
kernel_size=k, stride=s, bias=False)
|
||||
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
||||
|
||||
# Squeeze and Excitation layer, if desired
|
||||
if self.has_se:
|
||||
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
|
||||
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
|
||||
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
|
||||
|
||||
# Output phase
|
||||
final_oup = self._block_args.output_filters
|
||||
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
|
||||
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
|
||||
self._swish = MemoryEfficientSwish()
|
||||
|
||||
def forward(self, inputs, drop_connect_rate=None):
|
||||
"""
|
||||
:param inputs: input tensor
|
||||
:param drop_connect_rate: drop connect rate (float, between 0 and 1)
|
||||
:return: output of block
|
||||
"""
|
||||
|
||||
# Expansion and Depthwise Convolution
|
||||
x = inputs
|
||||
if self._block_args.expand_ratio != 1:
|
||||
x = self._swish(self._bn0(self._expand_conv(inputs)))
|
||||
x = self._swish(self._bn1(self._depthwise_conv(x)))
|
||||
|
||||
# Squeeze and Excitation
|
||||
if self.has_se:
|
||||
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
||||
x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
|
||||
x = torch.sigmoid(x_squeezed) * x
|
||||
|
||||
x = self._bn2(self._project_conv(x))
|
||||
|
||||
# Skip connection and drop connect
|
||||
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
|
||||
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
|
||||
if drop_connect_rate:
|
||||
x = drop_connect(x, p=drop_connect_rate, training=self.training)
|
||||
x = x + inputs # skip connection
|
||||
return x
|
||||
|
||||
def set_swish(self, memory_efficient=True):
|
||||
"""Sets swish function as memory efficient (for training) or standard (for export)"""
|
||||
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
||||
|
||||
|
||||
class EfficientNet(nn.Module):
|
||||
"""
|
||||
An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
|
||||
|
||||
Args:
|
||||
blocks_args (list): A list of BlockArgs to construct blocks
|
||||
global_params (namedtuple): A set of GlobalParams shared between blocks
|
||||
|
||||
Example:
|
||||
model = EfficientNet.from_pretrained('efficientnet-b0')
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, type, blocks_args=None, global_params=None):
|
||||
super().__init__()
|
||||
assert isinstance(blocks_args, list), 'blocks_args should be a list'
|
||||
assert len(blocks_args) > 0, 'block args must be greater than 0'
|
||||
self._global_params = global_params
|
||||
self._blocks_args = blocks_args
|
||||
self.type = type
|
||||
# Get static or dynamic convolution depending on image size
|
||||
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
|
||||
|
||||
# Batch norm parameters
|
||||
bn_mom = 1 - self._global_params.batch_norm_momentum
|
||||
bn_eps = self._global_params.batch_norm_epsilon
|
||||
|
||||
# Stem
|
||||
in_channels = 5 # rgb
|
||||
out_channels = round_filters(32, self._global_params) # number of output channels
|
||||
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
||||
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
||||
|
||||
# Build blocks
|
||||
self._blocks = nn.ModuleList([])
|
||||
for block_args in self._blocks_args:
|
||||
|
||||
# Update block input and output filters based on depth multiplier.
|
||||
block_args = block_args._replace(
|
||||
input_filters=round_filters(block_args.input_filters, self._global_params),
|
||||
output_filters=round_filters(block_args.output_filters, self._global_params),
|
||||
num_repeat=round_repeats(block_args.num_repeat, self._global_params)
|
||||
)
|
||||
|
||||
# The first block needs to take care of stride and filter size increase.
|
||||
self._blocks.append(MBConvBlock(block_args, self._global_params))
|
||||
if block_args.num_repeat > 1:
|
||||
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
|
||||
for _ in range(block_args.num_repeat - 1):
|
||||
self._blocks.append(MBConvBlock(block_args, self._global_params))
|
||||
|
||||
# Head
|
||||
in_channels = block_args.output_filters # output of final block
|
||||
out_channels = round_filters(1280, self._global_params)
|
||||
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
||||
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
||||
|
||||
# Final linear layer
|
||||
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self._dropout = nn.Dropout(self._global_params.dropout_rate)
|
||||
self._fc = nn.Linear(out_channels, 1)
|
||||
self._swish = MemoryEfficientSwish()
|
||||
self.conv_reg = nn.Conv2d(1792, 1, 1)
|
||||
if self.type == 'big_map' or self.type == 'img':
|
||||
self.conv_transe1 = nn.Conv2d(1792, 448, 1)
|
||||
self.bn_transe1 = nn.BatchNorm2d(num_features=448, momentum=bn_mom, eps=bn_eps)
|
||||
self.conv_transe2 = nn.Conv2d(448, 112, 1)
|
||||
self.bn_transe2 = nn.BatchNorm2d(num_features=112, momentum=bn_mom, eps=bn_eps)
|
||||
if self.type == 'big_map':
|
||||
self.conv_transe_mask = nn.Conv2d(112, 1, 1)
|
||||
self.deconv_big = nn.ConvTranspose2d(1792, 1, 5, stride=4) ##transpose
|
||||
if self.type == 'img':
|
||||
self.conv_transe3 = nn.Conv2d(112, 3, 1)
|
||||
self.deconv_img = nn.ConvTranspose2d(1792, 3, 5, stride=4) ##transpose
|
||||
elif self.type == 'deconv_map' or self.type == 'deconv_img':
|
||||
self.conv_big_reg = nn.ConvTranspose2d(1792, 1, 5, stride=4) ##transpose
|
||||
self.conv_img = nn.ConvTranspose2d(1792, 3, 5, stride=4) ##transpose
|
||||
else:
|
||||
self.conv_reg = nn.Conv2d(1792, 1, 1)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.up_double = nn.Upsample(scale_factor=2, mode='bilinear')
|
||||
self.sig = nn.Sigmoid()
|
||||
|
||||
def set_swish(self, memory_efficient=True):
|
||||
"""Sets swish function as memory efficient (for training) or standard (for export)"""
|
||||
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
||||
for block in self._blocks:
|
||||
block.set_swish(memory_efficient)
|
||||
|
||||
def extract_features(self, inputs):
|
||||
""" Returns output of the final convolution layer """
|
||||
|
||||
# Stem
|
||||
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
||||
|
||||
# Blocks
|
||||
for idx, block in enumerate(self._blocks):
|
||||
drop_connect_rate = self._global_params.drop_connect_rate
|
||||
if drop_connect_rate:
|
||||
drop_connect_rate *= float(idx) / len(self._blocks)
|
||||
x = block(x, drop_connect_rate=drop_connect_rate)
|
||||
|
||||
# Head
|
||||
x = self._swish(self._bn1(self._conv_head(x)))
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, seg, label, natural):
|
||||
label = label.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(seg.size())
|
||||
|
||||
x = torch.cat((label, natural, seg), 1) # concated input
|
||||
bs = seg.size(0)
|
||||
# Convolution layers
|
||||
x = self.extract_features(x)
|
||||
if self.type == 'map':
|
||||
reg = self.conv_reg(x)
|
||||
reg = self.sig(reg)
|
||||
elif self.type == 'big_map':
|
||||
reg = self.up_double(x) # 12*14
|
||||
reg = self.relu(reg)
|
||||
reg = self.conv_transe1(reg) # 448
|
||||
reg = self.bn_transe1(reg)
|
||||
|
||||
reg = self.up_double(reg) # 24*28
|
||||
reg = self.relu(reg)
|
||||
reg = self.conv_transe2(reg) # 112
|
||||
reg = self.bn_transe2(reg)
|
||||
|
||||
reg = self.conv_transe_mask(reg) # 1
|
||||
reg = self.sig(reg)
|
||||
elif self.type == 'img':
|
||||
reg = self.up_double(x) # 12*14
|
||||
reg = self.relu(reg)
|
||||
reg = self.conv_transe1(reg) # 448
|
||||
reg = self.bn_transe1(reg)
|
||||
|
||||
reg = self.up_double(reg) # 24*28
|
||||
reg = self.relu(reg)
|
||||
reg = self.conv_transe2(reg) # 112
|
||||
reg = self.bn_transe2(reg)
|
||||
|
||||
reg = self.conv_transe3(reg) # 3
|
||||
reg = self.sig(reg)
|
||||
elif self.type == 'deconv_map':
|
||||
reg = self.conv_big_reg(x)
|
||||
reg = self.sig(reg)
|
||||
elif self.type == 'deconv_img':
|
||||
reg = self.conv_img(x)
|
||||
reg = self.sig(reg)
|
||||
elif self.type == 'feature':
|
||||
reg = gram_matrix(x - x.mean(0, True))
|
||||
|
||||
return reg
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, model_name, type, override_params=None):
|
||||
cls._check_model_name_is_valid(model_name)
|
||||
blocks_args, global_params = get_model_params(model_name, override_params)
|
||||
return cls(type, blocks_args, global_params)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name, num_classes=1000, in_channels=3):
|
||||
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
|
||||
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
|
||||
if in_channels != 3:
|
||||
Conv2d = get_same_padding_conv2d(image_size=model._global_params.image_size)
|
||||
out_channels = round_filters(32, model._global_params)
|
||||
model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name, num_classes=1000):
|
||||
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
|
||||
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def get_image_size(cls, model_name):
|
||||
cls._check_model_name_is_valid(model_name)
|
||||
_, _, res, _ = efficientnet_params(model_name)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False):
|
||||
""" Validates model name. None that pretrained weights are only available for
|
||||
the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """
|
||||
num_models = 4 if also_need_pretrained_weights else 8
|
||||
valid_models = ['efficientnet-b' + str(i) for i in range(num_models)]
|
||||
if model_name not in valid_models:
|
||||
raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,104 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LinearBottleNeck(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride, t=6, class_num=1):
|
||||
super().__init__()
|
||||
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels * t, 1),
|
||||
nn.BatchNorm2d(in_channels * t),
|
||||
nn.ReLU6(inplace=True),
|
||||
|
||||
nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t),
|
||||
nn.BatchNorm2d(in_channels * t),
|
||||
nn.ReLU6(inplace=True),
|
||||
|
||||
nn.Conv2d(in_channels * t, out_channels, 1),
|
||||
nn.BatchNorm2d(out_channels)
|
||||
)
|
||||
|
||||
self.stride = stride
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x):
|
||||
residual = self.residual(x)
|
||||
|
||||
if self.stride == 1 and self.in_channels == self.out_channels:
|
||||
residual += x
|
||||
|
||||
return residual
|
||||
|
||||
|
||||
|
||||
|
||||
class ImplicitNet(nn.Module):
|
||||
|
||||
def __init__(self, class_num=1):
|
||||
super().__init__()
|
||||
|
||||
self.pre = nn.Sequential(
|
||||
nn.Conv2d(5, 32, 1, padding=1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
self.stage1 = LinearBottleNeck(32, 16, 1, 1)
|
||||
self.stage2 = self._make_stage(2, 16, 24, 2, 6)
|
||||
self.stage3 = self._make_stage(3, 24, 32, 2, 6)
|
||||
self.stage4 = self._make_stage(4, 32, 64, 2, 6)
|
||||
self.stage5 = self._make_stage(3, 64, 96, 1, 6)
|
||||
self.stage6 = self._make_stage(3, 96, 160, 1, 6)
|
||||
self.stage7 = LinearBottleNeck(160, 320, 1, 6)
|
||||
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(320, 1280, 1),
|
||||
nn.BatchNorm2d(1280),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
self.conv2 = nn.Conv2d(1280, class_num, 1)
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, seg, label, natural):
|
||||
label = label.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(seg.size())
|
||||
|
||||
x = torch.cat((label,natural,seg),1) # concated input
|
||||
x = self.pre(x)
|
||||
x = self.stage1(x)
|
||||
x = self.stage2(x)
|
||||
x = self.stage3(x)
|
||||
x = self.stage4(x)
|
||||
x = self.stage5(x)
|
||||
x = self.stage6(x)
|
||||
x = self.stage7(x)
|
||||
x = self.conv1(x)
|
||||
#x = F.adaptive_avg_pool2d(x, 1)
|
||||
x = self.conv2(x) # (b,h/s,w/s,1)
|
||||
x = self.sigmoid(x)
|
||||
return x
|
||||
|
||||
def _make_stage(self, repeat, in_channels, out_channels, stride, t):
|
||||
layers = []
|
||||
layers.append(LinearBottleNeck(in_channels, out_channels, stride, t))
|
||||
|
||||
while repeat - 1:
|
||||
layers.append(LinearBottleNeck(out_channels, out_channels, 1, t))
|
||||
repeat -= 1
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
|
||||
|
||||
def implicitnet():
|
||||
return ImplicitNet()
|
||||
@ -0,0 +1,14 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .build_oneprompt import (
|
||||
build_one_vit_h,
|
||||
build_one_vit_l,
|
||||
build_one_vit_b,
|
||||
one_model_registry,
|
||||
)
|
||||
from .predictor import OnePredictor
|
||||
from .automatic_mask_generator import OneAutomaticMaskGenerator
|
||||
@ -0,0 +1,368 @@
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from .modeling import OnePrompt
|
||||
from .predictor import OnePredictor
|
||||
from .utils.amg import (
|
||||
MaskData,
|
||||
area_from_rle,
|
||||
batch_iterator,
|
||||
batched_mask_to_box,
|
||||
box_xyxy_to_xywh,
|
||||
build_all_layer_point_grids,
|
||||
calculate_stability_score,
|
||||
coco_encode_rle,
|
||||
generate_crop_boxes,
|
||||
is_box_near_crop_edge,
|
||||
mask_to_rle_pytorch,
|
||||
remove_small_regions,
|
||||
rle_to_mask,
|
||||
uncrop_boxes_xyxy,
|
||||
uncrop_masks,
|
||||
uncrop_points,
|
||||
)
|
||||
|
||||
|
||||
class OneAutomaticMaskGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
model: OnePrompt,
|
||||
points_per_side: Optional[int] = 32,
|
||||
points_per_batch: int = 64,
|
||||
pred_iou_thresh: float = 0.88,
|
||||
stability_score_thresh: float = 0.95,
|
||||
stability_score_offset: float = 1.0,
|
||||
box_nms_thresh: float = 0.7,
|
||||
crop_n_layers: int = 0,
|
||||
crop_nms_thresh: float = 0.7,
|
||||
crop_overlap_ratio: float = 512 / 1500,
|
||||
crop_n_points_downscale_factor: int = 1,
|
||||
point_grids: Optional[List[np.ndarray]] = None,
|
||||
min_mask_region_area: int = 0,
|
||||
output_mode: str = "binary_mask",
|
||||
) -> None:
|
||||
"""
|
||||
Using a One model, generates masks for the entire image.
|
||||
Generates a grid of point prompts over the image, then filters
|
||||
low quality and duplicate masks. The default settings are chosen
|
||||
for One with a ViT-H backbone.
|
||||
|
||||
Arguments:
|
||||
model (One): The One model to use for mask prediction.
|
||||
points_per_side (int or None): The number of points to be Onepled
|
||||
along one side of the image. The total number of points is
|
||||
points_per_side**2. If None, 'point_grids' must provide explicit
|
||||
point Onepling.
|
||||
points_per_batch (int): Sets the number of points run simultaneously
|
||||
by the model. Higher numbers may be faster but use more GPU memory.
|
||||
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
||||
model's predicted mask quality.
|
||||
stability_score_thresh (float): A filtering threshold in [0,1], using
|
||||
the stability of the mask under changes to the cutoff used to binarize
|
||||
the model's mask predictions.
|
||||
stability_score_offset (float): The amount to shift the cutoff when
|
||||
calculated the stability score.
|
||||
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
||||
suppression to filter duplicate masks.
|
||||
crop_n_layers (int): If >0, mask prediction will be run again on
|
||||
crops of the image. Sets the number of layers to run, where each
|
||||
layer has 2**i_layer number of image crops.
|
||||
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
||||
suppression to filter duplicate masks between different crops.
|
||||
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
||||
In the first crop layer, crops will overlap by this fraction of
|
||||
the image length. Later layers with more crops scale down this overlap.
|
||||
crop_n_points_downscale_factor (int): The number of points-per-side
|
||||
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
point_grids (list(np.ndarray) or None): A list over explicit grids
|
||||
of points used for sampling, normalized to [0,1]. The nth grid in the
|
||||
list is used in the nth crop layer. Exclusive with points_per_side.
|
||||
min_mask_region_area (int): If >0, postprocessing will be applied
|
||||
to remove disconnected regions and holes in masks with area smaller
|
||||
than min_mask_region_area. Requires opencv.
|
||||
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
||||
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
||||
For large resolutions, 'binary_mask' may consume large amounts of
|
||||
memory.
|
||||
"""
|
||||
|
||||
assert (points_per_side is None) != (
|
||||
point_grids is None
|
||||
), "Exactly one of points_per_side or point_grid must be provided."
|
||||
if points_per_side is not None:
|
||||
self.point_grids = build_all_layer_point_grids(
|
||||
points_per_side,
|
||||
crop_n_layers,
|
||||
crop_n_points_downscale_factor,
|
||||
)
|
||||
elif point_grids is not None:
|
||||
self.point_grids = point_grids
|
||||
else:
|
||||
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
||||
|
||||
assert output_mode in [
|
||||
"binary_mask",
|
||||
"uncompressed_rle",
|
||||
"coco_rle",
|
||||
], f"Unknown output_mode {output_mode}."
|
||||
if output_mode == "coco_rle":
|
||||
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
||||
|
||||
if min_mask_region_area > 0:
|
||||
import cv2 # type: ignore # noqa: F401
|
||||
|
||||
self.predictor = OnePredictor(model)
|
||||
self.points_per_batch = points_per_batch
|
||||
self.pred_iou_thresh = pred_iou_thresh
|
||||
self.stability_score_thresh = stability_score_thresh
|
||||
self.stability_score_offset = stability_score_offset
|
||||
self.box_nms_thresh = box_nms_thresh
|
||||
self.crop_n_layers = crop_n_layers
|
||||
self.crop_nms_thresh = crop_nms_thresh
|
||||
self.crop_overlap_ratio = crop_overlap_ratio
|
||||
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
||||
self.min_mask_region_area = min_mask_region_area
|
||||
self.output_mode = output_mode
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generates masks for the given image.
|
||||
|
||||
Arguments:
|
||||
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
||||
|
||||
Returns:
|
||||
list(dict(str, any)): A list over records for masks. Each record is
|
||||
a dict containing the following keys:
|
||||
segmentation (dict(str, any) or np.ndarray): The mask. If
|
||||
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
||||
is a dictionary containing the RLE.
|
||||
bbox (list(float)): The box around the mask, in XYWH format.
|
||||
area (int): The area in pixels of the mask.
|
||||
predicted_iou (float): The model's own prediction of the mask's
|
||||
quality. This is filtered by the pred_iou_thresh parameter.
|
||||
point_coords (list(list(float))): The point coordinates input
|
||||
to the model to generate this mask.
|
||||
stability_score (float): A measure of the mask's quality. This
|
||||
is filtered on using the stability_score_thresh parameter.
|
||||
crop_box (list(float)): The crop of the image used to generate
|
||||
the mask, given in XYWH format.
|
||||
"""
|
||||
|
||||
# Generate masks
|
||||
mask_data = self._generate_masks(image)
|
||||
|
||||
# Filter small disconnected regions and holes in masks
|
||||
if self.min_mask_region_area > 0:
|
||||
mask_data = self.postprocess_small_regions(
|
||||
mask_data,
|
||||
self.min_mask_region_area,
|
||||
max(self.box_nms_thresh, self.crop_nms_thresh),
|
||||
)
|
||||
|
||||
# Encode masks
|
||||
if self.output_mode == "coco_rle":
|
||||
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
||||
elif self.output_mode == "binary_mask":
|
||||
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
||||
else:
|
||||
mask_data["segmentations"] = mask_data["rles"]
|
||||
|
||||
# Write mask records
|
||||
curr_anns = []
|
||||
for idx in range(len(mask_data["segmentations"])):
|
||||
ann = {
|
||||
"segmentation": mask_data["segmentations"][idx],
|
||||
"area": area_from_rle(mask_data["rles"][idx]),
|
||||
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
||||
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
||||
"point_coords": [mask_data["points"][idx].tolist()],
|
||||
"stability_score": mask_data["stability_score"][idx].item(),
|
||||
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
||||
}
|
||||
curr_anns.append(ann)
|
||||
|
||||
return curr_anns
|
||||
|
||||
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
||||
orig_size = image.shape[:2]
|
||||
crop_boxes, layer_idxs = generate_crop_boxes(
|
||||
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
||||
)
|
||||
|
||||
# Iterate over image crops
|
||||
data = MaskData()
|
||||
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
||||
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
||||
data.cat(crop_data)
|
||||
|
||||
# Remove duplicate masks between crops
|
||||
if len(crop_boxes) > 1:
|
||||
# Prefer masks from smaller crops
|
||||
scores = 1 / box_area(data["crop_boxes"])
|
||||
scores = scores.to(data["boxes"].device)
|
||||
keep_by_nms = batched_nms(
|
||||
data["boxes"].float(),
|
||||
scores,
|
||||
torch.zeros_like(data["boxes"][:, 0]), # categories
|
||||
iou_threshold=self.crop_nms_thresh,
|
||||
)
|
||||
data.filter(keep_by_nms)
|
||||
|
||||
data.to_numpy()
|
||||
return data
|
||||
|
||||
def _process_crop(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
crop_box: List[int],
|
||||
crop_layer_idx: int,
|
||||
orig_size: Tuple[int, ...],
|
||||
) -> MaskData:
|
||||
# Crop the image and calculate embeddings
|
||||
x0, y0, x1, y1 = crop_box
|
||||
cropped_im = image[y0:y1, x0:x1, :]
|
||||
cropped_im_size = cropped_im.shape[:2]
|
||||
self.predictor.set_image(cropped_im)
|
||||
|
||||
# Get points for this crop
|
||||
points_scale = np.array(cropped_im_size)[None, ::-1]
|
||||
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
||||
|
||||
# Generate masks for this crop in batches
|
||||
data = MaskData()
|
||||
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
||||
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
|
||||
data.cat(batch_data)
|
||||
del batch_data
|
||||
self.predictor.reset_image()
|
||||
|
||||
# Remove duplicates within this crop.
|
||||
keep_by_nms = batched_nms(
|
||||
data["boxes"].float(),
|
||||
data["iou_preds"],
|
||||
torch.zeros_like(data["boxes"][:, 0]), # categories
|
||||
iou_threshold=self.box_nms_thresh,
|
||||
)
|
||||
data.filter(keep_by_nms)
|
||||
|
||||
# Return to the original image frame
|
||||
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
||||
data["points"] = uncrop_points(data["points"], crop_box)
|
||||
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
||||
|
||||
return data
|
||||
|
||||
def _process_batch(
|
||||
self,
|
||||
points: np.ndarray,
|
||||
im_size: Tuple[int, ...],
|
||||
crop_box: List[int],
|
||||
orig_size: Tuple[int, ...],
|
||||
) -> MaskData:
|
||||
orig_h, orig_w = orig_size
|
||||
|
||||
# Run model on this batch
|
||||
transformed_points = self.predictor.transform.apply_coords(points, im_size)
|
||||
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
|
||||
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
|
||||
masks, iou_preds, _ = self.predictor.predict_torch(
|
||||
in_points[:, None, :],
|
||||
in_labels[:, None],
|
||||
multimask_output=True,
|
||||
return_logits=True,
|
||||
)
|
||||
|
||||
# Serialize predictions and store in MaskData
|
||||
data = MaskData(
|
||||
masks=masks.flatten(0, 1),
|
||||
iou_preds=iou_preds.flatten(0, 1),
|
||||
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
||||
)
|
||||
del masks
|
||||
|
||||
# Filter by predicted IoU
|
||||
if self.pred_iou_thresh > 0.0:
|
||||
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Calculate stability score
|
||||
data["stability_score"] = calculate_stability_score(
|
||||
data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
|
||||
)
|
||||
if self.stability_score_thresh > 0.0:
|
||||
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Threshold masks and calculate boxes
|
||||
data["masks"] = data["masks"] > self.predictor.model.mask_threshold
|
||||
data["boxes"] = batched_mask_to_box(data["masks"])
|
||||
|
||||
# Filter boxes that touch crop boundaries
|
||||
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
||||
if not torch.all(keep_mask):
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Compress to RLE
|
||||
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
||||
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
||||
del data["masks"]
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def postprocess_small_regions(
|
||||
mask_data: MaskData, min_area: int, nms_thresh: float
|
||||
) -> MaskData:
|
||||
"""
|
||||
Removes small disconnected regions and holes in masks, then reruns
|
||||
box NMS to remove any new duplicates.
|
||||
|
||||
Edits mask_data in place.
|
||||
|
||||
Requires open-cv as a dependency.
|
||||
"""
|
||||
if len(mask_data["rles"]) == 0:
|
||||
return mask_data
|
||||
|
||||
# Filter small disconnected regions and holes
|
||||
new_masks = []
|
||||
scores = []
|
||||
for rle in mask_data["rles"]:
|
||||
mask = rle_to_mask(rle)
|
||||
|
||||
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
||||
unchanged = not changed
|
||||
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
||||
unchanged = unchanged and not changed
|
||||
|
||||
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
||||
# Give score=0 to changed masks and score=1 to unchanged masks
|
||||
# so NMS will prefer ones that didn't need postprocessing
|
||||
scores.append(float(unchanged))
|
||||
|
||||
# Recalculate boxes and remove any new duplicates
|
||||
masks = torch.cat(new_masks, dim=0)
|
||||
boxes = batched_mask_to_box(masks)
|
||||
keep_by_nms = batched_nms(
|
||||
boxes.float(),
|
||||
torch.as_tensor(scores),
|
||||
torch.zeros_like(boxes[:, 0]), # categories
|
||||
iou_threshold=nms_thresh,
|
||||
)
|
||||
|
||||
# Only recalculate RLEs for masks that have changed
|
||||
for i_mask in keep_by_nms:
|
||||
if scores[i_mask] == 0.0:
|
||||
mask_torch = masks[i_mask].unsqueeze(0)
|
||||
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
||||
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
||||
mask_data.filter(keep_by_nms)
|
||||
|
||||
return mask_data
|
||||
@ -0,0 +1,139 @@
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
import urllib.request
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
|
||||
from .modeling import (
|
||||
OnePrompt,
|
||||
OnePromptDecoder,
|
||||
PromptEncoder,
|
||||
OnePromptEncoderViT,
|
||||
OnePromptEncoderUnet,
|
||||
CrossAttentionBlock,
|
||||
)
|
||||
|
||||
|
||||
def build_one_vit_h(args = None, checkpoint=None):
|
||||
return _build_one(
|
||||
args,
|
||||
encoder_embed_dim=1280,
|
||||
encoder_depth=32,
|
||||
encoder_num_heads=16,
|
||||
encoder_global_attn_indexes=[7, 15, 23, 31],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_one_vit_l(args, checkpoint=None):
|
||||
return _build_one(
|
||||
args,
|
||||
encoder_embed_dim=1024,
|
||||
encoder_depth=24,
|
||||
encoder_num_heads=16,
|
||||
encoder_global_attn_indexes=[5, 11, 17, 23],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_one_vit_b(args, checkpoint=None):
|
||||
return _build_one(
|
||||
args,
|
||||
encoder_embed_dim=768,
|
||||
encoder_depth=12,
|
||||
encoder_num_heads=12,
|
||||
encoder_global_attn_indexes=[2, 5, 8, 11],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
def build_one_unet(args, checkpoint=None):
|
||||
return _build_one(
|
||||
args,
|
||||
encoder_embed_dim=256,
|
||||
encoder_depth=4,
|
||||
encoder_num_heads=12,
|
||||
encoder_global_attn_indexes=[2, 5, 8, 11],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
one_model_registry = {
|
||||
"default": build_one_vit_h,
|
||||
"unet": build_one_unet,
|
||||
"vit_h": build_one_vit_h,
|
||||
"vit_l": build_one_vit_l,
|
||||
"vit_b": build_one_vit_b,
|
||||
}
|
||||
|
||||
|
||||
def _build_one(
|
||||
args,
|
||||
encoder_embed_dim,
|
||||
encoder_depth,
|
||||
encoder_num_heads,
|
||||
encoder_global_attn_indexes,
|
||||
checkpoint=None,
|
||||
):
|
||||
prompt_embed_dim = args.dim
|
||||
image_size = args.image_size
|
||||
vit_patch_size = args.patch_size
|
||||
image_embedding_size = image_size // vit_patch_size
|
||||
one = OnePrompt(
|
||||
args,
|
||||
image_encoder= OnePromptEncoderUnet(
|
||||
input_channels = 3,
|
||||
base_num_features = encoder_embed_dim // 2,
|
||||
final_num_features = encoder_embed_dim,
|
||||
fea_size=image_embedding_size,
|
||||
num_pool = encoder_depth,
|
||||
) if args.baseline == 'unet' else
|
||||
OnePromptEncoderViT(
|
||||
args = args,
|
||||
depth=encoder_depth,
|
||||
embed_dim=encoder_embed_dim,
|
||||
img_size=image_size,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
||||
num_heads=encoder_num_heads,
|
||||
patch_size=vit_patch_size,
|
||||
qkv_bias=True,
|
||||
use_rel_pos=True,
|
||||
global_attn_indexes=encoder_global_attn_indexes,
|
||||
window_size=14,
|
||||
),
|
||||
prompt_encoder=PromptEncoder(
|
||||
embed_dim=prompt_embed_dim,
|
||||
image_embedding_size=(image_embedding_size, image_embedding_size),
|
||||
input_image_size=(image_size, image_size),
|
||||
mask_in_chans=16,
|
||||
),
|
||||
mask_decoder=OnePromptDecoder(
|
||||
depth = 4,
|
||||
prompt_embed_dim = prompt_embed_dim,
|
||||
embed_dim = encoder_embed_dim,
|
||||
out_chans=prompt_embed_dim,
|
||||
token_num = int(image_embedding_size * image_embedding_size),
|
||||
patch_size = vit_patch_size,
|
||||
mlp_dim = 256,
|
||||
),
|
||||
pixel_mean=[123.675, 116.28, 103.53],
|
||||
pixel_std=[58.395, 57.12, 57.375],
|
||||
)
|
||||
one.eval()
|
||||
|
||||
if checkpoint is not None:
|
||||
checkpoint = Path(checkpoint)
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
if args.image_size != 1024:
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
if "image_encoder.patch_embed" not in k:
|
||||
new_state_dict[k] = v
|
||||
# load params
|
||||
else:
|
||||
new_state_dict = state_dict
|
||||
|
||||
one.load_state_dict(new_state_dict, strict = False)
|
||||
return one
|
||||
@ -0,0 +1,8 @@
|
||||
|
||||
|
||||
from .oneprompt import OnePrompt
|
||||
from .image_encoder import OnePromptEncoderViT, OnePromptEncoderUnet
|
||||
from .mask_decoder import OnePromptDecoder
|
||||
from .prompt_encoder import PromptEncoder
|
||||
from .modules import CrossAttentionBlock
|
||||
|
||||
@ -0,0 +1,59 @@
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from typing import Type
|
||||
|
||||
class Adapter(nn.Module):
|
||||
def __init__(self, D_features, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True):
|
||||
super().__init__()
|
||||
self.skip_connect = skip_connect
|
||||
D_hidden_features = int(D_features * mlp_ratio)
|
||||
self.act = act_layer()
|
||||
self.D_fc1 = nn.Linear(D_features, D_hidden_features)
|
||||
self.D_fc2 = nn.Linear(D_hidden_features, D_features)
|
||||
|
||||
def forward(self, x):
|
||||
# x is (BT, HW+1, D)
|
||||
xs = self.D_fc1(x)
|
||||
xs = self.act(xs)
|
||||
xs = self.D_fc2(xs)
|
||||
if self.skip_connect:
|
||||
x = x + xs
|
||||
else:
|
||||
x = xs
|
||||
return x
|
||||
|
||||
|
||||
class MLPBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
mlp_dim: int,
|
||||
act: Type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
||||
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
||||
self.act = act()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.lin2(self.act(self.lin1(x)))
|
||||
|
||||
|
||||
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
||||
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
||||
class LayerNorm2d(nn.Module):
|
||||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(num_channels))
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
return x
|
||||
@ -0,0 +1,236 @@
|
||||
"""
|
||||
Helpers to train with 16-bit precision.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from . import logger
|
||||
|
||||
INITIAL_LOG_LOSS_SCALE = 20.0
|
||||
|
||||
|
||||
def convert_module_to_f16(l):
|
||||
"""
|
||||
Convert primitive modules to float16.
|
||||
"""
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
l.weight.data = l.weight.data.half()
|
||||
if l.bias is not None:
|
||||
l.bias.data = l.bias.data.half()
|
||||
|
||||
|
||||
def convert_module_to_f32(l):
|
||||
"""
|
||||
Convert primitive modules to float32, undoing convert_module_to_f16().
|
||||
"""
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
l.weight.data = l.weight.data.float()
|
||||
if l.bias is not None:
|
||||
l.bias.data = l.bias.data.float()
|
||||
|
||||
|
||||
def make_master_params(param_groups_and_shapes):
|
||||
"""
|
||||
Copy model parameters into a (differently-shaped) list of full-precision
|
||||
parameters.
|
||||
"""
|
||||
master_params = []
|
||||
for param_group, shape in param_groups_and_shapes:
|
||||
master_param = nn.Parameter(
|
||||
_flatten_dense_tensors(
|
||||
[param.detach().float() for (_, param) in param_group]
|
||||
).view(shape)
|
||||
)
|
||||
master_param.requires_grad = True
|
||||
master_params.append(master_param)
|
||||
return master_params
|
||||
|
||||
|
||||
def model_grads_to_master_grads(param_groups_and_shapes, master_params):
|
||||
"""
|
||||
Copy the gradients from the model parameters into the master parameters
|
||||
from make_master_params().
|
||||
"""
|
||||
for master_param, (param_group, shape) in zip(
|
||||
master_params, param_groups_and_shapes
|
||||
):
|
||||
master_param.grad = _flatten_dense_tensors(
|
||||
[param_grad_or_zeros(param) for (_, param) in param_group]
|
||||
).view(shape)
|
||||
|
||||
|
||||
def master_params_to_model_params(param_groups_and_shapes, master_params):
|
||||
"""
|
||||
Copy the master parameter data back into the model parameters.
|
||||
"""
|
||||
# Without copying to a list, if a generator is passed, this will
|
||||
# silently not copy any parameters.
|
||||
for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
|
||||
for (_, param), unflat_master_param in zip(
|
||||
param_group, unflatten_master_params(param_group, master_param.view(-1))
|
||||
):
|
||||
param.detach().copy_(unflat_master_param)
|
||||
|
||||
|
||||
def unflatten_master_params(param_group, master_param):
|
||||
return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
|
||||
|
||||
|
||||
def get_param_groups_and_shapes(named_model_params):
|
||||
named_model_params = list(named_model_params)
|
||||
scalar_vector_named_params = (
|
||||
[(n, p) for (n, p) in named_model_params if p.ndim <= 1],
|
||||
(-1),
|
||||
)
|
||||
matrix_named_params = (
|
||||
[(n, p) for (n, p) in named_model_params if p.ndim > 1],
|
||||
(1, -1),
|
||||
)
|
||||
return [scalar_vector_named_params, matrix_named_params]
|
||||
|
||||
|
||||
def master_params_to_state_dict(
|
||||
model, param_groups_and_shapes, master_params, use_fp16
|
||||
):
|
||||
if use_fp16:
|
||||
state_dict = model.state_dict()
|
||||
for master_param, (param_group, _) in zip(
|
||||
master_params, param_groups_and_shapes
|
||||
):
|
||||
for (name, _), unflat_master_param in zip(
|
||||
param_group, unflatten_master_params(param_group, master_param.view(-1))
|
||||
):
|
||||
assert name in state_dict
|
||||
state_dict[name] = unflat_master_param
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
for i, (name, _value) in enumerate(model.named_parameters()):
|
||||
assert name in state_dict
|
||||
state_dict[name] = master_params[i]
|
||||
return state_dict
|
||||
|
||||
|
||||
def state_dict_to_master_params(model, state_dict, use_fp16):
|
||||
if use_fp16:
|
||||
named_model_params = [
|
||||
(name, state_dict[name]) for name, _ in model.named_parameters()
|
||||
]
|
||||
param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
|
||||
master_params = make_master_params(param_groups_and_shapes)
|
||||
else:
|
||||
master_params = [state_dict[name] for name, _ in model.named_parameters()]
|
||||
return master_params
|
||||
|
||||
|
||||
def zero_master_grads(master_params):
|
||||
for param in master_params:
|
||||
param.grad = None
|
||||
|
||||
|
||||
def zero_grad(model_params):
|
||||
for param in model_params:
|
||||
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
|
||||
if param.grad is not None:
|
||||
param.grad.detach_()
|
||||
param.grad.zero_()
|
||||
|
||||
|
||||
def param_grad_or_zeros(param):
|
||||
if param.grad is not None:
|
||||
return param.grad.data.detach()
|
||||
else:
|
||||
return th.zeros_like(param)
|
||||
|
||||
|
||||
class MixedPrecisionTrainer:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model,
|
||||
use_fp16=False,
|
||||
fp16_scale_growth=1e-3,
|
||||
initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
|
||||
):
|
||||
self.model = model
|
||||
self.use_fp16 = use_fp16
|
||||
self.fp16_scale_growth = fp16_scale_growth
|
||||
|
||||
self.model_params = list(self.model.parameters())
|
||||
self.master_params = self.model_params
|
||||
self.param_groups_and_shapes = None
|
||||
self.lg_loss_scale = initial_lg_loss_scale
|
||||
|
||||
if self.use_fp16:
|
||||
self.param_groups_and_shapes = get_param_groups_and_shapes(
|
||||
self.model.named_parameters()
|
||||
)
|
||||
self.master_params = make_master_params(self.param_groups_and_shapes)
|
||||
self.model.convert_to_fp16()
|
||||
|
||||
def zero_grad(self):
|
||||
zero_grad(self.model_params)
|
||||
|
||||
def backward(self, loss: th.Tensor):
|
||||
if self.use_fp16:
|
||||
loss_scale = 2 ** self.lg_loss_scale
|
||||
(loss * loss_scale).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
def optimize(self, opt: th.optim.Optimizer):
|
||||
if self.use_fp16:
|
||||
return self._optimize_fp16(opt)
|
||||
else:
|
||||
return self._optimize_normal(opt)
|
||||
|
||||
def _optimize_fp16(self, opt: th.optim.Optimizer):
|
||||
logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
|
||||
model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
|
||||
grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
|
||||
if check_overflow(grad_norm):
|
||||
self.lg_loss_scale -= 1
|
||||
logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
|
||||
zero_master_grads(self.master_params)
|
||||
return False
|
||||
|
||||
logger.logkv_mean("grad_norm", grad_norm)
|
||||
logger.logkv_mean("param_norm", param_norm)
|
||||
|
||||
self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
|
||||
opt.step()
|
||||
zero_master_grads(self.master_params)
|
||||
master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
|
||||
self.lg_loss_scale += self.fp16_scale_growth
|
||||
return True
|
||||
|
||||
def _optimize_normal(self, opt: th.optim.Optimizer):
|
||||
grad_norm, param_norm = self._compute_norms()
|
||||
logger.logkv_mean("grad_norm", grad_norm)
|
||||
logger.logkv_mean("param_norm", param_norm)
|
||||
opt.step()
|
||||
return True
|
||||
|
||||
def _compute_norms(self, grad_scale=1.0):
|
||||
grad_norm = 0.0
|
||||
param_norm = 0.0
|
||||
for p in self.master_params:
|
||||
with th.no_grad():
|
||||
param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
|
||||
if p.grad is not None:
|
||||
grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
|
||||
return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
|
||||
|
||||
def master_params_to_state_dict(self, master_params):
|
||||
return master_params_to_state_dict(
|
||||
self.model, self.param_groups_and_shapes, master_params, self.use_fp16
|
||||
)
|
||||
|
||||
def state_dict_to_master_params(self, state_dict):
|
||||
return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
|
||||
|
||||
|
||||
def check_overflow(value):
|
||||
return (value == float("inf")) or (value == -float("inf")) or (value != value)
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,324 @@
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
from .common import LayerNorm2d
|
||||
from .modules import CrossAttentionBlock, OnePromptFormer, TwoWayTransformer
|
||||
from einops import rearrange
|
||||
import math
|
||||
from .image_encoder import PatchEmbed
|
||||
|
||||
class OnePromptDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
depth: int = 4,
|
||||
prompt_embed_dim: int = 256,
|
||||
embed_dim: int = 768,
|
||||
out_chans: int = 256,
|
||||
token_num: int,
|
||||
patch_size: int,
|
||||
mlp_dim: int = 1024,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.of = nn.ModuleList()
|
||||
self.deals = nn.ModuleList()
|
||||
|
||||
# nlist = [4096, 4096, 4096, 4096]
|
||||
# embed_dim_list = [768, 256, 256, 256]
|
||||
|
||||
self.updecode = MaskDecoder(
|
||||
transformer_dim = prompt_embed_dim,
|
||||
num_multimask_outputs=3,
|
||||
transformer= TwoWayTransformer(
|
||||
depth=2,
|
||||
embedding_dim=prompt_embed_dim,
|
||||
# mlp_dim=2048,
|
||||
# num_heads=8,
|
||||
mlp_dim=256,
|
||||
num_heads=2,
|
||||
)
|
||||
)
|
||||
|
||||
self.neck = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
embed_dim,
|
||||
out_chans,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(out_chans),
|
||||
nn.Conv2d(
|
||||
out_chans,
|
||||
out_chans,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(out_chans),
|
||||
)
|
||||
|
||||
for i in range(depth):
|
||||
self.of.append(
|
||||
OnePromptFormer(
|
||||
embedding_dim = prompt_embed_dim,
|
||||
prompt_embed_dim = prompt_embed_dim,
|
||||
token_num = token_num,
|
||||
num_heads = 2,
|
||||
mlp_dim = mlp_dim
|
||||
)
|
||||
)
|
||||
|
||||
self.deals.append(
|
||||
Decode_Align(embed_dim=embed_dim, transformer_dim=prompt_embed_dim, stages=token_num-1)
|
||||
)
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
kernel_size=(patch_size, patch_size),
|
||||
stride=(patch_size, patch_size),
|
||||
in_chans=prompt_embed_dim,
|
||||
embed_dim=out_chans,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
skips_raw: list,
|
||||
skips_tmp: list,
|
||||
raw_emb: torch.Tensor,
|
||||
tmp_emb: torch.Tensor,
|
||||
pt1: torch.Tensor,
|
||||
pt2: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
multimask_output: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = raw_emb + tmp_emb
|
||||
x = self.neck(x.permute(0, 3, 1, 2))
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
|
||||
raw_emb = self.neck(raw_emb.permute(0, 3, 1, 2))
|
||||
# raw_emb = raw_emb.permute(0, 2, 3, 1)
|
||||
|
||||
for u in range(self.depth):
|
||||
if u == 0:
|
||||
x, img_embed, tmp_embed, temp_pos, p1, p2= self.deals[u](x, skips_raw[-(u + 1)], skips_tmp[-(u + 1)], image_pe, pt1, pt2, dense_prompt_embeddings)
|
||||
p1 = p1 + temp_pos.flatten(2).permute(0, 2, 1)
|
||||
p2 = p2 + temp_pos.flatten(2).permute(0, 2, 1)
|
||||
img_embed = img_embed.flatten(2).permute(0, 2, 1)
|
||||
tmp_embed = tmp_embed.flatten(2).permute(0, 2, 1)
|
||||
x = x.flatten(2).permute(0, 2, 1)
|
||||
# print('tmp_embed size', tmp_embed.size())
|
||||
# print('temp_pos size', temp_pos.size())
|
||||
# print('p1 size', p1.size())
|
||||
# print('p2 size', p2.size())
|
||||
x = self.of[u](x,img_embed, tmp_embed, p1, p2)
|
||||
# print(x.size())
|
||||
x = rearrange(x,'b (c1 c2) d -> b d c1 c2', c1 = int(math.sqrt(x.size(1))))
|
||||
x = self.patch_embed(x)
|
||||
x = rearrange(x,'b c1 c2 d-> b (c1 c2) d')
|
||||
# Select the correct mask or masks for output
|
||||
low_res_masks, iou_predictions = self.updecode(
|
||||
image_embeddings=raw_emb,
|
||||
image_pe=image_pe,
|
||||
mix_embeddings=x,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
|
||||
return low_res_masks, iou_predictions
|
||||
|
||||
|
||||
|
||||
class MaskDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
transformer_dim: int,
|
||||
transformer: nn.Module,
|
||||
num_multimask_outputs: int = 3,
|
||||
activation: Type[nn.Module] = nn.GELU,
|
||||
iou_head_depth: int = 3,
|
||||
iou_head_hidden_dim: int = 256,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.transformer_dim = transformer_dim
|
||||
self.transformer = transformer
|
||||
|
||||
self.num_multimask_outputs = num_multimask_outputs
|
||||
|
||||
self.iou_token = nn.Embedding(1, transformer_dim)
|
||||
self.num_mask_tokens = num_multimask_outputs + 1
|
||||
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
||||
|
||||
self.output_upscaling = nn.Sequential(
|
||||
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(transformer_dim // 4),
|
||||
activation(),
|
||||
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
||||
activation(),
|
||||
)
|
||||
self.output_hypernetworks_mlps = nn.ModuleList(
|
||||
[
|
||||
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
||||
for i in range(self.num_mask_tokens)
|
||||
]
|
||||
)
|
||||
|
||||
self.iou_prediction_head = MLP(
|
||||
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
mix_embeddings: torch.Tensor,
|
||||
multimask_output: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
masks, iou_pred = self.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=image_pe,
|
||||
mix_embeddings=mix_embeddings,
|
||||
)
|
||||
|
||||
# Select the correct mask or masks for output
|
||||
if multimask_output:
|
||||
mask_slice = slice(1, None)
|
||||
else:
|
||||
mask_slice = slice(0, 1)
|
||||
masks = masks[:, mask_slice, :, :]
|
||||
iou_pred = iou_pred[:, mask_slice]
|
||||
|
||||
# Prepare output
|
||||
return masks, iou_pred
|
||||
|
||||
def predict_masks(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
mix_embeddings: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Predicts masks. See 'forward' for more details."""
|
||||
# Concatenate output tokens
|
||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(image_embeddings.size(0), -1, -1)
|
||||
# print("output_tokens", output_tokens.size())
|
||||
# print("mix_embeddings", mix_embeddings.size())
|
||||
tokens = torch.cat((output_tokens, mix_embeddings), dim=1)
|
||||
|
||||
# Expand per-image data in batch direction to be per-mask
|
||||
if image_embeddings.shape[0] != tokens.shape[0]:
|
||||
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
||||
else:
|
||||
src = image_embeddings
|
||||
# print("src size is", src.size())
|
||||
# print("dense_prompt_embeddings size is", dense_prompt_embeddings.size())
|
||||
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
||||
b, c, h, w = src.shape
|
||||
|
||||
# Run the transformer
|
||||
hs, src = self.transformer(src, pos_src, tokens)
|
||||
iou_token_out = hs[:, 0, :]
|
||||
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
||||
|
||||
# Upscale mask embeddings and predict masks using the mask tokens
|
||||
src = src.transpose(1, 2).view(b, c, h, w)
|
||||
upscaled_embedding = self.output_upscaling(src)
|
||||
hyper_in_list: List[torch.Tensor] = []
|
||||
for i in range(self.num_mask_tokens):
|
||||
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
|
||||
hyper_in = torch.stack(hyper_in_list, dim=1)
|
||||
b, c, h, w = upscaled_embedding.shape
|
||||
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
||||
|
||||
# Generate mask quality predictions
|
||||
iou_pred = self.iou_prediction_head(iou_token_out)
|
||||
|
||||
return masks, iou_pred
|
||||
|
||||
|
||||
|
||||
# Lightly adapted from
|
||||
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
sigmoid_output: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(
|
||||
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
||||
)
|
||||
self.sigmoid_output = sigmoid_output
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
if self.sigmoid_output:
|
||||
x = F.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decode_Align(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
embed_dim: int,
|
||||
transformer_dim: int,
|
||||
stages: int = 4096,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.transformer_dim = transformer_dim
|
||||
|
||||
self.num_mask_tokens = stages
|
||||
self.p1_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
||||
self.p2_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
||||
self.layer = nn.Linear(embed_dim, transformer_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x:torch.Tensor,
|
||||
src_embeddings:torch.Tensor,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
pt1: torch.Tensor,
|
||||
pt2: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
image_embeddings = self.layer(image_embeddings)
|
||||
src_embeddings = self.layer(src_embeddings)
|
||||
# x = self.layer(x)
|
||||
|
||||
p1 = self.p1_tokens.weight.unsqueeze(0).expand(pt1.size(0), -1, -1)
|
||||
p2 = self.p2_tokens.weight.unsqueeze(0).expand(pt1.size(0), -1, -1)
|
||||
|
||||
p1_tokens = torch.cat((p1, pt1), dim=1)
|
||||
p2_tokens = torch.cat((p2, pt2), dim=1)
|
||||
|
||||
if image_embeddings.shape[0] != p1_tokens.shape[0]:
|
||||
src = torch.repeat_interleave(image_embeddings, p1_tokens.shape[0], dim=0)
|
||||
else:
|
||||
src = image_embeddings
|
||||
src = src.permute(0, 3, 1 ,2)
|
||||
img = src_embeddings.permute(0, 3, 1 ,2)
|
||||
x = x.permute(0, 3, 1 ,2)
|
||||
src = src + dense_prompt_embeddings
|
||||
pos_src = torch.repeat_interleave(image_pe, p1_tokens.shape[0], dim=0)
|
||||
b, c, h, w = src.shape
|
||||
|
||||
return x, img, src, pos_src, p1_tokens, p2_tokens
|
||||
@ -0,0 +1,520 @@
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
import math
|
||||
from typing import Tuple, Type
|
||||
|
||||
from .common import MLPBlock
|
||||
|
||||
from torch import nn
|
||||
# from functools import partial
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
pair = lambda x: x if isinstance(x, tuple) else (x, x)
|
||||
|
||||
def gaussian_kernel(size, mean, std):
|
||||
"""Generates a 2D Gaussian kernel."""
|
||||
d = torch.distributions.Normal(mean, std)
|
||||
vals = d.log_prob(torch.arange(size).float())
|
||||
grid = torch.exp(vals[:, None] + vals[None, :])
|
||||
grid /= grid.sum()
|
||||
return grid
|
||||
|
||||
class GaussianConv2d(nn.Module):
|
||||
def __init__(self, in_channels = 1, out_channels = 1, kernel_size = 3, stride=1, padding=1, mean=0.0, std=1.0):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.mean = nn.Parameter(torch.tensor(mean), requires_grad=True)
|
||||
self.std = nn.Parameter(torch.tensor(std), requires_grad=True)
|
||||
self.weights = nn.Parameter(gaussian_kernel(kernel_size, self.mean, self.std), requires_grad=True)
|
||||
self.bias = nn.Parameter(torch.zeros(out_channels), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
return F.conv2d(x, self.weights.unsqueeze(0).unsqueeze(0).repeat(self.out_channels, self.in_channels, 1, 1),
|
||||
bias=self.bias, stride=self.stride, padding=self.padding)
|
||||
|
||||
|
||||
def PromptMLP(dim = 3, expansion_factor = 4, dropout = 0., dense = nn.Linear):
|
||||
inner_dim = int(dim * expansion_factor)
|
||||
return nn.Sequential(
|
||||
dense(dim, inner_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
dense(inner_dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
class PromptMixer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 3,
|
||||
depth: int = 1,
|
||||
expansion_factor: int = 4,
|
||||
dropout: float = 0.,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.dim = dim
|
||||
self.expansion_factor = expansion_factor
|
||||
self.dropout = dropout
|
||||
self.layers = nn.Sequential(
|
||||
Rearrange('k b n d -> b n d k'),
|
||||
*[nn.Sequential(
|
||||
PromptMLP(dim, expansion_factor, dropout),
|
||||
) for _ in range(depth)],
|
||||
# nn.LayerNorm(dim) # b n d
|
||||
)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
qk = torch.stack([q, k, v]) # 3 b n d
|
||||
res = self.layers(qk)
|
||||
# print("res size is", res.size())
|
||||
return res.squeeze(-1) # b n d
|
||||
|
||||
|
||||
class PromptParser(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
token_num: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.token_num = token_num
|
||||
|
||||
self.pt_mix = PromptMixer()
|
||||
# 使用固定的小通道数,避免显存爆炸
|
||||
self.gauss = GaussianConv2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embedding: Tensor,
|
||||
tmp_embedding: Tensor,
|
||||
prompt_embedding1: Tensor,
|
||||
prompt_embedding2: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
pt_pe = prompt_embedding1 + prompt_embedding2
|
||||
etpp = self.pt_mix(tmp_embedding, prompt_embedding1, prompt_embedding2)
|
||||
|
||||
# 使用更节省显存的计算方式
|
||||
# 原: att_m = torch.einsum ('bncd, bndx -> bncx', etpp.unsqueeze(-1), image_embedding.unsqueeze(-2))
|
||||
b, n, d = etpp.shape
|
||||
att_m = torch.bmm(etpp.view(b*n, d, 1), image_embedding.view(b*n, 1, d)).view(b, n, d, d)
|
||||
|
||||
# 禁用模糊以节省显存
|
||||
# att_m = F.avg_pool2d(att_m, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# 使用更节省显存的计算方式
|
||||
tmp_pe = tmp_embedding + pt_pe
|
||||
etq = torch.bmm(image_embedding.view(b*n, d, 1), tmp_pe.view(b*n, 1, d)).view(b, n, d, d)
|
||||
|
||||
eg = torch.max(att_m * etq, etq)
|
||||
res = torch.einsum ('bncx, bnx -> bnc', eg, tmp_pe)
|
||||
return image_embedding, res
|
||||
|
||||
class OnePromptFormer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
prompt_embed_dim: int,
|
||||
token_num: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
self.mlp_dim = mlp_dim
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
self.nn = nn.Linear(embedding_dim, prompt_embed_dim)
|
||||
|
||||
self.attns1 = Attention(prompt_embed_dim, num_heads)
|
||||
self.attns2 = Attention(prompt_embed_dim, num_heads)
|
||||
self.mlps1 = MLPBlock(prompt_embed_dim, mlp_dim, activation)
|
||||
self.norms1 = nn.LayerNorm(prompt_embed_dim)
|
||||
self.norms2 = nn.LayerNorm(prompt_embed_dim)
|
||||
|
||||
|
||||
self.parser = PromptParser(embedding_dim = prompt_embed_dim, token_num = token_num)
|
||||
self.attnt1 = Attention(prompt_embed_dim, num_heads)
|
||||
self.mlpt1 = MLPBlock(prompt_embed_dim, mlp_dim, activation)
|
||||
self.normt1 = nn.LayerNorm(prompt_embed_dim)
|
||||
self.normt2 = nn.LayerNorm(prompt_embed_dim)
|
||||
|
||||
self.attnm1 = Attention(prompt_embed_dim, num_heads)
|
||||
self.attnm2 = Attention(prompt_embed_dim, num_heads)
|
||||
|
||||
self.final = nn.Sequential(
|
||||
MLPBlock(prompt_embed_dim, mlp_dim, activation),
|
||||
nn.LayerNorm(prompt_embed_dim)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
emb: Tensor,
|
||||
image_embedding: Tensor,
|
||||
tmp_embedding: Tensor,
|
||||
prompt_embedding1: Tensor,
|
||||
prompt_embedding2: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
image_embedding, et = self.parser(image_embedding,tmp_embedding, prompt_embedding1, prompt_embedding2)
|
||||
es = self.attns1(q=image_embedding, k= emb, v= emb)
|
||||
es_bk = es
|
||||
es = self.attns2(q=et, k= es, v= es)
|
||||
es = self.norms1(es + et)
|
||||
es = self.norms2(self.mlps1(es) + es)
|
||||
|
||||
et = self.attnt1(q = es_bk, k = et, v = et)
|
||||
et = self.normt1(es_bk + et)
|
||||
et = self.norms2(self.mlps1(et) + et)
|
||||
|
||||
e = self.attnm1(q = et, k = es, v = es)
|
||||
e = self.attnm2(q = e, k = e, v = e)
|
||||
e = self.final(e)
|
||||
|
||||
return e
|
||||
|
||||
|
||||
class MixedUpScale(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
depth: int,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
self.mlp_dim = mlp_dim
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
for i in range(depth):
|
||||
self.layers.append(
|
||||
CrossAttentionBlock(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_dim=mlp_dim,
|
||||
activation=activation,
|
||||
attention_downsample_rate=attention_downsample_rate,
|
||||
skip_first_layer_pe=(i == 0),
|
||||
)
|
||||
)
|
||||
|
||||
self.final_attn = Attention(
|
||||
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
||||
)
|
||||
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embedding: Tensor,
|
||||
image_pe: Tensor,
|
||||
point_embedding: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
||||
bs, c, h, w = image_embedding.shape
|
||||
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
||||
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
||||
|
||||
# Prepare queries
|
||||
queries = point_embedding
|
||||
keys = image_embedding
|
||||
|
||||
# Apply transformer blocks and final layernorm
|
||||
for layer in self.layers:
|
||||
queries, keys = layer(
|
||||
queries=queries,
|
||||
keys=keys,
|
||||
query_pe=point_embedding,
|
||||
key_pe=image_pe,
|
||||
)
|
||||
|
||||
# Apply the final attention layer from the points to the image
|
||||
q = queries + point_embedding
|
||||
k = keys + image_pe
|
||||
attn_out = self.final_attn(q=q, k=k, v=keys)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm_final_attn(queries)
|
||||
|
||||
return queries, keys
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
depth: int,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
self.mlp_dim = mlp_dim
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
for i in range(depth):
|
||||
self.layers.append(
|
||||
TwoWayAttentionBlock(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_dim=mlp_dim,
|
||||
activation=activation,
|
||||
attention_downsample_rate=attention_downsample_rate,
|
||||
skip_first_layer_pe=(i == 0),
|
||||
)
|
||||
)
|
||||
|
||||
self.final_attn_token_to_image = Attention(
|
||||
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
||||
)
|
||||
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embedding: Tensor,
|
||||
image_pe: Tensor,
|
||||
point_embedding: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
bs, c, h, w = image_embedding.shape
|
||||
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
||||
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
||||
|
||||
# Prepare queries
|
||||
queries = point_embedding
|
||||
keys = image_embedding
|
||||
|
||||
# Apply transformer blocks and final layernorm
|
||||
for layer in self.layers:
|
||||
queries, keys = layer(
|
||||
queries=queries,
|
||||
keys=keys,
|
||||
query_pe=point_embedding,
|
||||
key_pe=image_pe,
|
||||
)
|
||||
|
||||
# Apply the final attention layer from the points to the image
|
||||
q = queries + point_embedding
|
||||
k = keys + image_pe
|
||||
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm_final_attn(queries)
|
||||
|
||||
return queries, keys
|
||||
|
||||
|
||||
class TwoWayAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int = 2048,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
skip_first_layer_pe: bool = False,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.self_attn = Attention(embedding_dim, num_heads)
|
||||
self.norm1 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.cross_attn_token_to_image = Attention(
|
||||
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
||||
self.norm3 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.norm4 = nn.LayerNorm(embedding_dim)
|
||||
self.cross_attn_image_to_token = Attention(
|
||||
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
||||
)
|
||||
|
||||
self.skip_first_layer_pe = skip_first_layer_pe
|
||||
|
||||
def forward(
|
||||
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
# Self attention block
|
||||
if self.skip_first_layer_pe:
|
||||
queries = self.self_attn(q=queries, k=queries, v=queries)
|
||||
else:
|
||||
q = queries + query_pe
|
||||
attn_out = self.self_attn(q=q, k=q, v=queries)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm1(queries)
|
||||
|
||||
# Cross attention block, tokens attending to image embedding
|
||||
q = queries + query_pe
|
||||
# print("key size is", keys.size())
|
||||
# print("image_pe size is", key_pe.size())
|
||||
k = keys + key_pe
|
||||
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm2(queries)
|
||||
|
||||
# MLP block
|
||||
mlp_out = self.mlp(queries)
|
||||
queries = queries + mlp_out
|
||||
queries = self.norm3(queries)
|
||||
|
||||
# Cross attention block, image embedding attending to tokens
|
||||
q = queries + query_pe
|
||||
k = keys + key_pe
|
||||
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
||||
keys = keys + attn_out
|
||||
keys = self.norm4(keys)
|
||||
|
||||
return queries, keys
|
||||
|
||||
|
||||
class CrossAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
depth: int,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int = 2048,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
skip_first_layer_pe: bool = False,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.self_attn = Attention(embedding_dim, num_heads)
|
||||
self.norm1 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.cross_attn_token_to_image = Attention(
|
||||
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
||||
self.norm3 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.norm4 = nn.LayerNorm(embedding_dim)
|
||||
self.cross_attn_image_to_token = Attention(
|
||||
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
||||
)
|
||||
|
||||
self.skip_first_layer_pe = skip_first_layer_pe
|
||||
|
||||
def forward(
|
||||
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
# Self attention block
|
||||
if self.skip_first_layer_pe:
|
||||
queries = self.self_attn(q=queries, k=queries, v=queries)
|
||||
else:
|
||||
q = queries + query_pe
|
||||
attn_out = self.self_attn(q=q, k=q, v=queries)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm1(queries)
|
||||
|
||||
# Cross attention block, tokens attending to image embedding
|
||||
q = queries + query_pe
|
||||
k = keys + key_pe
|
||||
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm2(queries)
|
||||
|
||||
# MLP block
|
||||
mlp_out = self.mlp(queries)
|
||||
queries = queries + mlp_out
|
||||
queries = self.norm3(queries)
|
||||
|
||||
# Cross attention block, image embedding attending to tokens
|
||||
q = queries + query_pe
|
||||
k = keys + key_pe
|
||||
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
||||
keys = keys + attn_out
|
||||
keys = self.norm4(keys)
|
||||
|
||||
return queries, keys
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
An attention layer that allows for downscaling the size of the embedding
|
||||
after projection to queries, keys, and values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
downsample_rate: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.internal_dim = embedding_dim // downsample_rate
|
||||
self.num_heads = num_heads
|
||||
# print("self.embedding_dim is", self.embedding_dim)
|
||||
# print("self.internal_dim is", self.internal_dim)
|
||||
# print("num_heads is", num_heads)
|
||||
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
||||
|
||||
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
||||
|
||||
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
||||
b, n, c = x.shape
|
||||
x = x.reshape(b, n, num_heads, c // num_heads)
|
||||
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
||||
|
||||
def _recombine_heads(self, x: Tensor) -> Tensor:
|
||||
b, n_heads, n_tokens, c_per_head = x.shape
|
||||
x = x.transpose(1, 2)
|
||||
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
||||
# Input projections
|
||||
q = self.q_proj(q)
|
||||
k = self.k_proj(k)
|
||||
v = self.v_proj(v)
|
||||
|
||||
# Separate into heads
|
||||
q = self._separate_heads(q, self.num_heads)
|
||||
k = self._separate_heads(k, self.num_heads)
|
||||
v = self._separate_heads(v, self.num_heads)
|
||||
|
||||
# Attention
|
||||
_, _, _, c_per_head = q.shape
|
||||
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
||||
attn = attn / math.sqrt(c_per_head)
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
|
||||
# Get output
|
||||
out = attn @ v
|
||||
out = self._recombine_heads(out)
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out
|
||||
@ -0,0 +1,173 @@
|
||||
"""
|
||||
Various utilities for neural networks.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * th.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
def layer_norm(shape, *args, **kwargs):
|
||||
|
||||
return nn.LayerNorm(shape, *args, **kwargs)
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def update_ema(target_params, source_params, rate=0.99):
|
||||
"""
|
||||
Update target parameters to be closer to those of source parameters using
|
||||
an exponential moving average.
|
||||
|
||||
:param target_params: the target parameter sequence.
|
||||
:param source_params: the source parameter sequence.
|
||||
:param rate: the EMA rate (closer to 1 means slower).
|
||||
"""
|
||||
for targ, src in zip(target_params, source_params):
|
||||
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = th.exp(
|
||||
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
|
||||
).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(th.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
with th.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
with th.enable_grad():
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
||||
output_tensors = ctx.run_function(*shallow_copies)
|
||||
input_grads = th.autograd.grad(
|
||||
output_tensors,
|
||||
ctx.input_tensors + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (None, None) + input_grads
|
||||
@ -0,0 +1,132 @@
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from .image_encoder import OnePromptEncoderViT
|
||||
from .mask_decoder import OnePromptDecoder
|
||||
from .prompt_encoder import PromptEncoder
|
||||
|
||||
|
||||
class OnePrompt(nn.Module):
|
||||
mask_threshold: float = 0.0
|
||||
image_format: str = "RGB"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
image_encoder: OnePromptEncoderViT,
|
||||
prompt_encoder: PromptEncoder,
|
||||
mask_decoder: OnePromptDecoder,
|
||||
pixel_mean: List[float] = [123.675, 116.28, 103.53],
|
||||
pixel_std: List[float] = [58.395, 57.12, 57.375],
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.image_encoder = image_encoder
|
||||
self.prompt_encoder = prompt_encoder
|
||||
self.mask_decoder = mask_decoder
|
||||
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
||||
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
||||
|
||||
@property
|
||||
def device(self) -> Any:
|
||||
return self.pixel_mean.device
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
batched_input: List[Dict[str, Any]],
|
||||
template_input: List[Dict[str, Any]],
|
||||
multimask_output: bool,
|
||||
) -> List[Dict[str, torch.Tensor]]:
|
||||
|
||||
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
|
||||
template_images = torch.stack([self.preprocess(x["image"]) for x in template_input], dim=0)
|
||||
r_emb, r_list = self.image_encoder(input_images)
|
||||
t_emb, t_list = self.image_encoder(template_images)
|
||||
|
||||
outputs = []
|
||||
for image_record, r_list, t_list, r_emb, t_emb in zip(batched_input, r_list, t_list, r_emb, t_emb):
|
||||
if "point_coords" in image_record:
|
||||
points = (image_record["point_coords"], image_record["point_labels"])
|
||||
else:
|
||||
points = None
|
||||
p1, p2, sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
||||
points=points,
|
||||
boxes=image_record.get("boxes", None),
|
||||
masks=image_record.get("mask_inputs", None),
|
||||
)
|
||||
low_res_masks, iou_predictions = self.mask_decoder(
|
||||
skips_raw = r_list,
|
||||
skips_tmp = t_list,
|
||||
raw_emb = r_emb,
|
||||
tmp_emb = t_emb,
|
||||
pt1 = p1,
|
||||
pt2 = p2,
|
||||
image_pe=self.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
masks = self.postprocess_masks(
|
||||
low_res_masks,
|
||||
input_size=image_record["image"].shape[-2:],
|
||||
original_size=image_record["original_size"],
|
||||
)
|
||||
masks = masks > self.mask_threshold
|
||||
outputs.append(
|
||||
{
|
||||
"masks": masks,
|
||||
"iou_predictions": iou_predictions,
|
||||
"low_res_logits": low_res_masks,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
||||
def postprocess_masks(
|
||||
self,
|
||||
masks: torch.Tensor,
|
||||
input_size: Tuple[int, ...],
|
||||
original_size: Tuple[int, ...],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Remove padding and upscale masks to the original image size.
|
||||
|
||||
Arguments:
|
||||
masks (torch.Tensor): Batched masks from the mask_decoder,
|
||||
in BxCxHxW format.
|
||||
input_size (tuple(int, int)): The size of the image input to the
|
||||
model, in (H, W) format. Used to remove padding.
|
||||
original_size (tuple(int, int)): The original size of the image
|
||||
before resizing for input to the model, in (H, W) format.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
|
||||
is given by original_size.
|
||||
"""
|
||||
masks = F.interpolate(
|
||||
masks,
|
||||
(self.image_encoder.img_size, self.image_encoder.img_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
masks = masks[..., : input_size[0], : input_size[1]]
|
||||
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
|
||||
return masks
|
||||
|
||||
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Normalize pixel values and pad to a square input."""
|
||||
# Normalize colors
|
||||
x = (x - self.pixel_mean) / self.pixel_std
|
||||
|
||||
# Pad
|
||||
h, w = x.shape[-2:]
|
||||
padh = self.image_encoder.img_size - h
|
||||
padw = self.image_encoder.img_size - w
|
||||
x = F.pad(x, (0, padw, 0, padh))
|
||||
return x
|
||||
@ -0,0 +1,219 @@
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from typing import Any, Optional, Tuple, Type
|
||||
|
||||
from .common import LayerNorm2d
|
||||
|
||||
|
||||
class PromptEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
image_embedding_size: Tuple[int, int],
|
||||
input_image_size: Tuple[int, int],
|
||||
mask_in_chans: int,
|
||||
activation: Type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
Arguments:
|
||||
embed_dim (int): The prompts' embedding dimension
|
||||
image_embedding_size (tuple(int, int)): The spatial size of the
|
||||
image embedding, as (H, W).
|
||||
input_image_size (int): The padded size of the image as input
|
||||
to the image encoder, as (H, W).
|
||||
mask_in_chans (int): The number of hidden channels used for
|
||||
encoding input masks.
|
||||
activation (nn.Module): The activation to use when encoding
|
||||
input masks.
|
||||
"""
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.input_image_size = input_image_size
|
||||
self.image_embedding_size = image_embedding_size
|
||||
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
||||
|
||||
self.num_point_embeddings: int = 6 # pos/neg point/doodle + 2 box corners
|
||||
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
|
||||
self.point_embeddings = nn.ModuleList(point_embeddings)
|
||||
|
||||
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
||||
self.not_a_doodle_embed = nn.Embedding(1, embed_dim)
|
||||
|
||||
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
|
||||
self.mask_downscaling = nn.Sequential(
|
||||
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(mask_in_chans // 4),
|
||||
activation(),
|
||||
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
||||
LayerNorm2d(mask_in_chans),
|
||||
activation(),
|
||||
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
||||
)
|
||||
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
||||
|
||||
def get_dense_pe(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns the positional encoding used to encode point prompts,
|
||||
applied to a dense set of points the shape of the image encoding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Positional encoding with shape
|
||||
1x(embed_dim)x(embedding_h)x(embedding_w)
|
||||
"""
|
||||
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
||||
|
||||
def _embed_points(
|
||||
self,
|
||||
points: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
pad: bool,
|
||||
) -> torch.Tensor:
|
||||
"""Embeds point prompts."""
|
||||
points = points + 0.5 # Shift to center of pixel
|
||||
if pad:
|
||||
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
||||
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
||||
points = torch.cat([points, padding_point], dim=1)
|
||||
labels = torch.cat([labels, padding_label], dim=1)
|
||||
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
|
||||
point_embedding[labels == -1] = 0.0
|
||||
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
||||
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
||||
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
||||
return point_embedding[:, 0, :], point_embedding[:, 1, :]
|
||||
|
||||
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
||||
"""Embeds box prompts."""
|
||||
boxes = boxes + 0.5 # Shift to center of pixel
|
||||
coords = boxes.reshape(-1, 2, 2)
|
||||
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
|
||||
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
||||
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
||||
return corner_embedding[:, 0, :], corner_embedding[:, 1, :]
|
||||
|
||||
def _embed_doodles(
|
||||
self,
|
||||
doodles: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Embeds doodle prompts."""
|
||||
doodles = doodles + 0.5 # Shift to center of pixel
|
||||
doodle_embedding = self.pe_layer.forward_with_coords(doodles, self.input_image_size)
|
||||
doodle_embedding[labels == -1] = 0.0
|
||||
doodle_embedding[labels == -1] += self.not_a_doodle_embed.weight
|
||||
doodle_embedding[labels == 0] += self.point_embeddings[4].weight
|
||||
doodle_embedding[labels == 1] += self.point_embeddings[5].weight
|
||||
return doodle_embedding[:, 0, :], doodle_embedding[:, 1, :]
|
||||
|
||||
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
||||
"""Embeds mask inputs."""
|
||||
mask_embedding = self.mask_downscaling(masks)
|
||||
return mask_embedding, mask_embedding
|
||||
|
||||
def _get_batch_size(
|
||||
self,
|
||||
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
boxes: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
) -> int:
|
||||
"""
|
||||
Gets the batch size of the output given the batch size of the input prompts.
|
||||
"""
|
||||
if points is not None:
|
||||
return points[0].shape[0]
|
||||
elif boxes is not None:
|
||||
return boxes.shape[0]
|
||||
elif masks is not None:
|
||||
return masks.shape[0]
|
||||
else:
|
||||
return 1
|
||||
|
||||
def _get_device(self) -> torch.device:
|
||||
return self.point_embeddings[0].weight.device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
boxes: Optional[torch.Tensor],
|
||||
doodles: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
masks: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
|
||||
bs = self._get_batch_size(points, boxes, masks)
|
||||
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
|
||||
if points is not None:
|
||||
coords, labels = points
|
||||
p1, p2 = self._embed_points(coords, labels, pad=(boxes is None))
|
||||
p1 = torch.cat([sparse_embeddings, p1.unsqueeze(1)], dim=1)
|
||||
p2 = torch.cat([sparse_embeddings, p2.unsqueeze(1)], dim=1)
|
||||
sparse_embeddings = torch.cat([sparse_embeddings, p1, p2], dim=1)
|
||||
if boxes is not None:
|
||||
p1, p2 = self._embed_boxes(boxes)
|
||||
p1 = torch.cat([sparse_embeddings, p1.unsqueeze(1)], dim=1)
|
||||
p2 = torch.cat([sparse_embeddings, p2.unsqueeze(1)], dim=1)
|
||||
sparse_embeddings = torch.cat([sparse_embeddings, p1, p2], dim=1)
|
||||
if doodles is not None:
|
||||
coords, labels = doodles
|
||||
p1, p2 = self._embed_doodles(coords, labels, pad=(boxes is None))
|
||||
p1 = torch.cat([sparse_embeddings, p1.unsqueeze(1)], dim=1)
|
||||
p2 = torch.cat([sparse_embeddings, p2.unsqueeze(1)], dim=1)
|
||||
sparse_embeddings = torch.cat([sparse_embeddings, p1, p2], dim=1)
|
||||
if masks is not None:
|
||||
p1, p2 = self._embed_masks(masks)
|
||||
dense_embeddings = p1
|
||||
else:
|
||||
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
||||
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
||||
)
|
||||
return p1, p2, sparse_embeddings, dense_embeddings
|
||||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
"""
|
||||
Positional encoding using random spatial frequencies.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
||||
super().__init__()
|
||||
if scale is None or scale <= 0.0:
|
||||
scale = 1.0
|
||||
self.register_buffer(
|
||||
"positional_encoding_gaussian_matrix",
|
||||
scale * torch.randn((2, num_pos_feats)),
|
||||
)
|
||||
|
||||
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
||||
"""Positionally encode points that are normalized to [0,1]."""
|
||||
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
||||
coords = 2 * coords - 1
|
||||
coords = coords @ self.positional_encoding_gaussian_matrix
|
||||
coords = 2 * np.pi * coords
|
||||
# outputs d_1 x ... x d_n x C shape
|
||||
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
||||
|
||||
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Generate positional encoding for a grid of the specified size."""
|
||||
h, w = size
|
||||
device: Any = self.positional_encoding_gaussian_matrix.device
|
||||
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
||||
y_embed = grid.cumsum(dim=0) - 0.5
|
||||
x_embed = grid.cumsum(dim=1) - 0.5
|
||||
y_embed = y_embed / h
|
||||
x_embed = x_embed / w
|
||||
|
||||
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
||||
return pe.permute(2, 0, 1) # C x H x W
|
||||
|
||||
def forward_with_coords(
|
||||
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
||||
) -> torch.Tensor:
|
||||
"""Positionally encode points that are not normalized to [0,1]."""
|
||||
coords = coords_input.clone()
|
||||
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
||||
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
||||
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
||||
@ -0,0 +1,176 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from .image_encoder import ImageEncoderViT
|
||||
from .mask_decoder import MaskDecoder
|
||||
from .prompt_encoder import PromptEncoder
|
||||
|
||||
|
||||
class Sam(nn.Module):
|
||||
mask_threshold: float = 0.0
|
||||
image_format: str = "RGB"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
image_encoder: ImageEncoderViT,
|
||||
prompt_encoder: PromptEncoder,
|
||||
mask_decoder: MaskDecoder,
|
||||
pixel_mean: List[float] = [123.675, 116.28, 103.53],
|
||||
pixel_std: List[float] = [58.395, 57.12, 57.375],
|
||||
) -> None:
|
||||
"""
|
||||
SAM predicts object masks from an image and input prompts.
|
||||
|
||||
Arguments:
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the
|
||||
image into image embeddings that allow for efficient mask prediction.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
|
||||
and encoded prompts.
|
||||
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
||||
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
||||
"""
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.image_encoder = image_encoder
|
||||
self.prompt_encoder = prompt_encoder
|
||||
self.mask_decoder = mask_decoder
|
||||
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
||||
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
||||
|
||||
@property
|
||||
def device(self) -> Any:
|
||||
return self.pixel_mean.device
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
batched_input: List[Dict[str, Any]],
|
||||
multimask_output: bool,
|
||||
) -> List[Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Predicts masks end-to-end from provided images and prompts.
|
||||
If prompts are not known in advance, using SamPredictor is
|
||||
recommended over calling the model directly.
|
||||
|
||||
Arguments:
|
||||
batched_input (list(dict)): A list over input images, each a
|
||||
dictionary with the following keys. A prompt key can be
|
||||
excluded if it is not present.
|
||||
'image': The image as a torch tensor in 3xHxW format,
|
||||
already transformed for input to the model.
|
||||
'original_size': (tuple(int, int)) The original size of
|
||||
the image before transformation, as (H, W).
|
||||
'point_coords': (torch.Tensor) Batched point prompts for
|
||||
this image, with shape BxNx2. Already transformed to the
|
||||
input frame of the model.
|
||||
'point_labels': (torch.Tensor) Batched labels for point prompts,
|
||||
with shape BxN.
|
||||
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
|
||||
Already transformed to the input frame of the model.
|
||||
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
|
||||
in the form Bx1xHxW.
|
||||
multimask_output (bool): Whether the model should predict multiple
|
||||
disambiguating masks, or return a single mask.
|
||||
|
||||
Returns:
|
||||
(list(dict)): A list over input images, where each element is
|
||||
as dictionary with the following keys.
|
||||
'masks': (torch.Tensor) Batched binary mask predictions,
|
||||
with shape BxCxHxW, where B is the number of input prompts,
|
||||
C is determined by multimask_output, and (H, W) is the
|
||||
original size of the image.
|
||||
'iou_predictions': (torch.Tensor) The model's predictions
|
||||
of mask quality, in shape BxC.
|
||||
'low_res_logits': (torch.Tensor) Low resolution logits with
|
||||
shape BxCxHxW, where H=W=256. Can be passed as mask input
|
||||
to subsequent iterations of prediction.
|
||||
"""
|
||||
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
|
||||
image_embeddings = self.image_encoder(input_images)
|
||||
|
||||
outputs = []
|
||||
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
||||
if "point_coords" in image_record:
|
||||
points = (image_record["point_coords"], image_record["point_labels"])
|
||||
else:
|
||||
points = None
|
||||
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
||||
points=points,
|
||||
boxes=image_record.get("boxes", None),
|
||||
masks=image_record.get("mask_inputs", None),
|
||||
)
|
||||
low_res_masks, iou_predictions = self.mask_decoder(
|
||||
image_embeddings=curr_embedding.unsqueeze(0),
|
||||
image_pe=self.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
masks = self.postprocess_masks(
|
||||
low_res_masks,
|
||||
input_size=image_record["image"].shape[-2:],
|
||||
original_size=image_record["original_size"],
|
||||
)
|
||||
masks = masks > self.mask_threshold
|
||||
outputs.append(
|
||||
{
|
||||
"masks": masks,
|
||||
"iou_predictions": iou_predictions,
|
||||
"low_res_logits": low_res_masks,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
||||
def postprocess_masks(
|
||||
self,
|
||||
masks: torch.Tensor,
|
||||
input_size: Tuple[int, ...],
|
||||
original_size: Tuple[int, ...],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Remove padding and upscale masks to the original image size.
|
||||
|
||||
Arguments:
|
||||
masks (torch.Tensor): Batched masks from the mask_decoder,
|
||||
in BxCxHxW format.
|
||||
input_size (tuple(int, int)): The size of the image input to the
|
||||
model, in (H, W) format. Used to remove padding.
|
||||
original_size (tuple(int, int)): The original size of the image
|
||||
before resizing for input to the model, in (H, W) format.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
|
||||
is given by original_size.
|
||||
"""
|
||||
masks = F.interpolate(
|
||||
masks,
|
||||
(self.image_encoder.img_size, self.image_encoder.img_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
masks = masks[..., : input_size[0], : input_size[1]]
|
||||
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
|
||||
return masks
|
||||
|
||||
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Normalize pixel values and pad to a square input."""
|
||||
# Normalize colors
|
||||
x = (x - self.pixel_mean) / self.pixel_std
|
||||
|
||||
# Pad
|
||||
h, w = x.shape[-2:]
|
||||
padh = self.image_encoder.img_size - h
|
||||
padw = self.image_encoder.img_size - w
|
||||
x = F.pad(x, (0, padw, 0, padh))
|
||||
return x
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,94 @@
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
softmax_helper = lambda x: F.softmax(x, 1)
|
||||
sigmoid_helper = lambda x: F.sigmoid(x)
|
||||
|
||||
|
||||
class InitWeights_He(object):
|
||||
def __init__(self, neg_slope=1e-2):
|
||||
self.neg_slope = neg_slope
|
||||
|
||||
def __call__(self, module):
|
||||
if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d):
|
||||
module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope)
|
||||
if module.bias is not None:
|
||||
module.bias = nn.init.constant_(module.bias, 0)
|
||||
|
||||
def maybe_to_torch(d):
|
||||
if isinstance(d, list):
|
||||
d = [maybe_to_torch(i) if not isinstance(i, torch.Tensor) else i for i in d]
|
||||
elif not isinstance(d, torch.Tensor):
|
||||
d = torch.from_numpy(d).float()
|
||||
return d
|
||||
|
||||
|
||||
def to_cuda(data, non_blocking=True, gpu_id=0):
|
||||
if isinstance(data, list):
|
||||
data = [i.cuda(gpu_id, non_blocking=non_blocking) for i in data]
|
||||
else:
|
||||
data = data.cuda(gpu_id, non_blocking=non_blocking)
|
||||
return data
|
||||
|
||||
|
||||
class no_op(object):
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
def staple(a):
|
||||
# a: n,c,h,w detach tensor
|
||||
mvres = mv(a)
|
||||
gap = 0.4
|
||||
if gap > 0.02:
|
||||
for i, s in enumerate(a):
|
||||
r = s * mvres
|
||||
res = r if i == 0 else torch.cat((res,r),0)
|
||||
nres = mv(res)
|
||||
gap = torch.mean(torch.abs(mvres - nres))
|
||||
mvres = nres
|
||||
a = res
|
||||
return mvres
|
||||
|
||||
def allone(disc,cup):
|
||||
disc = np.array(disc) / 255
|
||||
cup = np.array(cup) / 255
|
||||
res = np.clip(disc * 0.5 + cup,0,1) * 255
|
||||
res = 255 - res
|
||||
res = Image.fromarray(np.uint8(res))
|
||||
return res
|
||||
|
||||
def dice_score(pred, targs):
|
||||
pred = (pred>0).float()
|
||||
return 2. * (pred*targs).sum() / (pred+targs).sum()
|
||||
|
||||
def mv(a):
|
||||
# res = Image.fromarray(np.uint8(img_list[0] / 2 + img_list[1] / 2 ))
|
||||
# res.show()
|
||||
b = a.size(0)
|
||||
return torch.sum(a, 0, keepdim=True) / b
|
||||
|
||||
def tensor_to_img_array(tensor):
|
||||
image = tensor.cpu().detach().numpy()
|
||||
image = np.transpose(image, [0, 2, 3, 1])
|
||||
return image
|
||||
|
||||
def export(tar, img_path=None):
|
||||
# image_name = image_name or "image.jpg"
|
||||
c = tar.size(1)
|
||||
if c == 3:
|
||||
vutils.save_image(tar, fp = img_path)
|
||||
else:
|
||||
s = th.tensor(tar)[:,-1,:,:].unsqueeze(1)
|
||||
s = th.cat((s,s,s),1)
|
||||
vutils.save_image(s, fp = img_path)
|
||||
|
||||
def norm(t):
|
||||
m, s, v = torch.mean(t), torch.std(t), torch.var(t)
|
||||
return (t - m) / s
|
||||
@ -0,0 +1,264 @@
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .modeling import OnePrompt
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from .utils.transforms import ResizeLongestSide
|
||||
|
||||
|
||||
class OnePredictor:
|
||||
def __init__(
|
||||
self,
|
||||
one_model: OnePrompt,
|
||||
) -> None:
|
||||
"""
|
||||
Uses one to calculate the image embedding for an image, and then
|
||||
allow repeated, efficient mask prediction given prompts.
|
||||
|
||||
Arguments:
|
||||
one_model (one): The model to use for mask prediction.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = one_model
|
||||
self.transform = ResizeLongestSide(one_model.image_encoder.img_size)
|
||||
self.reset_image()
|
||||
|
||||
def set_image(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
image_format: str = "RGB",
|
||||
) -> None:
|
||||
"""
|
||||
Calculates the image embeddings for the provided image, allowing
|
||||
masks to be predicted with the 'predict' method.
|
||||
|
||||
Arguments:
|
||||
image (np.ndarray): The image for calculating masks. Expects an
|
||||
image in HWC uint8 format, with pixel values in [0, 255].
|
||||
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
||||
"""
|
||||
assert image_format in [
|
||||
"RGB",
|
||||
"BGR",
|
||||
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
|
||||
if image_format != self.model.image_format:
|
||||
image = image[..., ::-1]
|
||||
|
||||
# Transform the image to the form expected by the model
|
||||
input_image = self.transform.apply_image(image)
|
||||
input_image_torch = torch.as_tensor(input_image, device=self.device)
|
||||
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
|
||||
|
||||
self.set_torch_image(input_image_torch, image.shape[:2])
|
||||
|
||||
@torch.no_grad()
|
||||
def set_torch_image(
|
||||
self,
|
||||
transformed_image: torch.Tensor,
|
||||
original_image_size: Tuple[int, ...],
|
||||
) -> None:
|
||||
"""
|
||||
Calculates the image embeddings for the provided image, allowing
|
||||
masks to be predicted with the 'predict' method. Expects the input
|
||||
image to be already transformed to the format expected by the model.
|
||||
|
||||
Arguments:
|
||||
transformed_image (torch.Tensor): The input image, with shape
|
||||
1x3xHxW, which has been transformed with ResizeLongestSide.
|
||||
original_image_size (tuple(int, int)): The size of the image
|
||||
before transformation, in (H, W) format.
|
||||
"""
|
||||
assert (
|
||||
len(transformed_image.shape) == 4
|
||||
and transformed_image.shape[1] == 3
|
||||
and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
|
||||
), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
|
||||
self.reset_image()
|
||||
|
||||
self.original_size = original_image_size
|
||||
self.input_size = tuple(transformed_image.shape[-2:])
|
||||
input_image = self.model.preprocess(transformed_image)
|
||||
self.features = self.model.image_encoder(input_image)
|
||||
self.is_image_set = True
|
||||
|
||||
def predict(
|
||||
self,
|
||||
point_coords: Optional[np.ndarray] = None,
|
||||
point_labels: Optional[np.ndarray] = None,
|
||||
box: Optional[np.ndarray] = None,
|
||||
mask_input: Optional[np.ndarray] = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Predict masks for the given input prompts, using the currently set image.
|
||||
|
||||
Arguments:
|
||||
point_coords (np.ndarray or None): A Nx2 array of point prompts to the
|
||||
model. Each point is in (X,Y) in pixels.
|
||||
point_labels (np.ndarray or None): A length N array of labels for the
|
||||
point prompts. 1 indicates a foreground point and 0 indicates a
|
||||
background point.
|
||||
box (np.ndarray or None): A length 4 array given a box prompt to the
|
||||
model, in XYXY format.
|
||||
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
||||
coming from a previous prediction iteration. Has form 1xHxW, where
|
||||
for one, H=W=256.
|
||||
multimask_output (bool): If true, the model will return three masks.
|
||||
For ambiguous input prompts (such as a single click), this will often
|
||||
produce better masks than a single prediction. If only a single
|
||||
mask is needed, the model's predicted quality score can be used
|
||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
||||
input prompts, multimask_output=False can give better results.
|
||||
return_logits (bool): If true, returns un-thresholded masks logits
|
||||
instead of a binary mask.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): The output masks in CxHxW format, where C is the
|
||||
number of masks, and (H, W) is the original image size.
|
||||
(np.ndarray): An array of length C containing the model's
|
||||
predictions for the quality of each mask.
|
||||
(np.ndarray): An array of shape CxHxW, where C is the number
|
||||
of masks and H=W=256. These low resolution logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
"""
|
||||
if not self.is_image_set:
|
||||
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
||||
|
||||
# Transform input prompts
|
||||
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
|
||||
if point_coords is not None:
|
||||
assert (
|
||||
point_labels is not None
|
||||
), "point_labels must be supplied if point_coords is supplied."
|
||||
point_coords = self.transform.apply_coords(point_coords, self.original_size)
|
||||
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
|
||||
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
|
||||
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
|
||||
if box is not None:
|
||||
box = self.transform.apply_boxes(box, self.original_size)
|
||||
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
|
||||
box_torch = box_torch[None, :]
|
||||
if mask_input is not None:
|
||||
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
|
||||
mask_input_torch = mask_input_torch[None, :, :, :]
|
||||
|
||||
masks, iou_predictions, low_res_masks = self.predict_torch(
|
||||
coords_torch,
|
||||
labels_torch,
|
||||
box_torch,
|
||||
mask_input_torch,
|
||||
multimask_output,
|
||||
return_logits=return_logits,
|
||||
)
|
||||
|
||||
masks_np = masks[0].detach().cpu().numpy()
|
||||
iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
|
||||
low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
|
||||
return masks_np, iou_predictions_np, low_res_masks_np
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_torch(
|
||||
self,
|
||||
point_coords: Optional[torch.Tensor],
|
||||
point_labels: Optional[torch.Tensor],
|
||||
boxes: Optional[torch.Tensor] = None,
|
||||
mask_input: Optional[torch.Tensor] = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict masks for the given input prompts, using the currently set image.
|
||||
Input prompts are batched torch tensors and are expected to already be
|
||||
transformed to the input frame using ResizeLongestSide.
|
||||
|
||||
Arguments:
|
||||
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
|
||||
model. Each point is in (X,Y) in pixels.
|
||||
point_labels (torch.Tensor or None): A BxN array of labels for the
|
||||
point prompts. 1 indicates a foreground point and 0 indicates a
|
||||
background point.
|
||||
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
|
||||
model, in XYXY format.
|
||||
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
||||
coming from a previous prediction iteration. Has form Bx1xHxW, where
|
||||
for one, H=W=256. Masks returned by a previous iteration of the
|
||||
predict method do not need further transformation.
|
||||
multimask_output (bool): If true, the model will return three masks.
|
||||
For ambiguous input prompts (such as a single click), this will often
|
||||
produce better masks than a single prediction. If only a single
|
||||
mask is needed, the model's predicted quality score can be used
|
||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
||||
input prompts, multimask_output=False can give better results.
|
||||
return_logits (bool): If true, returns un-thresholded masks logits
|
||||
instead of a binary mask.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The output masks in BxCxHxW format, where C is the
|
||||
number of masks, and (H, W) is the original image size.
|
||||
(torch.Tensor): An array of shape BxC containing the model's
|
||||
predictions for the quality of each mask.
|
||||
(torch.Tensor): An array of shape BxCxHxW, where C is the number
|
||||
of masks and H=W=256. These low res logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
"""
|
||||
if not self.is_image_set:
|
||||
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
||||
|
||||
if point_coords is not None:
|
||||
points = (point_coords, point_labels)
|
||||
else:
|
||||
points = None
|
||||
|
||||
# Embed prompts
|
||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
||||
points=points,
|
||||
boxes=boxes,
|
||||
masks=mask_input,
|
||||
)
|
||||
|
||||
# Predict masks
|
||||
low_res_masks, iou_predictions = self.model.mask_decoder(
|
||||
image_embeddings=self.features,
|
||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
|
||||
# Upscale the masks to the original image resolution
|
||||
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
|
||||
|
||||
if not return_logits:
|
||||
masks = masks > self.model.mask_threshold
|
||||
|
||||
return masks, iou_predictions, low_res_masks
|
||||
|
||||
def get_image_embedding(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns the image embeddings for the currently set image, with
|
||||
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
|
||||
the embedding spatial dimension of one (typically C=256, H=W=64).
|
||||
"""
|
||||
if not self.is_image_set:
|
||||
raise RuntimeError(
|
||||
"An image must be set with .set_image(...) to generate an embedding."
|
||||
)
|
||||
assert self.features is not None, "Features must exist if an image has been set."
|
||||
return self.features
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.model.device
|
||||
|
||||
def reset_image(self) -> None:
|
||||
"""Resets the currently set image."""
|
||||
self.is_image_set = False
|
||||
self.features = None
|
||||
self.orig_h = None
|
||||
self.orig_w = None
|
||||
self.input_h = None
|
||||
self.input_w = None
|
||||
@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
@ -0,0 +1,346 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from typing import Any, Dict, Generator, ItemsView, List, Tuple
|
||||
|
||||
|
||||
class MaskData:
|
||||
"""
|
||||
A structure for storing masks and their related data in batched format.
|
||||
Implements basic filtering and concatenation.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
for v in kwargs.values():
|
||||
assert isinstance(
|
||||
v, (list, np.ndarray, torch.Tensor)
|
||||
), "MaskData only supports list, numpy arrays, and torch tensors."
|
||||
self._stats = dict(**kwargs)
|
||||
|
||||
def __setitem__(self, key: str, item: Any) -> None:
|
||||
assert isinstance(
|
||||
item, (list, np.ndarray, torch.Tensor)
|
||||
), "MaskData only supports list, numpy arrays, and torch tensors."
|
||||
self._stats[key] = item
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
del self._stats[key]
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._stats[key]
|
||||
|
||||
def items(self) -> ItemsView[str, Any]:
|
||||
return self._stats.items()
|
||||
|
||||
def filter(self, keep: torch.Tensor) -> None:
|
||||
for k, v in self._stats.items():
|
||||
if v is None:
|
||||
self._stats[k] = None
|
||||
elif isinstance(v, torch.Tensor):
|
||||
self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
|
||||
elif isinstance(v, np.ndarray):
|
||||
self._stats[k] = v[keep.detach().cpu().numpy()]
|
||||
elif isinstance(v, list) and keep.dtype == torch.bool:
|
||||
self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
|
||||
elif isinstance(v, list):
|
||||
self._stats[k] = [v[i] for i in keep]
|
||||
else:
|
||||
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
||||
|
||||
def cat(self, new_stats: "MaskData") -> None:
|
||||
for k, v in new_stats.items():
|
||||
if k not in self._stats or self._stats[k] is None:
|
||||
self._stats[k] = deepcopy(v)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
self._stats[k] = torch.cat([self._stats[k], v], dim=0)
|
||||
elif isinstance(v, np.ndarray):
|
||||
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
|
||||
elif isinstance(v, list):
|
||||
self._stats[k] = self._stats[k] + deepcopy(v)
|
||||
else:
|
||||
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
||||
|
||||
def to_numpy(self) -> None:
|
||||
for k, v in self._stats.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
self._stats[k] = v.detach().cpu().numpy()
|
||||
|
||||
|
||||
def is_box_near_crop_edge(
|
||||
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
|
||||
) -> torch.Tensor:
|
||||
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
||||
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
||||
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
||||
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
|
||||
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
||||
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
||||
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
||||
return torch.any(near_crop_edge, dim=1)
|
||||
|
||||
|
||||
def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
|
||||
box_xywh = deepcopy(box_xyxy)
|
||||
box_xywh[2] = box_xywh[2] - box_xywh[0]
|
||||
box_xywh[3] = box_xywh[3] - box_xywh[1]
|
||||
return box_xywh
|
||||
|
||||
|
||||
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
||||
assert len(args) > 0 and all(
|
||||
len(a) == len(args[0]) for a in args
|
||||
), "Batched iteration must have inputs of all the same size."
|
||||
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
||||
for b in range(n_batches):
|
||||
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
|
||||
|
||||
|
||||
def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Encodes masks to an uncompressed RLE, in the format expected by
|
||||
pycoco tools.
|
||||
"""
|
||||
# Put in fortran order and flatten h,w
|
||||
b, h, w = tensor.shape
|
||||
tensor = tensor.permute(0, 2, 1).flatten(1)
|
||||
|
||||
# Compute change indices
|
||||
diff = tensor[:, 1:] ^ tensor[:, :-1]
|
||||
change_indices = diff.nonzero()
|
||||
|
||||
# Encode run length
|
||||
out = []
|
||||
for i in range(b):
|
||||
cur_idxs = change_indices[change_indices[:, 0] == i, 1]
|
||||
cur_idxs = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
|
||||
cur_idxs + 1,
|
||||
torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
|
||||
]
|
||||
)
|
||||
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
||||
counts = [] if tensor[i, 0] == 0 else [0]
|
||||
counts.extend(btw_idxs.detach().cpu().tolist())
|
||||
out.append({"size": [h, w], "counts": counts})
|
||||
return out
|
||||
|
||||
|
||||
def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
||||
"""Compute a binary mask from an uncompressed RLE."""
|
||||
h, w = rle["size"]
|
||||
mask = np.empty(h * w, dtype=bool)
|
||||
idx = 0
|
||||
parity = False
|
||||
for count in rle["counts"]:
|
||||
mask[idx : idx + count] = parity
|
||||
idx += count
|
||||
parity ^= True
|
||||
mask = mask.reshape(w, h)
|
||||
return mask.transpose() # Put in C order
|
||||
|
||||
|
||||
def area_from_rle(rle: Dict[str, Any]) -> int:
|
||||
return sum(rle["counts"][1::2])
|
||||
|
||||
|
||||
def calculate_stability_score(
|
||||
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the stability score for a batch of masks. The stability
|
||||
score is the IoU between the binary masks obtained by thresholding
|
||||
the predicted mask logits at high and low values.
|
||||
"""
|
||||
# One mask is always contained inside the other.
|
||||
# Save memory by preventing unnecessary cast to torch.int64
|
||||
intersections = (
|
||||
(masks > (mask_threshold + threshold_offset))
|
||||
.sum(-1, dtype=torch.int16)
|
||||
.sum(-1, dtype=torch.int32)
|
||||
)
|
||||
unions = (
|
||||
(masks > (mask_threshold - threshold_offset))
|
||||
.sum(-1, dtype=torch.int16)
|
||||
.sum(-1, dtype=torch.int32)
|
||||
)
|
||||
return intersections / unions
|
||||
|
||||
|
||||
def build_point_grid(n_per_side: int) -> np.ndarray:
|
||||
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
||||
offset = 1 / (2 * n_per_side)
|
||||
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
||||
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
||||
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
||||
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
||||
return points
|
||||
|
||||
|
||||
def build_all_layer_point_grids(
|
||||
n_per_side: int, n_layers: int, scale_per_layer: int
|
||||
) -> List[np.ndarray]:
|
||||
"""Generates point grids for all crop layers."""
|
||||
points_by_layer = []
|
||||
for i in range(n_layers + 1):
|
||||
n_points = int(n_per_side / (scale_per_layer**i))
|
||||
points_by_layer.append(build_point_grid(n_points))
|
||||
return points_by_layer
|
||||
|
||||
|
||||
def generate_crop_boxes(
|
||||
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
||||
) -> Tuple[List[List[int]], List[int]]:
|
||||
"""
|
||||
Generates a list of crop boxes of different sizes. Each layer
|
||||
has (2**i)**2 boxes for the ith layer.
|
||||
"""
|
||||
crop_boxes, layer_idxs = [], []
|
||||
im_h, im_w = im_size
|
||||
short_side = min(im_h, im_w)
|
||||
|
||||
# Original image
|
||||
crop_boxes.append([0, 0, im_w, im_h])
|
||||
layer_idxs.append(0)
|
||||
|
||||
def crop_len(orig_len, n_crops, overlap):
|
||||
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
|
||||
|
||||
for i_layer in range(n_layers):
|
||||
n_crops_per_side = 2 ** (i_layer + 1)
|
||||
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
||||
|
||||
crop_w = crop_len(im_w, n_crops_per_side, overlap)
|
||||
crop_h = crop_len(im_h, n_crops_per_side, overlap)
|
||||
|
||||
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
|
||||
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
|
||||
|
||||
# Crops in XYWH format
|
||||
for x0, y0 in product(crop_box_x0, crop_box_y0):
|
||||
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
|
||||
crop_boxes.append(box)
|
||||
layer_idxs.append(i_layer + 1)
|
||||
|
||||
return crop_boxes, layer_idxs
|
||||
|
||||
|
||||
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
||||
x0, y0, _, _ = crop_box
|
||||
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
||||
# Check if boxes has a channel dimension
|
||||
if len(boxes.shape) == 3:
|
||||
offset = offset.unsqueeze(1)
|
||||
return boxes + offset
|
||||
|
||||
|
||||
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
||||
x0, y0, _, _ = crop_box
|
||||
offset = torch.tensor([[x0, y0]], device=points.device)
|
||||
# Check if points has a channel dimension
|
||||
if len(points.shape) == 3:
|
||||
offset = offset.unsqueeze(1)
|
||||
return points + offset
|
||||
|
||||
|
||||
def uncrop_masks(
|
||||
masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
|
||||
) -> torch.Tensor:
|
||||
x0, y0, x1, y1 = crop_box
|
||||
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
||||
return masks
|
||||
# Coordinate transform masks
|
||||
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
|
||||
pad = (x0, pad_x - x0, y0, pad_y - y0)
|
||||
return torch.nn.functional.pad(masks, pad, value=0)
|
||||
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
|
||||
def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
|
||||
from pycocotools import mask as mask_utils # type: ignore
|
||||
|
||||
h, w = uncompressed_rle["size"]
|
||||
rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
|
||||
rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
|
||||
return rle
|
||||
|
||||
|
||||
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
|
||||
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
||||
"""
|
||||
# torch.max below raises an error on empty inputs, just skip in this case
|
||||
if torch.numel(masks) == 0:
|
||||
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
||||
|
||||
# Normalize shape to CxHxW
|
||||
shape = masks.shape
|
||||
h, w = shape[-2:]
|
||||
if len(shape) > 2:
|
||||
masks = masks.flatten(0, -3)
|
||||
else:
|
||||
masks = masks.unsqueeze(0)
|
||||
|
||||
# Get top and bottom edges
|
||||
in_height, _ = torch.max(masks, dim=-1)
|
||||
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
|
||||
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
||||
in_height_coords = in_height_coords + h * (~in_height)
|
||||
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
||||
|
||||
# Get left and right edges
|
||||
in_width, _ = torch.max(masks, dim=-2)
|
||||
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
|
||||
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
||||
in_width_coords = in_width_coords + w * (~in_width)
|
||||
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
||||
|
||||
# If the mask is empty the right edge will be to the left of the left edge.
|
||||
# Replace these boxes with [0, 0, 0, 0]
|
||||
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
||||
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
||||
out = out * (~empty_filter).unsqueeze(-1)
|
||||
|
||||
# Return to original shape
|
||||
if len(shape) > 2:
|
||||
out = out.reshape(*shape[:-2], 4)
|
||||
else:
|
||||
out = out[0]
|
||||
|
||||
return out
|
||||
@ -0,0 +1,144 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from ..modeling import Sam
|
||||
from .amg import calculate_stability_score
|
||||
|
||||
|
||||
class SamOnnxModel(nn.Module):
|
||||
"""
|
||||
This model should not be called directly, but is used in ONNX export.
|
||||
It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
|
||||
with some functions modified to enable model tracing. Also supports extra
|
||||
options controlling what information. See the ONNX export script for details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Sam,
|
||||
return_single_mask: bool,
|
||||
use_stability_score: bool = False,
|
||||
return_extra_metrics: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.mask_decoder = model.mask_decoder
|
||||
self.model = model
|
||||
self.img_size = model.image_encoder.img_size
|
||||
self.return_single_mask = return_single_mask
|
||||
self.use_stability_score = use_stability_score
|
||||
self.stability_score_offset = 1.0
|
||||
self.return_extra_metrics = return_extra_metrics
|
||||
|
||||
@staticmethod
|
||||
def resize_longest_image_size(
|
||||
input_image_size: torch.Tensor, longest_side: int
|
||||
) -> torch.Tensor:
|
||||
input_image_size = input_image_size.to(torch.float32)
|
||||
scale = longest_side / torch.max(input_image_size)
|
||||
transformed_size = scale * input_image_size
|
||||
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
|
||||
return transformed_size
|
||||
|
||||
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
|
||||
point_coords = point_coords + 0.5
|
||||
point_coords = point_coords / self.img_size
|
||||
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
|
||||
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
|
||||
|
||||
point_embedding = point_embedding * (point_labels != -1)
|
||||
point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
|
||||
point_labels == -1
|
||||
)
|
||||
|
||||
for i in range(self.model.prompt_encoder.num_point_embeddings):
|
||||
point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
|
||||
i
|
||||
].weight * (point_labels == i)
|
||||
|
||||
return point_embedding
|
||||
|
||||
def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
|
||||
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
|
||||
mask_embedding = mask_embedding + (
|
||||
1 - has_mask_input
|
||||
) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
|
||||
return mask_embedding
|
||||
|
||||
def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
|
||||
masks = F.interpolate(
|
||||
masks,
|
||||
size=(self.img_size, self.img_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
|
||||
masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
|
||||
|
||||
orig_im_size = orig_im_size.to(torch.int64)
|
||||
h, w = orig_im_size[0], orig_im_size[1]
|
||||
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
|
||||
return masks
|
||||
|
||||
def select_masks(
|
||||
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Determine if we should return the multiclick mask or not from the number of points.
|
||||
# The reweighting is used to avoid control flow.
|
||||
score_reweight = torch.tensor(
|
||||
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
|
||||
).to(iou_preds.device)
|
||||
score = iou_preds + (num_points - 2.5) * score_reweight
|
||||
best_idx = torch.argmax(score, dim=1)
|
||||
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
|
||||
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
|
||||
|
||||
return masks, iou_preds
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
point_coords: torch.Tensor,
|
||||
point_labels: torch.Tensor,
|
||||
mask_input: torch.Tensor,
|
||||
has_mask_input: torch.Tensor,
|
||||
orig_im_size: torch.Tensor,
|
||||
):
|
||||
sparse_embedding = self._embed_points(point_coords, point_labels)
|
||||
dense_embedding = self._embed_masks(mask_input, has_mask_input)
|
||||
|
||||
masks, scores = self.model.mask_decoder.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embedding,
|
||||
dense_prompt_embeddings=dense_embedding,
|
||||
)
|
||||
|
||||
if self.use_stability_score:
|
||||
scores = calculate_stability_score(
|
||||
masks, self.model.mask_threshold, self.stability_score_offset
|
||||
)
|
||||
|
||||
if self.return_single_mask:
|
||||
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
|
||||
|
||||
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
|
||||
|
||||
if self.return_extra_metrics:
|
||||
stability_scores = calculate_stability_score(
|
||||
upscaled_masks, self.model.mask_threshold, self.stability_score_offset
|
||||
)
|
||||
areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
|
||||
return upscaled_masks, scores, stability_scores, areas, masks
|
||||
|
||||
return upscaled_masks, scores, masks
|
||||
@ -0,0 +1,101 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class ResizeLongestSide:
|
||||
"""
|
||||
Resizes images to the longest side 'target_length', as well as provides
|
||||
methods for resizing coordinates and boxes. Provides methods for
|
||||
transforming both numpy array and batched torch tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, target_length: int) -> None:
|
||||
self.target_length = target_length
|
||||
|
||||
def apply_image(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Expects a numpy array with shape HxWxC in uint8 format.
|
||||
"""
|
||||
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
|
||||
return np.array(resize(to_pil_image(image), target_size))
|
||||
|
||||
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
||||
"""
|
||||
Expects a numpy array of length 2 in the final dimension. Requires the
|
||||
original image size in (H, W) format.
|
||||
"""
|
||||
old_h, old_w = original_size
|
||||
new_h, new_w = self.get_preprocess_shape(old_h, old_w, self.target_length)
|
||||
new_coords = np.empty_like(coords)
|
||||
new_coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
||||
new_coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
||||
return new_coords
|
||||
|
||||
|
||||
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
||||
"""
|
||||
Expects a numpy array shape Bx4. Requires the original image size
|
||||
in (H, W) format.
|
||||
"""
|
||||
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
|
||||
return boxes.reshape(-1, 4)
|
||||
|
||||
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Expects batched images with shape BxCxHxW and float format. This
|
||||
transformation may not exactly match apply_image. apply_image is
|
||||
the transformation expected by the model.
|
||||
"""
|
||||
# Expects an image in BCHW format. May not exactly match apply_image.
|
||||
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
|
||||
return F.interpolate(
|
||||
image, target_size, mode="bilinear", align_corners=False, antialias=True
|
||||
)
|
||||
|
||||
def apply_coords_torch(
|
||||
self, coords: torch.Tensor, original_size: Tuple[int, ...]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Expects a torch tensor with length 2 in the last dimension. Requires the
|
||||
original image size in (H, W) format.
|
||||
"""
|
||||
old_h, old_w = original_size
|
||||
new_h, new_w = self.get_preprocess_shape(
|
||||
original_size[0], original_size[1], self.target_length
|
||||
)
|
||||
coords = deepcopy(coords).to(torch.float)
|
||||
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
||||
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
||||
return coords
|
||||
|
||||
def apply_boxes_torch(
|
||||
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Expects a torch tensor with shape Bx4. Requires the original image
|
||||
size in (H, W) format.
|
||||
"""
|
||||
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
|
||||
return boxes.reshape(-1, 4)
|
||||
|
||||
@staticmethod
|
||||
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
|
||||
"""
|
||||
Compute the output size given input size and target long side length.
|
||||
"""
|
||||
scale = long_side_length * 1.0 / max(oldh, oldw)
|
||||
newh, neww = oldh * scale, oldw * scale
|
||||
neww = int(neww + 0.5)
|
||||
newh = int(newh + 0.5)
|
||||
return (newh, neww)
|
||||
@ -0,0 +1,163 @@
|
||||
"""resnet in pytorch
|
||||
|
||||
|
||||
|
||||
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
|
||||
|
||||
Deep Residual Learning for Image Recognition
|
||||
https://arxiv.org/abs/1512.03385v1
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
"""Basic Block for resnet 18 and resnet 34
|
||||
|
||||
"""
|
||||
|
||||
#BasicBlock and BottleNeck block
|
||||
#have different output size
|
||||
#we use class attribute expansion
|
||||
#to distinct
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super().__init__()
|
||||
|
||||
#residual function
|
||||
self.residual_function = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||
)
|
||||
|
||||
#shortcut
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
#the shortcut output dimension is not the same with residual function
|
||||
#use 1*1 convolution to match the dimension
|
||||
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||
|
||||
class BottleNeck(nn.Module):
|
||||
"""Residual block for resnet over 50 layers
|
||||
|
||||
"""
|
||||
expansion = 4
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super().__init__()
|
||||
self.residual_function = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
|
||||
)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
|
||||
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, num_block, num_classes=1):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = 64
|
||||
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True))
|
||||
#we use a different inputsize than the original paper
|
||||
#so conv2_x's stride is 1
|
||||
self.conv2_x = self._make_layer(block, 64, num_block[0], 2)
|
||||
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
|
||||
self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
|
||||
self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||
"""make resnet layers(by layer i didnt mean this 'layer' was the
|
||||
same as a neuron netowork layer, ex. conv layer), one layer may
|
||||
contain more than one residual block
|
||||
|
||||
Args:
|
||||
block: block type, basic block or bottle neck block
|
||||
out_channels: output depth channel number of this layer
|
||||
num_blocks: how many blocks per layer
|
||||
stride: the stride of the first block of this layer
|
||||
|
||||
Return:
|
||||
return a resnet layer
|
||||
"""
|
||||
|
||||
# we have num_block blocks per layer, the first block
|
||||
# could be 1 or 2, other blocks would always be 1
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_channels, out_channels, stride))
|
||||
self.in_channels = out_channels * block.expansion
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.conv1(x)
|
||||
output = self.conv2_x(output)
|
||||
output = self.conv3_x(output)
|
||||
output = self.conv4_x(output)
|
||||
output = self.conv5_x(output)
|
||||
output = self.avg_pool(output)
|
||||
output = output.view(output.size(0), -1)
|
||||
output = self.fc(output)
|
||||
|
||||
return output
|
||||
|
||||
def resnet18():
|
||||
""" return a ResNet 18 object
|
||||
"""
|
||||
return ResNet(BasicBlock, [2, 2, 2, 2])
|
||||
|
||||
def resnet34():
|
||||
""" return a ResNet 34 object
|
||||
"""
|
||||
return ResNet(BasicBlock, [3, 4, 6, 3])
|
||||
|
||||
def resnet50():
|
||||
""" return a ResNet 50 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 4, 6, 3])
|
||||
|
||||
def resnet101():
|
||||
""" return a ResNet 101 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 4, 23, 3])
|
||||
|
||||
def resnet152():
|
||||
""" return a ResNet 152 object
|
||||
"""
|
||||
return ResNet(BottleNeck, [3, 8, 36, 3])
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,171 @@
|
||||
"""senet in pytorch
|
||||
|
||||
|
||||
|
||||
[1] Jie Hu, Li Shen, Samuel Albanie, Gang Sun, Enhua Wu
|
||||
|
||||
Squeeze-and-Excitation Networks
|
||||
https://arxiv.org/abs/1709.01507
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class BasicResidualSEBlock(nn.Module):
|
||||
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride, r=16):
|
||||
super().__init__()
|
||||
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
|
||||
nn.Conv2d(out_channels, out_channels * self.expansion, 3, padding=1),
|
||||
nn.BatchNorm2d(out_channels * self.expansion),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_channels != out_channels * self.expansion:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride),
|
||||
nn.BatchNorm2d(out_channels * self.expansion)
|
||||
)
|
||||
|
||||
self.squeeze = nn.AdaptiveAvgPool2d(1)
|
||||
self.excitation = nn.Sequential(
|
||||
nn.Linear(out_channels * self.expansion, out_channels * self.expansion // r),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(out_channels * self.expansion // r, out_channels * self.expansion),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut(x)
|
||||
residual = self.residual(x)
|
||||
|
||||
squeeze = self.squeeze(residual)
|
||||
squeeze = squeeze.view(squeeze.size(0), -1)
|
||||
excitation = self.excitation(squeeze)
|
||||
excitation = excitation.view(residual.size(0), residual.size(1), 1, 1)
|
||||
|
||||
x = residual * excitation.expand_as(residual) + shortcut
|
||||
|
||||
return F.relu(x)
|
||||
|
||||
class BottleneckResidualSEBlock(nn.Module):
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride, r=16):
|
||||
super().__init__()
|
||||
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
|
||||
nn.Conv2d(out_channels, out_channels, 3, stride=stride, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
|
||||
nn.Conv2d(out_channels, out_channels * self.expansion, 1),
|
||||
nn.BatchNorm2d(out_channels * self.expansion),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
self.squeeze = nn.AdaptiveAvgPool2d(1)
|
||||
self.excitation = nn.Sequential(
|
||||
nn.Linear(out_channels * self.expansion, out_channels * self.expansion // r),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(out_channels * self.expansion // r, out_channels * self.expansion),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_channels != out_channels * self.expansion:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride),
|
||||
nn.BatchNorm2d(out_channels * self.expansion)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
shortcut = self.shortcut(x)
|
||||
|
||||
residual = self.residual(x)
|
||||
squeeze = self.squeeze(residual)
|
||||
squeeze = squeeze.view(squeeze.size(0), -1)
|
||||
excitation = self.excitation(squeeze)
|
||||
excitation = excitation.view(residual.size(0), residual.size(1), 1, 1)
|
||||
|
||||
x = residual * excitation.expand_as(residual) + shortcut
|
||||
|
||||
return F.relu(x)
|
||||
|
||||
class SEResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, block_num, class_num=1):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = 64
|
||||
|
||||
self.pre = nn.Sequential(
|
||||
nn.Conv2d(3, 64, 3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
self.stage1 = self._make_stage(block, block_num[0], 64, 1)
|
||||
self.stage2 = self._make_stage(block, block_num[1], 128, 2)
|
||||
self.stage3 = self._make_stage(block, block_num[2], 256, 2)
|
||||
self.stage4 = self._make_stage(block, block_num[3], 516, 2)
|
||||
|
||||
self.linear = nn.Linear(self.in_channels, class_num)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pre(x)
|
||||
|
||||
x = self.stage1(x)
|
||||
x = self.stage2(x)
|
||||
x = self.stage3(x)
|
||||
x = self.stage4(x)
|
||||
|
||||
x = F.adaptive_avg_pool2d(x, 1)
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
x = self.linear(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _make_stage(self, block, num, out_channels, stride):
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.in_channels, out_channels, stride))
|
||||
self.in_channels = out_channels * block.expansion
|
||||
|
||||
while num - 1:
|
||||
layers.append(block(self.in_channels, out_channels, 1))
|
||||
num -= 1
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def seresnet18():
|
||||
return SEResNet(BasicResidualSEBlock, [2, 2, 2, 2])
|
||||
|
||||
def seresnet34():
|
||||
return SEResNet(BasicResidualSEBlock, [3, 4, 6, 3])
|
||||
|
||||
def seresnet50():
|
||||
return SEResNet(BottleneckResidualSEBlock, [3, 4, 6, 3])
|
||||
|
||||
def seresnet101():
|
||||
return SEResNet(BottleneckResidualSEBlock, [3, 4, 23, 3])
|
||||
|
||||
def seresnet152():
|
||||
return SEResNet(BottleneckResidualSEBlock, [3, 8, 36, 3])
|
||||
@ -0,0 +1,89 @@
|
||||
"""squeezenet in pytorch
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Fire(nn.Module):
|
||||
|
||||
def __init__(self, in_channel, out_channel, squzee_channel):
|
||||
|
||||
super().__init__()
|
||||
self.squeeze = nn.Sequential(
|
||||
nn.Conv2d(in_channel, squzee_channel, 1),
|
||||
nn.BatchNorm2d(squzee_channel),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
self.expand_1x1 = nn.Sequential(
|
||||
nn.Conv2d(squzee_channel, int(out_channel / 2), 1),
|
||||
nn.BatchNorm2d(int(out_channel / 2)),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
self.expand_3x3 = nn.Sequential(
|
||||
nn.Conv2d(squzee_channel, int(out_channel / 2), 3, padding=1),
|
||||
nn.BatchNorm2d(int(out_channel / 2)),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.squeeze(x)
|
||||
x = torch.cat([
|
||||
self.expand_1x1(x),
|
||||
self.expand_3x3(x)
|
||||
], 1)
|
||||
|
||||
return x
|
||||
|
||||
class SqueezeNet(nn.Module):
|
||||
|
||||
"""mobile net with simple bypass"""
|
||||
def __init__(self, class_num=100):
|
||||
|
||||
super().__init__()
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, 96, 3, padding=1),
|
||||
nn.BatchNorm2d(96),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2)
|
||||
)
|
||||
|
||||
self.fire2 = Fire(96, 128, 16)
|
||||
self.fire3 = Fire(128, 128, 16)
|
||||
self.fire4 = Fire(128, 256, 32)
|
||||
self.fire5 = Fire(256, 256, 32)
|
||||
self.fire6 = Fire(256, 384, 48)
|
||||
self.fire7 = Fire(384, 384, 48)
|
||||
self.fire8 = Fire(384, 512, 64)
|
||||
self.fire9 = Fire(512, 512, 64)
|
||||
|
||||
self.conv10 = nn.Conv2d(512, class_num, 1)
|
||||
self.avg = nn.AdaptiveAvgPool2d(1)
|
||||
self.maxpool = nn.MaxPool2d(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stem(x)
|
||||
|
||||
f2 = self.fire2(x)
|
||||
f3 = self.fire3(f2) + f2
|
||||
f4 = self.fire4(f3)
|
||||
f4 = self.maxpool(f4)
|
||||
|
||||
f5 = self.fire5(f4) + f4
|
||||
f6 = self.fire6(f5)
|
||||
f7 = self.fire7(f6) + f6
|
||||
f8 = self.fire8(f7)
|
||||
f8 = self.maxpool(f8)
|
||||
|
||||
f9 = self.fire9(f8)
|
||||
c10 = self.conv10(f9)
|
||||
|
||||
x = self.avg(c10)
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
return x
|
||||
|
||||
def squeezenet(class_num=1):
|
||||
return SqueezeNet(class_num=class_num)
|
||||
@ -0,0 +1 @@
|
||||
from .tag import *
|
||||
@ -0,0 +1,412 @@
|
||||
import math
|
||||
import torch.nn.init as init
|
||||
from timm.models.registry import register_model
|
||||
from timm.models.layers import DropPath
|
||||
|
||||
from .tag_layers import *
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(self, stride, has_mask=False, in_ch=0, out_ch=0):
|
||||
super(PatchEmbed, self).__init__()
|
||||
self.to_token = nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1, stride=stride, groups=in_ch)
|
||||
self.proj = nn.Linear(in_ch, out_ch, bias=False)
|
||||
self.has_mask = has_mask
|
||||
|
||||
def process_mask(self, x, mask, H, W):
|
||||
if mask is None and self.has_mask:
|
||||
mask = x.new_zeros((1, 1, H, W))
|
||||
if mask is not None:
|
||||
H_mask, W_mask = mask.shape[-2:]
|
||||
if H_mask != H or W_mask != W:
|
||||
mask = F.interpolate(mask, (H, W), mode='nearest')
|
||||
return mask
|
||||
|
||||
def forward(self, x, mask):
|
||||
"""
|
||||
Args:
|
||||
x: [B, C, H, W]
|
||||
mask: [B, 1, H, W] if exists, else None
|
||||
Returns:
|
||||
out: [B, out_H * out_W, out_C]
|
||||
H, W: output height & width
|
||||
mask: [B, 1, out_H, out_W] if exists, else None
|
||||
"""
|
||||
out = self.to_token(x)
|
||||
B, C, H, W = out.shape
|
||||
mask = self.process_mask(out, mask, H, W)
|
||||
out = rearrange(out, "b c h w -> b (h w) c").contiguous()
|
||||
out = self.proj(out)
|
||||
return out, H, W, mask
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, dim, num_parts=64, num_enc_heads=1, drop_path=0.1, act=nn.GELU, has_ffn=True):
|
||||
super(Encoder, self).__init__()
|
||||
self.num_heads = num_enc_heads
|
||||
self.enc_attn = AnyAttention(dim, num_enc_heads)
|
||||
self.drop_path = DropPath(drop_prob=drop_path) if drop_path else nn.Identity()
|
||||
self.reason = SimpleReasoning(num_parts, dim)
|
||||
self.enc_ffn = Mlp(dim, hidden_features=dim, act_layer=act) if has_ffn else None
|
||||
|
||||
def forward(self, feats, parts=None, qpos=None, kpos=None, mask=None):
|
||||
"""
|
||||
Args:
|
||||
feats: [B, patch_num * patch_size, C]
|
||||
parts: [B, N, C]
|
||||
qpos: [B, N, 1, C]
|
||||
kpos: [B, patch_num * patch_size, C]
|
||||
mask: [B, 1, patch_num, patch_size] if exists, else None
|
||||
Returns:
|
||||
parts: [B, N, C]
|
||||
"""
|
||||
attn_out = self.enc_attn(q=parts, k=feats, v=feats, qpos=qpos, kpos=kpos, mask=mask)
|
||||
parts = parts + self.drop_path(attn_out)
|
||||
parts = self.reason(parts)
|
||||
if self.enc_ffn is not None:
|
||||
parts = parts + self.drop_path(self.enc_ffn(parts))
|
||||
return parts
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, patch_size=7, ffn_exp=3, act=nn.GELU, drop_path=0.1):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.attn1 = AnyAttention(dim, num_heads)
|
||||
self.attn2 = AnyAttention(dim, num_heads)
|
||||
self.rel_pos = FullRelPos(patch_size, patch_size, dim // num_heads)
|
||||
self.ffn1 = Mlp(dim, hidden_features=dim * ffn_exp, act_layer=act, norm_layer=Norm)
|
||||
self.ffn2 = Mlp(dim, hidden_features=dim * ffn_exp, act_layer=act, norm_layer=Norm)
|
||||
self.drop_path = DropPath(drop_path)
|
||||
|
||||
def forward(self, x, parts=None, qpos=None, kpos=None, mask=None, P=0):
|
||||
"""
|
||||
Args:
|
||||
x: [B, patch_num * patch_size, C]
|
||||
parts: [B, N, C]
|
||||
part_kpos: [B, N, 1, C]
|
||||
mask: [B, 1, patch_num, patch_size] if exists, else None
|
||||
P: patch_num
|
||||
Returns:
|
||||
feat: [B, patch_num, patch_size, C]
|
||||
"""
|
||||
dec_mask = None if mask is None else rearrange(mask.squeeze(1), "b h w -> b (h w) 1 1")
|
||||
out = self.attn1(q=x, k=parts, v=parts, qpos=qpos, kpos=kpos, mask=dec_mask)
|
||||
out = x + self.drop_path(out)
|
||||
out = out + self.drop_path(self.ffn1(out))
|
||||
|
||||
# out = rearrange(out, "b (p k) c -> (b p) k c", p=P)
|
||||
# local_out = self.attn2(q=out, k=out, v=out, mask=mask, rel_pos=self.rel_pos)
|
||||
# out = out + self.drop_path(local_out)
|
||||
# out = out + self.drop_path(self.ffn2(out))
|
||||
return rearrange(out, "b (p k) c -> b p k c", p=P)
|
||||
|
||||
|
||||
class TAGBlock(nn.Module):
|
||||
def __init__(self, dim, ffn_exp=4, drop_path=0.1, patch_size=7, num_heads=1, num_enc_heads=1, num_parts=0):
|
||||
super(TAGBlock, self).__init__()
|
||||
# self.encoder = Encoder(dim, num_parts=num_parts, num_enc_heads=num_enc_heads, drop_path=drop_path)
|
||||
self.decoder = Decoder(dim, num_heads=num_heads, patch_size=patch_size, ffn_exp=ffn_exp, drop_path=drop_path)
|
||||
|
||||
def forward(self, x, parts=None, qpos=None, kpos=None, mask=None):
|
||||
"""
|
||||
Args:
|
||||
x: [B, patch_num, patch_size, C]
|
||||
parts: [B, N, C]
|
||||
part_qpos: [B, N, 1, C]
|
||||
part_kpos: [B, N, 1, C]
|
||||
mask: [B, 1, patch_num, patch_size] if exists, else None
|
||||
Returns:
|
||||
feats: [B, patch_num, patch_size, C]
|
||||
parts: [B, N, C]
|
||||
part_qpos: [B, N, 1, C]
|
||||
mask: [B, 1, patch_num, patch_size] if exists, else None
|
||||
"""
|
||||
P = x.shape[1]
|
||||
x = rearrange(x, "b p k c -> b (p k) c")
|
||||
feats = self.decoder(x, parts=parts, qpos=qpos, kpos=kpos, mask=mask, P=P)
|
||||
return feats, parts, qpos, mask
|
||||
|
||||
|
||||
class Stage(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, num_blocks, patch_size=7, num_heads=1, num_enc_heads=1, stride=1, num_parts=0,
|
||||
last_np=0, last_enc=False, drop_path=0.1, has_mask=None, ffn_exp=3):
|
||||
super(Stage, self).__init__()
|
||||
if isinstance(drop_path, float):
|
||||
drop_path = [drop_path for _ in range(num_blocks)]
|
||||
self.patch_size = patch_size
|
||||
self.rpn_qpos = nn.Parameter(torch.Tensor(1, num_parts, 1, out_ch // num_heads))
|
||||
self.rpn_kpos = nn.Parameter(torch.Tensor(1, num_parts, 1, out_ch // num_heads))
|
||||
|
||||
self.proj_p = PatchEmbed(stride, has_mask = has_mask, in_ch=in_ch, out_ch=out_ch)
|
||||
self.proj_x = PatchEmbed(stride, has_mask = has_mask, in_ch=in_ch, out_ch=out_ch)
|
||||
# self.proj_token = nn.Sequential(
|
||||
# nn.Conv1d(last_np, num_parts, 1, bias=False) if last_np != num_parts else nn.Identity(),
|
||||
# nn.Linear(in_ch, out_ch),
|
||||
# Norm(out_ch)
|
||||
# )
|
||||
self.proj_token = None
|
||||
self.proj_norm = Norm(out_ch)
|
||||
blocks = [
|
||||
TAGBlock(out_ch,
|
||||
patch_size=patch_size,
|
||||
num_heads=num_heads,
|
||||
num_enc_heads=num_enc_heads,
|
||||
num_parts=num_parts,
|
||||
ffn_exp=ffn_exp,
|
||||
drop_path=drop_path[i])
|
||||
for i in range(num_blocks)
|
||||
]
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.last_enc = Encoder(dim=out_ch,
|
||||
num_enc_heads=num_enc_heads,
|
||||
num_parts=num_parts,
|
||||
drop_path=drop_path[-1],
|
||||
has_ffn=False) if last_enc else None
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
init.kaiming_uniform_(self.rpn_qpos, a=math.sqrt(5))
|
||||
trunc_normal_(self.rpn_qpos, std=.02)
|
||||
init.kaiming_uniform_(self.rpn_kpos, a=math.sqrt(5))
|
||||
trunc_normal_(self.rpn_kpos, std=.02)
|
||||
|
||||
def to_patch(self, x, patch_size, H, W, mask=None):
|
||||
x = rearrange(x, "b (h w) c -> b h w c", h=H)
|
||||
pad_l = pad_t = 0
|
||||
pad_r = int(math.ceil(W / patch_size)) * patch_size - W
|
||||
pad_b = int(math.ceil(H / patch_size)) * patch_size - H
|
||||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||
if mask is not None:
|
||||
mask = F.pad(mask, (pad_l, pad_r, pad_t, pad_b), value=1)
|
||||
x = rearrange(x, "b (sh kh) (sw kw) c -> b (sh sw) (kh kw) c", kh=patch_size, kw=patch_size)
|
||||
if mask is not None:
|
||||
mask = rearrange(mask, "b c (sh kh) (sw kw) -> b c (kh kw) (sh sw)", kh=patch_size, kw=patch_size)
|
||||
return x, mask, H + pad_b, W + pad_r
|
||||
|
||||
def to_part(self, x, mask=None):
|
||||
x, H, W, mask = self.proj_p(x, mask=mask)
|
||||
x = self.proj_norm(x)
|
||||
if self.proj_token is not None:
|
||||
parts = self.proj_token(parts)
|
||||
ori_H, ori_W = H, W
|
||||
x, mask, H, W = self.to_patch(x, self.patch_size, H, W, mask)
|
||||
P = x.shape[1]
|
||||
x = rearrange(x, "b p k c -> b (p k) c")
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, p, mask=None):
|
||||
"""
|
||||
Args:
|
||||
x: [B, C, H, W]
|
||||
parts: [B, N, C]
|
||||
mask: [B, 1, H, W] if exists, else None
|
||||
Returns:
|
||||
x: [B, out_C, out_H, out_W]
|
||||
parts: [B, out_N, out_C]
|
||||
mask: [B, 1, out_H, out_W] if exists else None
|
||||
"""
|
||||
parts = self.to_part(p, mask = mask)
|
||||
x, H, W, mask = self.proj_x(x, mask=mask)
|
||||
x = self.proj_norm(x)
|
||||
if self.proj_token is not None:
|
||||
parts = self.proj_token(parts)
|
||||
|
||||
rpn_qpos, rpn_kpos = self.rpn_qpos, self.rpn_kpos
|
||||
rpn_qpos = rpn_qpos.expand(x.shape[0], -1, -1, -1)
|
||||
rpn_kpos = rpn_kpos.expand(x.shape[0], -1, -1, -1)
|
||||
|
||||
ori_H, ori_W = H, W
|
||||
x, mask, H, W = self.to_patch(x, self.patch_size, H, W, mask)
|
||||
for blk in self.blocks:
|
||||
# x: [B, K, P, C]
|
||||
x, parts, rpn_qpos, mask = blk(x,
|
||||
parts=parts,
|
||||
qpos=rpn_qpos,
|
||||
kpos=rpn_kpos,
|
||||
mask=mask)
|
||||
|
||||
dec_mask = None if mask is None else rearrange(mask.squeeze(1), "b h w -> b 1 1 (h w)")
|
||||
if self.last_enc is not None:
|
||||
x = rearrange(x, "b p k c -> b (p k) c")
|
||||
rpn_out = self.last_enc(x, parts=parts, qpos=rpn_qpos, mask=dec_mask)
|
||||
return rpn_out
|
||||
else:
|
||||
x = rearrange(x, "b (sh sw) (kh kw) c -> b c (sh kh) (sw kw)", kh=self.patch_size, sh=H // self.patch_size)
|
||||
x = x[:, :, :ori_H, :ori_W]
|
||||
return x
|
||||
|
||||
|
||||
class TAG(nn.Module):
|
||||
def __init__(self,
|
||||
in_chans=3,
|
||||
inplanes=64,
|
||||
num_layers=(3, 4, 6, 3),
|
||||
num_chs=(256, 512, 1024, 2048),
|
||||
num_strides=(1, 2, 2, 2),
|
||||
num_classes=1000,
|
||||
num_heads=(1, 1, 1, 1),
|
||||
num_parts=(1, 1, 1, 1),
|
||||
patch_sizes=(1, 1, 1, 1),
|
||||
drop_path=0.1,
|
||||
num_enc_heads=(1, 1, 1, 1),
|
||||
act=nn.GELU,
|
||||
ffn_exp=3,
|
||||
no_pos_wd=False,
|
||||
has_last_encoder=False,
|
||||
pretrained=False,
|
||||
**ret_args):
|
||||
super(TAG, self).__init__()
|
||||
self.depth = len(num_layers)
|
||||
self.no_pos_wd = no_pos_wd
|
||||
|
||||
self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, padding=3, stride=2, bias=False)
|
||||
self.norm1 = nn.BatchNorm2d(inplanes)
|
||||
self.act = act()
|
||||
self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.rpn_tokens = nn.Parameter(torch.Tensor(1, num_parts[0], inplanes))
|
||||
|
||||
drop_path_ratios = torch.linspace(0, drop_path, sum(num_layers))
|
||||
last_chs = [inplanes, *num_chs[:-1]]
|
||||
last_nps = [num_parts[0], *num_parts[:-1]]
|
||||
|
||||
for i, n_l in enumerate(num_layers):
|
||||
stage_ratios = [drop_path_ratios[sum(num_layers[:i]) + did] for did in range(n_l)]
|
||||
setattr(self,
|
||||
"layer_{}".format(i),
|
||||
Stage(last_chs[i],
|
||||
num_chs[i],
|
||||
n_l,
|
||||
stride=num_strides[i],
|
||||
num_heads=num_heads[i],
|
||||
num_enc_heads=num_enc_heads[i],
|
||||
patch_size=patch_sizes[i],
|
||||
drop_path=stage_ratios,
|
||||
ffn_exp=ffn_exp,
|
||||
num_parts=num_parts[i],
|
||||
last_np=last_nps[i],
|
||||
last_enc=has_last_encoder and i == len(num_layers) - 1)
|
||||
)
|
||||
|
||||
if has_last_encoder:
|
||||
self.last_fc = nn.Linear(num_chs[-1], num_classes)
|
||||
else:
|
||||
self.last_linear = nn.Conv2d(num_chs[-1], num_chs[-1], kernel_size=1, bias=False)
|
||||
self.last_norm = nn.BatchNorm2d(num_chs[-1])
|
||||
self.pool2 = nn.AdaptiveAvgPool2d(1)
|
||||
self.last_fc = nn.Linear(num_chs[-1], num_classes)
|
||||
|
||||
self.has_last_encoder = has_last_encoder
|
||||
self._init_weights(pretrained=pretrained)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
skip_pattern = ['rel_pos'] if self.no_pos_wd else []
|
||||
no_wd_layers = set()
|
||||
for name, param in self.named_parameters():
|
||||
for skip_name in skip_pattern:
|
||||
if skip_name in name:
|
||||
no_wd_layers.add(name)
|
||||
return no_wd_layers
|
||||
|
||||
def _init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str):
|
||||
state_dict = torch.load(pretrained, map_location=torch.device("cpu"))
|
||||
if "state_dict" in state_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
self.load_state_dict(state_dict, strict=True)
|
||||
return
|
||||
|
||||
init.kaiming_uniform_(self.rpn_tokens, a=math.sqrt(5))
|
||||
trunc_normal_(self.rpn_tokens, std=.02)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv1d):
|
||||
n = m.kernel_size[0] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
if not torch.sum(m.weight.data == 0).item() == m.num_features: # zero gamma
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out)
|
||||
out = self.act(out)
|
||||
out = self.pool1(out)
|
||||
|
||||
B, _, H, W = out.shape
|
||||
rpn_tokens, mask = self.rpn_tokens.expand(x.shape[0], -1, -1), None
|
||||
for i in range(self.depth):
|
||||
layer = getattr(self, "layer_{}".format(i))
|
||||
out, rpn_tokens, mask = layer(out, rpn_tokens, mask=mask)
|
||||
|
||||
if self.has_last_encoder:
|
||||
out = self.act(out)
|
||||
out = out.mean(1)
|
||||
else:
|
||||
out = self.last_linear(out)
|
||||
out = self.last_norm(out)
|
||||
out = self.act(out)
|
||||
out = self.pool2(out)
|
||||
out = out.squeeze()
|
||||
out = self.last_fc(out).squeeze()
|
||||
return out.view(out.size(0), -1)
|
||||
|
||||
|
||||
@register_model
|
||||
def TAG_mobile(pretrained=False, **cfg):
|
||||
model_cfg = dict(inplanes=64, num_chs=(48, 96, 192, 384), patch_sizes=[8, 7, 7, 7], num_heads=[1, 2, 4, 8],
|
||||
num_enc_heads=[1, 2, 4, 8], num_parts=[16, 16, 16, 32], num_layers=[1, 1, 1, 1], ffn_exp=3,
|
||||
has_last_encoder=True, drop_path=0., **cfg)
|
||||
return TAG(pretrained=pretrained, **model_cfg)
|
||||
|
||||
|
||||
@register_model
|
||||
def TAG_tiny(pretrained=False, **cfg):
|
||||
model_cfg = dict(inplanes=64, num_chs=(64, 128, 256, 512), patch_sizes=[8, 7, 7, 7], num_heads=[1, 2, 4, 8],
|
||||
num_enc_heads=[1, 2, 4, 8], num_parts=[32, 32, 32, 32], num_layers=[1, 1, 2, 1], ffn_exp=3,
|
||||
has_last_encoder=True, drop_path=0.1, **cfg)
|
||||
return TAG(pretrained=pretrained, **model_cfg)
|
||||
|
||||
|
||||
@register_model
|
||||
def TAG_small(pretrained=False, **cfg):
|
||||
model_cfg = dict(inplanes=64, num_chs=(96, 192, 384, 768), patch_sizes=[8, 7, 7, 7], num_heads=[3, 6, 12, 24],
|
||||
num_enc_heads=[1, 3, 6, 12], num_parts=[64, 64, 64, 64], num_layers=[1, 1, 3, 1], ffn_exp=3,
|
||||
has_last_encoder=True, drop_path=0.1, **cfg)
|
||||
return TAG(pretrained=pretrained, **model_cfg)
|
||||
|
||||
|
||||
@register_model
|
||||
def TAG_medium(pretrained=False, **cfg):
|
||||
model_cfg = dict(inplanes=64, num_chs=(96, 192, 384, 768), patch_sizes=[8, 7, 7, 7], num_heads=[3, 6, 12, 24],
|
||||
num_enc_heads=[1, 3, 6, 12], num_parts=[64, 64, 64, 128], num_layers=[1, 1, 8, 1], ffn_exp=3,
|
||||
has_last_encoder=False, drop_path=0.2, **cfg)
|
||||
return TAG(pretrained=pretrained, **model_cfg)
|
||||
|
||||
|
||||
@register_model
|
||||
def TAG_base(pretrained=False, **cfg):
|
||||
model_cfg = dict(inplanes=64, num_chs=(128, 256, 512, 1024), patch_sizes=[8, 7, 7, 7], num_heads=[4, 8, 16, 32],
|
||||
num_enc_heads=[1, 4, 8, 16], num_parts=[64, 64, 128, 128], num_layers=[1, 1, 8, 1], ffn_exp=3,
|
||||
has_last_encoder=False, drop_path=0.3, **cfg)
|
||||
return TAG(pretrained=pretrained, **model_cfg)
|
||||
@ -0,0 +1,138 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from timm.models.layers import trunc_normal_
|
||||
|
||||
|
||||
Norm = nn.LayerNorm
|
||||
|
||||
|
||||
def apply_pos(tensor, pos, num_heads):
|
||||
if pos is None:
|
||||
return tensor
|
||||
elif len(tensor.shape) != len(pos.shape):
|
||||
tensor = rearrange(tensor, "b n (g c) -> b n g c", g=num_heads)
|
||||
tensor = tensor + pos
|
||||
tensor = rearrange(tensor, "b n g c -> b n (g c)")
|
||||
else:
|
||||
tensor = tensor + pos
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
class FullRelPos(nn.Module):
|
||||
def __init__(self, h, w, dim, drop_ratio=0.):
|
||||
super(FullRelPos, self).__init__()
|
||||
self.h, self.w = h, w
|
||||
self.rel_emb_h = nn.Parameter(torch.Tensor(2 * h - 1, dim // 2)) # [-(q-1), q-1]
|
||||
self.rel_emb_w = nn.Parameter(torch.Tensor(2 * w - 1, dim // 2)) # [-(q-1), q-1]
|
||||
|
||||
# get relative coordinates of the q-k index table
|
||||
coords_h = torch.arange(h)
|
||||
coords_w = torch.arange(w)
|
||||
self.rel_idx_h = coords_h[None, :] - coords_h[:, None]
|
||||
self.rel_idx_w = coords_w[None, :] - coords_w[:, None]
|
||||
self.rel_idx_h += h - 1
|
||||
self.rel_idx_w += w - 1
|
||||
|
||||
nn.init.normal_(self.rel_emb_w, std=dim ** -0.5)
|
||||
nn.init.normal_(self.rel_emb_h, std=dim ** -0.5)
|
||||
trunc_normal_(self.rel_emb_w, std=.02)
|
||||
trunc_normal_(self.rel_emb_h, std=.02)
|
||||
self.drop_ratio = drop_ratio
|
||||
|
||||
def forward(self, q, attn):
|
||||
abs_pos_h = self.rel_emb_h[self.rel_idx_h.view(-1)]
|
||||
abs_pos_w = self.rel_emb_w[self.rel_idx_w.view(-1)]
|
||||
abs_pos_h = rearrange(abs_pos_h, "(q k) c -> q k c", q=self.h) # [qh, kh, c]
|
||||
abs_pos_w = rearrange(abs_pos_w, "(q k) c -> q k c", q=self.w) # [qw, kw, c]
|
||||
|
||||
q = rearrange(q, "b (qh qw) g (n c) -> b qh qw g n c", qh=self.h, qw=self.w, n=2)
|
||||
logits_h = torch.einsum("b h w g c, h k c -> b h w g k", q[..., 0, :], abs_pos_h)
|
||||
logits_w = torch.einsum("b h w g c, w k c -> b h w g k", q[..., 1, :], abs_pos_w)
|
||||
logits_h = rearrange(logits_h, "b h w g k -> b (h w) g k 1")
|
||||
logits_w = rearrange(logits_w, "b h w g k -> b (h w) g 1 k")
|
||||
|
||||
attn = rearrange(attn, "b q g (kh kw) -> b q g kh kw", kh=self.h, kw=self.w)
|
||||
attn += logits_h
|
||||
attn += logits_w
|
||||
return rearrange(attn, "b q g h w -> b q g (h w)")
|
||||
|
||||
|
||||
class SimpleReasoning(nn.Module):
|
||||
def __init__(self, np, dim):
|
||||
super(SimpleReasoning, self).__init__()
|
||||
self.norm = Norm(dim)
|
||||
self.linear = nn.Conv1d(np, np, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
tokens = self.norm(x)
|
||||
tokens = self.linear(tokens)
|
||||
return x + tokens
|
||||
|
||||
|
||||
class AnyAttention(nn.Module):
|
||||
def __init__(self, dim, num_heads, qkv_bias=False):
|
||||
super(AnyAttention, self).__init__()
|
||||
self.norm_q, self.norm_k, self.norm_v = Norm(dim), Norm(dim), Norm(dim)
|
||||
self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.to_k = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.to_v = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
|
||||
self.scale = (dim / num_heads) ** (-0.5)
|
||||
self.num_heads = num_heads
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def get_qkv(self, q, k, v, qpos, kpos):
|
||||
q = apply_pos(q, qpos, self.num_heads)
|
||||
k = apply_pos(k, kpos, self.num_heads)
|
||||
v = apply_pos(v, None, 0)
|
||||
q, k, v = self.norm_q(q), self.norm_k(k), self.norm_v(v)
|
||||
q, k, v = self.to_q(q), self.to_k(k), self.to_v(v)
|
||||
return q, k, v
|
||||
|
||||
def forward(self, q=None, k=None, v=None, qpos=None, kpos=None, mask=None, rel_pos=None):
|
||||
q, k, v = self.get_qkv(q, k, v, qpos, kpos)
|
||||
|
||||
# reshape
|
||||
q = rearrange(q, "b n (g c) -> b n g c", g=self.num_heads)
|
||||
k = rearrange(k, "b n (g c) -> b n g c", g=self.num_heads)
|
||||
v = rearrange(v, "b n (g c) -> b n g c", g=self.num_heads)
|
||||
|
||||
# attn matrix calculation
|
||||
attn = torch.einsum("b q g c, b k g c -> b q g k", q, k)
|
||||
if rel_pos is not None:
|
||||
attn = rel_pos(q, attn)
|
||||
attn *= self.scale
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill(mask.bool(), value=float('-inf'))
|
||||
attn = F.softmax(attn, dim=-1)
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill(mask.bool(), value=0)
|
||||
out = torch.einsum("b q g k, b k g c -> b q g c", attn, v.float())
|
||||
out = rearrange(out, "b q g c -> b q (g c)")
|
||||
out = self.proj(out)
|
||||
return out
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = int(hidden_features) or in_features
|
||||
self.norm = norm_layer(in_features)
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
@ -0,0 +1,4 @@
|
||||
from typing import List, Callable, Union, Any, TypeVar, Tuple
|
||||
# from torch import tensor as Tensor
|
||||
|
||||
Tensor = TypeVar('torch.tensor')
|
||||
@ -0,0 +1 @@
|
||||
from .unet_model import TransUNet
|
||||
@ -0,0 +1,214 @@
|
||||
import torch.nn as nn
|
||||
import math
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152']
|
||||
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, inplanes= 3, num_classes=1000):
|
||||
self.inplanes = 64
|
||||
super(ResNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.avgpool = nn.AvgPool2d(7, stride=1)
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def resnet18(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-18 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet34(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-34 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(torch.load('resnet34-333f7ec4.pth'))
|
||||
return model
|
||||
|
||||
|
||||
def resnet50(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet101(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-101 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
|
||||
return model
|
||||
|
||||
|
||||
def resnet152(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-152 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
|
||||
return model
|
||||
@ -0,0 +1,599 @@
|
||||
""" Full assembly of the parts to form the complete network """
|
||||
import os, sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
# import ../db.py
|
||||
from tag.tag import Stage
|
||||
from .unet_parts import *
|
||||
from torch import nn
|
||||
import torch
|
||||
from .res_net import resnet34, resnet18, resnet50, resnet101, resnet152, BasicBlock, Bottleneck, ResNet
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
import math
|
||||
|
||||
|
||||
class SaveFeatures():
|
||||
features = None
|
||||
|
||||
# def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
|
||||
|
||||
# def hook_fn(self, module, input, output): self.features = output
|
||||
|
||||
# def remove(self): self.hook.remove()
|
||||
|
||||
def __init__(self,m):
|
||||
self._outputs_lists = {}
|
||||
self.mymodule = m
|
||||
m.register_forward_hook(hook=self.save_output_hook)
|
||||
|
||||
def save_output_hook(self, _, input, output):
|
||||
self._outputs_lists[input[0].device.index] = output
|
||||
self.features = self._outputs_lists
|
||||
|
||||
def forward(self, x) -> list:
|
||||
self._outputs_lists[x.device.index] = []
|
||||
self.mymodule(x)
|
||||
return self._outputs_lists[x.device.index]
|
||||
|
||||
|
||||
class UnetStageBlock(nn.Module):
|
||||
def __init__(self, stage, up_in, x_in, n_out, ratio):
|
||||
super().__init__()
|
||||
# super(UnetBlock, self).__init__()
|
||||
up_out = x_out = n_out // 2
|
||||
self.x_conv = nn.Conv2d(x_in, x_out, 1)
|
||||
self.g_fc = nn.Linear(7,x_out * 2)
|
||||
self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)
|
||||
self.stage = stage
|
||||
|
||||
self.bn = nn.BatchNorm2d(n_out)
|
||||
|
||||
self.pointwise = nn.Conv2d(14, n_out, kernel_size=1)
|
||||
self.depthwise = nn.Conv2d(n_out, n_out, kernel_size=3, stride=ratio , padding=1, groups=up_out)
|
||||
|
||||
def forward(self, up_p, x_p, give):
|
||||
up_p = self.tr_conv(up_p)
|
||||
x_p = self.x_conv(x_p)
|
||||
cat_p = torch.cat([up_p, x_p], dim=1)
|
||||
res = self.bn(F.relu(cat_p))
|
||||
# g_p = self.g_fc(give).unsqueeze(-1).unsqueeze(-1) * res
|
||||
g_p = self.depthwise(self.pointwise(give))
|
||||
res = self.stage(g_p,res)
|
||||
return res
|
||||
|
||||
class UnetBlock(nn.Module):
|
||||
def __init__(self, up_in, x_in, n_out):
|
||||
super().__init__()
|
||||
# super(UnetBlock, self).__init__()
|
||||
up_out = x_out = n_out // 2
|
||||
self.x_conv = nn.Conv2d(x_in, x_out, 1)
|
||||
self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)
|
||||
|
||||
self.bn = nn.BatchNorm2d(n_out)
|
||||
|
||||
def forward(self, up_p, x_p):
|
||||
up_p = self.tr_conv(up_p)
|
||||
x_p = self.x_conv(x_p)
|
||||
cat_p = torch.cat([up_p, x_p], dim=1)
|
||||
res = self.bn(F.relu(cat_p))
|
||||
return res
|
||||
|
||||
class TransUNet(nn.Module):
|
||||
|
||||
def __init__(self, args, resnet='resnet34', num_classes=2, pretrained=False,
|
||||
in_chans=3,
|
||||
inplanes=64,
|
||||
num_layers=(3, 4, 6, 3),
|
||||
num_chs=(256, 512, 1024, 2048),
|
||||
num_strides=(1, 2, 2, 2),
|
||||
num_heads=(1, 1, 1, 1),
|
||||
num_parts=(1, 1, 1, 1),
|
||||
patch_sizes=(1, 1, 1, 1),
|
||||
drop_path=0.1,
|
||||
num_enc_heads=(1, 1, 1, 1),
|
||||
act=nn.GELU,
|
||||
ffn_exp=3,
|
||||
has_last_encoder=False
|
||||
):
|
||||
super().__init__()
|
||||
# super(ResUnet, self).__init__()
|
||||
|
||||
''' ~~~~~ For the embedding transformer~~~~~'''
|
||||
cut, lr_cut = [8, 6]
|
||||
|
||||
dim = args.dim #dim of transformer sequence, D of E
|
||||
img_size = 8 #emebeding 8*8*512
|
||||
channels = 512
|
||||
patch_size = args.patch_size
|
||||
depth = args.depth
|
||||
heads = args.heads
|
||||
mlp_dim = args.mlp_dim
|
||||
dim_head = 64
|
||||
|
||||
assert img_size % patch_size == 0 , 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (img_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size * patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
'''~~~~~~End of embedding transformer~~~~~'''
|
||||
|
||||
'unet and goinnet parameters'
|
||||
if resnet == 'resnet34':
|
||||
base_model = resnet34
|
||||
elif resnet == 'resnet18':
|
||||
base_model = resnet18
|
||||
elif resnet == 'resnet50':
|
||||
base_model = resnet50
|
||||
elif resnet == 'resnet101':
|
||||
base_model = resnet101
|
||||
elif resnet == 'resnet152':
|
||||
base_model = resnet152
|
||||
else:
|
||||
raise Exception('The Resnet Model only accept resnet18, resnet34, resnet50,'
|
||||
'resnet101 and resnet152')
|
||||
|
||||
'''define the stage for goinnet giving'''
|
||||
last_chs = (256,256,256,256)
|
||||
num_chs = (256, 256, 256, 256)
|
||||
down_samples = (2,4,8,16)
|
||||
n_l = 1
|
||||
stage_list = []
|
||||
for i in range(4):
|
||||
stage_list.append(
|
||||
Stage(last_chs[i],
|
||||
num_chs[i],
|
||||
n_l,
|
||||
num_heads=num_heads[i], #1,2,4,8
|
||||
num_parts = (patch_sizes[i]**2 * (args.image_size // down_samples[i] // patch_sizes[i])**2),
|
||||
patch_size=patch_sizes[i], #8,8,8,8
|
||||
drop_path=drop_path, #0.05
|
||||
ffn_exp=ffn_exp, #mlp hidden fea
|
||||
last_enc=has_last_encoder and i == len(num_layers) - 1)
|
||||
)
|
||||
self.stages = nn.ModuleList(stage_list)
|
||||
|
||||
patch_s = 8
|
||||
separator_list = []
|
||||
for i in range(7):
|
||||
separator_list.append(
|
||||
Stage(256,
|
||||
2,
|
||||
n_l,
|
||||
num_heads=1, #1,2,4,8
|
||||
num_parts = (patch_s**2 * (args.image_size // 2 // patch_s)**2),
|
||||
patch_size=patch_s, #8,8,8,8
|
||||
drop_path=drop_path, #0.05
|
||||
ffn_exp=ffn_exp #mlp hidden fea
|
||||
))
|
||||
self.s_list = nn.ModuleList(separator_list)
|
||||
|
||||
# self.stage_encoding = Stage(512,
|
||||
# 512,
|
||||
# 1,
|
||||
# num_heads=16, #1,2,4,8
|
||||
# num_parts = (8**2),
|
||||
# patch_size=8, #8,8,8,8
|
||||
# drop_path=drop_path, #0.05
|
||||
# ffn_exp=ffn_exp, #mlp hidden fea
|
||||
# last_enc=has_last_encoder and i == len(num_layers) - 1)
|
||||
'''end'''
|
||||
|
||||
layers = list(base_model(pretrained=pretrained).children())[:cut]
|
||||
self.check_layer = layers
|
||||
base_layers = nn.Sequential(*layers)
|
||||
self.rn = base_layers
|
||||
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.sfs = [SaveFeatures(base_layers[i]) for i in [2, 4, 5, 6]]
|
||||
self.up1 = UnetStageBlock(self.stages[3], 512, 256, 256,16)
|
||||
self.up2 = UnetStageBlock(self.stages[2], 256, 128, 256,8)
|
||||
self.up3 = UnetStageBlock(self.stages[1], 256, 64, 256,4)
|
||||
self.up4 = UnetStageBlock(self.stages[0], 256, 64, 256,2)
|
||||
|
||||
self.up5 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
|
||||
self.pred1 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred2 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred3 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred4 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred5 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred6 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred7 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
|
||||
self.munet = MUNet(args)
|
||||
|
||||
'''~~~ self definition ~~~'''
|
||||
self.fc = nn.Linear(7,dim)
|
||||
self.tr_conv = nn.ConvTranspose2d(512, 512, 2, stride=2)
|
||||
|
||||
def forward(self, x, cond, mod = 'train'):
|
||||
aux = {}
|
||||
mergf = []
|
||||
img = x
|
||||
# x = torch.cat((x,heatmap),1)
|
||||
x = F.relu(self.rn(x)) # x = [b_size, 2048, 8, 8]
|
||||
emb = x
|
||||
|
||||
if mod == 'shuffle':
|
||||
self.up1.eval()
|
||||
self.up2.eval()
|
||||
self.up3.eval()
|
||||
self.up4.eval()
|
||||
self.up5.eval()
|
||||
|
||||
'''~~~ 0: agg ~~~'''
|
||||
x = self.up1(x, self.sfs[3].features[x.device.index], cond)
|
||||
mergf.append(x)
|
||||
x = self.up2(x, self.sfs[2].features[x.device.index], cond)
|
||||
mergf.append(x)
|
||||
x = self.up3(x, self.sfs[1].features[x.device.index], cond)
|
||||
mergf.append(x)
|
||||
x = self.up4(x, self.sfs[0].features[x.device.index], cond)
|
||||
fea = x
|
||||
output = self.up5(x)
|
||||
'''end'''
|
||||
|
||||
if mod == 'shuffle':
|
||||
aux['mergfs'] = mergf
|
||||
return output, aux
|
||||
|
||||
|
||||
ave, maps = self.munet(img, output.detach())
|
||||
|
||||
'''~~~ 0: ENDs ~~~'''
|
||||
|
||||
pred_stack = torch.stack(maps) #7,b,c,w,w
|
||||
pred_stack_t = F.sigmoid(pred_stack)
|
||||
|
||||
self_pred = pred_stack_t * torch.div(pred_stack_t, torch.sum(pred_stack_t, dim = 0, keepdim=True)) #7,b,c,w,w
|
||||
self_pred = rearrange(self_pred, "a b c h w -> b (a c) h w").contiguous() #b,7c,w,w
|
||||
cond = self_pred
|
||||
# maps = [nn.Upsample(scale_factor=2, mode='bilinear')(a) for a in maps]
|
||||
aux['maps'] = maps
|
||||
aux['cond'] = cond
|
||||
aux['mergfs'] = mergf
|
||||
aux['emb'] = emb
|
||||
return output, aux
|
||||
|
||||
|
||||
def close(self):
|
||||
for sf in self.sfs: sf.remove()
|
||||
|
||||
class MUNet(nn.Module):
|
||||
|
||||
def __init__(self, args, resnet='resnet34', num_classes=2, pretrained=False):
|
||||
super().__init__()
|
||||
# super(ResUnet, self).__init__()
|
||||
|
||||
''' ~~~~~ For the embedding transformer~~~~~'''
|
||||
cut, lr_cut = [8, 6]
|
||||
|
||||
'unet and goinnet parameters'
|
||||
if resnet == 'resnet34':
|
||||
base_model = resnet34
|
||||
elif resnet == 'resnet18':
|
||||
base_model = resnet18
|
||||
elif resnet == 'resnet50':
|
||||
base_model = resnet50
|
||||
elif resnet == 'resnet101':
|
||||
base_model = resnet101
|
||||
elif resnet == 'resnet152':
|
||||
base_model = resnet152
|
||||
else:
|
||||
raise Exception('The Resnet Model only accept resnet18, resnet34, resnet50,'
|
||||
'resnet101 and resnet152')
|
||||
|
||||
layers = list(base_model(pretrained=pretrained,inplanes = 5).children())[:cut]
|
||||
self.check_layer = layers
|
||||
base_layers = nn.Sequential(*layers)
|
||||
self.rn = base_layers
|
||||
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.sfs = [SaveFeatures(base_layers[i]) for i in [2, 4, 5, 6]]
|
||||
self.up1 = UnetBlock(512, 256, 256)
|
||||
self.up2 = UnetBlock(256, 128, 256)
|
||||
self.up3 = UnetBlock(256, 64, 256)
|
||||
self.up4 = UnetBlock(256, 64, 256)
|
||||
|
||||
self.up5 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
|
||||
self.pred1 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred2 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred3 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred4 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred5 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred6 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred7 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
|
||||
def forward(self, x, heatmap):
|
||||
x = torch.cat((x,heatmap),1)
|
||||
x = F.relu(self.rn(x)) # x = [b_size, 2048, 8, 8]
|
||||
|
||||
|
||||
'''~~~ 0: Decoder ~~~'''
|
||||
x = self.up1(x, self.sfs[3].features[x.device.index])
|
||||
cond3 = x
|
||||
x = self.up2(x, self.sfs[2].features[x.device.index])
|
||||
cond2 = x
|
||||
x = self.up3(x, self.sfs[1].features[x.device.index])
|
||||
cond1 = x
|
||||
x = self.up4(x, self.sfs[0].features[x.device.index])
|
||||
cond0 = x
|
||||
fea = x
|
||||
output = self.up5(x)
|
||||
'''~~~ 0: ENDs ~~~'''
|
||||
out1 = self.pred1(fea)
|
||||
out2 = self.pred2(fea)
|
||||
out3 = self.pred3(fea)
|
||||
out4 = self.pred4(fea)
|
||||
out5 = self.pred5(fea)
|
||||
out6 = self.pred6(fea)
|
||||
out7 = self.pred7(fea)
|
||||
'''
|
||||
if self.num_classes==1:
|
||||
output = x_out[:, 0]
|
||||
else:
|
||||
output = x_out[:, :self.num_classes]
|
||||
'''
|
||||
return (out1+out2+out3+out4+out5+out6+out7) / 7, [out1,out2,out3,out4,out5,out6,out7]
|
||||
|
||||
def close(self):
|
||||
for sf in self.sfs: sf.remove()
|
||||
|
||||
class UNet(nn.Module):
|
||||
|
||||
def __init__(self, args, resnet='resnet34', num_classes=2, pretrained=False):
|
||||
super().__init__()
|
||||
# super(ResUnet, self).__init__()
|
||||
|
||||
''' ~~~~~ For the embedding transformer~~~~~'''
|
||||
cut, lr_cut = [8, 6]
|
||||
|
||||
'unet and goinnet parameters'
|
||||
if resnet == 'resnet34':
|
||||
base_model = resnet34
|
||||
elif resnet == 'resnet18':
|
||||
base_model = resnet18
|
||||
elif resnet == 'resnet50':
|
||||
base_model = resnet50
|
||||
elif resnet == 'resnet101':
|
||||
base_model = resnet101
|
||||
elif resnet == 'resnet152':
|
||||
base_model = resnet152
|
||||
else:
|
||||
raise Exception('The Resnet Model only accept resnet18, resnet34, resnet50,'
|
||||
'resnet101 and resnet152')
|
||||
|
||||
layers = list(base_model(pretrained=pretrained,inplanes = 3).children())[:cut]
|
||||
self.check_layer = layers
|
||||
base_layers = nn.Sequential(*layers)
|
||||
self.rn = base_layers
|
||||
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.sfs = [SaveFeatures(base_layers[i]) for i in [2, 4, 5, 6]]
|
||||
self.up1 = UnetBlock(512, 256, 256)
|
||||
self.up2 = UnetBlock(256, 128, 256)
|
||||
self.up3 = UnetBlock(256, 64, 256)
|
||||
self.up4 = UnetBlock(256, 64, 256)
|
||||
|
||||
self.up5 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
|
||||
self.pred1 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred2 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred3 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred4 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred5 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred6 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
self.pred7 = nn.ConvTranspose2d(256, self.num_classes, 2, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.rn(x)) # x = [b_size, 2048, 8, 8]
|
||||
|
||||
|
||||
'''~~~ 0: Decoder ~~~'''
|
||||
x = self.up1(x, self.sfs[3].features[x.device.index])
|
||||
x = self.up2(x, self.sfs[2].features[x.device.index])
|
||||
x = self.up3(x, self.sfs[1].features[x.device.index])
|
||||
x = self.up4(x, self.sfs[0].features[x.device.index])
|
||||
fea = x
|
||||
output = self.up5(x)
|
||||
'''~~~ 0: ENDs ~~~'''
|
||||
'''
|
||||
if self.num_classes==1:
|
||||
output = x_out[:, 0]
|
||||
else:
|
||||
output = x_out[:, :self.num_classes]
|
||||
'''
|
||||
return output
|
||||
|
||||
def close(self):
|
||||
for sf in self.sfs: sf.remove()
|
||||
|
||||
|
||||
class GoinNet(nn.Module):
|
||||
def __init__(self,
|
||||
args,
|
||||
resnet,
|
||||
in_chans=3,
|
||||
inplanes=64,
|
||||
num_layers=(3, 4, 6, 3),
|
||||
num_chs=(256, 512, 1024, 2048),
|
||||
num_strides=(1, 2, 2, 2),
|
||||
num_heads=(1, 1, 1, 1),
|
||||
num_parts=(1, 1, 1, 1),
|
||||
patch_sizes=(1, 1, 1, 1),
|
||||
drop_path=0.1,
|
||||
num_enc_heads=(1, 1, 1, 1),
|
||||
act=nn.GELU,
|
||||
ffn_exp=3,
|
||||
has_last_encoder=False,
|
||||
pretrained=False
|
||||
):
|
||||
self.inplanes = 64
|
||||
super(GoinNet, self).__init__()
|
||||
|
||||
cut, lr_cut = [8, 6]
|
||||
self.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
last_chs = (64,64,128,256)
|
||||
num_chs = (64, 64, 128, 256)
|
||||
down_samples = (2,4,8,16)
|
||||
n_l = 1
|
||||
|
||||
self.stage = Stage(last_chs[i],
|
||||
num_chs[i],
|
||||
n_l,
|
||||
num_heads=num_heads[i], #1,2,4,8
|
||||
num_parts = (patch_sizes[i]**2 * (args.image_size // down_samples[i] // patch_sizes[i])**2) ,
|
||||
patch_size=patch_sizes[i], #8,8,8,8
|
||||
drop_path=drop_path, #0.05
|
||||
ffn_exp=ffn_exp, #mlp hidden fea
|
||||
last_enc=has_last_encoder and i == len(num_layers) - 1)
|
||||
|
||||
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, img, x, p):
|
||||
|
||||
x = torch.cat((img,x),1)
|
||||
x = F.relu(self.rn(x))
|
||||
|
||||
x = self.stages[0](p[0].features[x.device.index], self.sfs[0].features[x.device.index])
|
||||
turn0 = x
|
||||
|
||||
x = self.stages[1](p[1].features[x.device.index], self.sfs[1].features[x.device.index])
|
||||
turn1 = x
|
||||
|
||||
x = self.stages[2](p[2].features[x.device.index], self.sfs[2].features[x.device.index])
|
||||
turn2 = x
|
||||
|
||||
x = self.stages[3](p[3].features[x.device.index], self.sfs[3].features[x.device.index])
|
||||
turn3 = x
|
||||
|
||||
return x, [turn0, turn1, turn2, turn3]
|
||||
|
||||
class PromptUNet(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
embedding_dim: int,
|
||||
token_num: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int,
|
||||
resnet='resnet34',
|
||||
pretrained=False
|
||||
):
|
||||
super().__init__()
|
||||
# super(ResUnet, self).__init__()
|
||||
|
||||
''' ~~~~~ For the embedding transformer~~~~~'''
|
||||
cut, lr_cut = [8, 6]
|
||||
|
||||
'unet and goinnet parameters'
|
||||
if resnet == 'resnet34':
|
||||
base_model = resnet34
|
||||
elif resnet == 'resnet18':
|
||||
base_model = resnet18
|
||||
elif resnet == 'resnet50':
|
||||
base_model = resnet50
|
||||
elif resnet == 'resnet101':
|
||||
base_model = resnet101
|
||||
elif resnet == 'resnet152':
|
||||
base_model = resnet152
|
||||
else:
|
||||
raise Exception('The Resnet Model only accept resnet18, resnet34, resnet50,'
|
||||
'resnet101 and resnet152')
|
||||
|
||||
layers = list(base_model(pretrained=pretrained,inplanes = 3).children())[:cut]
|
||||
self.check_layer = layers
|
||||
base_layers = nn.Sequential(*layers)
|
||||
self.rn = base_layers
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.token_num = token_num
|
||||
self.num_heads = num_heads
|
||||
self.mlp_dim = mlp_dim
|
||||
|
||||
self.sfs = [SaveFeatures(base_layers[i]) for i in [2, 4, 5, 6]]
|
||||
|
||||
self.decoder = nn.ModuleList()
|
||||
|
||||
for i in range(4):
|
||||
self.decoder.append(
|
||||
OnePromptFormer(
|
||||
embedding_dim = self.embedding_dim,
|
||||
token_num = self.token_num,
|
||||
num_heads = self.num_heads,
|
||||
mlp_dim = self.mlp_dim
|
||||
)
|
||||
)
|
||||
|
||||
self.decoder.append(MaskDecoder())
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.rn(x)) # x = [b_size, 2048, 8, 8]
|
||||
|
||||
|
||||
'''~~~ 0: Decoder ~~~'''
|
||||
x = self.up1(x, self.sfs[3].features[x.device.index])
|
||||
x = self.up2(x, self.sfs[2].features[x.device.index])
|
||||
x = self.up3(x, self.sfs[1].features[x.device.index])
|
||||
x = self.up4(x, self.sfs[0].features[x.device.index])
|
||||
fea = x
|
||||
output = self.up5(x)
|
||||
'''~~~ 0: ENDs ~~~'''
|
||||
'''
|
||||
if self.num_classes==1:
|
||||
output = x_out[:, 0]
|
||||
else:
|
||||
output = x_out[:, :self.num_classes]
|
||||
'''
|
||||
return output
|
||||
|
||||
def close(self):
|
||||
for sf in self.sfs: sf.remove()
|
||||
|
||||
@ -0,0 +1,75 @@
|
||||
""" Parts of the U-Net model """
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class DoubleConv(nn.Module):
|
||||
"""(convolution => [BN] => ReLU) * 2"""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.double_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.double_conv(x)
|
||||
|
||||
|
||||
class Down(nn.Module):
|
||||
"""Downscaling with maxpool then double conv"""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.maxpool_conv = nn.Sequential(
|
||||
nn.MaxPool2d(2),
|
||||
DoubleConv(in_channels, out_channels)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.maxpool_conv(x)
|
||||
|
||||
|
||||
class Up(nn.Module):
|
||||
"""Upscaling then double conv"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, bilinear=True):
|
||||
super().__init__()
|
||||
|
||||
# if bilinear, use the normal convolutions to reduce the number of channels
|
||||
if bilinear:
|
||||
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
||||
else:
|
||||
self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
|
||||
|
||||
self.conv = DoubleConv(in_channels, out_channels)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
x1 = self.up(x1)
|
||||
# input is CHW
|
||||
diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
|
||||
diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
|
||||
|
||||
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
|
||||
diffY // 2, diffY - diffY // 2])
|
||||
# if you have padding issues, see
|
||||
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
||||
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
||||
x = torch.cat([x2, x1], dim=1)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class OutConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(OutConv, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
@ -0,0 +1,398 @@
|
||||
"""
|
||||
This file contains helper functions for building the model and for loading model parameters.
|
||||
These helper functions are built to mirror those in the official TensorFlow implementation.
|
||||
"""
|
||||
|
||||
import re
|
||||
import math
|
||||
import collections
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.utils import model_zoo
|
||||
|
||||
########################################################################
|
||||
############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
|
||||
########################################################################
|
||||
|
||||
|
||||
# Parameters for the entire model (stem, all blocks, and head)
|
||||
GlobalParams = collections.namedtuple('GlobalParams', [
|
||||
'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate',
|
||||
'num_classes', 'width_coefficient', 'depth_coefficient',
|
||||
'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])
|
||||
|
||||
# Parameters for an individual model block
|
||||
BlockArgs = collections.namedtuple('BlockArgs', [
|
||||
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
|
||||
'expand_ratio', 'id_skip', 'stride', 'se_ratio'])
|
||||
|
||||
# Change namedtuple defaults
|
||||
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
|
||||
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
|
||||
|
||||
#nnunet pkg
|
||||
softmax_helper = lambda x: F.softmax(x, 1)
|
||||
sigmoid_helper = lambda x: F.sigmoid(x)
|
||||
|
||||
class InitWeights_He(object):
|
||||
def __init__(self, neg_slope=1e-2):
|
||||
self.neg_slope = neg_slope
|
||||
|
||||
def __call__(self, module):
|
||||
if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d):
|
||||
module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope)
|
||||
if module.bias is not None:
|
||||
module.bias = nn.init.constant_(module.bias, 0)
|
||||
|
||||
def maybe_to_torch(d):
|
||||
if isinstance(d, list):
|
||||
d = [maybe_to_torch(i) if not isinstance(i, torch.Tensor) else i for i in d]
|
||||
elif not isinstance(d, torch.Tensor):
|
||||
d = torch.from_numpy(d).float()
|
||||
return d
|
||||
|
||||
|
||||
def to_cuda(data, non_blocking=True, gpu_id=0):
|
||||
if isinstance(data, list):
|
||||
data = [i.cuda(gpu_id, non_blocking=non_blocking) for i in data]
|
||||
else:
|
||||
data = data.cuda(gpu_id, non_blocking=non_blocking)
|
||||
return data
|
||||
|
||||
|
||||
class no_op(object):
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
class SwishImplementation(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, i):
|
||||
result = i * torch.sigmoid(i)
|
||||
ctx.save_for_backward(i)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
i = ctx.saved_variables[0]
|
||||
sigmoid_i = torch.sigmoid(i)
|
||||
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
||||
|
||||
|
||||
class MemoryEfficientSwish(nn.Module):
|
||||
def forward(self, x):
|
||||
return SwishImplementation.apply(x)
|
||||
|
||||
class Swish(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def round_filters(filters, global_params):
|
||||
""" Calculate and round number of filters based on depth multiplier. """
|
||||
multiplier = global_params.width_coefficient
|
||||
if not multiplier:
|
||||
return filters
|
||||
divisor = global_params.depth_divisor
|
||||
min_depth = global_params.min_depth
|
||||
filters *= multiplier
|
||||
min_depth = min_depth or divisor
|
||||
new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
|
||||
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
|
||||
new_filters += divisor
|
||||
return int(new_filters)
|
||||
|
||||
|
||||
def round_repeats(repeats, global_params):
|
||||
""" Round number of filters based on depth multiplier. """
|
||||
multiplier = global_params.depth_coefficient
|
||||
if not multiplier:
|
||||
return repeats
|
||||
return int(math.ceil(multiplier * repeats))
|
||||
|
||||
|
||||
def drop_connect(inputs, p, training):
|
||||
""" Drop connect. """
|
||||
if not training: return inputs
|
||||
batch_size = inputs.shape[0]
|
||||
keep_prob = 1 - p
|
||||
random_tensor = keep_prob
|
||||
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
|
||||
binary_tensor = torch.floor(random_tensor)
|
||||
output = inputs / keep_prob * binary_tensor
|
||||
return output
|
||||
|
||||
|
||||
def get_same_padding_conv2d(image_size=None):
|
||||
""" Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
||||
Static padding is necessary for ONNX exporting of models. """
|
||||
if image_size is None:
|
||||
return Conv2dDynamicSamePadding
|
||||
else:
|
||||
return partial(Conv2dStaticSamePadding, image_size=image_size)
|
||||
|
||||
def get_same_padding_conv2d_freeze(image_size=None):
|
||||
""" Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
||||
Static padding is necessary for ONNX exporting of models. """
|
||||
if image_size is None:
|
||||
return Conv2dStaticSamePadding_freeze
|
||||
else:
|
||||
return partial(Conv2dStaticSamePadding_freeze, image_size=image_size)
|
||||
|
||||
class Conv2dDynamicSamePadding(nn.Conv2d):
|
||||
""" 2D Convolutions like TensorFlow, for a dynamic image size """
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
|
||||
super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
||||
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
||||
|
||||
def forward(self, x):
|
||||
ih, iw = x.size()[-2:]
|
||||
kh, kw = self.weight.size()[-2:]
|
||||
sh, sw = self.stride
|
||||
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
||||
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
||||
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
||||
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class Conv2dStaticSamePadding(nn.Conv2d):
|
||||
""" 2D Convolutions like TensorFlow, for a fixed image size"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
|
||||
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
|
||||
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
||||
|
||||
# Calculate padding based on image size and save it
|
||||
assert image_size is not None
|
||||
ih, iw = image_size if type(image_size) == list else [image_size, image_size]
|
||||
kh, kw = self.weight.size()[-2:]
|
||||
sh, sw = self.stride
|
||||
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
||||
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
||||
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
|
||||
else:
|
||||
self.static_padding = Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.static_padding(x)
|
||||
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return x
|
||||
|
||||
def Conv2dStaticSamePadding_freeze(inputs, weight, bias=None, image_size=None, stride=(1,1), padding=0, dilation=(1,1), groups=1):
|
||||
""" 2D Convolutions like TensorFlow, for a fixed image size"""
|
||||
|
||||
if type(stride) == int:
|
||||
stride = [stride] * 2
|
||||
else:
|
||||
stride = stride if len(stride) == 2 else [stride[0]] * 2
|
||||
# Calculate padding based on image size and save it
|
||||
assert image_size is not None
|
||||
ih, iw = image_size if type(image_size) == list else [image_size, image_size]
|
||||
kh, kw = weight.size()[-2:]
|
||||
sh, sw = stride
|
||||
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
||||
pad_h = max((oh - 1) * stride[0] + (kh - 1) * dilation[0] + 1 - ih, 0)
|
||||
pad_w = max((ow - 1) * stride[1] + (kw - 1) * dilation[1] + 1 - iw, 0)
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
|
||||
else:
|
||||
static_padding = Identity()
|
||||
|
||||
x = static_padding(inputs)
|
||||
x = F.conv2d(x, weight, bias, stride, padding, dilation, groups)
|
||||
return x
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self, ):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
########################################################################
|
||||
############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
|
||||
########################################################################
|
||||
|
||||
|
||||
def efficientnet_params(model_name):
|
||||
""" Map EfficientNet model name to parameter coefficients. """
|
||||
params_dict = {
|
||||
# Coefficients: width,depth,res,dropout
|
||||
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
|
||||
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
|
||||
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
|
||||
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
|
||||
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
|
||||
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
|
||||
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
|
||||
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
|
||||
}
|
||||
return params_dict[model_name]
|
||||
|
||||
|
||||
class BlockDecoder(object):
|
||||
""" Block Decoder for readability, straight from the official TensorFlow repository """
|
||||
|
||||
@staticmethod
|
||||
def _decode_block_string(block_string):
|
||||
""" Gets a block through a string notation of arguments. """
|
||||
assert isinstance(block_string, str)
|
||||
|
||||
ops = block_string.split('_')
|
||||
options = {}
|
||||
for op in ops:
|
||||
splits = re.split(r'(\d.*)', op)
|
||||
if len(splits) >= 2:
|
||||
key, value = splits[:2]
|
||||
options[key] = value
|
||||
|
||||
# Check stride
|
||||
assert (('s' in options and len(options['s']) == 1) or
|
||||
(len(options['s']) == 2 and options['s'][0] == options['s'][1]))
|
||||
|
||||
return BlockArgs(
|
||||
kernel_size=int(options['k']),
|
||||
num_repeat=int(options['r']),
|
||||
input_filters=int(options['i']),
|
||||
output_filters=int(options['o']),
|
||||
expand_ratio=int(options['e']),
|
||||
id_skip=('noskip' not in block_string),
|
||||
se_ratio=float(options['se']) if 'se' in options else None,
|
||||
stride=[int(options['s'][0])])
|
||||
|
||||
@staticmethod
|
||||
def _encode_block_string(block):
|
||||
"""Encodes a block to a string."""
|
||||
args = [
|
||||
'r%d' % block.num_repeat,
|
||||
'k%d' % block.kernel_size,
|
||||
's%d%d' % (block.strides[0], block.strides[1]),
|
||||
'e%s' % block.expand_ratio,
|
||||
'i%d' % block.input_filters,
|
||||
'o%d' % block.output_filters
|
||||
]
|
||||
if 0 < block.se_ratio <= 1:
|
||||
args.append('se%s' % block.se_ratio)
|
||||
if block.id_skip is False:
|
||||
args.append('noskip')
|
||||
return '_'.join(args)
|
||||
|
||||
@staticmethod
|
||||
def decode(string_list):
|
||||
"""
|
||||
Decodes a list of string notations to specify blocks inside the network.
|
||||
|
||||
:param string_list: a list of strings, each string is a notation of block
|
||||
:return: a list of BlockArgs namedtuples of block args
|
||||
"""
|
||||
assert isinstance(string_list, list)
|
||||
blocks_args = []
|
||||
for block_string in string_list:
|
||||
blocks_args.append(BlockDecoder._decode_block_string(block_string))
|
||||
return blocks_args
|
||||
|
||||
@staticmethod
|
||||
def encode(blocks_args):
|
||||
"""
|
||||
Encodes a list of BlockArgs to a list of strings.
|
||||
|
||||
:param blocks_args: a list of BlockArgs namedtuples of block args
|
||||
:return: a list of strings, each string is a notation of block
|
||||
"""
|
||||
block_strings = []
|
||||
for block in blocks_args:
|
||||
block_strings.append(BlockDecoder._encode_block_string(block))
|
||||
return block_strings
|
||||
|
||||
|
||||
def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2,
|
||||
drop_connect_rate=0.2, image_size=None, num_classes=1000):
|
||||
""" Creates a efficientnet model. """
|
||||
|
||||
blocks_args = [
|
||||
'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
|
||||
'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
|
||||
'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
|
||||
'r1_k3_s11_e6_i192_o320_se0.25',
|
||||
]
|
||||
blocks_args = BlockDecoder.decode(blocks_args)
|
||||
|
||||
global_params = GlobalParams(
|
||||
batch_norm_momentum=0.99,
|
||||
batch_norm_epsilon=1e-3,
|
||||
dropout_rate=dropout_rate,
|
||||
drop_connect_rate=drop_connect_rate,
|
||||
# data_format='channels_last', # removed, this is always true in PyTorch
|
||||
num_classes=num_classes,
|
||||
width_coefficient=width_coefficient,
|
||||
depth_coefficient=depth_coefficient,
|
||||
depth_divisor=8,
|
||||
min_depth=None,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
return blocks_args, global_params
|
||||
|
||||
|
||||
def get_model_params(model_name, override_params):
|
||||
""" Get the block args and global params for a given model """
|
||||
if model_name.startswith('efficientnet'):
|
||||
w, d, s, p = efficientnet_params(model_name)
|
||||
# note: all models have drop connect rate = 0.2
|
||||
blocks_args, global_params = efficientnet(
|
||||
width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
|
||||
else:
|
||||
raise NotImplementedError('model name is not pre-defined: %s' % model_name)
|
||||
if override_params:
|
||||
# ValueError will be raised here if override_params has fields not included in global_params.
|
||||
global_params = global_params._replace(**override_params)
|
||||
return blocks_args, global_params
|
||||
|
||||
|
||||
url_map = {
|
||||
'efficientnet-b0': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b0-355c32eb.pth',
|
||||
'efficientnet-b1': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b1-f1951068.pth',
|
||||
'efficientnet-b2': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b2-8bb594d6.pth',
|
||||
'efficientnet-b3': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b3-5fb5a3c3.pth',
|
||||
'efficientnet-b4': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b4-6ed6700e.pth',
|
||||
'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b5-b6417697.pth',
|
||||
'efficientnet-b6': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b6-c76e70fd.pth',
|
||||
'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth',
|
||||
}
|
||||
|
||||
|
||||
def load_pretrained_weights(model, model_name, load_fc=True):
|
||||
""" Loads pretrained weights, and downloads if loading for the first time. """
|
||||
state_dict = model_zoo.load_url(url_map[model_name])
|
||||
if load_fc:
|
||||
model.load_state_dict(state_dict)
|
||||
else:
|
||||
state_dict.pop('_fc.weight')
|
||||
state_dict.pop('_fc.bias')
|
||||
res = model.load_state_dict(state_dict, strict=False)
|
||||
assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
|
||||
print('Loaded pretrained weights for {}'.format(model_name))
|
||||
|
||||
def gram_matrix(input):
|
||||
a, b, c, d = input.size() # a=batch size(=1)
|
||||
# b=number of feature maps
|
||||
# (c,d)=dimensions of a f. map (N=c*d)
|
||||
|
||||
features = input.view(a * b, c * d) # resise F_XL into \hat F_XL
|
||||
|
||||
G = torch.mm(features, features.t()) # compute the gram product
|
||||
|
||||
# we 'normalize' the values of the gram matrix
|
||||
# by dividing by the number of element in each feature maps.
|
||||
return G.div(a * b * c * d)
|
||||
@ -0,0 +1,159 @@
|
||||
import os, sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
# from .types_ import *
|
||||
|
||||
|
||||
class VanillaVAE(nn.Module):
|
||||
def __init__(self,args,
|
||||
in_channels: int,
|
||||
latent_dim: int,
|
||||
hidden_dims = None,
|
||||
**kwargs) -> None:
|
||||
super(VanillaVAE, self).__init__()
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
modules = []
|
||||
if hidden_dims is None:
|
||||
hidden_dims = [32, 64, 128, 256, 512]
|
||||
|
||||
if latent_dim is None:
|
||||
latent_dim = 512
|
||||
|
||||
# Build Encoder
|
||||
for h_dim in hidden_dims:
|
||||
modules.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels=h_dim,
|
||||
kernel_size= 3, stride= 2, padding = 1),
|
||||
nn.BatchNorm2d(h_dim),
|
||||
nn.LeakyReLU())
|
||||
)
|
||||
in_channels = h_dim
|
||||
|
||||
self.encoder = nn.Sequential(*modules)
|
||||
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
|
||||
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
|
||||
|
||||
|
||||
# Build Decoder
|
||||
modules = []
|
||||
|
||||
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
|
||||
|
||||
hidden_dims.reverse()
|
||||
|
||||
for i in range(len(hidden_dims) - 1):
|
||||
modules.append(
|
||||
nn.Sequential(
|
||||
nn.ConvTranspose2d(hidden_dims[i],
|
||||
hidden_dims[i + 1],
|
||||
kernel_size=3,
|
||||
stride = 2,
|
||||
padding=1,
|
||||
output_padding=1),
|
||||
nn.BatchNorm2d(hidden_dims[i + 1]),
|
||||
nn.LeakyReLU())
|
||||
)
|
||||
|
||||
|
||||
|
||||
self.decoder = nn.Sequential(*modules)
|
||||
|
||||
self.final_layer = nn.Sequential(
|
||||
nn.ConvTranspose2d(hidden_dims[-1],
|
||||
hidden_dims[-1],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1),
|
||||
nn.BatchNorm2d(hidden_dims[-1]),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(hidden_dims[-1], out_channels= 3,
|
||||
kernel_size= 3, padding= 1),
|
||||
nn.Tanh())
|
||||
|
||||
def encode(self, input):
|
||||
"""
|
||||
Encodes the input by passing through the encoder network
|
||||
and returns the latent codes.
|
||||
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
|
||||
:return: (Tensor) List of latent codes
|
||||
"""
|
||||
result = self.encoder(input)
|
||||
result = torch.flatten(result, start_dim=1)
|
||||
|
||||
# Split the result into mu and var components
|
||||
# of the latent Gaussian distribution
|
||||
mu = self.fc_mu(result)
|
||||
# log_var = self.fc_var(result)
|
||||
|
||||
return mu
|
||||
|
||||
def decode(self, z):
|
||||
"""
|
||||
Maps the given latent codes
|
||||
onto the image space.
|
||||
:param z: (Tensor) [B x D]
|
||||
:return: (Tensor) [B x C x H x W]
|
||||
"""
|
||||
result = self.decoder_input(z)
|
||||
result = result.view(-1, 512, 2, 2)
|
||||
result = self.decoder(result)
|
||||
result = self.final_layer(result)
|
||||
return result
|
||||
|
||||
# def reparameterize(self, mu, logvar):
|
||||
# """
|
||||
# Reparameterization trick to sample from N(mu, var) from
|
||||
# N(0,1).
|
||||
# :param mu: (Tensor) Mean of the latent Gaussian [B x D]
|
||||
# :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
|
||||
# :return: (Tensor) [B x D]
|
||||
# """
|
||||
# std = torch.exp(0.5 * logvar)
|
||||
# eps = torch.randn_like(std)
|
||||
# return eps * std + mu
|
||||
|
||||
def forward(self, input, **kwargs):
|
||||
mu = self.encode(input)
|
||||
# z = self.reparameterize(mu, log_var)
|
||||
return self.decode(mu)
|
||||
|
||||
def loss_function(self,
|
||||
*args,
|
||||
**kwargs) -> dict:
|
||||
"""
|
||||
Computes the VAE loss function.
|
||||
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
recons = args[0]
|
||||
input = args[1]
|
||||
# mu = args[2]
|
||||
# log_var = args[3]
|
||||
|
||||
# kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
|
||||
recons_loss =F.mse_loss(recons, input)
|
||||
|
||||
|
||||
# kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
|
||||
|
||||
loss = recons_loss
|
||||
return loss
|
||||
# {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':recons_loss.detach()}
|
||||
|
||||
|
||||
def generate(self, x, **kwargs):
|
||||
"""
|
||||
Given an input image x, returns the reconstructed image
|
||||
:param x: (Tensor) [B x C x H x W]
|
||||
:return: (Tensor) [B x C x H x W]
|
||||
"""
|
||||
|
||||
return self.forward(x)[0]
|
||||
@ -0,0 +1,75 @@
|
||||
"""vgg in pytorch
|
||||
|
||||
|
||||
[1] Karen Simonyan, Andrew Zisserman
|
||||
|
||||
Very Deep Convolutional Networks for Large-Scale Image Recognition.
|
||||
https://arxiv.org/abs/1409.1556v6
|
||||
"""
|
||||
'''VGG11/13/16/19 in Pytorch.'''
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
cfg = {
|
||||
'A' : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'B' : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'D' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
||||
'E' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
|
||||
}
|
||||
|
||||
class VGG(nn.Module):
|
||||
|
||||
def __init__(self, features, num_class=100):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(512, 4096),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(),
|
||||
nn.Linear(4096, 4096),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(),
|
||||
nn.Linear(4096, num_class)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.features(x)
|
||||
output = output.view(output.size()[0], -1)
|
||||
output = self.classifier(output)
|
||||
|
||||
return output
|
||||
|
||||
def make_layers(cfg, batch_norm=False):
|
||||
layers = []
|
||||
|
||||
input_channel = 3
|
||||
for l in cfg:
|
||||
if l == 'M':
|
||||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
continue
|
||||
|
||||
layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)]
|
||||
|
||||
if batch_norm:
|
||||
layers += [nn.BatchNorm2d(l)]
|
||||
|
||||
layers += [nn.ReLU(inplace=True)]
|
||||
input_channel = l
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def vgg11_bn():
|
||||
return VGG(make_layers(cfg['A'], batch_norm=True))
|
||||
|
||||
def vgg13_bn():
|
||||
return VGG(make_layers(cfg['B'], batch_norm=True))
|
||||
|
||||
def vgg16_bn():
|
||||
return VGG(make_layers(cfg['D'], batch_norm=True))
|
||||
|
||||
def vgg19_bn():
|
||||
return VGG(make_layers(cfg['E'], batch_norm=True))
|
||||
|
||||
|
||||
|
@ -0,0 +1,205 @@
|
||||
import sys
|
||||
|
||||
import numpy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.utils as vutils
|
||||
from torch.utils.data import DataLoader
|
||||
from dataset import Dataset_FullImg, Dataset_DiscRegion
|
||||
import math
|
||||
import PIL
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import dateutil.tz
|
||||
|
||||
from typing import Union, Optional, List, Tuple, Text, BinaryIO
|
||||
import pathlib
|
||||
import warnings
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont, ImageColor
|
||||
from lucent.optvis.param.spatial import pixel_image, fft_image, init_image
|
||||
from lucent.optvis.param.color import to_valid_rgb
|
||||
from torchvision.models import vgg19
|
||||
import torch.nn.functional as F
|
||||
import cfg
|
||||
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
|
||||
args = cfg.parse_args()
|
||||
device = torch.device('cuda', args.gpu_device)
|
||||
cnn = vgg19(pretrained=True).features.to(device).eval()
|
||||
|
||||
content_layers_default = ['conv_4']
|
||||
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
|
||||
|
||||
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
|
||||
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
|
||||
|
||||
class ContentLoss(nn.Module):
|
||||
|
||||
def __init__(self, target,):
|
||||
super(ContentLoss, self).__init__()
|
||||
# we 'detach' the target content from the tree used
|
||||
# to dynamically compute the gradient: this is a stated value,
|
||||
# not a variable. Otherwise the forward method of the criterion
|
||||
# will throw an error.
|
||||
self.target = target.detach()
|
||||
|
||||
def forward(self, input):
|
||||
self.loss = F.mse_loss(input, self.target)
|
||||
return input
|
||||
|
||||
def gram_matrix(input):
|
||||
a, b, c, d = input.size() # a=batch size(=1)
|
||||
# b=number of feature maps
|
||||
# (c,d)=dimensions of a f. map (N=c*d)
|
||||
|
||||
features = input.view(a * b, c * d) # resise F_XL into \hat F_XL
|
||||
|
||||
G = torch.mm(features, features.t()) # compute the gram product
|
||||
|
||||
# we 'normalize' the values of the gram matrix
|
||||
# by dividing by the number of element in each feature maps.
|
||||
return G.div(a * b * c * d)
|
||||
|
||||
class StyleLoss(nn.Module):
|
||||
|
||||
def __init__(self, target_feature):
|
||||
super(StyleLoss, self).__init__()
|
||||
self.target = gram_matrix(target_feature).detach()
|
||||
|
||||
def forward(self, input):
|
||||
G = gram_matrix(input)
|
||||
self.loss = F.mse_loss(G, self.target)
|
||||
return input
|
||||
|
||||
# create a module to normalize input image so we can easily put it in a
|
||||
# nn.Sequential
|
||||
class Normalization(nn.Module):
|
||||
def __init__(self, mean, std):
|
||||
super(Normalization, self).__init__()
|
||||
# .view the mean and std to make them [C x 1 x 1] so that they can
|
||||
# directly work with image Tensor of shape [B x C x H x W].
|
||||
# B is batch size. C is number of channels. H is height and W is width.
|
||||
self.mean = torch.tensor(mean).view(-1, 1, 1)
|
||||
self.std = torch.tensor(std).view(-1, 1, 1)
|
||||
|
||||
def forward(self, img):
|
||||
# normalize img
|
||||
return (img - self.mean) / self.std
|
||||
|
||||
def run_precpt(cnn, normalization_mean, normalization_std,
|
||||
content_img, style_img, input_img,
|
||||
style_weight=1000000, content_weight=1):
|
||||
model, style_losses, content_losses = precpt_loss(cnn,
|
||||
normalization_mean, normalization_std, style_img, content_img)
|
||||
|
||||
# We want to optimize the input and not the model parameters so we
|
||||
# update all the requires_grad fields accordingly
|
||||
model.requires_grad_(False)
|
||||
input_img.requires_grad_(True)
|
||||
|
||||
model(input_img)
|
||||
style_score = 0
|
||||
content_score = 0
|
||||
|
||||
for sl in style_losses:
|
||||
style_score += sl.loss
|
||||
for cl in content_losses:
|
||||
content_score += cl.loss
|
||||
|
||||
content_weight = 100
|
||||
style_weight = 100000
|
||||
style_score *= style_weight
|
||||
content_score *= content_weight
|
||||
|
||||
loss = style_score + content_score
|
||||
# loss = content_score
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def precpt_loss(cnn, normalization_mean, normalization_std,
|
||||
style_img, content_img,
|
||||
content_layers=content_layers_default,
|
||||
style_layers=style_layers_default):
|
||||
|
||||
# normalization module
|
||||
normalization = Normalization(normalization_mean, normalization_std).to(device)
|
||||
|
||||
# just in order to have an iterable access to or list of content/syle
|
||||
# losses
|
||||
content_losses = []
|
||||
style_losses = []
|
||||
# assuming that cnn is a nn.Sequential, so we make a new nn.Sequential
|
||||
# to put in modules that are supposed to be activated sequentially
|
||||
model = nn.Sequential(normalization)
|
||||
|
||||
i = 0 # increment every time we see a conv
|
||||
for layer in cnn.children():
|
||||
if isinstance(layer, nn.Conv2d):
|
||||
i += 1
|
||||
name = 'conv_{}'.format(i)
|
||||
elif isinstance(layer, nn.ReLU):
|
||||
name = 'relu_{}'.format(i)
|
||||
# The in-place version doesn't play very nicely with the ContentLoss
|
||||
# and StyleLoss we insert below. So we replace with out-of-place
|
||||
# ones here.
|
||||
layer = nn.ReLU(inplace=False)
|
||||
elif isinstance(layer, nn.MaxPool2d):
|
||||
name = 'pool_{}'.format(i)
|
||||
elif isinstance(layer, nn.BatchNorm2d):
|
||||
name = 'bn_{}'.format(i)
|
||||
else:
|
||||
raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
|
||||
|
||||
model.add_module(name, layer)
|
||||
|
||||
if name in content_layers:
|
||||
# add content loss:
|
||||
target = model(content_img).detach()
|
||||
content_loss = ContentLoss(target)
|
||||
model.add_module("content_loss_{}".format(i), content_loss)
|
||||
content_losses.append(content_loss)
|
||||
|
||||
if name in style_layers:
|
||||
# add style loss:
|
||||
if style_img.size(1) == 1:
|
||||
style_img = style_img.expand(style_img.size(0),3, style_img.size(2),style_img.size(3))
|
||||
target_feature = model(style_img).detach()
|
||||
style_loss = StyleLoss(target_feature)
|
||||
model.add_module("style_loss_{}".format(i), style_loss)
|
||||
style_losses.append(style_loss)
|
||||
|
||||
# now we trim off the layers after the last content and style losses
|
||||
for i in range(len(model) - 1, -1, -1):
|
||||
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
|
||||
break
|
||||
|
||||
model = model[:(i + 1)]
|
||||
|
||||
return model, style_losses, content_losses
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,73 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
import numpy as np
|
||||
from math import exp
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
||||
return gauss/gauss.sum()
|
||||
|
||||
def create_window(window_size, channel):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
||||
return window
|
||||
|
||||
def _ssim(img1, img2, window, window_size, channel, size_average = True):
|
||||
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
|
||||
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1*mu2
|
||||
|
||||
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
|
||||
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
|
||||
|
||||
C1 = 0.01**2
|
||||
C2 = 0.03**2
|
||||
|
||||
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
|
||||
|
||||
if size_average:
|
||||
return ssim_map.mean()
|
||||
else:
|
||||
return ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
class SSIM(torch.nn.Module):
|
||||
def __init__(self, window_size = 11, size_average = True):
|
||||
super(SSIM, self).__init__()
|
||||
self.window_size = window_size
|
||||
self.size_average = size_average
|
||||
self.channel = 1
|
||||
self.window = create_window(window_size, self.channel)
|
||||
|
||||
def forward(self, img1, img2):
|
||||
(_, channel, _, _) = img1.size()
|
||||
|
||||
if channel == self.channel and self.window.data.type() == img1.data.type():
|
||||
window = self.window
|
||||
else:
|
||||
window = create_window(self.window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
self.window = window
|
||||
self.channel = channel
|
||||
|
||||
|
||||
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
|
||||
|
||||
def ssim(img1, img2, window_size = 11, size_average = True):
|
||||
(_, channel, _, _) = img1.size()
|
||||
window = create_window(window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
return _ssim(img1, img2, window, window_size, channel, size_average)
|
||||
@ -0,0 +1,136 @@
|
||||
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from skimage import io
|
||||
from torch.utils.data import DataLoader
|
||||
#from dataset import *
|
||||
from torch.autograd import Variable
|
||||
from PIL import Image
|
||||
from tensorboardX import SummaryWriter
|
||||
#from models.discriminatorlayer import discriminator
|
||||
from dataset import *
|
||||
from conf import settings
|
||||
import time
|
||||
import cfg
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from utils import *
|
||||
import function
|
||||
|
||||
|
||||
args = cfg.parse_args()
|
||||
|
||||
GPUdevice = torch.device('cuda', args.gpu_device)
|
||||
net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed)
|
||||
|
||||
optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
|
||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) #learning rate decay
|
||||
|
||||
'''load pretrained model'''
|
||||
if args.weights != 0:
|
||||
print(f'=> resuming from {args.weights}')
|
||||
assert os.path.exists(args.weights)
|
||||
checkpoint_file = os.path.join(args.weights)
|
||||
assert os.path.exists(checkpoint_file)
|
||||
loc = 'cuda:{}'.format(args.gpu_device)
|
||||
checkpoint = torch.load(checkpoint_file, map_location=loc)
|
||||
start_epoch = checkpoint['epoch']
|
||||
best_tol = checkpoint['best_tol']
|
||||
|
||||
net.load_state_dict(checkpoint['state_dict'],strict=False)
|
||||
# optimizer.load_state_dict(checkpoint['optimizer'], strict=False)
|
||||
|
||||
args.path_helper = checkpoint['path_helper']
|
||||
logger = create_logger(args.path_helper['log_path'])
|
||||
print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
|
||||
|
||||
args.path_helper = set_log_dir('logs', args.exp_name)
|
||||
logger = create_logger(args.path_helper['log_path'])
|
||||
logger.info(args)
|
||||
|
||||
if args.dataset == 'oneprompt':
|
||||
nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list =get_decath_loader(args)
|
||||
elif args.dataset == 'polyp':
|
||||
# 息肉数据集
|
||||
transform_train = transforms.Compose([
|
||||
transforms.Resize((args.image_size, args.image_size)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
transform_train_seg = transforms.Compose([
|
||||
transforms.Resize((args.out_size, args.out_size)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
transform_test = transforms.Compose([
|
||||
transforms.Resize((args.image_size, args.image_size)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
transform_test_seg = transforms.Compose([
|
||||
transforms.Resize((args.out_size, args.out_size)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
train_dataset = CombinedPolypDataset(args, args.data_path, transform=transform_train, transform_msk=transform_train_seg, mode='Training')
|
||||
test_dataset = CombinedPolypDataset(args, args.data_path, transform=transform_test, transform_msk=transform_test_seg, mode='Test')
|
||||
|
||||
nice_train_loader = DataLoader(train_dataset, batch_size=args.b, shuffle=True, num_workers=args.w, pin_memory=True)
|
||||
nice_test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=args.w, pin_memory=True)
|
||||
|
||||
'''checkpoint path and tensorboard'''
|
||||
checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW)
|
||||
#use tensorboard
|
||||
if not os.path.exists(settings.LOG_DIR):
|
||||
os.mkdir(settings.LOG_DIR)
|
||||
writer = SummaryWriter(log_dir=os.path.join(
|
||||
settings.LOG_DIR, args.net, settings.TIME_NOW))
|
||||
|
||||
if not os.path.exists(checkpoint_path):
|
||||
os.makedirs(checkpoint_path)
|
||||
checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth')
|
||||
|
||||
'''begain training'''
|
||||
best_acc = 0.0
|
||||
best_tol = 1e4
|
||||
for epoch in range(settings.EPOCH):
|
||||
net.train()
|
||||
time_start = time.time()
|
||||
|
||||
loss = function.train_one(args, net, optimizer, nice_train_loader, epoch, writer, vis = args.vis)
|
||||
logger.info(f'Train loss: {loss}|| @ epoch {epoch}.')
|
||||
time_end = time.time()
|
||||
print('time_for_training ', time_end - time_start)
|
||||
|
||||
net.eval()
|
||||
if epoch and epoch % args.val_freq == 0 or epoch == settings.EPOCH-1:
|
||||
tol, (eiou, edice) = function.validation_one(args, nice_test_loader, epoch, net, writer)
|
||||
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.')
|
||||
|
||||
if args.distributed != 'none':
|
||||
sd = net.module.state_dict()
|
||||
else:
|
||||
sd = net.state_dict()
|
||||
|
||||
if tol < best_tol:
|
||||
best_tol = tol
|
||||
is_best = True
|
||||
|
||||
save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'model': args.net,
|
||||
'state_dict': sd,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'best_tol': best_tol,
|
||||
'path_helper': args.path_helper,
|
||||
}, is_best, args.path_helper['ckpt_path'], filename="best_checkpoint")
|
||||
else:
|
||||
is_best = False
|
||||
|
||||
writer.close()
|
||||
@ -0,0 +1,127 @@
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from skimage import io
|
||||
from torch.utils.data import DataLoader
|
||||
#from dataset import *
|
||||
from torch.autograd import Variable
|
||||
from PIL import Image
|
||||
from tensorboardX import SummaryWriter
|
||||
#from models.discriminatorlayer import discriminator
|
||||
from dataset import ISIC2016, REFUGE, PolypDataset, CombinedPolypDataset
|
||||
from conf import settings
|
||||
import time
|
||||
import cfg
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from utils import *
|
||||
import function
|
||||
|
||||
|
||||
args = cfg.parse_args()
|
||||
|
||||
GPUdevice = torch.device('cuda', args.gpu_device)
|
||||
|
||||
net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed)
|
||||
|
||||
'''load pretrained model'''
|
||||
assert args.weights != 0
|
||||
print(f'=> resuming from {args.weights}')
|
||||
assert os.path.exists(args.weights)
|
||||
checkpoint_file = os.path.join(args.weights)
|
||||
assert os.path.exists(checkpoint_file)
|
||||
loc = 'cuda:{}'.format(args.gpu_device)
|
||||
checkpoint = torch.load(checkpoint_file, map_location=loc)
|
||||
start_epoch = checkpoint['epoch']
|
||||
best_tol = checkpoint['best_tol']
|
||||
|
||||
state_dict = checkpoint['state_dict']
|
||||
if args.distributed != 'none':
|
||||
from collections import OrderedDict
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
name = 'module.' + k
|
||||
new_state_dict[name] = v
|
||||
# load params
|
||||
else:
|
||||
new_state_dict = state_dict
|
||||
|
||||
net.load_state_dict(new_state_dict)
|
||||
|
||||
args.path_helper = set_log_dir('logs', args.exp_name)
|
||||
logger = create_logger(args.path_helper['log_path'])
|
||||
logger.info(args)
|
||||
|
||||
'''segmentation data'''
|
||||
transform_train = transforms.Compose([
|
||||
transforms.Resize((args.image_size,args.image_size)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
transform_train_seg = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Resize((args.image_size,args.image_size)),
|
||||
])
|
||||
|
||||
transform_test = transforms.Compose([
|
||||
transforms.Resize((args.image_size, args.image_size)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
transform_test_seg = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Resize((args.image_size, args.image_size)),
|
||||
|
||||
])
|
||||
'''data end'''
|
||||
if args.dataset == 'isic':
|
||||
'''isic data'''
|
||||
isic_train_dataset = ISIC2016(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training')
|
||||
isic_test_dataset = ISIC2016(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test')
|
||||
|
||||
nice_train_loader = DataLoader(isic_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True)
|
||||
nice_test_loader = DataLoader(isic_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True)
|
||||
'''end'''
|
||||
|
||||
elif args.dataset == 'oneprompt':
|
||||
nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list =get_decath_loader(args)
|
||||
|
||||
elif args.dataset == 'REFUGE':
|
||||
'''REFUGE data'''
|
||||
refuge_train_dataset = REFUGE(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training')
|
||||
refuge_test_dataset = REFUGE(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test')
|
||||
|
||||
nice_train_loader = DataLoader(refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True)
|
||||
nice_test_loader = DataLoader(refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True)
|
||||
'''end'''
|
||||
|
||||
elif args.dataset == 'polyp':
|
||||
'''Polyp data'''
|
||||
transform_test_seg = transforms.Compose([
|
||||
transforms.Resize((args.out_size, args.out_size)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
polyp_test_dataset = CombinedPolypDataset(args, args.data_path, transform=transform_test, transform_msk=transform_test_seg, mode='Test')
|
||||
nice_test_loader = DataLoader(polyp_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True)
|
||||
'''end'''
|
||||
|
||||
'''begain valuation'''
|
||||
best_acc = 0.0
|
||||
best_tol = 1e4
|
||||
|
||||
if args.mod == 'sam_adpt' or args.mod == 'one_adpt':
|
||||
net.eval()
|
||||
tol, (eiou, edice) = function.validation_one(args, nice_test_loader, 0, net)
|
||||
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {start_epoch}.')
|
||||
|
||||
Loading…
Reference in new issue