You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

320 lines
12 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import sys
import argparse
from datetime import datetime
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix
import torchvision
import torchvision.transforms as transforms
from skimage import io
from torch.utils.data import DataLoader
#from dataset import *
from torch.autograd import Variable
from PIL import Image
from tensorboardX import SummaryWriter
#from models.discriminatorlayer import discriminator
from conf import settings
import time
import cfg
from conf import settings
from tqdm import tqdm
from utils import *
import torch.nn.functional as F
import torch
from einops import rearrange
import pytorch_ssim
import shutil
import tempfile
import matplotlib.pyplot as plt
from tqdm import tqdm
from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
AsDiscrete,
)
import torch
args = cfg.parse_args()
GPUdevice = torch.device('cuda', args.gpu_device)
pos_weight = torch.ones([1]).cuda(device=GPUdevice)*2
criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
seed = torch.randint(1,11,(args.b,7))
torch.backends.cudnn.benchmark = True
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
scaler = torch.cuda.amp.GradScaler()
max_iterations = settings.EPOCH
post_label = AsDiscrete(to_onehot=14)
post_pred = AsDiscrete(argmax=True, to_onehot=14)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
def train_one(args, net: nn.Module, optimizer, train_loader,
epoch, writer, schedulers=None, vis = 50):
hard = 0
epoch_loss = 0
ind = 0
# train mode
net.train()
optimizer.zero_grad()
epoch_loss = 0
GPUdevice = torch.device('cuda:' + str(args.gpu_device))
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = criterion_G
with tqdm(total=len(train_loader), desc=f'Epoch {epoch}', unit='img') as pbar:
for pack in train_loader:
if ind == 0:
tmp_img = pack['image'].to(dtype = torch.float32, device = GPUdevice)[0,:,:,:].unsqueeze(0).repeat(args.b, 1, 1, 1)
tmp_mask = pack['label'].to(dtype = torch.float32, device = GPUdevice)[0,:,:,:].unsqueeze(0).repeat(args.b, 1, 1, 1)
if 'pt' not in pack:
tmp_img, pt, tmp_mask = generate_click_prompt(tmp_img, tmp_mask)
else:
pt = pack['pt']
point_labels = pack['p_label']
if point_labels[0] != -1:
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
pt = (coords_torch, labels_torch)
imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice)
masks = pack['label'].to(dtype = torch.float32, device = GPUdevice)
name = pack['image_meta_dict']['filename_or_obj']
if args.thd:
pt = rearrange(pt, 'b n d -> (b d) n')
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
imgs = imgs.repeat(1,3,1,1)
point_labels = torch.ones(imgs.size(0))
imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks)
showp = pt
mask_type = torch.float32
ind += 1
b_size,c,w,h = imgs.size()
longsize = w if w >=h else h
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
imgs = imgs.to(dtype = mask_type,device = GPUdevice)
with torch.no_grad():
imge, skips= net.image_encoder(imgs)
timge, tskips = net.image_encoder(tmp_img)
# imge= net.image_encoder(imgs)
p1, p2, se, de = net.prompt_encoder(
points=pt,
boxes=None,
doodles= None,
masks=None,
)
pred, _ = net.mask_decoder(
skips_raw = skips,
skips_tmp = tskips,
raw_emb = imge,
tmp_emb = timge,
pt1 = p1,
pt2 = p2,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
# 调整预测掩码的尺寸以匹配目标掩码尺寸
pred = F.interpolate(pred, size=(masks.shape[2], masks.shape[3]), mode='bilinear', align_corners=False)
loss = lossfunc(pred, masks)
pbar.set_postfix(**{'loss (batch)': loss.item()})
epoch_loss += loss.item()
loss.backward()
# nn.utils.clip_grad_value_(net.parameters(), 0.1)
optimizer.step()
optimizer.zero_grad()
'''vis images'''
if vis:
if ind % vis == 0:
namecat = 'Train'
for na in name:
namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
vis_image(imgs,pred,masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False)
pbar.update()
return loss
def validation_one(args, val_loader, epoch, net: nn.Module, clean_dir=True):
# eval mode
net.eval()
mask_type = torch.float32
n_val = len(val_loader) # the number of batch
rater_res = [(0,0,0,0) for _ in range(6)]
tot = 0
hard = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
valid_eval_count = 0 # Counter for samples with valid masks
GPUdevice = torch.device('cuda:' + str(args.gpu_device))
device = GPUdevice
# Initialize metrics tuples with correct length based on mask channels (to be determined)
metrics_length = None
mix_res = None
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = criterion_G
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
for ind, pack in enumerate(val_loader):
imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice)
masks = pack['label'].to(dtype = torch.float32, device = GPUdevice)
b_size, _, _, _ = imgs.size()
# Determine metrics length based on mask channels (first batch only)
if metrics_length is None:
mask_channels = masks.shape[1]
metrics_length = 4 if mask_channels == 2 else 2
mix_res = (0,) * metrics_length
if ind == 0:
tmp_img_base = pack['image'].to(dtype = torch.float32, device = GPUdevice)[0,:,:,:].unsqueeze(0)
tmp_mask_base = pack['label'].to(dtype = torch.float32, device = GPUdevice)[0,:,:,:].unsqueeze(0)
if 'pt' not in pack:
# 生成单个样本的prompt
single_tmp_img = tmp_img_base
single_tmp_mask = tmp_mask_base
single_tmp_img, pt, single_tmp_mask = generate_click_prompt(single_tmp_img, single_tmp_mask)
else:
pt = pack['pt']
point_labels = pack['p_label']
if point_labels[0] != -1:
# point_coords = onetrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w))
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
# 检查coords_torch的形状确保它是3维的(B, N, 2)
if coords_torch.dim() == 2:
# 如果是2维的(N, 2),添加批次维度
coords_torch = coords_torch.unsqueeze(0)
elif coords_torch.dim() == 3:
# 如果已经是3维的(B, N, 2),不需要再添加维度
pass
labels_torch = labels_torch.unsqueeze(0) if labels_torch.dim() == 1 else labels_torch
pt = (coords_torch, labels_torch)
# 每次迭代都生成与当前批次大小匹配的tmp_img和tmp_mask
tmp_img = tmp_img_base.repeat(b_size, 1, 1, 1)
tmp_mask = tmp_mask_base.repeat(b_size, 1, 1, 1)
name = pack['image_meta_dict']['filename_or_obj']
showp = pt
mask_type = torch.float32
ind += 1
b_size,c,w,h = imgs.size()
longsize = w if w >=h else h
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
#true_mask_ave = cons_tensor(true_mask_ave)
imgs = imgs.to(dtype = mask_type,device = GPUdevice)
'''test'''
with torch.no_grad():
imge, skips= net.image_encoder(imgs)
timge, tskips = net.image_encoder(tmp_img)
p1, p2, se, de = net.prompt_encoder(
points=pt,
boxes=None,
doodles= None,
masks=None,
)
pred, _ = net.mask_decoder(
skips_raw = skips,
skips_tmp = tskips,
raw_emb = imge,
tmp_emb = timge,
pt1 = p1,
pt2 = p2,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
# 调整预测掩码的尺寸以匹配目标掩码尺寸
pred = F.interpolate(pred, size=(masks.shape[2], masks.shape[3]), mode='bilinear', align_corners=False)
tot += lossfunc(pred, masks)
'''vis images'''
if args.vis and ind % args.vis == 0:
namecat = 'Test'
for na in name:
img_name = na.split('/')[-1].split('.')[0]
namecat = namecat + img_name + '+'
vis_image(imgs,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False)
# Check if masks are all zeros (dummy masks)
is_dummy_mask = torch.all(masks == 0).item()
if not is_dummy_mask:
temp = eval_seg(pred, masks, threshold)
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
valid_eval_count += 1
else:
# Skip evaluation for dummy masks
pass
pbar.update()
# Use valid_eval_count to calculate average, avoid division by zero
if valid_eval_count > 0:
return tot / n_val, tuple([a / valid_eval_count for a in mix_res])
else:
# If no valid masks, return loss only and zeros for metrics
return tot / n_val, tuple([0.0 for _ in range(metrics_length)])