You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
335 lines
12 KiB
335 lines
12 KiB
|
|
import os
|
|
import sys
|
|
import argparse
|
|
from datetime import datetime
|
|
from collections import OrderedDict
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from skimage import io
|
|
from torch.utils.data import DataLoader
|
|
#from dataset import *
|
|
from torch.autograd import Variable
|
|
from PIL import Image
|
|
from tensorboardX import SummaryWriter
|
|
#from models.discriminatorlayer import discriminator
|
|
from conf import settings
|
|
import time
|
|
import cfg
|
|
from conf import settings
|
|
from tqdm import tqdm
|
|
from utils import *
|
|
import torch.nn.functional as F
|
|
import torch
|
|
from einops import rearrange
|
|
import pytorch_ssim
|
|
|
|
import shutil
|
|
import tempfile
|
|
|
|
import matplotlib.pyplot as plt
|
|
from tqdm import tqdm
|
|
|
|
from monai.losses import DiceCELoss
|
|
from monai.inferers import sliding_window_inference
|
|
from monai.transforms import (
|
|
AsDiscrete,
|
|
)
|
|
|
|
|
|
import torch
|
|
|
|
|
|
args = cfg.parse_args()
|
|
|
|
GPUdevice = torch.device('cuda', args.gpu_device)
|
|
pos_weight = torch.ones([1]).cuda(device=GPUdevice)*2
|
|
criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
|
seed = torch.randint(1,11,(args.b,7))
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
max_iterations = settings.EPOCH
|
|
post_label = AsDiscrete(to_onehot=14)
|
|
post_pred = AsDiscrete(argmax=True, to_onehot=14)
|
|
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
|
|
dice_val_best = 0.0
|
|
global_step_best = 0
|
|
epoch_loss_values = []
|
|
metric_values = []
|
|
|
|
def train_one(args, net: nn.Module, optimizer, train_loader,
|
|
epoch, writer, schedulers=None, vis = 50):
|
|
hard = 0
|
|
epoch_loss = 0
|
|
ind = 0
|
|
# train mode
|
|
net.train()
|
|
optimizer.zero_grad()
|
|
|
|
# 处理 DataParallel 包装
|
|
model = net.module if hasattr(net, 'module') else net
|
|
|
|
epoch_loss = 0
|
|
GPUdevice = torch.device('cuda:' + str(args.gpu_device))
|
|
|
|
if args.thd:
|
|
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
|
|
else:
|
|
lossfunc = criterion_G
|
|
|
|
with tqdm(total=len(train_loader), desc=f'Epoch {epoch}', unit='img') as pbar:
|
|
|
|
for pack in train_loader:
|
|
# 获取当前batch的实际大小
|
|
current_b = pack['image'].size(0)
|
|
|
|
if ind == 0:
|
|
tmp_img = pack['image'].to(dtype = torch.float32, device = GPUdevice)[0,:,:,:].unsqueeze(0).repeat(current_b, 1, 1, 1)
|
|
tmp_mask = pack['label'].to(dtype = torch.float32, device = GPUdevice)[0,:,:,:].unsqueeze(0).repeat(current_b, 1, 1, 1)
|
|
if 'pt' not in pack:
|
|
tmp_img, pt, tmp_mask = generate_click_prompt(tmp_img, tmp_mask)
|
|
else:
|
|
pt = pack['pt']
|
|
point_labels = pack['p_label']
|
|
|
|
if point_labels[0] != -1:
|
|
point_coords = pt
|
|
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
|
|
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
|
|
# 只取第一个点并重复到当前batch大小
|
|
coords_torch = coords_torch[0:1, :].repeat(current_b, 1)
|
|
labels_torch = labels_torch[0:1].repeat(current_b)
|
|
coords_torch, labels_torch = coords_torch[:, None, :], labels_torch[:, None]
|
|
tmp_pt = (coords_torch, labels_torch)
|
|
else:
|
|
# 更新模板图片的batch大小以匹配当前batch
|
|
if tmp_img.size(0) != current_b:
|
|
tmp_img = tmp_img[0:1].repeat(current_b, 1, 1, 1)
|
|
tmp_mask = tmp_mask[0:1].repeat(current_b, 1, 1, 1)
|
|
if 'tmp_pt' in dir():
|
|
coords_torch = tmp_pt[0][0:1].repeat(current_b, 1, 1)
|
|
labels_torch = tmp_pt[1][0:1].repeat(current_b, 1)
|
|
tmp_pt = (coords_torch, labels_torch)
|
|
|
|
imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice)
|
|
masks = pack['label'].to(dtype = torch.float32, device = GPUdevice)
|
|
|
|
name = pack['image_meta_dict']['filename_or_obj']
|
|
|
|
# 处理当前batch的点击提示
|
|
if 'pt' in pack:
|
|
pt = pack['pt']
|
|
point_labels = pack['p_label']
|
|
if point_labels[0] != -1:
|
|
point_coords = pt
|
|
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
|
|
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
|
|
coords_torch, labels_torch = coords_torch[:, None, :], labels_torch[:, None]
|
|
pt = (coords_torch, labels_torch)
|
|
|
|
if args.thd:
|
|
pt = rearrange(pt, 'b n d -> (b d) n')
|
|
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
|
|
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
|
|
|
|
imgs = imgs.repeat(1,3,1,1)
|
|
point_labels = torch.ones(imgs.size(0))
|
|
|
|
imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs)
|
|
masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks)
|
|
|
|
showp = pt
|
|
|
|
mask_type = torch.float32
|
|
ind += 1
|
|
b_size,c,w,h = imgs.size()
|
|
longsize = w if w >=h else h
|
|
|
|
'''init'''
|
|
if hard:
|
|
true_mask_ave = (true_mask_ave > 0.5).float()
|
|
imgs = imgs.to(dtype = mask_type,device = GPUdevice)
|
|
|
|
# 使用混合精度训练
|
|
with torch.amp.autocast('cuda'):
|
|
with torch.no_grad():
|
|
# 使用梯度检查点节省显存
|
|
imge, skips= model.image_encoder(imgs)
|
|
timge, tskips = model.image_encoder(tmp_img)
|
|
|
|
# imge= net.image_encoder(imgs)
|
|
p1, p2, se, de = model.prompt_encoder(
|
|
points=pt,
|
|
boxes=None,
|
|
doodles= None,
|
|
masks=None,
|
|
)
|
|
|
|
# 清理不需要的中间变量
|
|
torch.cuda.empty_cache()
|
|
|
|
pred, _ = model.mask_decoder(
|
|
skips_raw = skips,
|
|
skips_tmp = tskips,
|
|
raw_emb = imge,
|
|
tmp_emb = timge,
|
|
pt1 = p1,
|
|
pt2 = p2,
|
|
image_pe=model.prompt_encoder.get_dense_pe(),
|
|
sparse_prompt_embeddings=se,
|
|
dense_prompt_embeddings=de,
|
|
multimask_output=False,
|
|
)
|
|
|
|
# 调整预测大小以匹配目标
|
|
if pred.shape[-2:] != masks.shape[-2:]:
|
|
pred = F.interpolate(pred, size=masks.shape[-2:], mode='bilinear', align_corners=False)
|
|
|
|
loss = lossfunc(pred, masks)
|
|
|
|
# 检查 nan 并跳过
|
|
if torch.isnan(loss) or torch.isinf(loss):
|
|
optimizer.zero_grad()
|
|
pbar.set_postfix(**{'loss (batch)': 'nan/inf skipped'})
|
|
pbar.update()
|
|
ind += 1
|
|
continue
|
|
|
|
pbar.set_postfix(**{'loss (batch)': loss.item()})
|
|
epoch_loss += loss.item()
|
|
|
|
scaler.scale(loss).backward()
|
|
# 梯度裁剪
|
|
scaler.unscale_(optimizer)
|
|
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
optimizer.zero_grad()
|
|
|
|
'''vis images'''
|
|
if vis:
|
|
if ind % vis == 0:
|
|
namecat = 'Train'
|
|
for na in name:
|
|
namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
|
|
vis_image(imgs,pred,masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False)
|
|
|
|
pbar.update()
|
|
|
|
return loss
|
|
|
|
def validation_one(args, val_loader, epoch, net: nn.Module, clean_dir=True):
|
|
# eval mode
|
|
net.eval()
|
|
|
|
# 处理 DataParallel 包装
|
|
model = net.module if hasattr(net, 'module') else net
|
|
|
|
mask_type = torch.float32
|
|
n_val = len(val_loader) # the number of batch
|
|
ave_res, mix_res = (0,0,0,0), (0,0,0,0)
|
|
rater_res = [(0,0,0,0) for _ in range(6)]
|
|
tot = 0
|
|
hard = 0
|
|
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
|
|
GPUdevice = torch.device('cuda:' + str(args.gpu_device))
|
|
device = GPUdevice
|
|
|
|
if args.thd:
|
|
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
|
|
else:
|
|
lossfunc = criterion_G
|
|
|
|
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
|
|
|
|
for ind, pack in enumerate(val_loader):
|
|
if ind == 0:
|
|
tmp_img = pack['image'].to(dtype = torch.float32, device = GPUdevice)[0,:,:,:].unsqueeze(0).repeat(args.b, 1, 1, 1)
|
|
tmp_mask = pack['label'].to(dtype = torch.float32, device = GPUdevice)[0,:,:,:].unsqueeze(0).repeat(args.b, 1, 1, 1)
|
|
if 'pt' not in pack:
|
|
tmp_img, pt, tmp_mask = generate_click_prompt(tmp_img, tmp_mask)
|
|
else:
|
|
pt = pack['pt']
|
|
point_labels = pack['p_label']
|
|
|
|
if point_labels[0] != -1:
|
|
# point_coords = onetrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w))
|
|
point_coords = pt
|
|
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
|
|
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
|
|
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
|
|
pt = (coords_torch, labels_torch)
|
|
|
|
|
|
imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice)
|
|
masks = pack['label'].to(dtype = torch.float32, device = GPUdevice)
|
|
|
|
name = pack['image_meta_dict']['filename_or_obj']
|
|
|
|
showp = pt
|
|
|
|
mask_type = torch.float32
|
|
ind += 1
|
|
b_size,c,w,h = imgs.size()
|
|
longsize = w if w >=h else h
|
|
|
|
'''init'''
|
|
if hard:
|
|
true_mask_ave = (true_mask_ave > 0.5).float()
|
|
#true_mask_ave = cons_tensor(true_mask_ave)
|
|
imgs = imgs.to(dtype = mask_type,device = GPUdevice)
|
|
|
|
'''test'''
|
|
with torch.no_grad():
|
|
imge, skips= model.image_encoder(imgs)
|
|
timge, tskips = model.image_encoder(tmp_img)
|
|
|
|
p1, p2, se, de = model.prompt_encoder(
|
|
points=pt,
|
|
boxes=None,
|
|
doodles= None,
|
|
masks=None,
|
|
)
|
|
pred, _ = model.mask_decoder(
|
|
skips_raw = skips,
|
|
skips_tmp = tskips,
|
|
raw_emb = imge,
|
|
tmp_emb = timge,
|
|
pt1 = p1,
|
|
pt2 = p2,
|
|
image_pe=model.prompt_encoder.get_dense_pe(),
|
|
sparse_prompt_embeddings=se,
|
|
dense_prompt_embeddings=de,
|
|
multimask_output=False,
|
|
)
|
|
|
|
# 调整预测大小以匹配目标
|
|
if pred.shape[-2:] != masks.shape[-2:]:
|
|
pred = F.interpolate(pred, size=masks.shape[-2:], mode='bilinear', align_corners=False)
|
|
|
|
tot += lossfunc(pred, masks)
|
|
|
|
'''vis images'''
|
|
if args.vis and ind % args.vis == 0:
|
|
namecat = 'Test'
|
|
for na in name:
|
|
img_name = na.split('/')[-1].split('.')[0]
|
|
namecat = namecat + img_name + '+'
|
|
vis_image(imgs,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False)
|
|
|
|
|
|
temp = eval_seg(pred, masks, threshold)
|
|
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
|
|
|
|
pbar.update()
|
|
|
|
|
|
return tot/ n_val , tuple([a/n_val for a in mix_res])
|