#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Validation/Evaluation script for One-Prompt Medical Image Segmentation. This script provides evaluation functionality for trained models. Usage: python scripts/val.py -net oneprompt -mod one_adpt -exp_name eval_exp \\ -dataset polyp -data_path ./data/polyp -weights ./checkpoints/best.pth """ import os import sys # Add project root to path for imports sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from collections import OrderedDict import torch import torchvision.transforms as transforms from torch.utils.data import DataLoader # Local imports import cfg from conf import settings from dataset import ISIC2016, REFUGE, PolypDataset, CombinedPolypDataset from utils import ( get_network, get_decath_loader, create_logger, set_log_dir, ) import function def main(): """Main evaluation function.""" # Parse arguments args = cfg.parse_args() # Setup device gpu_device = torch.device('cuda', args.gpu_device) # Build network net = get_network( args, args.net, use_gpu=args.gpu, gpu_device=gpu_device, distribution=args.distributed ) # Load pretrained model assert args.weights != 0, "Please specify model weights with -weights" print(f'=> resuming from {args.weights}') assert os.path.exists(args.weights) checkpoint_file = os.path.join(args.weights) assert os.path.exists(checkpoint_file) loc = f'cuda:{args.gpu_device}' checkpoint = torch.load(checkpoint_file, map_location=loc) start_epoch = checkpoint['epoch'] best_tol = checkpoint['best_tol'] state_dict = checkpoint['state_dict'] if args.distributed != 'none': new_state_dict = OrderedDict() for k, v in state_dict.items(): name = 'module.' + k new_state_dict[name] = v else: new_state_dict = state_dict net.load_state_dict(new_state_dict) # Setup logging args.path_helper = set_log_dir('logs', args.exp_name) logger = create_logger(args.path_helper['log_path']) logger.info(args) # Setup data transforms transform_train = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), ]) transform_train_seg = transforms.Compose([ transforms.ToTensor(), transforms.Resize((args.image_size, args.image_size)), ]) transform_test = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), ]) transform_test_seg = transforms.Compose([ transforms.ToTensor(), transforms.Resize((args.image_size, args.image_size)), ]) # Load data based on dataset type if args.dataset == 'isic': isic_train_dataset = ISIC2016( args, args.data_path, transform=transform_train, transform_msk=transform_train_seg, mode='Training' ) isic_test_dataset = ISIC2016( args, args.data_path, transform=transform_test, transform_msk=transform_test_seg, mode='Test' ) nice_train_loader = DataLoader( isic_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True ) nice_test_loader = DataLoader( isic_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True ) elif args.dataset == 'oneprompt': nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list = get_decath_loader(args) elif args.dataset == 'REFUGE': refuge_train_dataset = REFUGE( args, args.data_path, transform=transform_train, transform_msk=transform_train_seg, mode='Training' ) refuge_test_dataset = REFUGE( args, args.data_path, transform=transform_test, transform_msk=transform_test_seg, mode='Test' ) nice_train_loader = DataLoader( refuge_train_dataset, batch_size=args.b, shuffle=True, num_workers=8, pin_memory=True ) nice_test_loader = DataLoader( refuge_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True ) elif args.dataset == 'polyp': transform_test_seg = transforms.Compose([ transforms.Resize((args.out_size, args.out_size)), transforms.ToTensor(), ]) polyp_test_dataset = CombinedPolypDataset( args, args.data_path, transform=transform_test, transform_msk=transform_test_seg, mode='Test' ) nice_test_loader = DataLoader( polyp_test_dataset, batch_size=args.b, shuffle=False, num_workers=8, pin_memory=True ) # Run evaluation if args.mod == 'sam_adpt' or args.mod == 'one_adpt': net.eval() tol, (eiou, edice) = function.validation_one(args, nice_test_loader, 0, net) logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {start_epoch}.') print(f'\nEvaluation Results:') print(f' Total Score: {tol}') print(f' IoU: {eiou}') print(f' Dice: {edice}') if __name__ == '__main__': main()