#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Training script for One-Prompt Medical Image Segmentation. This script provides the main entry point for training the One-Prompt segmentation model on various medical imaging datasets. Usage: python scripts/train.py -net oneprompt -mod one_adpt -exp_name experiment1 \\ -dataset polyp -data_path ./data/polyp Example: python scripts/train.py \\ -net oneprompt \\ -mod one_adpt \\ -exp_name polyp_training \\ -dataset polyp \\ -data_path /path/to/data """ import os import sys import time # Add project root to path for imports sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch import torch.optim as optim import torchvision.transforms as transforms from torch.utils.data import DataLoader from tensorboardX import SummaryWriter # Local imports import cfg from conf import settings from dataset import CombinedPolypDataset from utils import ( get_network, get_decath_loader, create_logger, set_log_dir, save_checkpoint, ) import function def main(): """Main training function.""" # Parse arguments args = cfg.parse_args() # Setup device gpu_device = torch.device('cuda', args.gpu_device) # Build network net = get_network( args, args.net, use_gpu=args.gpu, gpu_device=gpu_device, distribution=args.distributed ) # Setup optimizer and scheduler optimizer = optim.Adam( net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False ) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) # Load pretrained model if specified start_epoch = 0 best_tol = 1e4 if args.weights != 0: print(f'=> resuming from {args.weights}') assert os.path.exists(args.weights) checkpoint_file = os.path.join(args.weights) assert os.path.exists(checkpoint_file) loc = f'cuda:{args.gpu_device}' checkpoint = torch.load(checkpoint_file, map_location=loc) start_epoch = checkpoint['epoch'] best_tol = checkpoint['best_tol'] net.load_state_dict(checkpoint['state_dict'], strict=False) args.path_helper = checkpoint['path_helper'] logger = create_logger(args.path_helper['log_path']) print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') # Setup logging args.path_helper = set_log_dir('logs', args.exp_name) logger = create_logger(args.path_helper['log_path']) logger.info(args) # Load data if args.dataset == 'oneprompt': nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list = get_decath_loader(args) elif args.dataset == 'polyp': # Polyp dataset transform_train = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), ]) transform_train_seg = transforms.Compose([ transforms.Resize((args.out_size, args.out_size)), transforms.ToTensor(), ]) transform_test = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), ]) transform_test_seg = transforms.Compose([ transforms.Resize((args.out_size, args.out_size)), transforms.ToTensor(), ]) train_dataset = CombinedPolypDataset( args, args.data_path, transform=transform_train, transform_msk=transform_train_seg, mode='Training' ) test_dataset = CombinedPolypDataset( args, args.data_path, transform=transform_test, transform_msk=transform_test_seg, mode='Test' ) nice_train_loader = DataLoader( train_dataset, batch_size=args.b, shuffle=True, num_workers=args.w, pin_memory=True ) nice_test_loader = DataLoader( test_dataset, batch_size=1, shuffle=False, num_workers=args.w, pin_memory=True ) # Setup checkpoint path and tensorboard checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW) if not os.path.exists(settings.LOG_DIR): os.mkdir(settings.LOG_DIR) writer = SummaryWriter( log_dir=os.path.join(settings.LOG_DIR, args.net, settings.TIME_NOW) ) if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth') # Training loop best_acc = 0.0 for epoch in range(settings.EPOCH): net.train() time_start = time.time() loss = function.train_one( args, net, optimizer, nice_train_loader, epoch, writer, vis=args.vis ) logger.info(f'Train loss: {loss}|| @ epoch {epoch}.') time_end = time.time() print(f'time_for_training {time_end - time_start}') net.eval() if epoch and epoch % args.val_freq == 0 or epoch == settings.EPOCH - 1: tol, (eiou, edice) = function.validation_one( args, nice_test_loader, epoch, net, writer ) logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.') if args.distributed != 'none': sd = net.module.state_dict() else: sd = net.state_dict() if tol < best_tol: best_tol = tol is_best = True save_checkpoint({ 'epoch': epoch + 1, 'model': args.net, 'state_dict': sd, 'optimizer': optimizer.state_dict(), 'best_tol': best_tol, 'path_helper': args.path_helper, }, is_best, args.path_helper['ckpt_path'], filename="best_checkpoint") else: is_best = False writer.close() logger.info("Training completed!") if __name__ == '__main__': main()