You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

335 lines
12 KiB

import os
import sys
import argparse
from datetime import datetime
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix
import torchvision
import torchvision.transforms as transforms
from skimage import io
from torch.utils.data import DataLoader
#from dataset import *
from torch.autograd import Variable
from PIL import Image
from tensorboardX import SummaryWriter
#from models.discriminatorlayer import discriminator
from conf import settings
import time
import cfg
from conf import settings
from tqdm import tqdm
from utils import *
import torch.nn.functional as F
import torch
from einops import rearrange
import pytorch_ssim
import shutil
import tempfile
import matplotlib.pyplot as plt
from tqdm import tqdm
from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
AsDiscrete,
)
import torch
args = cfg.parse_args()
GPUdevice = torch.device('cuda', args.gpu_device)
pos_weight = torch.ones([1]).cuda(device=GPUdevice)*2
criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
seed = torch.randint(1,11,(args.b,7))
torch.backends.cudnn.benchmark = True
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
scaler = torch.cuda.amp.GradScaler()
max_iterations = settings.EPOCH
post_label = AsDiscrete(to_onehot=14)
post_pred = AsDiscrete(argmax=True, to_onehot=14)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
def train_one(args, net: nn.Module, optimizer, train_loader,
epoch, writer, schedulers=None, vis = 50):
hard = 0
epoch_loss = 0
ind = 0
# train mode
net.train()
optimizer.zero_grad()
# 处理 DataParallel 包装
model = net.module if hasattr(net, 'module') else net
epoch_loss = 0
GPUdevice = torch.device('cuda:' + str(args.gpu_device))
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = criterion_G
with tqdm(total=len(train_loader), desc=f'Epoch {epoch}', unit='img') as pbar:
for pack in train_loader:
# 获取当前batch的实际大小
current_b = pack['image'].size(0)
if ind == 0:
tmp_img = pack['image'].to(dtype = torch.float32, device = GPUdevice)[0,:,:,:].unsqueeze(0).repeat(current_b, 1, 1, 1)
tmp_mask = pack['label'].to(dtype = torch.float32, device = GPUdevice)[0,:,:,:].unsqueeze(0).repeat(current_b, 1, 1, 1)
if 'pt' not in pack:
tmp_img, pt, tmp_mask = generate_click_prompt(tmp_img, tmp_mask)
else:
pt = pack['pt']
point_labels = pack['p_label']
if point_labels[0] != -1:
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
# 只取第一个点并重复到当前batch大小
coords_torch = coords_torch[0:1, :].repeat(current_b, 1)
labels_torch = labels_torch[0:1].repeat(current_b)
coords_torch, labels_torch = coords_torch[:, None, :], labels_torch[:, None]
tmp_pt = (coords_torch, labels_torch)
else:
# 更新模板图片的batch大小以匹配当前batch
if tmp_img.size(0) != current_b:
tmp_img = tmp_img[0:1].repeat(current_b, 1, 1, 1)
tmp_mask = tmp_mask[0:1].repeat(current_b, 1, 1, 1)
if 'tmp_pt' in dir():
coords_torch = tmp_pt[0][0:1].repeat(current_b, 1, 1)
labels_torch = tmp_pt[1][0:1].repeat(current_b, 1)
tmp_pt = (coords_torch, labels_torch)
imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice)
masks = pack['label'].to(dtype = torch.float32, device = GPUdevice)
name = pack['image_meta_dict']['filename_or_obj']
# 处理当前batch的点击提示
if 'pt' in pack:
pt = pack['pt']
point_labels = pack['p_label']
if point_labels[0] != -1:
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
coords_torch, labels_torch = coords_torch[:, None, :], labels_torch[:, None]
pt = (coords_torch, labels_torch)
if args.thd:
pt = rearrange(pt, 'b n d -> (b d) n')
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
imgs = imgs.repeat(1,3,1,1)
point_labels = torch.ones(imgs.size(0))
imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks)
showp = pt
mask_type = torch.float32
ind += 1
b_size,c,w,h = imgs.size()
longsize = w if w >=h else h
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
imgs = imgs.to(dtype = mask_type,device = GPUdevice)
# 使用混合精度训练
with torch.amp.autocast('cuda'):
with torch.no_grad():
# 使用梯度检查点节省显存
imge, skips= model.image_encoder(imgs)
timge, tskips = model.image_encoder(tmp_img)
# imge= net.image_encoder(imgs)
p1, p2, se, de = model.prompt_encoder(
points=pt,
boxes=None,
doodles= None,
masks=None,
)
# 清理不需要的中间变量
torch.cuda.empty_cache()
pred, _ = model.mask_decoder(
skips_raw = skips,
skips_tmp = tskips,
raw_emb = imge,
tmp_emb = timge,
pt1 = p1,
pt2 = p2,
image_pe=model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
# 调整预测大小以匹配目标
if pred.shape[-2:] != masks.shape[-2:]:
pred = F.interpolate(pred, size=masks.shape[-2:], mode='bilinear', align_corners=False)
loss = lossfunc(pred, masks)
# 检查 nan 并跳过
if torch.isnan(loss) or torch.isinf(loss):
optimizer.zero_grad()
pbar.set_postfix(**{'loss (batch)': 'nan/inf skipped'})
pbar.update()
ind += 1
continue
pbar.set_postfix(**{'loss (batch)': loss.item()})
epoch_loss += loss.item()
scaler.scale(loss).backward()
# 梯度裁剪
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
'''vis images'''
if vis:
if ind % vis == 0:
namecat = 'Train'
for na in name:
namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
vis_image(imgs,pred,masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False)
pbar.update()
return loss
def validation_one(args, val_loader, epoch, net: nn.Module, clean_dir=True):
# eval mode
net.eval()
# 处理 DataParallel 包装
model = net.module if hasattr(net, 'module') else net
mask_type = torch.float32
n_val = len(val_loader) # the number of batch
ave_res, mix_res = (0,0,0,0), (0,0,0,0)
rater_res = [(0,0,0,0) for _ in range(6)]
tot = 0
hard = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
GPUdevice = torch.device('cuda:' + str(args.gpu_device))
device = GPUdevice
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = criterion_G
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
for ind, pack in enumerate(val_loader):
if ind == 0:
tmp_img = pack['image'].to(dtype = torch.float32, device = GPUdevice)[0,:,:,:].unsqueeze(0).repeat(args.b, 1, 1, 1)
tmp_mask = pack['label'].to(dtype = torch.float32, device = GPUdevice)[0,:,:,:].unsqueeze(0).repeat(args.b, 1, 1, 1)
if 'pt' not in pack:
tmp_img, pt, tmp_mask = generate_click_prompt(tmp_img, tmp_mask)
else:
pt = pack['pt']
point_labels = pack['p_label']
if point_labels[0] != -1:
# point_coords = onetrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w))
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
pt = (coords_torch, labels_torch)
imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice)
masks = pack['label'].to(dtype = torch.float32, device = GPUdevice)
name = pack['image_meta_dict']['filename_or_obj']
showp = pt
mask_type = torch.float32
ind += 1
b_size,c,w,h = imgs.size()
longsize = w if w >=h else h
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
#true_mask_ave = cons_tensor(true_mask_ave)
imgs = imgs.to(dtype = mask_type,device = GPUdevice)
'''test'''
with torch.no_grad():
imge, skips= model.image_encoder(imgs)
timge, tskips = model.image_encoder(tmp_img)
p1, p2, se, de = model.prompt_encoder(
points=pt,
boxes=None,
doodles= None,
masks=None,
)
pred, _ = model.mask_decoder(
skips_raw = skips,
skips_tmp = tskips,
raw_emb = imge,
tmp_emb = timge,
pt1 = p1,
pt2 = p2,
image_pe=model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
# 调整预测大小以匹配目标
if pred.shape[-2:] != masks.shape[-2:]:
pred = F.interpolate(pred, size=masks.shape[-2:], mode='bilinear', align_corners=False)
tot += lossfunc(pred, masks)
'''vis images'''
if args.vis and ind % args.vis == 0:
namecat = 'Test'
for na in name:
img_name = na.split('/')[-1].split('.')[0]
namecat = namecat + img_name + '+'
vis_image(imgs,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False)
temp = eval_seg(pred, masks, threshold)
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
pbar.update()
return tot/ n_val , tuple([a/n_val for a in mix_res])