import os from datetime import datetime from collections import OrderedDict import numpy as np import torch import torch.nn as nn import torch.optim as optim from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix import torchvision import torchvision.transforms as transforms from skimage import io from torch.utils.data import DataLoader #from dataset import * from torch.autograd import Variable from PIL import Image from tensorboardX import SummaryWriter #from models.discriminatorlayer import discriminator from dataset import * from conf import settings import time import cfg from tqdm import tqdm from torch.utils.data import DataLoader, random_split from utils import * import function args = cfg.parse_args() GPUdevice = torch.device('cuda', args.gpu_device) net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed) optimizer = optim.Adam(net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) #learning rate decay '''load pretrained model''' if args.weights != 0: print(f'=> resuming from {args.weights}') assert os.path.exists(args.weights) checkpoint_file = os.path.join(args.weights) assert os.path.exists(checkpoint_file) loc = 'cuda:{}'.format(args.gpu_device) checkpoint = torch.load(checkpoint_file, map_location=loc) start_epoch = checkpoint['epoch'] best_tol = checkpoint['best_tol'] net.load_state_dict(checkpoint['state_dict'],strict=False) # optimizer.load_state_dict(checkpoint['optimizer'], strict=False) args.path_helper = checkpoint['path_helper'] logger = create_logger(args.path_helper['log_path']) print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') args.path_helper = set_log_dir('logs', args.exp_name) logger = create_logger(args.path_helper['log_path']) logger.info(args) if args.dataset == 'oneprompt': nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list =get_decath_loader(args) elif args.dataset == 'isic' or args.dataset == 'custom': # 定义数据变换 transform_train = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) transform_val = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) transform_train_seg = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor() ]) # 创建数据集 train_dataset = ISIC2016(args, args.data_path, transform=transform_train, transform_msk=transform_train_seg, mode='Training', prompt='click') test_dataset = ISIC2016(args, args.data_path, transform=transform_val, transform_msk=transform_train_seg, mode='Test', prompt='click') # 创建数据加载器 nice_train_loader = DataLoader(train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) '''checkpoint path and tensorboard''' checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW) #use tensorboard if not os.path.exists(settings.LOG_DIR): os.mkdir(settings.LOG_DIR) writer = SummaryWriter(log_dir=os.path.join( settings.LOG_DIR, args.net, settings.TIME_NOW)) if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth') '''begain training''' best_acc = 0.0 best_tol = 1e4 for epoch in range(settings.EPOCH): net.train() time_start = time.time() loss = function.train_one(args, net, optimizer, nice_train_loader, epoch, writer, vis = args.vis) logger.info(f'Train loss: {loss}|| @ epoch {epoch}.') time_end = time.time() print('time_for_training ', time_end - time_start) net.eval() if epoch and epoch % args.val_freq == 0 or epoch == settings.EPOCH-1: tol, metrics = function.validation_one(args, nice_test_loader, epoch, net, writer) # Handle both 2-metric and 4-metric cases if len(metrics) == 2: eiou, edice = metrics logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.') elif len(metrics) == 4: iou_d, iou_c, disc_dice, cup_dice = metrics logger.info(f'Total score: {tol}, Disc IOU: {iou_d}, Cup IOU: {iou_c}, Disc DICE: {disc_dice}, Cup DICE: {cup_dice} || @ epoch {epoch}.') else: logger.info(f'Total score: {tol}, Metrics: {metrics} || @ epoch {epoch}.') if args.distributed != 'none': sd = net.module.state_dict() else: sd = net.state_dict() if tol < best_tol: best_tol = tol is_best = True save_checkpoint({ 'epoch': epoch + 1, 'model': args.net, 'state_dict': sd, 'optimizer': optimizer.state_dict(), 'best_tol': best_tol, 'path_helper': args.path_helper, }, is_best, args.path_helper['ckpt_path'], filename="best_checkpoint") else: is_best = False writer.close()