You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
146 lines
5.4 KiB
146 lines
5.4 KiB
|
|
|
|
import os
|
|
from datetime import datetime
|
|
from collections import OrderedDict
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from skimage import io
|
|
from torch.utils.data import DataLoader
|
|
#from dataset import *
|
|
from torch.autograd import Variable
|
|
from PIL import Image
|
|
from tensorboardX import SummaryWriter
|
|
#from models.discriminatorlayer import discriminator
|
|
from dataset import *
|
|
from conf import settings
|
|
import time
|
|
import cfg
|
|
from tqdm import tqdm
|
|
from torch.utils.data import DataLoader, random_split
|
|
from utils import *
|
|
import function
|
|
|
|
|
|
args = cfg.parse_args()
|
|
|
|
GPUdevice = torch.device('cuda', args.gpu_device)
|
|
net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed)
|
|
|
|
optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
|
|
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) #learning rate decay
|
|
|
|
'''load pretrained model'''
|
|
if args.weights != 0:
|
|
print(f'=> resuming from {args.weights}')
|
|
assert os.path.exists(args.weights)
|
|
checkpoint_file = os.path.join(args.weights)
|
|
assert os.path.exists(checkpoint_file)
|
|
loc = 'cuda:{}'.format(args.gpu_device)
|
|
checkpoint = torch.load(checkpoint_file, map_location=loc)
|
|
start_epoch = checkpoint['epoch']
|
|
best_tol = checkpoint['best_tol']
|
|
|
|
net.load_state_dict(checkpoint['state_dict'],strict=False)
|
|
# optimizer.load_state_dict(checkpoint['optimizer'], strict=False)
|
|
|
|
args.path_helper = checkpoint['path_helper']
|
|
logger = create_logger(args.path_helper['log_path'])
|
|
print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
|
|
|
|
args.path_helper = set_log_dir('logs', args.exp_name)
|
|
logger = create_logger(args.path_helper['log_path'])
|
|
logger.info(args)
|
|
|
|
if args.dataset == 'oneprompt':
|
|
nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list =get_decath_loader(args)
|
|
elif args.dataset == 'isic' or args.dataset == 'custom':
|
|
# 定义数据变换
|
|
transform_train = transforms.Compose([
|
|
transforms.Resize((args.image_size, args.image_size)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
transform_val = transforms.Compose([
|
|
transforms.Resize((args.image_size, args.image_size)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
transform_train_seg = transforms.Compose([
|
|
transforms.Resize((args.image_size, args.image_size)),
|
|
transforms.ToTensor()
|
|
])
|
|
|
|
# 创建数据集
|
|
train_dataset = ISIC2016(args, args.data_path, transform=transform_train, transform_msk=transform_train_seg, mode='Training', prompt='click')
|
|
test_dataset = ISIC2016(args, args.data_path, transform=transform_val, transform_msk=transform_train_seg, mode='Test', prompt='click')
|
|
|
|
# 创建数据加载器
|
|
nice_train_loader = DataLoader(train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True)
|
|
nice_test_loader = DataLoader(test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True)
|
|
|
|
'''checkpoint path and tensorboard'''
|
|
checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW)
|
|
#use tensorboard
|
|
if not os.path.exists(settings.LOG_DIR):
|
|
os.mkdir(settings.LOG_DIR)
|
|
writer = SummaryWriter(log_dir=os.path.join(
|
|
settings.LOG_DIR, args.net, settings.TIME_NOW))
|
|
|
|
if not os.path.exists(checkpoint_path):
|
|
os.makedirs(checkpoint_path)
|
|
checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth')
|
|
|
|
'''begain training'''
|
|
best_acc = 0.0
|
|
best_tol = 1e4
|
|
for epoch in range(settings.EPOCH):
|
|
net.train()
|
|
time_start = time.time()
|
|
|
|
loss = function.train_one(args, net, optimizer, nice_train_loader, epoch, writer, vis = args.vis)
|
|
logger.info(f'Train loss: {loss}|| @ epoch {epoch}.')
|
|
time_end = time.time()
|
|
print('time_for_training ', time_end - time_start)
|
|
|
|
net.eval()
|
|
if epoch and epoch % args.val_freq == 0 or epoch == settings.EPOCH-1:
|
|
tol, metrics = function.validation_one(args, nice_test_loader, epoch, net, writer)
|
|
|
|
# Handle both 2-metric and 4-metric cases
|
|
if len(metrics) == 2:
|
|
eiou, edice = metrics
|
|
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.')
|
|
elif len(metrics) == 4:
|
|
iou_d, iou_c, disc_dice, cup_dice = metrics
|
|
logger.info(f'Total score: {tol}, Disc IOU: {iou_d}, Cup IOU: {iou_c}, Disc DICE: {disc_dice}, Cup DICE: {cup_dice} || @ epoch {epoch}.')
|
|
else:
|
|
logger.info(f'Total score: {tol}, Metrics: {metrics} || @ epoch {epoch}.')
|
|
|
|
if args.distributed != 'none':
|
|
sd = net.module.state_dict()
|
|
else:
|
|
sd = net.state_dict()
|
|
|
|
if tol < best_tol:
|
|
best_tol = tol
|
|
is_best = True
|
|
|
|
save_checkpoint({
|
|
'epoch': epoch + 1,
|
|
'model': args.net,
|
|
'state_dict': sd,
|
|
'optimizer': optimizer.state_dict(),
|
|
'best_tol': best_tol,
|
|
'path_helper': args.path_helper,
|
|
}, is_best, args.path_helper['ckpt_path'], filename="best_checkpoint")
|
|
else:
|
|
is_best = False
|
|
|
|
writer.close()
|