You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

196 lines
5.5 KiB

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Validation/Evaluation script for One-Prompt Medical Image Segmentation.
This script provides evaluation functionality for trained models.
Usage:
python scripts/val.py -net oneprompt -mod one_adpt -exp_name eval_exp \\
-dataset polyp -data_path ./data/polyp -weights ./checkpoints/best.pth
"""
import os
import sys
# Add project root to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from collections import OrderedDict
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# Local imports
import cfg
from conf import settings
from dataset import ISIC2016, REFUGE, PolypDataset, CombinedPolypDataset
from utils import (
get_network,
get_decath_loader,
create_logger,
set_log_dir,
)
import function
def main():
"""Main evaluation function."""
# Parse arguments
args = cfg.parse_args()
# Setup device
gpu_device = torch.device('cuda', args.gpu_device)
# Build network
net = get_network(
args, args.net,
use_gpu=args.gpu,
gpu_device=gpu_device,
distribution=args.distributed
)
# Load pretrained model
assert args.weights != 0, "Please specify model weights with -weights"
print(f'=> resuming from {args.weights}')
assert os.path.exists(args.weights)
checkpoint_file = os.path.join(args.weights)
assert os.path.exists(checkpoint_file)
loc = f'cuda:{args.gpu_device}'
checkpoint = torch.load(checkpoint_file, map_location=loc)
start_epoch = checkpoint['epoch']
best_tol = checkpoint['best_tol']
state_dict = checkpoint['state_dict']
if args.distributed != 'none':
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = 'module.' + k
new_state_dict[name] = v
else:
new_state_dict = state_dict
net.load_state_dict(new_state_dict)
# Setup logging
args.path_helper = set_log_dir('logs', args.exp_name)
logger = create_logger(args.path_helper['log_path'])
logger.info(args)
# Setup data transforms
transform_train = transforms.Compose([
transforms.Resize((args.image_size, args.image_size)),
transforms.ToTensor(),
])
transform_train_seg = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((args.image_size, args.image_size)),
])
transform_test = transforms.Compose([
transforms.Resize((args.image_size, args.image_size)),
transforms.ToTensor(),
])
transform_test_seg = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((args.image_size, args.image_size)),
])
# Load data based on dataset type
if args.dataset == 'isic':
isic_train_dataset = ISIC2016(
args, args.data_path,
transform=transform_train,
transform_msk=transform_train_seg,
mode='Training'
)
isic_test_dataset = ISIC2016(
args, args.data_path,
transform=transform_test,
transform_msk=transform_test_seg,
mode='Test'
)
nice_train_loader = DataLoader(
isic_train_dataset,
batch_size=args.b,
shuffle=True,
num_workers=8,
pin_memory=True
)
nice_test_loader = DataLoader(
isic_test_dataset,
batch_size=args.b,
shuffle=False,
num_workers=8,
pin_memory=True
)
elif args.dataset == 'oneprompt':
nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list = get_decath_loader(args)
elif args.dataset == 'REFUGE':
refuge_train_dataset = REFUGE(
args, args.data_path,
transform=transform_train,
transform_msk=transform_train_seg,
mode='Training'
)
refuge_test_dataset = REFUGE(
args, args.data_path,
transform=transform_test,
transform_msk=transform_test_seg,
mode='Test'
)
nice_train_loader = DataLoader(
refuge_train_dataset,
batch_size=args.b,
shuffle=True,
num_workers=8,
pin_memory=True
)
nice_test_loader = DataLoader(
refuge_test_dataset,
batch_size=args.b,
shuffle=False,
num_workers=8,
pin_memory=True
)
elif args.dataset == 'polyp':
transform_test_seg = transforms.Compose([
transforms.Resize((args.out_size, args.out_size)),
transforms.ToTensor(),
])
polyp_test_dataset = CombinedPolypDataset(
args, args.data_path,
transform=transform_test,
transform_msk=transform_test_seg,
mode='Test'
)
nice_test_loader = DataLoader(
polyp_test_dataset,
batch_size=args.b,
shuffle=False,
num_workers=8,
pin_memory=True
)
# Run evaluation
if args.mod == 'sam_adpt' or args.mod == 'one_adpt':
net.eval()
tol, (eiou, edice) = function.validation_one(args, nice_test_loader, 0, net)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {start_epoch}.')
print(f'\nEvaluation Results:')
print(f' Total Score: {tol}')
print(f' IoU: {eiou}')
print(f' Dice: {edice}')
if __name__ == '__main__':
main()