You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

207 lines
6.2 KiB

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Training script for One-Prompt Medical Image Segmentation.
This script provides the main entry point for training the One-Prompt
segmentation model on various medical imaging datasets.
Usage:
python scripts/train.py -net oneprompt -mod one_adpt -exp_name experiment1 \\
-dataset polyp -data_path ./data/polyp
Example:
python scripts/train.py \\
-net oneprompt \\
-mod one_adpt \\
-exp_name polyp_training \\
-dataset polyp \\
-data_path /path/to/data
"""
import os
import sys
import time
# Add project root to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
# Local imports
import cfg
from conf import settings
from dataset import CombinedPolypDataset
from utils import (
get_network,
get_decath_loader,
create_logger,
set_log_dir,
save_checkpoint,
)
import function
def main():
"""Main training function."""
# Parse arguments
args = cfg.parse_args()
# Setup device
gpu_device = torch.device('cuda', args.gpu_device)
# Build network
net = get_network(
args, args.net,
use_gpu=args.gpu,
gpu_device=gpu_device,
distribution=args.distributed
)
# Setup optimizer and scheduler
optimizer = optim.Adam(
net.parameters(),
lr=args.lr,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0,
amsgrad=False
)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
# Load pretrained model if specified
start_epoch = 0
best_tol = 1e4
if args.weights != 0:
print(f'=> resuming from {args.weights}')
assert os.path.exists(args.weights)
checkpoint_file = os.path.join(args.weights)
assert os.path.exists(checkpoint_file)
loc = f'cuda:{args.gpu_device}'
checkpoint = torch.load(checkpoint_file, map_location=loc)
start_epoch = checkpoint['epoch']
best_tol = checkpoint['best_tol']
net.load_state_dict(checkpoint['state_dict'], strict=False)
args.path_helper = checkpoint['path_helper']
logger = create_logger(args.path_helper['log_path'])
print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
# Setup logging
args.path_helper = set_log_dir('logs', args.exp_name)
logger = create_logger(args.path_helper['log_path'])
logger.info(args)
# Load data
if args.dataset == 'oneprompt':
nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list = get_decath_loader(args)
elif args.dataset == 'polyp':
# Polyp dataset
transform_train = transforms.Compose([
transforms.Resize((args.image_size, args.image_size)),
transforms.ToTensor(),
])
transform_train_seg = transforms.Compose([
transforms.Resize((args.out_size, args.out_size)),
transforms.ToTensor(),
])
transform_test = transforms.Compose([
transforms.Resize((args.image_size, args.image_size)),
transforms.ToTensor(),
])
transform_test_seg = transforms.Compose([
transforms.Resize((args.out_size, args.out_size)),
transforms.ToTensor(),
])
train_dataset = CombinedPolypDataset(
args, args.data_path,
transform=transform_train,
transform_msk=transform_train_seg,
mode='Training'
)
test_dataset = CombinedPolypDataset(
args, args.data_path,
transform=transform_test,
transform_msk=transform_test_seg,
mode='Test'
)
nice_train_loader = DataLoader(
train_dataset,
batch_size=args.b,
shuffle=True,
num_workers=args.w,
pin_memory=True
)
nice_test_loader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=args.w,
pin_memory=True
)
# Setup checkpoint path and tensorboard
checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW)
if not os.path.exists(settings.LOG_DIR):
os.mkdir(settings.LOG_DIR)
writer = SummaryWriter(
log_dir=os.path.join(settings.LOG_DIR, args.net, settings.TIME_NOW)
)
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)
checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth')
# Training loop
best_acc = 0.0
for epoch in range(settings.EPOCH):
net.train()
time_start = time.time()
loss = function.train_one(
args, net, optimizer, nice_train_loader, epoch, writer, vis=args.vis
)
logger.info(f'Train loss: {loss}|| @ epoch {epoch}.')
time_end = time.time()
print(f'time_for_training {time_end - time_start}')
net.eval()
if epoch and epoch % args.val_freq == 0 or epoch == settings.EPOCH - 1:
tol, (eiou, edice) = function.validation_one(
args, nice_test_loader, epoch, net, writer
)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.')
if args.distributed != 'none':
sd = net.module.state_dict()
else:
sd = net.state_dict()
if tol < best_tol:
best_tol = tol
is_best = True
save_checkpoint({
'epoch': epoch + 1,
'model': args.net,
'state_dict': sd,
'optimizer': optimizer.state_dict(),
'best_tol': best_tol,
'path_helper': args.path_helper,
}, is_best, args.path_helper['ckpt_path'], filename="best_checkpoint")
else:
is_best = False
writer.close()
logger.info("Training completed!")
if __name__ == '__main__':
main()