import os import sys import argparse from datetime import datetime from collections import OrderedDict import numpy as np import torch import torch.nn as nn import torch.optim as optim from sklearn.metrics import roc_auc_score, accuracy_score,confusion_matrix import torchvision import torchvision.transforms as transforms from skimage import io from torch.utils.data import DataLoader #from dataset import * from torch.autograd import Variable from PIL import Image from tensorboardX import SummaryWriter #from models.discriminatorlayer import discriminator from dataset import ISIC2016, REFUGE, PolypDataset, CombinedPolypDataset from conf import settings import time import cfg from tqdm import tqdm from torch.utils.data import DataLoader, random_split from utils import * import function args = cfg.parse_args() GPUdevice = torch.device('cuda', args.gpu_device) net = get_network(args, args.net, use_gpu=args.gpu, gpu_device=GPUdevice, distribution = args.distributed) '''load pretrained model''' assert args.weights != 0 print(f'=> resuming from {args.weights}') assert os.path.exists(args.weights) checkpoint_file = os.path.join(args.weights) assert os.path.exists(checkpoint_file) loc = 'cuda:{}'.format(args.gpu_device) checkpoint = torch.load(checkpoint_file, map_location=loc) start_epoch = checkpoint['epoch'] best_tol = checkpoint['best_tol'] state_dict = checkpoint['state_dict'] if args.distributed != 'none': from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = 'module.' + k new_state_dict[name] = v # load params else: new_state_dict = state_dict net.load_state_dict(new_state_dict) args.path_helper = set_log_dir('logs', args.exp_name) logger = create_logger(args.path_helper['log_path']) logger.info(args) '''segmentation data''' transform_train = transforms.Compose([ transforms.Resize((args.image_size,args.image_size)), transforms.ToTensor(), ]) transform_train_seg = transforms.Compose([ transforms.ToTensor(), transforms.Resize((args.image_size,args.image_size)), ]) transform_test = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), ]) transform_test_seg = transforms.Compose([ transforms.ToTensor(), transforms.Resize((args.image_size, args.image_size)), ]) '''data end''' if args.dataset == 'isic': '''isic data''' isic_train_dataset = ISIC2016(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training') isic_test_dataset = ISIC2016(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test') nice_train_loader = DataLoader(isic_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(isic_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) '''end''' elif args.dataset == 'oneprompt': nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list =get_decath_loader(args) elif args.dataset == 'REFUGE': '''REFUGE data''' refuge_train_dataset = REFUGE(args, args.data_path, transform = transform_train, transform_msk= transform_train_seg, mode = 'Training') refuge_test_dataset = REFUGE(args, args.data_path, transform = transform_test, transform_msk= transform_test_seg, mode = 'Test') nice_train_loader = DataLoader(refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True) nice_test_loader = DataLoader(refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) '''end''' elif args.dataset == 'polyp': '''Polyp data''' transform_test_seg = transforms.Compose([ transforms.Resize((args.out_size, args.out_size)), transforms.ToTensor(), ]) polyp_test_dataset = CombinedPolypDataset(args, args.data_path, transform=transform_test, transform_msk=transform_test_seg, mode='Test') nice_test_loader = DataLoader(polyp_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True) '''end''' '''begain valuation''' best_acc = 0.0 best_tol = 1e4 if args.mod == 'sam_adpt' or args.mod == 'one_adpt': net.eval() tol, (eiou, edice) = function.validation_one(args, nice_test_loader, 0, net) logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {start_epoch}.')