- Add src/oneprompt_seg package with modular structure: - core/: network builder and training functions - data/: dataset classes and data loaders - utils/: metrics, visualization, logging, schedulers - Add tests/ directory with unit tests for metrics and data utils - Add docs/ directory with Sphinx documentation setup - Add configs/ with development and production YAML configs - Update scripts/ with refactored train.py and evaluate.py - Add pyproject.toml for modern Python project configuration - Add requirements.txt and requirements-dev.txt - Add Makefile for common development commands - Add .pre-commit-config.yaml for code quality hooks - Update .gitignore to exclude .claude/ and improve patterns - Update README.md with new structure and usage instructions - Add CLAUDE.md as project guide for Claude Code 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>main
parent
60b9c98523
commit
e0b048b2a5
@ -0,0 +1,44 @@
|
||||
# Pre-commit hooks configuration
|
||||
# Install: pip install pre-commit && pre-commit install
|
||||
|
||||
repos:
|
||||
# General hooks
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=500']
|
||||
- id: check-merge-conflict
|
||||
- id: debug-statements
|
||||
|
||||
# Python code formatting with Black
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3
|
||||
|
||||
# Import sorting with isort
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
args: ["--profile", "black"]
|
||||
|
||||
# Linting with flake8
|
||||
- repo: https://github.com/pycqa/flake8
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: ["--max-line-length=88", "--extend-ignore=E203,W503"]
|
||||
|
||||
# Type checking with mypy
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.3.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies: [types-PyYAML]
|
||||
args: ["--ignore-missing-imports"]
|
||||
@ -0,0 +1,186 @@
|
||||
# CLAUDE.md - Project Guide for Claude Code
|
||||
|
||||
This document provides context and guidance for Claude Code when working with this project.
|
||||
|
||||
## Project Overview
|
||||
|
||||
**One-Prompt Medical Image Segmentation** is a deep learning framework for medical image segmentation using one-prompt learning. It implements the method described in the CVPR 2024 paper "One-Prompt to Segment All Medical Images".
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
One-Prompt-Medical-Image-Segmentation/
|
||||
├── src/oneprompt_seg/ # Main package (src layout)
|
||||
│ ├── __init__.py # Package init with lazy imports
|
||||
│ ├── core/ # Core functionality
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── network.py # Network builder (get_network)
|
||||
│ │ └── function.py # Training/validation functions
|
||||
│ ├── data/ # Data handling
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── datasets.py # Dataset classes (ISIC2016, REFUGE, PolypDataset)
|
||||
│ │ └── loader.py # Data loaders (MONAI-based)
|
||||
│ └── utils/ # Utilities
|
||||
│ ├── __init__.py # Lazy imports
|
||||
│ ├── metrics.py # Evaluation metrics (IoU, Dice)
|
||||
│ ├── visualization.py # Image visualization
|
||||
│ ├── data_utils.py # Data utilities (click generation)
|
||||
│ ├── logging.py # Logging utilities
|
||||
│ └── scheduler.py # LR schedulers
|
||||
├── models/ # Model implementations
|
||||
│ ├── oneprompt/ # Main One-Prompt model
|
||||
│ │ ├── modeling/ # Core model components
|
||||
│ │ │ ├── oneprompt.py # Main model class
|
||||
│ │ │ ├── image_encoder.py # Image encoder
|
||||
│ │ │ ├── prompt_encoder.py# Prompt encoder
|
||||
│ │ │ └── mask_decoder.py # Mask decoder
|
||||
│ │ └── utils/ # Model utilities
|
||||
│ ├── unet/ # UNet backbone
|
||||
│ └── tag/ # TAG model
|
||||
├── tests/ # Unit tests
|
||||
├── docs/ # Documentation (Sphinx)
|
||||
├── configs/ # YAML configurations
|
||||
├── scripts/ # Executable scripts
|
||||
├── train.py # Original training entry
|
||||
├── val.py # Original validation entry
|
||||
├── utils.py # Original utilities (kept for compatibility)
|
||||
├── dataset.py # Original dataset classes
|
||||
├── function.py # Original training functions
|
||||
├── cfg.py # Argument parser
|
||||
└── conf/ # Settings module
|
||||
```
|
||||
|
||||
## Key Technologies
|
||||
|
||||
- **PyTorch** - Deep learning framework
|
||||
- **MONAI** - Medical imaging library (data loading, transforms)
|
||||
- **TensorBoardX** - Training visualization
|
||||
- **einops** - Tensor operations
|
||||
|
||||
## Common Commands
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
pip install -e . # Install package in editable mode
|
||||
|
||||
# Training
|
||||
python train.py -net oneprompt -mod one_adpt -exp_name <name> \
|
||||
-dataset polyp -data_path <path> -baseline unet
|
||||
|
||||
# Evaluation
|
||||
python val.py -net oneprompt -mod one_adpt -exp_name <name> \
|
||||
-weights <checkpoint_path> -dataset polyp -data_path <path>
|
||||
|
||||
# Run tests
|
||||
pytest tests/ -v
|
||||
|
||||
# Format code
|
||||
make format
|
||||
```
|
||||
|
||||
## Important Files
|
||||
|
||||
### Entry Points
|
||||
- `train.py` - Main training script
|
||||
- `val.py` - Evaluation script
|
||||
- `scripts/train.py` - Refactored training script
|
||||
- `scripts/evaluate.py` - Refactored evaluation script
|
||||
|
||||
### Configuration
|
||||
- `cfg.py` - Command line argument parser
|
||||
- `conf/global_settings.py` - Global settings (EPOCH, paths)
|
||||
- `configs/*.yaml` - YAML configuration files
|
||||
|
||||
### Core Model
|
||||
- `models/oneprompt/modeling/oneprompt.py` - Main model implementation
|
||||
- `models/oneprompt/build_oneprompt.py` - Model registry
|
||||
|
||||
### Data
|
||||
- `dataset.py` - Dataset classes
|
||||
- `utils.py` - Utilities including `get_network`, `get_decath_loader`
|
||||
|
||||
## Code Conventions
|
||||
|
||||
This project follows **PEP 8** with the following specifics:
|
||||
|
||||
1. **Naming**:
|
||||
- Classes: `CamelCase` (e.g., `CombinedPolypDataset`)
|
||||
- Functions/variables: `snake_case` (e.g., `get_network`)
|
||||
- Constants: `UPPER_CASE` (e.g., `EPOCH`)
|
||||
|
||||
2. **Imports**: Grouped as standard library, third-party, local
|
||||
|
||||
3. **Type Hints**: Used in new code (src/ directory)
|
||||
|
||||
4. **Docstrings**: Google-style docstrings
|
||||
|
||||
## Dataset Format
|
||||
|
||||
All datasets should return a dict with:
|
||||
|
||||
```python
|
||||
{
|
||||
'image': torch.Tensor, # [C, H, W] or [C, H, W, D]
|
||||
'label': torch.Tensor, # Same shape as image
|
||||
'p_label': int, # 1 for positive, 0 for negative
|
||||
'pt': tuple or np.ndarray, # Click coordinates
|
||||
'image_meta_dict': dict, # {'filename_or_obj': str}
|
||||
}
|
||||
```
|
||||
|
||||
## Common Tasks
|
||||
|
||||
### Adding a New Dataset
|
||||
|
||||
1. Create a new class in `dataset.py` or `src/oneprompt_seg/data/datasets.py`
|
||||
2. Inherit from `torch.utils.data.Dataset`
|
||||
3. Implement `__len__` and `__getitem__`
|
||||
4. Return the standard dict format
|
||||
|
||||
### Modifying Training
|
||||
|
||||
1. Main training loop is in `train.py`
|
||||
2. Per-epoch logic is in `function.py` (`train_one`, `validation_one`)
|
||||
3. Loss functions are configured based on `args.thd`
|
||||
|
||||
### Adding Metrics
|
||||
|
||||
1. Add to `src/oneprompt_seg/utils/metrics.py`
|
||||
2. Update `__init__.py` exports if needed
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Import Errors
|
||||
- Ensure MONAI and other dependencies are installed
|
||||
- The src package uses lazy imports to avoid import errors during installation
|
||||
|
||||
### CUDA Out of Memory
|
||||
- Reduce batch size with `-b` argument
|
||||
- Reduce image size with `-image_size` argument
|
||||
- Enable gradient checkpointing in the model
|
||||
|
||||
### Path Issues
|
||||
- Data paths are relative to the working directory
|
||||
- Logs are saved to `./logs/` by default
|
||||
- Checkpoints are saved to `./checkpoint/` by default
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
pytest tests/ -v
|
||||
|
||||
# Run specific test file
|
||||
pytest tests/test_metrics.py -v
|
||||
|
||||
# Run with coverage
|
||||
pytest tests/ --cov=src/oneprompt_seg --cov-report=html
|
||||
```
|
||||
|
||||
## Notes for Development
|
||||
|
||||
1. The project maintains backward compatibility - original files (`train.py`, `val.py`, `utils.py`) still work
|
||||
2. New code should be added to `src/oneprompt_seg/` when possible
|
||||
3. Use `configs/` YAML files for configuration instead of hardcoding
|
||||
4. Run `pre-commit` hooks before committing: `pre-commit run --all-files`
|
||||
@ -0,0 +1,75 @@
|
||||
# Makefile for One-Prompt Medical Image Segmentation
|
||||
# Common commands for development, testing, and training
|
||||
|
||||
.PHONY: help install install-dev test lint format clean train eval docs
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@echo "Available commands:"
|
||||
@echo " install Install production dependencies"
|
||||
@echo " install-dev Install development dependencies"
|
||||
@echo " test Run unit tests"
|
||||
@echo " lint Run code linting"
|
||||
@echo " format Format code with black and isort"
|
||||
@echo " clean Clean temporary files"
|
||||
@echo " train Run training script"
|
||||
@echo " eval Run evaluation script"
|
||||
@echo " docs Build documentation"
|
||||
|
||||
# Installation
|
||||
install:
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
|
||||
install-dev:
|
||||
pip install -r requirements.txt
|
||||
pip install -r requirements-dev.txt
|
||||
pip install -e .
|
||||
pre-commit install
|
||||
|
||||
# Testing
|
||||
test:
|
||||
pytest tests/ -v --cov=src/oneprompt_seg --cov-report=html
|
||||
|
||||
test-quick:
|
||||
pytest tests/ -v -x
|
||||
|
||||
# Linting and formatting
|
||||
lint:
|
||||
flake8 src/ tests/
|
||||
mypy src/
|
||||
|
||||
format:
|
||||
black src/ tests/ scripts/
|
||||
isort src/ tests/ scripts/
|
||||
|
||||
# Cleaning
|
||||
clean:
|
||||
find . -type f -name "*.pyc" -delete
|
||||
find . -type d -name "__pycache__" -delete
|
||||
find . -type d -name ".pytest_cache" -delete
|
||||
find . -type d -name ".mypy_cache" -delete
|
||||
rm -rf build/ dist/ *.egg-info/
|
||||
rm -rf .coverage htmlcov/
|
||||
|
||||
# Training and evaluation
|
||||
train:
|
||||
python scripts/train.py --config configs/default.yaml
|
||||
|
||||
train-polyp:
|
||||
python scripts/train.py \
|
||||
-net oneprompt \
|
||||
-mod one_adpt \
|
||||
-exp_name polyp_training \
|
||||
-dataset polyp \
|
||||
-data_path ./data/polyp
|
||||
|
||||
eval:
|
||||
python scripts/val.py --config configs/default.yaml
|
||||
|
||||
# Documentation
|
||||
docs:
|
||||
cd docs && make html
|
||||
|
||||
docs-clean:
|
||||
cd docs && make clean
|
||||
@ -0,0 +1,53 @@
|
||||
# One-Prompt Medical Image Segmentation - Development Configuration
|
||||
|
||||
project:
|
||||
name: "one-prompt-segmentation"
|
||||
version: "1.0.0"
|
||||
description: "开发环境配置"
|
||||
mode: "development"
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
dataset: "polyp"
|
||||
data_path: "./data/TestDataset"
|
||||
train_ratio: 0.8
|
||||
batch_size: 2
|
||||
num_workers: 2
|
||||
|
||||
# 模型配置
|
||||
model:
|
||||
net: "oneprompt"
|
||||
baseline: "unet"
|
||||
mod: "one_adpt"
|
||||
image_size: 256
|
||||
out_size: 256
|
||||
patch_size: 16
|
||||
dim: 256
|
||||
depth: 1
|
||||
heads: 16
|
||||
mlp_dim: 1024
|
||||
|
||||
# 训练配置 (开发环境使用较小的配置)
|
||||
training:
|
||||
epochs: 10
|
||||
learning_rate: 0.0001
|
||||
optimizer: "adam"
|
||||
weight_decay: 0.0
|
||||
scheduler:
|
||||
name: "step"
|
||||
step_size: 5
|
||||
gamma: 0.5
|
||||
gradient_clip: 1.0
|
||||
|
||||
# 验证配置
|
||||
validation:
|
||||
val_freq: 2
|
||||
vis_freq: 10
|
||||
|
||||
# 日志配置
|
||||
logging:
|
||||
log_dir: "logs/dev"
|
||||
tensorboard: true
|
||||
save_best: true
|
||||
checkpoint_freq: 5
|
||||
log_level: "DEBUG"
|
||||
@ -0,0 +1,59 @@
|
||||
# One-Prompt Medical Image Segmentation - Production Configuration
|
||||
|
||||
project:
|
||||
name: "one-prompt-segmentation"
|
||||
version: "1.0.0"
|
||||
description: "生产环境配置"
|
||||
mode: "production"
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
dataset: "oneprompt"
|
||||
data_path: "/data/medical_images"
|
||||
train_ratio: 0.8
|
||||
batch_size: 8
|
||||
num_workers: 8
|
||||
|
||||
# 模型配置
|
||||
model:
|
||||
net: "oneprompt"
|
||||
baseline: "unet"
|
||||
mod: "one_adpt"
|
||||
image_size: 1024
|
||||
out_size: 256
|
||||
patch_size: 16
|
||||
dim: 256
|
||||
depth: 1
|
||||
heads: 16
|
||||
mlp_dim: 1024
|
||||
|
||||
# 训练配置
|
||||
training:
|
||||
epochs: 30000
|
||||
learning_rate: 0.0001
|
||||
optimizer: "adam"
|
||||
weight_decay: 0.0
|
||||
scheduler:
|
||||
name: "step"
|
||||
step_size: 10
|
||||
gamma: 0.5
|
||||
early_stopping_patience: 50
|
||||
gradient_clip: 1.0
|
||||
|
||||
# 验证配置
|
||||
validation:
|
||||
val_freq: 100
|
||||
vis_freq: 50
|
||||
|
||||
# 日志配置
|
||||
logging:
|
||||
log_dir: "logs/production"
|
||||
tensorboard: true
|
||||
save_best: true
|
||||
checkpoint_freq: 100
|
||||
log_level: "INFO"
|
||||
|
||||
# 分布式训练
|
||||
distributed:
|
||||
enabled: false
|
||||
gpu_ids: "0,1,2,3"
|
||||
@ -0,0 +1,31 @@
|
||||
# Documentation
|
||||
|
||||
This directory contains the project documentation.
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
docs/
|
||||
├── api/ # API reference documentation
|
||||
├── guides/ # User guides and tutorials
|
||||
├── examples/ # Example code and notebooks
|
||||
└── index.rst # Main documentation index
|
||||
```
|
||||
|
||||
## Building Documentation
|
||||
|
||||
To build the documentation locally:
|
||||
|
||||
```bash
|
||||
cd docs
|
||||
pip install sphinx sphinx-rtd-theme
|
||||
make html
|
||||
```
|
||||
|
||||
The built documentation will be available in `docs/_build/html/`.
|
||||
|
||||
## Documentation Style
|
||||
|
||||
- Use Google-style docstrings for all modules, classes, and functions
|
||||
- Include type hints in function signatures
|
||||
- Provide examples in docstrings where appropriate
|
||||
@ -0,0 +1,72 @@
|
||||
# Sphinx configuration file
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.abspath('../src'))
|
||||
|
||||
# Project information
|
||||
project = 'One-Prompt Medical Image Segmentation'
|
||||
copyright = '2024, One-Prompt Team'
|
||||
author = 'One-Prompt Team'
|
||||
version = '1.0.0'
|
||||
release = '1.0.0'
|
||||
|
||||
# Extensions
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.napoleon',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinx.ext.mathjax',
|
||||
'sphinx.ext.intersphinx',
|
||||
]
|
||||
|
||||
# Napoleon settings for Google-style docstrings
|
||||
napoleon_google_docstring = True
|
||||
napoleon_numpy_docstring = True
|
||||
napoleon_include_init_with_doc = True
|
||||
napoleon_include_private_with_doc = False
|
||||
napoleon_include_special_with_doc = True
|
||||
|
||||
# Mock imports for documentation building
|
||||
autodoc_mock_imports = [
|
||||
'torch',
|
||||
'torchvision',
|
||||
'numpy',
|
||||
'pandas',
|
||||
'monai',
|
||||
'PIL',
|
||||
'cv2',
|
||||
'skimage',
|
||||
'einops',
|
||||
'tensorboardX',
|
||||
'sklearn',
|
||||
'matplotlib',
|
||||
'seaborn',
|
||||
'tqdm',
|
||||
'dateutil',
|
||||
]
|
||||
|
||||
# Intersphinx mapping
|
||||
intersphinx_mapping = {
|
||||
'python': ('https://docs.python.org/3', None),
|
||||
'torch': ('https://pytorch.org/docs/stable/', None),
|
||||
'numpy': ('https://numpy.org/doc/stable/', None),
|
||||
}
|
||||
|
||||
# HTML theme
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_static_path = ['_static']
|
||||
|
||||
# Source file settings
|
||||
source_suffix = '.rst'
|
||||
master_doc = 'index'
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
|
||||
# Autodoc settings
|
||||
autodoc_default_options = {
|
||||
'members': True,
|
||||
'member-order': 'bysource',
|
||||
'special-members': '__init__',
|
||||
'undoc-members': True,
|
||||
'exclude-members': '__weakref__'
|
||||
}
|
||||
@ -0,0 +1,49 @@
|
||||
One-Prompt Medical Image Segmentation
|
||||
=====================================
|
||||
|
||||
Welcome to the documentation for One-Prompt Medical Image Segmentation.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Contents:
|
||||
|
||||
installation
|
||||
quickstart
|
||||
api/index
|
||||
examples/index
|
||||
|
||||
Overview
|
||||
--------
|
||||
|
||||
One-Prompt Medical Image Segmentation is a deep learning framework for
|
||||
medical image segmentation that uses a single prompt to segment various
|
||||
types of medical images.
|
||||
|
||||
Features
|
||||
--------
|
||||
|
||||
* **One-Prompt Learning**: Segment medical images using a single example
|
||||
* **Multi-Dataset Support**: Works with ISIC, REFUGE, Polyp datasets
|
||||
* **Flexible Architecture**: Based on SAM with custom modifications
|
||||
* **Easy Integration**: Simple API for training and inference
|
||||
|
||||
Quick Start
|
||||
-----------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from oneprompt_seg.core import get_network
|
||||
from oneprompt_seg.data import CombinedPolypDataset
|
||||
|
||||
# Load model
|
||||
model = get_network(args, 'oneprompt')
|
||||
|
||||
# Create dataset
|
||||
dataset = CombinedPolypDataset(args, data_path)
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
|
||||
* :ref:`genindex`
|
||||
* :ref:`modindex`
|
||||
* :ref:`search`
|
||||
@ -0,0 +1,151 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "oneprompt-seg"
|
||||
version = "1.0.0"
|
||||
description = "One-Prompt to Segment All Medical Images"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
authors = [
|
||||
{name = "One-Prompt Team", email = "oneprompt@example.com"}
|
||||
]
|
||||
keywords = [
|
||||
"medical imaging",
|
||||
"image segmentation",
|
||||
"deep learning",
|
||||
"pytorch",
|
||||
"computer vision"
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Intended Audience :: Healthcare Industry",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Scientific/Engineering :: Medical Science Apps.",
|
||||
]
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"torch>=1.10.0",
|
||||
"torchvision>=0.11.0",
|
||||
"numpy>=1.19.0",
|
||||
"pandas>=1.3.0",
|
||||
"Pillow>=8.0.0",
|
||||
"scikit-image>=0.18.0",
|
||||
"scikit-learn>=0.24.0",
|
||||
"monai>=0.8.0",
|
||||
"einops>=0.4.0",
|
||||
"tensorboardX>=2.4.0",
|
||||
"tqdm>=4.60.0",
|
||||
"python-dateutil>=2.8.0",
|
||||
"PyYAML>=5.4.0",
|
||||
"matplotlib>=3.3.0",
|
||||
"seaborn>=0.11.0",
|
||||
"opencv-python>=4.5.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=6.0.0",
|
||||
"pytest-cov>=2.12.0",
|
||||
"black>=21.0.0",
|
||||
"isort>=5.9.0",
|
||||
"flake8>=3.9.0",
|
||||
"mypy>=0.900",
|
||||
"pre-commit>=2.13.0",
|
||||
]
|
||||
docs = [
|
||||
"sphinx>=4.0.0",
|
||||
"sphinx-rtd-theme>=0.5.0",
|
||||
"sphinx-autodoc-typehints>=1.12.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/oneprompt/One-Prompt-Medical-Image-Segmentation"
|
||||
Documentation = "https://oneprompt-seg.readthedocs.io"
|
||||
Repository = "https://github.com/oneprompt/One-Prompt-Medical-Image-Segmentation"
|
||||
Issues = "https://github.com/oneprompt/One-Prompt-Medical-Image-Segmentation/issues"
|
||||
|
||||
[project.scripts]
|
||||
oneprompt-train = "scripts.train:main"
|
||||
oneprompt-val = "scripts.val:main"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
"*" = ["*.yaml", "*.yml", "*.json"]
|
||||
|
||||
# Black configuration
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ['py38', 'py39', 'py310', 'py311']
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
/(
|
||||
\.git
|
||||
| \.mypy_cache
|
||||
| \.pytest_cache
|
||||
| \.venv
|
||||
| build
|
||||
| dist
|
||||
| __pycache__
|
||||
)/
|
||||
'''
|
||||
|
||||
# isort configuration
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 88
|
||||
skip_gitignore = true
|
||||
known_first_party = ["oneprompt_seg"]
|
||||
known_third_party = ["torch", "torchvision", "numpy", "pandas", "monai"]
|
||||
|
||||
# pytest configuration
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
python_functions = ["test_*"]
|
||||
addopts = "-v --tb=short"
|
||||
filterwarnings = [
|
||||
"ignore::DeprecationWarning",
|
||||
"ignore::UserWarning",
|
||||
]
|
||||
|
||||
# mypy configuration
|
||||
[tool.mypy]
|
||||
python_version = "3.8"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
ignore_missing_imports = true
|
||||
exclude = [
|
||||
"build",
|
||||
"dist",
|
||||
"docs",
|
||||
]
|
||||
|
||||
# Coverage configuration
|
||||
[tool.coverage.run]
|
||||
source = ["src/oneprompt_seg"]
|
||||
branch = true
|
||||
omit = [
|
||||
"*/tests/*",
|
||||
"*/__pycache__/*",
|
||||
]
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = [
|
||||
"pragma: no cover",
|
||||
"def __repr__",
|
||||
"raise AssertionError",
|
||||
"raise NotImplementedError",
|
||||
"if __name__ == .__main__.:",
|
||||
]
|
||||
@ -0,0 +1,21 @@
|
||||
# Development dependencies for testing and code quality
|
||||
|
||||
# Testing
|
||||
pytest>=6.0.0
|
||||
pytest-cov>=2.12.0
|
||||
|
||||
# Code formatting
|
||||
black>=21.0.0
|
||||
isort>=5.9.0
|
||||
|
||||
# Linting
|
||||
flake8>=3.9.0
|
||||
mypy>=0.900
|
||||
|
||||
# Pre-commit hooks
|
||||
pre-commit>=2.13.0
|
||||
|
||||
# Documentation
|
||||
sphinx>=4.0.0
|
||||
sphinx-rtd-theme>=0.5.0
|
||||
sphinx-autodoc-typehints>=1.12.0
|
||||
@ -0,0 +1,29 @@
|
||||
# One-Prompt Medical Image Segmentation Dependencies
|
||||
|
||||
# Core dependencies
|
||||
torch>=1.10.0
|
||||
torchvision>=0.11.0
|
||||
numpy>=1.19.0
|
||||
pandas>=1.3.0
|
||||
Pillow>=8.0.0
|
||||
|
||||
# Image processing
|
||||
scikit-image>=0.18.0
|
||||
scikit-learn>=0.24.0
|
||||
opencv-python>=4.5.0
|
||||
|
||||
# Medical imaging
|
||||
monai>=0.8.0
|
||||
|
||||
# Deep learning utilities
|
||||
einops>=0.4.0
|
||||
tensorboardX>=2.4.0
|
||||
tqdm>=4.60.0
|
||||
|
||||
# Utilities
|
||||
python-dateutil>=2.8.0
|
||||
PyYAML>=5.4.0
|
||||
|
||||
# Visualization
|
||||
matplotlib>=3.3.0
|
||||
seaborn>=0.11.0
|
||||
@ -0,0 +1,6 @@
|
||||
"""
|
||||
Scripts package for One-Prompt Medical Image Segmentation.
|
||||
|
||||
This package contains executable scripts for training, evaluation,
|
||||
and utility operations.
|
||||
"""
|
||||
@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Validation/Evaluation script for One-Prompt Medical Image Segmentation.
|
||||
|
||||
This script provides evaluation functionality for trained models.
|
||||
|
||||
Usage:
|
||||
python scripts/val.py -net oneprompt -mod one_adpt -exp_name eval_exp \\
|
||||
-dataset polyp -data_path ./data/polyp -weights ./checkpoints/best.pth
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add project root to path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Local imports
|
||||
import cfg
|
||||
from conf import settings
|
||||
from dataset import ISIC2016, REFUGE, PolypDataset, CombinedPolypDataset
|
||||
from utils import (
|
||||
get_network,
|
||||
get_decath_loader,
|
||||
create_logger,
|
||||
set_log_dir,
|
||||
)
|
||||
import function
|
||||
|
||||
|
||||
def main():
|
||||
"""Main evaluation function."""
|
||||
# Parse arguments
|
||||
args = cfg.parse_args()
|
||||
|
||||
# Setup device
|
||||
gpu_device = torch.device('cuda', args.gpu_device)
|
||||
|
||||
# Build network
|
||||
net = get_network(
|
||||
args, args.net,
|
||||
use_gpu=args.gpu,
|
||||
gpu_device=gpu_device,
|
||||
distribution=args.distributed
|
||||
)
|
||||
|
||||
# Load pretrained model
|
||||
assert args.weights != 0, "Please specify model weights with -weights"
|
||||
print(f'=> resuming from {args.weights}')
|
||||
assert os.path.exists(args.weights)
|
||||
checkpoint_file = os.path.join(args.weights)
|
||||
assert os.path.exists(checkpoint_file)
|
||||
loc = f'cuda:{args.gpu_device}'
|
||||
checkpoint = torch.load(checkpoint_file, map_location=loc)
|
||||
start_epoch = checkpoint['epoch']
|
||||
best_tol = checkpoint['best_tol']
|
||||
|
||||
state_dict = checkpoint['state_dict']
|
||||
if args.distributed != 'none':
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
name = 'module.' + k
|
||||
new_state_dict[name] = v
|
||||
else:
|
||||
new_state_dict = state_dict
|
||||
|
||||
net.load_state_dict(new_state_dict)
|
||||
|
||||
# Setup logging
|
||||
args.path_helper = set_log_dir('logs', args.exp_name)
|
||||
logger = create_logger(args.path_helper['log_path'])
|
||||
logger.info(args)
|
||||
|
||||
# Setup data transforms
|
||||
transform_train = transforms.Compose([
|
||||
transforms.Resize((args.image_size, args.image_size)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
transform_train_seg = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Resize((args.image_size, args.image_size)),
|
||||
])
|
||||
|
||||
transform_test = transforms.Compose([
|
||||
transforms.Resize((args.image_size, args.image_size)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
transform_test_seg = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Resize((args.image_size, args.image_size)),
|
||||
])
|
||||
|
||||
# Load data based on dataset type
|
||||
if args.dataset == 'isic':
|
||||
isic_train_dataset = ISIC2016(
|
||||
args, args.data_path,
|
||||
transform=transform_train,
|
||||
transform_msk=transform_train_seg,
|
||||
mode='Training'
|
||||
)
|
||||
isic_test_dataset = ISIC2016(
|
||||
args, args.data_path,
|
||||
transform=transform_test,
|
||||
transform_msk=transform_test_seg,
|
||||
mode='Test'
|
||||
)
|
||||
|
||||
nice_train_loader = DataLoader(
|
||||
isic_train_dataset,
|
||||
batch_size=args.b,
|
||||
shuffle=True,
|
||||
num_workers=8,
|
||||
pin_memory=True
|
||||
)
|
||||
nice_test_loader = DataLoader(
|
||||
isic_test_dataset,
|
||||
batch_size=args.b,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
elif args.dataset == 'oneprompt':
|
||||
nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list = get_decath_loader(args)
|
||||
|
||||
elif args.dataset == 'REFUGE':
|
||||
refuge_train_dataset = REFUGE(
|
||||
args, args.data_path,
|
||||
transform=transform_train,
|
||||
transform_msk=transform_train_seg,
|
||||
mode='Training'
|
||||
)
|
||||
refuge_test_dataset = REFUGE(
|
||||
args, args.data_path,
|
||||
transform=transform_test,
|
||||
transform_msk=transform_test_seg,
|
||||
mode='Test'
|
||||
)
|
||||
|
||||
nice_train_loader = DataLoader(
|
||||
refuge_train_dataset,
|
||||
batch_size=args.b,
|
||||
shuffle=True,
|
||||
num_workers=8,
|
||||
pin_memory=True
|
||||
)
|
||||
nice_test_loader = DataLoader(
|
||||
refuge_test_dataset,
|
||||
batch_size=args.b,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
elif args.dataset == 'polyp':
|
||||
transform_test_seg = transforms.Compose([
|
||||
transforms.Resize((args.out_size, args.out_size)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
polyp_test_dataset = CombinedPolypDataset(
|
||||
args, args.data_path,
|
||||
transform=transform_test,
|
||||
transform_msk=transform_test_seg,
|
||||
mode='Test'
|
||||
)
|
||||
nice_test_loader = DataLoader(
|
||||
polyp_test_dataset,
|
||||
batch_size=args.b,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
# Run evaluation
|
||||
if args.mod == 'sam_adpt' or args.mod == 'one_adpt':
|
||||
net.eval()
|
||||
tol, (eiou, edice) = function.validation_one(args, nice_test_loader, 0, net)
|
||||
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {start_epoch}.')
|
||||
print(f'\nEvaluation Results:')
|
||||
print(f' Total Score: {tol}')
|
||||
print(f' IoU: {eiou}')
|
||||
print(f' Dice: {edice}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -0,0 +1,206 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Training script for One-Prompt Medical Image Segmentation.
|
||||
|
||||
This script provides the main entry point for training the One-Prompt
|
||||
segmentation model on various medical imaging datasets.
|
||||
|
||||
Usage:
|
||||
python scripts/train.py -net oneprompt -mod one_adpt -exp_name experiment1 \\
|
||||
-dataset polyp -data_path ./data/polyp
|
||||
|
||||
Example:
|
||||
python scripts/train.py \\
|
||||
-net oneprompt \\
|
||||
-mod one_adpt \\
|
||||
-exp_name polyp_training \\
|
||||
-dataset polyp \\
|
||||
-data_path /path/to/data
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
# Add project root to path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
# Local imports
|
||||
import cfg
|
||||
from conf import settings
|
||||
from dataset import CombinedPolypDataset
|
||||
from utils import (
|
||||
get_network,
|
||||
get_decath_loader,
|
||||
create_logger,
|
||||
set_log_dir,
|
||||
save_checkpoint,
|
||||
)
|
||||
import function
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
# Parse arguments
|
||||
args = cfg.parse_args()
|
||||
|
||||
# Setup device
|
||||
gpu_device = torch.device('cuda', args.gpu_device)
|
||||
|
||||
# Build network
|
||||
net = get_network(
|
||||
args, args.net,
|
||||
use_gpu=args.gpu,
|
||||
gpu_device=gpu_device,
|
||||
distribution=args.distributed
|
||||
)
|
||||
|
||||
# Setup optimizer and scheduler
|
||||
optimizer = optim.Adam(
|
||||
net.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-08,
|
||||
weight_decay=0,
|
||||
amsgrad=False
|
||||
)
|
||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
|
||||
|
||||
# Load pretrained model if specified
|
||||
start_epoch = 0
|
||||
best_tol = 1e4
|
||||
if args.weights != 0:
|
||||
print(f'=> resuming from {args.weights}')
|
||||
assert os.path.exists(args.weights)
|
||||
checkpoint_file = os.path.join(args.weights)
|
||||
assert os.path.exists(checkpoint_file)
|
||||
loc = f'cuda:{args.gpu_device}'
|
||||
checkpoint = torch.load(checkpoint_file, map_location=loc)
|
||||
start_epoch = checkpoint['epoch']
|
||||
best_tol = checkpoint['best_tol']
|
||||
|
||||
net.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
args.path_helper = checkpoint['path_helper']
|
||||
logger = create_logger(args.path_helper['log_path'])
|
||||
print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
|
||||
|
||||
# Setup logging
|
||||
args.path_helper = set_log_dir('logs', args.exp_name)
|
||||
logger = create_logger(args.path_helper['log_path'])
|
||||
logger.info(args)
|
||||
|
||||
# Load data
|
||||
if args.dataset == 'oneprompt':
|
||||
nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list = get_decath_loader(args)
|
||||
elif args.dataset == 'polyp':
|
||||
# Polyp dataset
|
||||
transform_train = transforms.Compose([
|
||||
transforms.Resize((args.image_size, args.image_size)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
transform_train_seg = transforms.Compose([
|
||||
transforms.Resize((args.out_size, args.out_size)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
transform_test = transforms.Compose([
|
||||
transforms.Resize((args.image_size, args.image_size)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
transform_test_seg = transforms.Compose([
|
||||
transforms.Resize((args.out_size, args.out_size)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
train_dataset = CombinedPolypDataset(
|
||||
args, args.data_path,
|
||||
transform=transform_train,
|
||||
transform_msk=transform_train_seg,
|
||||
mode='Training'
|
||||
)
|
||||
test_dataset = CombinedPolypDataset(
|
||||
args, args.data_path,
|
||||
transform=transform_test,
|
||||
transform_msk=transform_test_seg,
|
||||
mode='Test'
|
||||
)
|
||||
|
||||
nice_train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.b,
|
||||
shuffle=True,
|
||||
num_workers=args.w,
|
||||
pin_memory=True
|
||||
)
|
||||
nice_test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=args.w,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
# Setup checkpoint path and tensorboard
|
||||
checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW)
|
||||
if not os.path.exists(settings.LOG_DIR):
|
||||
os.mkdir(settings.LOG_DIR)
|
||||
writer = SummaryWriter(
|
||||
log_dir=os.path.join(settings.LOG_DIR, args.net, settings.TIME_NOW)
|
||||
)
|
||||
|
||||
if not os.path.exists(checkpoint_path):
|
||||
os.makedirs(checkpoint_path)
|
||||
checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth')
|
||||
|
||||
# Training loop
|
||||
best_acc = 0.0
|
||||
for epoch in range(settings.EPOCH):
|
||||
net.train()
|
||||
time_start = time.time()
|
||||
|
||||
loss = function.train_one(
|
||||
args, net, optimizer, nice_train_loader, epoch, writer, vis=args.vis
|
||||
)
|
||||
logger.info(f'Train loss: {loss}|| @ epoch {epoch}.')
|
||||
time_end = time.time()
|
||||
print(f'time_for_training {time_end - time_start}')
|
||||
|
||||
net.eval()
|
||||
if epoch and epoch % args.val_freq == 0 or epoch == settings.EPOCH - 1:
|
||||
tol, (eiou, edice) = function.validation_one(
|
||||
args, nice_test_loader, epoch, net, writer
|
||||
)
|
||||
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.')
|
||||
|
||||
if args.distributed != 'none':
|
||||
sd = net.module.state_dict()
|
||||
else:
|
||||
sd = net.state_dict()
|
||||
|
||||
if tol < best_tol:
|
||||
best_tol = tol
|
||||
is_best = True
|
||||
|
||||
save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'model': args.net,
|
||||
'state_dict': sd,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'best_tol': best_tol,
|
||||
'path_helper': args.path_helper,
|
||||
}, is_best, args.path_helper['ckpt_path'], filename="best_checkpoint")
|
||||
else:
|
||||
is_best = False
|
||||
|
||||
writer.close()
|
||||
logger.info("Training completed!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -0,0 +1,52 @@
|
||||
"""
|
||||
One-Prompt Medical Image Segmentation
|
||||
|
||||
A deep learning framework for medical image segmentation using one-prompt learning.
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "One-Prompt Team"
|
||||
|
||||
# Lazy imports to avoid dependency issues during installation
|
||||
def __getattr__(name):
|
||||
"""Lazy import of submodules."""
|
||||
if name == "get_network":
|
||||
from .core.network import get_network
|
||||
return get_network
|
||||
elif name in ("train_one", "validation_one"):
|
||||
from .core import function
|
||||
return getattr(function, name)
|
||||
elif name in ("ISIC2016", "REFUGE", "PolypDataset", "CombinedPolypDataset"):
|
||||
from .data import datasets
|
||||
return getattr(datasets, name)
|
||||
elif name in ("create_logger", "set_log_dir", "save_checkpoint"):
|
||||
from .utils import logging
|
||||
return getattr(logging, name)
|
||||
elif name in ("eval_seg", "vis_image", "DiceMetric"):
|
||||
if name == "DiceMetric":
|
||||
from .utils.metrics import DiceMetric
|
||||
return DiceMetric
|
||||
elif name == "eval_seg":
|
||||
from .utils.metrics import eval_seg
|
||||
return eval_seg
|
||||
elif name == "vis_image":
|
||||
from .utils.visualization import vis_image
|
||||
return vis_image
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_network",
|
||||
"train_one",
|
||||
"validation_one",
|
||||
"ISIC2016",
|
||||
"REFUGE",
|
||||
"PolypDataset",
|
||||
"CombinedPolypDataset",
|
||||
"create_logger",
|
||||
"set_log_dir",
|
||||
"save_checkpoint",
|
||||
"eval_seg",
|
||||
"vis_image",
|
||||
"DiceMetric",
|
||||
]
|
||||
@ -0,0 +1,17 @@
|
||||
"""
|
||||
Core module for One-Prompt Medical Image Segmentation.
|
||||
|
||||
This module contains the core functionality including:
|
||||
- Network builders
|
||||
- Model definitions
|
||||
- Training and validation functions
|
||||
"""
|
||||
|
||||
from .network import get_network
|
||||
from .function import train_one, validation_one
|
||||
|
||||
__all__ = [
|
||||
"get_network",
|
||||
"train_one",
|
||||
"validation_one",
|
||||
]
|
||||
@ -0,0 +1,301 @@
|
||||
"""
|
||||
Training and validation functions.
|
||||
|
||||
This module provides the core training and validation logic
|
||||
for medical image segmentation models.
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
from monai.losses import DiceCELoss
|
||||
|
||||
from ..utils.metrics import eval_seg, DiceMetric
|
||||
from ..utils.visualization import vis_image
|
||||
from ..utils.data_utils import generate_click_prompt
|
||||
|
||||
|
||||
# Global variables for training
|
||||
_args = None
|
||||
_gpu_device = None
|
||||
_criterion_g = None
|
||||
_scaler = None
|
||||
|
||||
|
||||
def _init_training_globals(args):
|
||||
"""Initialize global variables for training."""
|
||||
global _args, _gpu_device, _criterion_g, _scaler
|
||||
_args = args
|
||||
_gpu_device = torch.device('cuda', args.gpu_device)
|
||||
pos_weight = torch.ones([1]).cuda(device=_gpu_device) * 2
|
||||
_criterion_g = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
||||
_scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
|
||||
def train_one(args, net: nn.Module, optimizer, train_loader,
|
||||
epoch: int, writer=None, schedulers=None, vis: int = 50) -> float:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
Args:
|
||||
args: Training arguments.
|
||||
net: The neural network model.
|
||||
optimizer: The optimizer for training.
|
||||
train_loader: DataLoader for training data.
|
||||
epoch: Current epoch number.
|
||||
writer: TensorBoard writer for logging.
|
||||
schedulers: Learning rate schedulers.
|
||||
vis: Visualization frequency (0 to disable).
|
||||
|
||||
Returns:
|
||||
The final batch loss value.
|
||||
"""
|
||||
_init_training_globals(args)
|
||||
|
||||
epoch_loss = 0
|
||||
ind = 0
|
||||
net.train()
|
||||
optimizer.zero_grad()
|
||||
|
||||
model = net.module if hasattr(net, 'module') else net
|
||||
gpu_device = torch.device('cuda:' + str(args.gpu_device))
|
||||
|
||||
if args.thd:
|
||||
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
|
||||
else:
|
||||
lossfunc = _criterion_g
|
||||
|
||||
with tqdm(total=len(train_loader), desc=f'Epoch {epoch}', unit='img') as pbar:
|
||||
for pack in train_loader:
|
||||
current_b = pack['image'].size(0)
|
||||
|
||||
if ind == 0:
|
||||
tmp_img = pack['image'].to(dtype=torch.float32, device=gpu_device)[0, :, :, :].unsqueeze(0).repeat(current_b, 1, 1, 1)
|
||||
tmp_mask = pack['label'].to(dtype=torch.float32, device=gpu_device)[0, :, :, :].unsqueeze(0).repeat(current_b, 1, 1, 1)
|
||||
if 'pt' not in pack:
|
||||
tmp_img, pt, tmp_mask = generate_click_prompt(tmp_img, tmp_mask)
|
||||
else:
|
||||
pt = pack['pt']
|
||||
point_labels = pack['p_label']
|
||||
|
||||
if point_labels[0] != -1:
|
||||
point_coords = pt
|
||||
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=gpu_device)
|
||||
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=gpu_device)
|
||||
coords_torch = coords_torch[0:1, :].repeat(current_b, 1)
|
||||
labels_torch = labels_torch[0:1].repeat(current_b)
|
||||
coords_torch, labels_torch = coords_torch[:, None, :], labels_torch[:, None]
|
||||
tmp_pt = (coords_torch, labels_torch)
|
||||
else:
|
||||
if tmp_img.size(0) != current_b:
|
||||
tmp_img = tmp_img[0:1].repeat(current_b, 1, 1, 1)
|
||||
tmp_mask = tmp_mask[0:1].repeat(current_b, 1, 1, 1)
|
||||
if 'tmp_pt' in dir():
|
||||
coords_torch = tmp_pt[0][0:1].repeat(current_b, 1, 1)
|
||||
labels_torch = tmp_pt[1][0:1].repeat(current_b, 1)
|
||||
tmp_pt = (coords_torch, labels_torch)
|
||||
|
||||
imgs = pack['image'].to(dtype=torch.float32, device=gpu_device)
|
||||
masks = pack['label'].to(dtype=torch.float32, device=gpu_device)
|
||||
name = pack['image_meta_dict']['filename_or_obj']
|
||||
|
||||
if 'pt' in pack:
|
||||
pt = pack['pt']
|
||||
point_labels = pack['p_label']
|
||||
if point_labels[0] != -1:
|
||||
point_coords = pt
|
||||
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=gpu_device)
|
||||
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=gpu_device)
|
||||
coords_torch, labels_torch = coords_torch[:, None, :], labels_torch[:, None]
|
||||
pt = (coords_torch, labels_torch)
|
||||
|
||||
if args.thd:
|
||||
pt = rearrange(pt, 'b n d -> (b d) n')
|
||||
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
|
||||
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
|
||||
imgs = imgs.repeat(1, 3, 1, 1)
|
||||
point_labels = torch.ones(imgs.size(0))
|
||||
imgs = torchvision.transforms.Resize((args.image_size, args.image_size))(imgs)
|
||||
masks = torchvision.transforms.Resize((args.out_size, args.out_size))(masks)
|
||||
|
||||
showp = pt
|
||||
mask_type = torch.float32
|
||||
ind += 1
|
||||
b_size, c, w, h = imgs.size()
|
||||
longsize = w if w >= h else h
|
||||
|
||||
imgs = imgs.to(dtype=mask_type, device=gpu_device)
|
||||
|
||||
with torch.amp.autocast('cuda'):
|
||||
with torch.no_grad():
|
||||
imge, skips = model.image_encoder(imgs)
|
||||
timge, tskips = model.image_encoder(tmp_img)
|
||||
|
||||
p1, p2, se, de = model.prompt_encoder(
|
||||
points=pt,
|
||||
boxes=None,
|
||||
doodles=None,
|
||||
masks=None,
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
pred, _ = model.mask_decoder(
|
||||
skips_raw=skips,
|
||||
skips_tmp=tskips,
|
||||
raw_emb=imge,
|
||||
tmp_emb=timge,
|
||||
pt1=p1,
|
||||
pt2=p2,
|
||||
image_pe=model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=se,
|
||||
dense_prompt_embeddings=de,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
if pred.shape[-2:] != masks.shape[-2:]:
|
||||
pred = F.interpolate(pred, size=masks.shape[-2:], mode='bilinear', align_corners=False)
|
||||
|
||||
loss = lossfunc(pred, masks)
|
||||
|
||||
if torch.isnan(loss) or torch.isinf(loss):
|
||||
optimizer.zero_grad()
|
||||
pbar.set_postfix(**{'loss (batch)': 'nan/inf skipped'})
|
||||
pbar.update()
|
||||
ind += 1
|
||||
continue
|
||||
|
||||
pbar.set_postfix(**{'loss (batch)': loss.item()})
|
||||
epoch_loss += loss.item()
|
||||
|
||||
_scaler.scale(loss).backward()
|
||||
_scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
|
||||
_scaler.step(optimizer)
|
||||
_scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if vis:
|
||||
if ind % vis == 0:
|
||||
namecat = 'Train'
|
||||
for na in name:
|
||||
namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
|
||||
vis_image(imgs, pred, masks,
|
||||
os.path.join(args.path_helper['sample_path'],
|
||||
namecat + 'epoch+' + str(epoch) + '.jpg'),
|
||||
reverse=False)
|
||||
|
||||
pbar.update()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def validation_one(args, val_loader, epoch: int, net: nn.Module,
|
||||
writer=None, clean_dir: bool = True):
|
||||
"""Validate the model.
|
||||
|
||||
Args:
|
||||
args: Validation arguments.
|
||||
val_loader: DataLoader for validation data.
|
||||
epoch: Current epoch number.
|
||||
net: The neural network model.
|
||||
writer: TensorBoard writer for logging.
|
||||
clean_dir: Whether to clean the output directory.
|
||||
|
||||
Returns:
|
||||
Tuple of (total_loss, (iou, dice)).
|
||||
"""
|
||||
_init_training_globals(args)
|
||||
|
||||
net.eval()
|
||||
model = net.module if hasattr(net, 'module') else net
|
||||
|
||||
mask_type = torch.float32
|
||||
n_val = len(val_loader)
|
||||
mix_res = (0, 0, 0, 0)
|
||||
tot = 0
|
||||
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
|
||||
gpu_device = torch.device('cuda:' + str(args.gpu_device))
|
||||
|
||||
if args.thd:
|
||||
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
|
||||
else:
|
||||
lossfunc = _criterion_g
|
||||
|
||||
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
|
||||
for ind, pack in enumerate(val_loader):
|
||||
if ind == 0:
|
||||
tmp_img = pack['image'].to(dtype=torch.float32, device=gpu_device)[0, :, :, :].unsqueeze(0).repeat(args.b, 1, 1, 1)
|
||||
tmp_mask = pack['label'].to(dtype=torch.float32, device=gpu_device)[0, :, :, :].unsqueeze(0).repeat(args.b, 1, 1, 1)
|
||||
if 'pt' not in pack:
|
||||
tmp_img, pt, tmp_mask = generate_click_prompt(tmp_img, tmp_mask)
|
||||
else:
|
||||
pt = pack['pt']
|
||||
point_labels = pack['p_label']
|
||||
|
||||
if point_labels[0] != -1:
|
||||
point_coords = pt
|
||||
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=gpu_device)
|
||||
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=gpu_device)
|
||||
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
|
||||
pt = (coords_torch, labels_torch)
|
||||
|
||||
imgs = pack['image'].to(dtype=torch.float32, device=gpu_device)
|
||||
masks = pack['label'].to(dtype=torch.float32, device=gpu_device)
|
||||
name = pack['image_meta_dict']['filename_or_obj']
|
||||
|
||||
showp = pt
|
||||
mask_type = torch.float32
|
||||
ind += 1
|
||||
b_size, c, w, h = imgs.size()
|
||||
longsize = w if w >= h else h
|
||||
|
||||
imgs = imgs.to(dtype=mask_type, device=gpu_device)
|
||||
|
||||
with torch.no_grad():
|
||||
imge, skips = model.image_encoder(imgs)
|
||||
timge, tskips = model.image_encoder(tmp_img)
|
||||
|
||||
p1, p2, se, de = model.prompt_encoder(
|
||||
points=pt,
|
||||
boxes=None,
|
||||
doodles=None,
|
||||
masks=None,
|
||||
)
|
||||
pred, _ = model.mask_decoder(
|
||||
skips_raw=skips,
|
||||
skips_tmp=tskips,
|
||||
raw_emb=imge,
|
||||
tmp_emb=timge,
|
||||
pt1=p1,
|
||||
pt2=p2,
|
||||
image_pe=model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=se,
|
||||
dense_prompt_embeddings=de,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
if pred.shape[-2:] != masks.shape[-2:]:
|
||||
pred = F.interpolate(pred, size=masks.shape[-2:], mode='bilinear', align_corners=False)
|
||||
|
||||
tot += lossfunc(pred, masks)
|
||||
|
||||
if args.vis and ind % args.vis == 0:
|
||||
namecat = 'Test'
|
||||
for na in name:
|
||||
img_name = na.split('/')[-1].split('.')[0]
|
||||
namecat = namecat + img_name + '+'
|
||||
vis_image(imgs, pred, masks,
|
||||
os.path.join(args.path_helper['sample_path'],
|
||||
namecat + 'epoch+' + str(epoch) + '.jpg'),
|
||||
reverse=False)
|
||||
|
||||
temp = eval_seg(pred, masks, threshold)
|
||||
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
|
||||
|
||||
pbar.update()
|
||||
|
||||
return tot / n_val, tuple([a / n_val for a in mix_res])
|
||||
@ -0,0 +1,54 @@
|
||||
"""
|
||||
Network builder module.
|
||||
|
||||
This module provides functions to build and initialize neural networks
|
||||
for medical image segmentation.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def get_network(args, net: str, use_gpu: bool = True,
|
||||
gpu_device: torch.device = None,
|
||||
distribution: str = 'none') -> nn.Module:
|
||||
"""Return the specified network architecture.
|
||||
|
||||
Args:
|
||||
args: Arguments containing model configuration.
|
||||
net: Network type identifier (e.g., 'oneprompt').
|
||||
use_gpu: Whether to use GPU acceleration.
|
||||
gpu_device: The GPU device to use.
|
||||
distribution: Distributed training configuration.
|
||||
- 'none': No distribution
|
||||
- comma-separated GPU IDs for DataParallel
|
||||
|
||||
Returns:
|
||||
The initialized neural network model.
|
||||
|
||||
Raises:
|
||||
SystemExit: If the network type is not supported.
|
||||
"""
|
||||
if gpu_device is None:
|
||||
gpu_device = torch.device('cuda', args.gpu_device)
|
||||
|
||||
if net == 'oneprompt':
|
||||
from models.oneprompt import OnePredictor, one_model_registry
|
||||
from models.oneprompt.utils.transforms import ResizeLongestSide
|
||||
model = one_model_registry[args.baseline](args).to(gpu_device)
|
||||
else:
|
||||
print('the network name you have entered is not supported yet')
|
||||
sys.exit()
|
||||
|
||||
if use_gpu:
|
||||
if distribution != 'none':
|
||||
model = torch.nn.DataParallel(
|
||||
model,
|
||||
device_ids=[int(id) for id in distribution.split(',')]
|
||||
)
|
||||
model = model.to(device=gpu_device)
|
||||
else:
|
||||
model = model.to(device=gpu_device)
|
||||
|
||||
return model
|
||||
@ -0,0 +1,16 @@
|
||||
"""
|
||||
Data module for One-Prompt Medical Image Segmentation.
|
||||
|
||||
This module contains dataset classes for various medical imaging datasets.
|
||||
"""
|
||||
|
||||
from .datasets import ISIC2016, REFUGE, PolypDataset, CombinedPolypDataset
|
||||
from .loader import get_decath_loader
|
||||
|
||||
__all__ = [
|
||||
"ISIC2016",
|
||||
"REFUGE",
|
||||
"PolypDataset",
|
||||
"CombinedPolypDataset",
|
||||
"get_decath_loader",
|
||||
]
|
||||
@ -0,0 +1,357 @@
|
||||
"""
|
||||
Dataset classes for medical image segmentation.
|
||||
|
||||
This module provides dataset implementations for various medical imaging datasets:
|
||||
- ISIC2016: Skin lesion segmentation
|
||||
- REFUGE: Optic disc/cup segmentation
|
||||
- PolypDataset: Polyp segmentation
|
||||
- CombinedPolypDataset: Combined polyp datasets
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from PIL import Image
|
||||
import torch.nn.functional as F
|
||||
import pandas as pd
|
||||
|
||||
from ..utils.data_utils import random_click
|
||||
|
||||
|
||||
class ISIC2016(Dataset):
|
||||
"""ISIC 2016 skin lesion segmentation dataset.
|
||||
|
||||
Args:
|
||||
args: Arguments containing image_size and other configurations.
|
||||
data_path: Path to the dataset directory.
|
||||
transform: Transform to apply to images.
|
||||
transform_msk: Transform to apply to masks.
|
||||
mode: 'Training' or 'Test'.
|
||||
prompt: Prompt type ('click').
|
||||
plane: Whether to use plane mode.
|
||||
"""
|
||||
|
||||
def __init__(self, args, data_path: str, transform=None,
|
||||
transform_msk=None, mode: str = 'Training',
|
||||
prompt: str = 'click', plane: bool = False):
|
||||
df = pd.read_csv(
|
||||
os.path.join(data_path, 'ISBI2016_ISIC_Part3B_' + mode + '_GroundTruth.csv'),
|
||||
encoding='gbk'
|
||||
)
|
||||
self.name_list = df.iloc[:, 1].tolist()
|
||||
self.label_list = df.iloc[:, 2].tolist()
|
||||
self.data_path = data_path
|
||||
self.mode = mode
|
||||
self.prompt = prompt
|
||||
self.img_size = args.image_size
|
||||
self.transform = transform
|
||||
self.transform_msk = transform_msk
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.name_list)
|
||||
|
||||
def __getitem__(self, index: int) -> dict:
|
||||
inout = 1
|
||||
point_label = 1
|
||||
|
||||
name = self.name_list[index]
|
||||
img_path = os.path.join(self.data_path, name)
|
||||
|
||||
mask_name = self.label_list[index]
|
||||
msk_path = os.path.join(self.data_path, mask_name)
|
||||
|
||||
img = Image.open(img_path).convert('RGB')
|
||||
mask = Image.open(msk_path).convert('L')
|
||||
|
||||
newsize = (self.img_size, self.img_size)
|
||||
mask = mask.resize(newsize)
|
||||
|
||||
if self.prompt == 'click':
|
||||
pt = random_click(np.array(mask) / 255, point_label, inout)
|
||||
|
||||
if self.transform:
|
||||
state = torch.get_rng_state()
|
||||
img = self.transform(img)
|
||||
torch.set_rng_state(state)
|
||||
|
||||
if self.transform_msk:
|
||||
mask = self.transform_msk(mask)
|
||||
|
||||
name = name.split('/')[-1].split(".jpg")[0]
|
||||
image_meta_dict = {'filename_or_obj': name}
|
||||
|
||||
return {
|
||||
'image': img,
|
||||
'label': mask,
|
||||
'p_label': point_label,
|
||||
'pt': pt,
|
||||
'image_meta_dict': image_meta_dict,
|
||||
}
|
||||
|
||||
|
||||
class PolypDataset(Dataset):
|
||||
"""Polyp segmentation dataset.
|
||||
|
||||
Supports CVC-300, CVC-ClinicDB, CVC-ColonDB, ETIS-LaribPolypDB, Kvasir datasets.
|
||||
|
||||
Directory structure:
|
||||
data_path/
|
||||
images/
|
||||
xxx.png
|
||||
masks/
|
||||
xxx.png
|
||||
|
||||
Args:
|
||||
args: Arguments containing image_size and out_size.
|
||||
data_path: Path to the dataset directory.
|
||||
transform: Transform to apply to images.
|
||||
transform_msk: Transform to apply to masks.
|
||||
mode: 'Training' or 'Test'.
|
||||
prompt: Prompt type ('click').
|
||||
plane: Whether to use plane mode.
|
||||
"""
|
||||
|
||||
def __init__(self, args, data_path: str, transform=None,
|
||||
transform_msk=None, mode: str = 'Training',
|
||||
prompt: str = 'click', plane: bool = False):
|
||||
self.data_path = data_path
|
||||
self.mode = mode
|
||||
self.prompt = prompt
|
||||
self.img_size = args.image_size
|
||||
self.out_size = args.out_size
|
||||
self.transform = transform
|
||||
self.transform_msk = transform_msk
|
||||
|
||||
img_dir = os.path.join(data_path, 'images')
|
||||
self.img_list = sorted([
|
||||
f for f in os.listdir(img_dir)
|
||||
if f.endswith(('.png', '.jpg', '.jpeg'))
|
||||
])
|
||||
|
||||
split_idx = int(len(self.img_list) * 0.8)
|
||||
if mode == 'Training':
|
||||
self.img_list = self.img_list[:split_idx]
|
||||
else:
|
||||
self.img_list = self.img_list[split_idx:]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.img_list)
|
||||
|
||||
def __getitem__(self, index: int) -> dict:
|
||||
point_label = 1
|
||||
inout = 1
|
||||
|
||||
img_name = self.img_list[index]
|
||||
img_path = os.path.join(self.data_path, 'images', img_name)
|
||||
mask_path = os.path.join(self.data_path, 'masks', img_name)
|
||||
|
||||
img = Image.open(img_path).convert('RGB')
|
||||
mask = Image.open(mask_path).convert('L')
|
||||
|
||||
newsize = (self.img_size, self.img_size)
|
||||
mask_resized = mask.resize(newsize)
|
||||
|
||||
if self.prompt == 'click':
|
||||
pt = random_click(np.array(mask_resized) / 255, point_label, inout)
|
||||
|
||||
if self.transform:
|
||||
state = torch.get_rng_state()
|
||||
img = self.transform(img)
|
||||
torch.set_rng_state(state)
|
||||
|
||||
if self.transform_msk:
|
||||
mask = self.transform_msk(mask)
|
||||
|
||||
name = img_name.split('.')[0]
|
||||
image_meta_dict = {'filename_or_obj': name}
|
||||
|
||||
return {
|
||||
'image': img,
|
||||
'label': mask,
|
||||
'p_label': point_label,
|
||||
'pt': pt,
|
||||
'image_meta_dict': image_meta_dict,
|
||||
}
|
||||
|
||||
|
||||
class CombinedPolypDataset(Dataset):
|
||||
"""Combined polyp dataset for training on multiple datasets.
|
||||
|
||||
Args:
|
||||
args: Arguments containing image_size and out_size.
|
||||
data_path: Path to the root directory containing dataset folders.
|
||||
transform: Transform to apply to images.
|
||||
transform_msk: Transform to apply to masks.
|
||||
mode: 'Training' or 'Test'.
|
||||
prompt: Prompt type ('click').
|
||||
plane: Whether to use plane mode.
|
||||
"""
|
||||
|
||||
def __init__(self, args, data_path: str, transform=None,
|
||||
transform_msk=None, mode: str = 'Training',
|
||||
prompt: str = 'click', plane: bool = False):
|
||||
self.datasets = []
|
||||
|
||||
dataset_dirs = [
|
||||
'CVC-300', 'CVC-ClinicDB', 'CVC-ColonDB',
|
||||
'ETIS-LaribPolypDB', 'Kvasir'
|
||||
]
|
||||
|
||||
for dataset_dir in dataset_dirs:
|
||||
full_path = os.path.join(data_path, dataset_dir)
|
||||
if os.path.exists(full_path):
|
||||
ds = PolypDataset(
|
||||
args, full_path, transform, transform_msk,
|
||||
mode, prompt, plane
|
||||
)
|
||||
if len(ds) > 0:
|
||||
self.datasets.append(ds)
|
||||
print(f"Loaded {dataset_dir}: {len(ds)} samples ({mode})")
|
||||
|
||||
self.cumulative_sizes = []
|
||||
total = 0
|
||||
for ds in self.datasets:
|
||||
total += len(ds)
|
||||
self.cumulative_sizes.append(total)
|
||||
|
||||
print(f"Total {mode} samples: {total}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.cumulative_sizes[-1] if self.cumulative_sizes else 0
|
||||
|
||||
def __getitem__(self, index: int) -> dict:
|
||||
for i, cumsize in enumerate(self.cumulative_sizes):
|
||||
if index < cumsize:
|
||||
if i == 0:
|
||||
return self.datasets[i][index]
|
||||
else:
|
||||
return self.datasets[i][index - self.cumulative_sizes[i - 1]]
|
||||
raise IndexError("Index out of range")
|
||||
|
||||
|
||||
class REFUGE(Dataset):
|
||||
"""REFUGE optic disc/cup segmentation dataset with multi-rater annotations.
|
||||
|
||||
Args:
|
||||
args: Arguments containing image_size and out_size.
|
||||
data_path: Path to the dataset directory.
|
||||
transform: Transform to apply to images.
|
||||
transform_msk: Transform to apply to masks.
|
||||
mode: 'Training' or 'Test'.
|
||||
prompt: Prompt type ('click').
|
||||
plane: Whether to use plane mode.
|
||||
"""
|
||||
|
||||
def __init__(self, args, data_path: str, transform=None,
|
||||
transform_msk=None, mode: str = 'Training',
|
||||
prompt: str = 'click', plane: bool = False):
|
||||
self.data_path = data_path
|
||||
self.subfolders = [
|
||||
f.path for f in os.scandir(os.path.join(data_path, mode + '-400'))
|
||||
if f.is_dir()
|
||||
]
|
||||
self.mode = mode
|
||||
self.prompt = prompt
|
||||
self.img_size = args.image_size
|
||||
self.mask_size = args.out_size
|
||||
self.transform = transform
|
||||
self.transform_msk = transform_msk
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.subfolders)
|
||||
|
||||
def __getitem__(self, index: int) -> dict:
|
||||
inout = 1
|
||||
point_label = 1
|
||||
|
||||
subfolder = self.subfolders[index]
|
||||
name = subfolder.split('/')[-1]
|
||||
|
||||
img_path = os.path.join(subfolder, name + '.jpg')
|
||||
multi_rater_cup_path = [
|
||||
os.path.join(subfolder, name + '_seg_cup_' + str(i) + '.png')
|
||||
for i in range(1, 8)
|
||||
]
|
||||
multi_rater_disc_path = [
|
||||
os.path.join(subfolder, name + '_seg_disc_' + str(i) + '.png')
|
||||
for i in range(1, 8)
|
||||
]
|
||||
|
||||
img = Image.open(img_path).convert('RGB')
|
||||
multi_rater_cup = [Image.open(path).convert('L') for path in multi_rater_cup_path]
|
||||
multi_rater_disc = [Image.open(path).convert('L') for path in multi_rater_disc_path]
|
||||
|
||||
newsize = (self.img_size, self.img_size)
|
||||
multi_rater_cup_np = [
|
||||
np.array(single_rater.resize(newsize))
|
||||
for single_rater in multi_rater_cup
|
||||
]
|
||||
multi_rater_disc_np = [
|
||||
np.array(single_rater.resize(newsize))
|
||||
for single_rater in multi_rater_disc
|
||||
]
|
||||
|
||||
if self.prompt == 'click':
|
||||
pt_cup = random_click(
|
||||
np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255,
|
||||
point_label, inout
|
||||
)
|
||||
pt_disc = random_click(
|
||||
np.array(np.mean(np.stack(multi_rater_disc_np), axis=0)) / 255,
|
||||
point_label, inout
|
||||
)
|
||||
|
||||
if self.transform:
|
||||
state = torch.get_rng_state()
|
||||
img = self.transform(img)
|
||||
|
||||
multi_rater_cup = [
|
||||
torch.as_tensor(
|
||||
(self.transform(single_rater) > 0.5).float(),
|
||||
dtype=torch.float32
|
||||
)
|
||||
for single_rater in multi_rater_cup
|
||||
]
|
||||
multi_rater_cup = torch.stack(multi_rater_cup, dim=0)
|
||||
mask_cup = F.interpolate(
|
||||
multi_rater_cup,
|
||||
size=(self.mask_size, self.mask_size),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
).mean(dim=0)
|
||||
|
||||
multi_rater_disc = [
|
||||
torch.as_tensor(
|
||||
(self.transform(single_rater) > 0.5).float(),
|
||||
dtype=torch.float32
|
||||
)
|
||||
for single_rater in multi_rater_disc
|
||||
]
|
||||
multi_rater_disc = torch.stack(multi_rater_disc, dim=0)
|
||||
mask_disc = F.interpolate(
|
||||
multi_rater_disc,
|
||||
size=(self.mask_size, self.mask_size),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
).mean(dim=0)
|
||||
|
||||
torch.set_rng_state(state)
|
||||
|
||||
image_meta_dict = {'filename_or_obj': name}
|
||||
|
||||
return {
|
||||
'image': img,
|
||||
'multi_rater_cup': multi_rater_cup,
|
||||
'multi_rater_disc': multi_rater_disc,
|
||||
'mask_cup': mask_cup,
|
||||
'mask_disc': mask_disc,
|
||||
'label': mask_disc,
|
||||
'p_label': point_label,
|
||||
'pt_cup': pt_cup,
|
||||
'pt_disc': pt_disc,
|
||||
'pt': pt_disc,
|
||||
'selected_rater': torch.tensor(np.arange(7)),
|
||||
'image_meta_dict': image_meta_dict,
|
||||
}
|
||||
@ -0,0 +1,161 @@
|
||||
"""
|
||||
Data loader utilities for medical imaging datasets.
|
||||
|
||||
This module provides data loading functions for various medical imaging formats,
|
||||
particularly for MONAI-based volumetric data.
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
from monai.transforms import (
|
||||
Compose,
|
||||
CropForegroundd,
|
||||
LoadImaged,
|
||||
Orientationd,
|
||||
RandFlipd,
|
||||
RandCropByPosNegLabeld,
|
||||
RandShiftIntensityd,
|
||||
ScaleIntensityRanged,
|
||||
Spacingd,
|
||||
RandRotate90d,
|
||||
EnsureTyped,
|
||||
)
|
||||
from monai.data import (
|
||||
ThreadDataLoader,
|
||||
CacheDataset,
|
||||
load_decathlon_datalist,
|
||||
set_track_meta,
|
||||
)
|
||||
|
||||
|
||||
def get_decath_loader(args):
|
||||
"""Get data loaders for decathlon-style datasets.
|
||||
|
||||
Creates training and validation data loaders with appropriate transforms
|
||||
for medical imaging data.
|
||||
|
||||
Args:
|
||||
args: Arguments containing:
|
||||
- data_path: Path to the dataset
|
||||
- gpu_device: GPU device ID
|
||||
- roi_size: Region of interest size
|
||||
- chunk: Volume chunk size
|
||||
- num_sample: Number of positive/negative samples
|
||||
- b: Batch size
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- train_loader: DataLoader for training
|
||||
- val_loader: DataLoader for validation
|
||||
- train_transforms: Training transforms
|
||||
- val_transforms: Validation transforms
|
||||
- datalist: Training data list
|
||||
- val_files: Validation data list
|
||||
"""
|
||||
device = torch.device('cuda', args.gpu_device)
|
||||
|
||||
train_transforms = Compose([
|
||||
LoadImaged(keys=["image", "label"], ensure_channel_first=True),
|
||||
ScaleIntensityRanged(
|
||||
keys=["image"],
|
||||
a_min=-175,
|
||||
a_max=250,
|
||||
b_min=0.0,
|
||||
b_max=1.0,
|
||||
clip=True,
|
||||
),
|
||||
CropForegroundd(keys=["image", "label"], source_key="image"),
|
||||
Orientationd(keys=["image", "label"], axcodes="RAS"),
|
||||
Spacingd(
|
||||
keys=["image", "label"],
|
||||
pixdim=(1.5, 1.5, 2.0),
|
||||
mode=("bilinear", "nearest"),
|
||||
),
|
||||
EnsureTyped(keys=["image", "label"], device=device, track_meta=False),
|
||||
RandCropByPosNegLabeld(
|
||||
keys=["image", "label"],
|
||||
label_key="label",
|
||||
spatial_size=(args.roi_size, args.roi_size, args.chunk),
|
||||
pos=1,
|
||||
neg=1,
|
||||
num_samples=args.num_sample,
|
||||
image_key="image",
|
||||
image_threshold=0,
|
||||
),
|
||||
RandFlipd(
|
||||
keys=["image", "label"],
|
||||
spatial_axis=[0],
|
||||
prob=0.10,
|
||||
),
|
||||
RandFlipd(
|
||||
keys=["image", "label"],
|
||||
spatial_axis=[1],
|
||||
prob=0.10,
|
||||
),
|
||||
RandFlipd(
|
||||
keys=["image", "label"],
|
||||
spatial_axis=[2],
|
||||
prob=0.10,
|
||||
),
|
||||
RandRotate90d(
|
||||
keys=["image", "label"],
|
||||
prob=0.10,
|
||||
max_k=3,
|
||||
),
|
||||
RandShiftIntensityd(
|
||||
keys=["image"],
|
||||
offsets=0.10,
|
||||
prob=0.50,
|
||||
),
|
||||
])
|
||||
|
||||
val_transforms = Compose([
|
||||
LoadImaged(keys=["image", "label"], ensure_channel_first=True),
|
||||
ScaleIntensityRanged(
|
||||
keys=["image"],
|
||||
a_min=-175,
|
||||
a_max=250,
|
||||
b_min=0.0,
|
||||
b_max=1.0,
|
||||
clip=True
|
||||
),
|
||||
CropForegroundd(keys=["image", "label"], source_key="image"),
|
||||
Orientationd(keys=["image", "label"], axcodes="RAS"),
|
||||
Spacingd(
|
||||
keys=["image", "label"],
|
||||
pixdim=(1.5, 1.5, 2.0),
|
||||
mode=("bilinear", "nearest"),
|
||||
),
|
||||
EnsureTyped(keys=["image", "label"], device=device, track_meta=True),
|
||||
])
|
||||
|
||||
data_dir = args.data_path
|
||||
split_json = "dataset_0.json"
|
||||
|
||||
datasets = os.path.join(data_dir, split_json)
|
||||
datalist = load_decathlon_datalist(datasets, True, "training")
|
||||
val_files = load_decathlon_datalist(datasets, True, "validation")
|
||||
|
||||
train_ds = CacheDataset(
|
||||
data=datalist,
|
||||
transform=train_transforms,
|
||||
cache_num=24,
|
||||
cache_rate=1.0,
|
||||
num_workers=8,
|
||||
)
|
||||
train_loader = ThreadDataLoader(
|
||||
train_ds, num_workers=0, batch_size=args.b, shuffle=True
|
||||
)
|
||||
|
||||
val_ds = CacheDataset(
|
||||
data=val_files,
|
||||
transform=val_transforms,
|
||||
cache_num=2,
|
||||
cache_rate=1.0,
|
||||
num_workers=0
|
||||
)
|
||||
val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)
|
||||
|
||||
set_track_meta(False)
|
||||
|
||||
return train_loader, val_loader, train_transforms, val_transforms, datalist, val_files
|
||||
@ -0,0 +1,52 @@
|
||||
"""
|
||||
Utility modules for One-Prompt Medical Image Segmentation.
|
||||
|
||||
This module provides various utility functions including:
|
||||
- Metrics (Dice, IoU)
|
||||
- Visualization tools
|
||||
- Data utilities
|
||||
- Logging utilities
|
||||
"""
|
||||
|
||||
# Use lazy imports to avoid dependency issues
|
||||
__all__ = [
|
||||
# Metrics
|
||||
"DiceMetric",
|
||||
"eval_seg",
|
||||
"iou",
|
||||
"dice_coeff",
|
||||
"DiceCoeff",
|
||||
# Visualization
|
||||
"vis_image",
|
||||
"save_image",
|
||||
"make_grid",
|
||||
# Data utilities
|
||||
"random_click",
|
||||
"generate_click_prompt",
|
||||
# Logging
|
||||
"create_logger",
|
||||
"set_log_dir",
|
||||
"save_checkpoint",
|
||||
# Scheduler
|
||||
"WarmUpLR",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
"""Lazy import of submodules."""
|
||||
if name in ("DiceMetric", "eval_seg", "iou", "dice_coeff", "DiceCoeff"):
|
||||
from . import metrics
|
||||
return getattr(metrics, name)
|
||||
elif name in ("vis_image", "save_image", "make_grid"):
|
||||
from . import visualization
|
||||
return getattr(visualization, name)
|
||||
elif name in ("random_click", "generate_click_prompt"):
|
||||
from . import data_utils
|
||||
return getattr(data_utils, name)
|
||||
elif name in ("create_logger", "set_log_dir", "save_checkpoint"):
|
||||
from . import logging
|
||||
return getattr(logging, name)
|
||||
elif name == "WarmUpLR":
|
||||
from .scheduler import WarmUpLR
|
||||
return WarmUpLR
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
@ -0,0 +1,78 @@
|
||||
"""
|
||||
Data utility functions for medical image segmentation.
|
||||
|
||||
This module provides utility functions for data processing,
|
||||
including click prompt generation and random sampling.
|
||||
"""
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def random_click(mask: np.ndarray, point_labels: int = 1,
|
||||
inout: int = 1) -> np.ndarray:
|
||||
"""Generate a random click point on the mask.
|
||||
|
||||
Args:
|
||||
mask: Binary mask array.
|
||||
point_labels: Label for the point (1 for foreground).
|
||||
inout: Value to match in the mask (1 for foreground).
|
||||
|
||||
Returns:
|
||||
Coordinates of the random click point [y, x].
|
||||
"""
|
||||
indices = np.argwhere(mask == inout)
|
||||
if len(indices) == 0:
|
||||
# Return center if no valid points
|
||||
h, w = mask.shape[:2]
|
||||
return np.array([h // 2, w // 2])
|
||||
return indices[np.random.randint(len(indices))]
|
||||
|
||||
|
||||
def generate_click_prompt(img: torch.Tensor, msk: torch.Tensor,
|
||||
pt_label: int = 1) -> tuple:
|
||||
"""Generate click prompts from image and mask.
|
||||
|
||||
Creates point prompts by randomly selecting positions from the mask
|
||||
for each slice in a 3D volume.
|
||||
|
||||
Args:
|
||||
img: Input image tensor of shape (B, C, H, W, D).
|
||||
msk: Mask tensor of shape (B, C, H, W, D).
|
||||
pt_label: Point label value.
|
||||
|
||||
Returns:
|
||||
Tuple of (image, point_coordinates, processed_mask).
|
||||
"""
|
||||
pt_list = []
|
||||
msk_list = []
|
||||
b, c, h, w, d = msk.size()
|
||||
msk = msk[:, 0, :, :, :]
|
||||
|
||||
for i in range(d):
|
||||
pt_list_s = []
|
||||
msk_list_s = []
|
||||
for j in range(b):
|
||||
msk_s = msk[j, :, :, i]
|
||||
indices = torch.nonzero(msk_s)
|
||||
if indices.size(0) == 0:
|
||||
random_index = torch.randint(0, h, (2,)).to(device=msk.device)
|
||||
new_s = msk_s
|
||||
else:
|
||||
random_index = random.choice(indices)
|
||||
label = msk_s[random_index[0], random_index[1]]
|
||||
new_s = torch.zeros_like(msk_s)
|
||||
new_s = (msk_s == label).to(dtype=torch.float)
|
||||
pt_list_s.append(random_index)
|
||||
msk_list_s.append(new_s)
|
||||
pts = torch.stack(pt_list_s, dim=0)
|
||||
msks = torch.stack(msk_list_s, dim=0)
|
||||
pt_list.append(pts)
|
||||
msk_list.append(msks)
|
||||
|
||||
pt = torch.stack(pt_list, dim=-1)
|
||||
msk = torch.stack(msk_list, dim=-1)
|
||||
msk = msk.unsqueeze(1)
|
||||
|
||||
return img, pt, msk
|
||||
@ -0,0 +1,103 @@
|
||||
"""
|
||||
Logging utilities for training and experiment tracking.
|
||||
|
||||
This module provides functions for creating loggers, managing log directories,
|
||||
and saving checkpoints.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import dateutil.tz
|
||||
import torch
|
||||
|
||||
|
||||
def create_logger(log_dir: str, phase: str = 'train') -> logging.Logger:
|
||||
"""Create a logger for training.
|
||||
|
||||
Args:
|
||||
log_dir: Directory to save log files.
|
||||
phase: Training phase name ('train', 'val', etc.).
|
||||
|
||||
Returns:
|
||||
Configured logger instance.
|
||||
"""
|
||||
time_str = time.strftime('%Y-%m-%d-%H-%M')
|
||||
log_file = '{}_{}.log'.format(time_str, phase)
|
||||
final_log_file = os.path.join(log_dir, log_file)
|
||||
head = '%(asctime)-15s %(message)s'
|
||||
logging.basicConfig(filename=str(final_log_file), format=head)
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
console = logging.StreamHandler()
|
||||
logging.getLogger('').addHandler(console)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def set_log_dir(root_dir: str, exp_name: str) -> Dict[str, str]:
|
||||
"""Set up logging directories for an experiment.
|
||||
|
||||
Creates the following directory structure:
|
||||
root_dir/
|
||||
exp_name_timestamp/
|
||||
Model/ - for checkpoints
|
||||
Log/ - for log files
|
||||
Samples/ - for sample images
|
||||
|
||||
Args:
|
||||
root_dir: Root directory for all experiments.
|
||||
exp_name: Name of the experiment.
|
||||
|
||||
Returns:
|
||||
Dictionary containing paths:
|
||||
- 'prefix': Base experiment directory
|
||||
- 'ckpt_path': Checkpoint directory
|
||||
- 'log_path': Log file directory
|
||||
- 'sample_path': Sample image directory
|
||||
"""
|
||||
path_dict = {}
|
||||
os.makedirs(root_dir, exist_ok=True)
|
||||
|
||||
# Set log path
|
||||
exp_path = os.path.join(root_dir, exp_name)
|
||||
now = datetime.now(dateutil.tz.tzlocal())
|
||||
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
|
||||
prefix = exp_path + '_' + timestamp
|
||||
os.makedirs(prefix)
|
||||
path_dict['prefix'] = prefix
|
||||
|
||||
# Set checkpoint path
|
||||
ckpt_path = os.path.join(prefix, 'Model')
|
||||
os.makedirs(ckpt_path)
|
||||
path_dict['ckpt_path'] = ckpt_path
|
||||
|
||||
log_path = os.path.join(prefix, 'Log')
|
||||
os.makedirs(log_path)
|
||||
path_dict['log_path'] = log_path
|
||||
|
||||
# Set sample image path
|
||||
sample_path = os.path.join(prefix, 'Samples')
|
||||
os.makedirs(sample_path)
|
||||
path_dict['sample_path'] = sample_path
|
||||
|
||||
return path_dict
|
||||
|
||||
|
||||
def save_checkpoint(states: Dict[str, Any], is_best: bool,
|
||||
output_dir: str, filename: str = 'checkpoint.pth') -> None:
|
||||
"""Save model checkpoint.
|
||||
|
||||
Args:
|
||||
states: Dictionary containing model state and metadata.
|
||||
is_best: Whether this is the best model so far.
|
||||
output_dir: Directory to save the checkpoint.
|
||||
filename: Name of the checkpoint file.
|
||||
"""
|
||||
torch.save(states, os.path.join(output_dir, filename))
|
||||
if is_best:
|
||||
torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth'))
|
||||
@ -0,0 +1,157 @@
|
||||
"""
|
||||
Evaluation metrics for medical image segmentation.
|
||||
|
||||
This module provides various metrics for evaluating segmentation quality:
|
||||
- Dice coefficient
|
||||
- IoU (Intersection over Union)
|
||||
- DiceMetric class for batch evaluation
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
|
||||
# Optional import for MONAI DiceMetric
|
||||
try:
|
||||
from monai.metrics import DiceMetric
|
||||
except ImportError:
|
||||
DiceMetric = None
|
||||
|
||||
|
||||
def iou(outputs: np.ndarray, labels: np.ndarray) -> float:
|
||||
"""Compute Intersection over Union (IoU) metric.
|
||||
|
||||
Args:
|
||||
outputs: Predicted binary masks of shape (N, H, W).
|
||||
labels: Ground truth binary masks of shape (N, H, W).
|
||||
|
||||
Returns:
|
||||
Mean IoU score across all samples.
|
||||
"""
|
||||
SMOOTH = 1e-6
|
||||
intersection = (outputs & labels).sum((1, 2))
|
||||
union = (outputs | labels).sum((1, 2))
|
||||
iou_score = (intersection + SMOOTH) / (union + SMOOTH)
|
||||
return iou_score.mean()
|
||||
|
||||
|
||||
class DiceCoeff(Function):
|
||||
"""Dice coefficient for individual examples.
|
||||
|
||||
This is a custom autograd function for computing the Dice coefficient
|
||||
with gradient support.
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute Dice coefficient.
|
||||
|
||||
Args:
|
||||
input: Predicted mask tensor.
|
||||
target: Ground truth mask tensor.
|
||||
|
||||
Returns:
|
||||
Dice coefficient value.
|
||||
"""
|
||||
self.save_for_backward(input, target)
|
||||
eps = 0.0001
|
||||
self.inter = torch.dot(input.view(-1), target.view(-1))
|
||||
self.union = torch.sum(input) + torch.sum(target) + eps
|
||||
t = (2 * self.inter.float() + eps) / self.union.float()
|
||||
return t
|
||||
|
||||
def backward(self, grad_output):
|
||||
"""Compute gradients for backpropagation.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient of the loss with respect to the output.
|
||||
|
||||
Returns:
|
||||
Tuple of gradients for input and target.
|
||||
"""
|
||||
input, target = self.saved_variables
|
||||
grad_input = grad_target = None
|
||||
|
||||
if self.needs_input_grad[0]:
|
||||
grad_input = grad_output * 2 * (target * self.union - self.inter) \
|
||||
/ (self.union * self.union)
|
||||
if self.needs_input_grad[1]:
|
||||
grad_target = None
|
||||
|
||||
return grad_input, grad_target
|
||||
|
||||
|
||||
def dice_coeff(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute Dice coefficient for batches.
|
||||
|
||||
Args:
|
||||
input: Predicted mask tensor of shape (N, H, W).
|
||||
target: Ground truth mask tensor of shape (N, H, W).
|
||||
|
||||
Returns:
|
||||
Mean Dice coefficient across the batch.
|
||||
"""
|
||||
if input.is_cuda:
|
||||
s = torch.FloatTensor(1).to(device=input.device).zero_()
|
||||
else:
|
||||
s = torch.FloatTensor(1).zero_()
|
||||
|
||||
for i, c in enumerate(zip(input, target)):
|
||||
s = s + DiceCoeff().forward(c[0], c[1])
|
||||
|
||||
return s / (i + 1)
|
||||
|
||||
|
||||
def eval_seg(pred: torch.Tensor, true_mask_p: torch.Tensor,
|
||||
threshold: tuple) -> tuple:
|
||||
"""Evaluate segmentation predictions.
|
||||
|
||||
Args:
|
||||
pred: Predicted masks of shape (B, C, H, W).
|
||||
true_mask_p: Ground truth masks of shape (B, C, H, W).
|
||||
threshold: Tuple of threshold values for binarization.
|
||||
|
||||
Returns:
|
||||
Tuple of evaluation metrics (IoU and Dice scores).
|
||||
"""
|
||||
b, c, h, w = pred.size()
|
||||
|
||||
if c == 2:
|
||||
# Two-channel output (disc and cup)
|
||||
iou_d, iou_c, disc_dice, cup_dice = 0, 0, 0, 0
|
||||
for th in threshold:
|
||||
gt_vmask_p = (true_mask_p > th).float()
|
||||
vpred = (pred > th).float()
|
||||
vpred_cpu = vpred.cpu()
|
||||
disc_pred = vpred_cpu[:, 0, :, :].numpy().astype('int32')
|
||||
cup_pred = vpred_cpu[:, 1, :, :].numpy().astype('int32')
|
||||
|
||||
disc_mask = gt_vmask_p[:, 0, :, :].squeeze(1).cpu().numpy().astype('int32')
|
||||
cup_mask = gt_vmask_p[:, 1, :, :].squeeze(1).cpu().numpy().astype('int32')
|
||||
|
||||
iou_d += iou(disc_pred, disc_mask)
|
||||
iou_c += iou(cup_pred, cup_mask)
|
||||
|
||||
disc_dice += dice_coeff(vpred[:, 0, :, :], gt_vmask_p[:, 0, :, :]).item()
|
||||
cup_dice += dice_coeff(vpred[:, 1, :, :], gt_vmask_p[:, 1, :, :]).item()
|
||||
|
||||
return (
|
||||
iou_d / len(threshold),
|
||||
iou_c / len(threshold),
|
||||
disc_dice / len(threshold),
|
||||
cup_dice / len(threshold)
|
||||
)
|
||||
else:
|
||||
# Single-channel output
|
||||
eiou, edice = 0, 0
|
||||
for th in threshold:
|
||||
gt_vmask_p = (true_mask_p > th).float()
|
||||
vpred = (pred > th).float()
|
||||
vpred_cpu = vpred.cpu()
|
||||
disc_pred = vpred_cpu[:, 0, :, :].numpy().astype('int32')
|
||||
|
||||
disc_mask = gt_vmask_p[:, 0, :, :].squeeze(1).cpu().numpy().astype('int32')
|
||||
|
||||
eiou += iou(disc_pred, disc_mask)
|
||||
edice += dice_coeff(vpred[:, 0, :, :], gt_vmask_p[:, 0, :, :]).item()
|
||||
|
||||
return eiou / len(threshold), edice / len(threshold)
|
||||
@ -0,0 +1,38 @@
|
||||
"""
|
||||
Learning rate schedulers for training.
|
||||
|
||||
This module provides custom learning rate scheduler implementations.
|
||||
"""
|
||||
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
class WarmUpLR(_LRScheduler):
|
||||
"""Warmup learning rate scheduler.
|
||||
|
||||
Gradually increases the learning rate from 0 to the initial learning rate
|
||||
during the warmup phase.
|
||||
|
||||
Args:
|
||||
optimizer: Wrapped optimizer.
|
||||
total_iters: Total number of warmup iterations.
|
||||
last_epoch: The index of last epoch.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, total_iters: int, last_epoch: int = -1):
|
||||
self.total_iters = total_iters
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
"""Calculate learning rate for the current step.
|
||||
|
||||
During warmup, the learning rate is linearly scaled:
|
||||
lr = base_lr * current_step / total_iters
|
||||
|
||||
Returns:
|
||||
List of learning rates for each parameter group.
|
||||
"""
|
||||
return [
|
||||
base_lr * self.last_epoch / (self.total_iters + 1e-8)
|
||||
for base_lr in self.base_lrs
|
||||
]
|
||||
@ -0,0 +1,204 @@
|
||||
"""
|
||||
Visualization utilities for medical image segmentation.
|
||||
|
||||
This module provides functions for visualizing segmentation results,
|
||||
creating image grids, and saving output images.
|
||||
"""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
import pathlib
|
||||
from typing import Union, Optional, List, Tuple, Text, BinaryIO
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.utils as vutils
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def make_grid(
|
||||
tensor: Union[torch.Tensor, List[torch.Tensor]],
|
||||
nrow: int = 8,
|
||||
padding: int = 2,
|
||||
normalize: bool = False,
|
||||
value_range: Optional[Tuple[int, int]] = None,
|
||||
scale_each: bool = False,
|
||||
pad_value: int = 0,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""Make a grid of images.
|
||||
|
||||
Args:
|
||||
tensor: 4D mini-batch Tensor of shape (B, C, H, W)
|
||||
or a list of images all of the same size.
|
||||
nrow: Number of images displayed in each row of the grid.
|
||||
padding: Amount of padding.
|
||||
normalize: If True, shift the image to the range (0, 1).
|
||||
value_range: Tuple (min, max) for normalization.
|
||||
scale_each: If True, scale each image in the batch independently.
|
||||
pad_value: Value for the padded pixels.
|
||||
|
||||
Returns:
|
||||
Grid tensor of shape (C, H, W).
|
||||
"""
|
||||
if not (torch.is_tensor(tensor) or
|
||||
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
||||
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
|
||||
|
||||
if "range" in kwargs.keys():
|
||||
warning = "range will be deprecated, please use value_range instead."
|
||||
warnings.warn(warning)
|
||||
value_range = kwargs["range"]
|
||||
|
||||
if isinstance(tensor, list):
|
||||
tensor = torch.stack(tensor, dim=0)
|
||||
|
||||
if tensor.dim() == 2:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
if tensor.dim() == 3:
|
||||
if tensor.size(0) == 1:
|
||||
tensor = torch.cat((tensor, tensor, tensor), 0)
|
||||
tensor = tensor.unsqueeze(0)
|
||||
|
||||
if tensor.dim() == 4 and tensor.size(1) == 1:
|
||||
tensor = torch.cat((tensor, tensor, tensor), 1)
|
||||
|
||||
if normalize is True:
|
||||
tensor = tensor.clone()
|
||||
if value_range is not None:
|
||||
assert isinstance(value_range, tuple), \
|
||||
"value_range has to be a tuple (min, max) if specified."
|
||||
|
||||
def norm_ip(img, low, high):
|
||||
img.clamp(min=low, max=high)
|
||||
img.sub_(low).div_(max(high - low, 1e-5))
|
||||
|
||||
def norm_range(t, value_range):
|
||||
if value_range is not None:
|
||||
norm_ip(t, value_range[0], value_range[1])
|
||||
else:
|
||||
norm_ip(t, float(t.min()), float(t.max()))
|
||||
|
||||
if scale_each is True:
|
||||
for t in tensor:
|
||||
norm_range(t, value_range)
|
||||
else:
|
||||
norm_range(tensor, value_range)
|
||||
|
||||
if tensor.size(0) == 1:
|
||||
return tensor.squeeze(0)
|
||||
|
||||
nmaps = tensor.size(0)
|
||||
xmaps = min(nrow, nmaps)
|
||||
ymaps = int(math.ceil(float(nmaps) / xmaps))
|
||||
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
|
||||
num_channels = tensor.size(1)
|
||||
grid = tensor.new_full(
|
||||
(num_channels, height * ymaps + padding, width * xmaps + padding),
|
||||
pad_value
|
||||
)
|
||||
k = 0
|
||||
for y in range(ymaps):
|
||||
for x in range(xmaps):
|
||||
if k >= nmaps:
|
||||
break
|
||||
grid.narrow(1, y * height + padding, height - padding).narrow(
|
||||
2, x * width + padding, width - padding
|
||||
).copy_(tensor[k])
|
||||
k = k + 1
|
||||
return grid
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def save_image(
|
||||
tensor: Union[torch.Tensor, List[torch.Tensor]],
|
||||
fp: Union[Text, pathlib.Path, BinaryIO],
|
||||
format: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""Save a tensor as an image file.
|
||||
|
||||
Args:
|
||||
tensor: Image tensor to be saved.
|
||||
fp: Filename or file object.
|
||||
format: Image format. If omitted, determined from filename extension.
|
||||
**kwargs: Other arguments passed to make_grid.
|
||||
"""
|
||||
grid = make_grid(tensor, **kwargs)
|
||||
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
||||
im = Image.fromarray(ndarr)
|
||||
im.save(fp, format=format)
|
||||
|
||||
|
||||
def vis_image(imgs: torch.Tensor, pred_masks: torch.Tensor,
|
||||
gt_masks: torch.Tensor, save_path: str,
|
||||
reverse: bool = False, points=None) -> None:
|
||||
"""Visualize segmentation results.
|
||||
|
||||
Creates a grid visualization with input images, predicted masks,
|
||||
and ground truth masks.
|
||||
|
||||
Args:
|
||||
imgs: Input images tensor of shape (B, C, H, W).
|
||||
pred_masks: Predicted masks tensor of shape (B, C, H, W).
|
||||
gt_masks: Ground truth masks tensor of shape (B, C, H, W).
|
||||
save_path: Path to save the visualization.
|
||||
reverse: If True, invert the mask colors.
|
||||
points: Optional point annotations to overlay.
|
||||
"""
|
||||
import torchvision
|
||||
|
||||
b, c, h, w = pred_masks.size()
|
||||
dev = pred_masks.get_device()
|
||||
row_num = min(b, 4)
|
||||
|
||||
if torch.max(pred_masks) > 1 or torch.min(pred_masks) < 0:
|
||||
pred_masks = torch.sigmoid(pred_masks)
|
||||
|
||||
if reverse:
|
||||
pred_masks = 1 - pred_masks
|
||||
gt_masks = 1 - gt_masks
|
||||
|
||||
if c == 2:
|
||||
# Two-channel (disc and cup)
|
||||
pred_disc = pred_masks[:, 0, :, :].unsqueeze(1).expand(b, 3, h, w)
|
||||
pred_cup = pred_masks[:, 1, :, :].unsqueeze(1).expand(b, 3, h, w)
|
||||
gt_disc = gt_masks[:, 0, :, :].unsqueeze(1).expand(b, 3, h, w)
|
||||
gt_cup = gt_masks[:, 1, :, :].unsqueeze(1).expand(b, 3, h, w)
|
||||
|
||||
compose = torch.cat((
|
||||
pred_disc[:row_num, :, :, :],
|
||||
pred_cup[:row_num, :, :, :],
|
||||
gt_disc[:row_num, :, :, :],
|
||||
gt_cup[:row_num, :, :, :]
|
||||
), 0)
|
||||
vutils.save_image(compose, fp=save_path, nrow=row_num, padding=10)
|
||||
else:
|
||||
# Single channel
|
||||
imgs = torchvision.transforms.Resize((h, w))(imgs)
|
||||
if imgs.size(1) == 1:
|
||||
imgs = imgs[:, 0, :, :].unsqueeze(1).expand(b, 3, h, w)
|
||||
pred_masks = pred_masks[:, 0, :, :].unsqueeze(1).expand(b, 3, h, w)
|
||||
gt_masks = gt_masks[:, 0, :, :].unsqueeze(1).expand(b, 3, h, w)
|
||||
|
||||
if points is not None:
|
||||
# Import args for point visualization
|
||||
import cfg
|
||||
args = cfg.parse_args()
|
||||
for i in range(b):
|
||||
if args.thd:
|
||||
p = np.round(points.cpu() / args.roi_size * args.out_size).to(dtype=torch.int)
|
||||
else:
|
||||
p = np.round(points.cpu() / args.image_size * args.out_size).to(dtype=torch.int)
|
||||
gt_masks[i, 0, p[i, 0]-5:p[i, 0]+5, p[i, 1]-5:p[i, 1]+5] = 0.5
|
||||
gt_masks[i, 1, p[i, 0]-5:p[i, 0]+5, p[i, 1]-5:p[i, 1]+5] = 0.1
|
||||
gt_masks[i, 2, p[i, 0]-5:p[i, 0]+5, p[i, 1]-5:p[i, 1]+5] = 0.4
|
||||
|
||||
tup = (
|
||||
imgs[:row_num, :, :, :],
|
||||
pred_masks[:row_num, :, :, :],
|
||||
gt_masks[:row_num, :, :, :]
|
||||
)
|
||||
compose = torch.cat(tup, 0)
|
||||
vutils.save_image(compose, fp=save_path, nrow=row_num, padding=10)
|
||||
@ -0,0 +1,5 @@
|
||||
"""
|
||||
Test package for One-Prompt Medical Image Segmentation.
|
||||
|
||||
This package contains unit tests and integration tests for the project.
|
||||
"""
|
||||
@ -0,0 +1,56 @@
|
||||
"""
|
||||
Unit tests for data utilities.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
from oneprompt_seg.utils.data_utils import random_click, generate_click_prompt
|
||||
|
||||
|
||||
class TestRandomClick:
|
||||
"""Tests for random_click function."""
|
||||
|
||||
def test_click_on_foreground(self):
|
||||
"""Test that click is generated on foreground."""
|
||||
mask = np.zeros((64, 64))
|
||||
mask[20:40, 20:40] = 1
|
||||
pt = random_click(mask, point_labels=1, inout=1)
|
||||
assert 20 <= pt[0] < 40
|
||||
assert 20 <= pt[1] < 40
|
||||
|
||||
def test_empty_mask(self):
|
||||
"""Test click generation with empty mask."""
|
||||
mask = np.zeros((64, 64))
|
||||
pt = random_click(mask, point_labels=1, inout=1)
|
||||
# Should return center when no valid points
|
||||
assert pt[0] == 32
|
||||
assert pt[1] == 32
|
||||
|
||||
def test_click_coordinates_shape(self):
|
||||
"""Test that click returns correct shape."""
|
||||
mask = np.ones((64, 64))
|
||||
pt = random_click(mask, point_labels=1, inout=1)
|
||||
assert len(pt) == 2
|
||||
|
||||
|
||||
class TestGenerateClickPrompt:
|
||||
"""Tests for generate_click_prompt function."""
|
||||
|
||||
def test_output_shapes(self):
|
||||
"""Test output tensor shapes."""
|
||||
img = torch.rand(2, 3, 64, 64, 4)
|
||||
msk = torch.rand(2, 1, 64, 64, 4)
|
||||
msk = (msk > 0.5).float()
|
||||
|
||||
out_img, pt, out_msk = generate_click_prompt(img, msk)
|
||||
|
||||
assert out_img.shape == img.shape
|
||||
assert pt.shape[0] == 2 # batch size
|
||||
assert pt.shape[-1] == 4 # depth
|
||||
assert out_msk.shape[0] == 2 # batch size
|
||||
@ -0,0 +1,78 @@
|
||||
"""
|
||||
Unit tests for metrics module.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
||||
|
||||
from oneprompt_seg.utils.metrics import iou, dice_coeff, eval_seg
|
||||
|
||||
|
||||
class TestIoU:
|
||||
"""Tests for IoU metric."""
|
||||
|
||||
def test_perfect_overlap(self):
|
||||
"""Test IoU with perfect overlap."""
|
||||
pred = np.ones((1, 64, 64), dtype=np.int32)
|
||||
target = np.ones((1, 64, 64), dtype=np.int32)
|
||||
result = iou(pred, target)
|
||||
assert result == pytest.approx(1.0, abs=1e-5)
|
||||
|
||||
def test_no_overlap(self):
|
||||
"""Test IoU with no overlap."""
|
||||
pred = np.ones((1, 64, 64), dtype=np.int32)
|
||||
target = np.zeros((1, 64, 64), dtype=np.int32)
|
||||
result = iou(pred, target)
|
||||
assert result == pytest.approx(0.0, abs=1e-5)
|
||||
|
||||
def test_partial_overlap(self):
|
||||
"""Test IoU with partial overlap."""
|
||||
pred = np.zeros((1, 64, 64), dtype=np.int32)
|
||||
target = np.zeros((1, 64, 64), dtype=np.int32)
|
||||
pred[0, :32, :] = 1
|
||||
target[0, 16:48, :] = 1
|
||||
result = iou(pred, target)
|
||||
assert 0 < result < 1
|
||||
|
||||
|
||||
class TestDiceCoeff:
|
||||
"""Tests for Dice coefficient."""
|
||||
|
||||
def test_perfect_overlap(self):
|
||||
"""Test Dice with perfect overlap."""
|
||||
pred = torch.ones(1, 64, 64)
|
||||
target = torch.ones(1, 64, 64)
|
||||
result = dice_coeff(pred, target)
|
||||
assert result.item() == pytest.approx(1.0, abs=1e-3)
|
||||
|
||||
def test_no_overlap(self):
|
||||
"""Test Dice with no overlap."""
|
||||
pred = torch.ones(1, 64, 64)
|
||||
target = torch.zeros(1, 64, 64)
|
||||
result = dice_coeff(pred, target)
|
||||
assert result.item() == pytest.approx(0.0, abs=1e-3)
|
||||
|
||||
|
||||
class TestEvalSeg:
|
||||
"""Tests for eval_seg function."""
|
||||
|
||||
def test_single_channel(self):
|
||||
"""Test evaluation with single channel output."""
|
||||
pred = torch.rand(2, 1, 64, 64)
|
||||
target = torch.rand(2, 1, 64, 64)
|
||||
threshold = (0.5,)
|
||||
result = eval_seg(pred, target, threshold)
|
||||
assert len(result) == 2 # IoU and Dice
|
||||
|
||||
def test_two_channel(self):
|
||||
"""Test evaluation with two channel output."""
|
||||
pred = torch.rand(2, 2, 64, 64)
|
||||
target = torch.rand(2, 2, 64, 64)
|
||||
threshold = (0.5,)
|
||||
result = eval_seg(pred, target, threshold)
|
||||
assert len(result) == 4 # IoU_d, IoU_c, Dice_d, Dice_c
|
||||
Loading…
Reference in new issue