Refactor project structure following Python code standards

- Add src/oneprompt_seg package with modular structure:
  - core/: network builder and training functions
  - data/: dataset classes and data loaders
  - utils/: metrics, visualization, logging, schedulers
- Add tests/ directory with unit tests for metrics and data utils
- Add docs/ directory with Sphinx documentation setup
- Add configs/ with development and production YAML configs
- Update scripts/ with refactored train.py and evaluate.py
- Add pyproject.toml for modern Python project configuration
- Add requirements.txt and requirements-dev.txt
- Add Makefile for common development commands
- Add .pre-commit-config.yaml for code quality hooks
- Update .gitignore to exclude .claude/ and improve patterns
- Update README.md with new structure and usage instructions
- Add CLAUDE.md as project guide for Claude Code

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
main
USM202504148 4 months ago
parent 60b9c98523
commit e0b048b2a5

28
.gitignore vendored

@ -1,3 +1,6 @@
# Claude Code
.claude/
# Python
__pycache__/
*.py[cod]
@ -9,6 +12,8 @@ __pycache__/
.eggs/
dist/
build/
.mypy_cache/
.pytest_cache/
# Virtual environments
.env
@ -30,9 +35,11 @@ ENV/
# Logs and outputs
logs/
runs/
experiments/
# Model checkpoints (large files)
checkpoint/
checkpoints/
*.pth
*.pt
*.ckpt
@ -40,15 +47,32 @@ checkpoint/
*.safetensors
# Data files (usually large)
data/
datasets/
/data/
/datasets/
# Documentation build
docs/_build/
docs/api/
# Coverage reports
.coverage
htmlcov/
coverage.xml
# OS generated files
.DS_Store
Thumbs.db
desktop.ini
nul
# Temporary files
*.tmp
*.temp
*.bak
# Local configuration
*.local.json
*.local.yaml
# Word documents (not needed in git)
*.docx

@ -0,0 +1,44 @@
# Pre-commit hooks configuration
# Install: pip install pre-commit && pre-commit install
repos:
# General hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
args: ['--maxkb=500']
- id: check-merge-conflict
- id: debug-statements
# Python code formatting with Black
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
language_version: python3
# Import sorting with isort
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black"]
# Linting with flake8
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
args: ["--max-line-length=88", "--extend-ignore=E203,W503"]
# Type checking with mypy
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
hooks:
- id: mypy
additional_dependencies: [types-PyYAML]
args: ["--ignore-missing-imports"]

@ -0,0 +1,186 @@
# CLAUDE.md - Project Guide for Claude Code
This document provides context and guidance for Claude Code when working with this project.
## Project Overview
**One-Prompt Medical Image Segmentation** is a deep learning framework for medical image segmentation using one-prompt learning. It implements the method described in the CVPR 2024 paper "One-Prompt to Segment All Medical Images".
## Project Structure
```
One-Prompt-Medical-Image-Segmentation/
├── src/oneprompt_seg/ # Main package (src layout)
│ ├── __init__.py # Package init with lazy imports
│ ├── core/ # Core functionality
│ │ ├── __init__.py
│ │ ├── network.py # Network builder (get_network)
│ │ └── function.py # Training/validation functions
│ ├── data/ # Data handling
│ │ ├── __init__.py
│ │ ├── datasets.py # Dataset classes (ISIC2016, REFUGE, PolypDataset)
│ │ └── loader.py # Data loaders (MONAI-based)
│ └── utils/ # Utilities
│ ├── __init__.py # Lazy imports
│ ├── metrics.py # Evaluation metrics (IoU, Dice)
│ ├── visualization.py # Image visualization
│ ├── data_utils.py # Data utilities (click generation)
│ ├── logging.py # Logging utilities
│ └── scheduler.py # LR schedulers
├── models/ # Model implementations
│ ├── oneprompt/ # Main One-Prompt model
│ │ ├── modeling/ # Core model components
│ │ │ ├── oneprompt.py # Main model class
│ │ │ ├── image_encoder.py # Image encoder
│ │ │ ├── prompt_encoder.py# Prompt encoder
│ │ │ └── mask_decoder.py # Mask decoder
│ │ └── utils/ # Model utilities
│ ├── unet/ # UNet backbone
│ └── tag/ # TAG model
├── tests/ # Unit tests
├── docs/ # Documentation (Sphinx)
├── configs/ # YAML configurations
├── scripts/ # Executable scripts
├── train.py # Original training entry
├── val.py # Original validation entry
├── utils.py # Original utilities (kept for compatibility)
├── dataset.py # Original dataset classes
├── function.py # Original training functions
├── cfg.py # Argument parser
└── conf/ # Settings module
```
## Key Technologies
- **PyTorch** - Deep learning framework
- **MONAI** - Medical imaging library (data loading, transforms)
- **TensorBoardX** - Training visualization
- **einops** - Tensor operations
## Common Commands
```bash
# Install dependencies
pip install -r requirements.txt
pip install -e . # Install package in editable mode
# Training
python train.py -net oneprompt -mod one_adpt -exp_name <name> \
-dataset polyp -data_path <path> -baseline unet
# Evaluation
python val.py -net oneprompt -mod one_adpt -exp_name <name> \
-weights <checkpoint_path> -dataset polyp -data_path <path>
# Run tests
pytest tests/ -v
# Format code
make format
```
## Important Files
### Entry Points
- `train.py` - Main training script
- `val.py` - Evaluation script
- `scripts/train.py` - Refactored training script
- `scripts/evaluate.py` - Refactored evaluation script
### Configuration
- `cfg.py` - Command line argument parser
- `conf/global_settings.py` - Global settings (EPOCH, paths)
- `configs/*.yaml` - YAML configuration files
### Core Model
- `models/oneprompt/modeling/oneprompt.py` - Main model implementation
- `models/oneprompt/build_oneprompt.py` - Model registry
### Data
- `dataset.py` - Dataset classes
- `utils.py` - Utilities including `get_network`, `get_decath_loader`
## Code Conventions
This project follows **PEP 8** with the following specifics:
1. **Naming**:
- Classes: `CamelCase` (e.g., `CombinedPolypDataset`)
- Functions/variables: `snake_case` (e.g., `get_network`)
- Constants: `UPPER_CASE` (e.g., `EPOCH`)
2. **Imports**: Grouped as standard library, third-party, local
3. **Type Hints**: Used in new code (src/ directory)
4. **Docstrings**: Google-style docstrings
## Dataset Format
All datasets should return a dict with:
```python
{
'image': torch.Tensor, # [C, H, W] or [C, H, W, D]
'label': torch.Tensor, # Same shape as image
'p_label': int, # 1 for positive, 0 for negative
'pt': tuple or np.ndarray, # Click coordinates
'image_meta_dict': dict, # {'filename_or_obj': str}
}
```
## Common Tasks
### Adding a New Dataset
1. Create a new class in `dataset.py` or `src/oneprompt_seg/data/datasets.py`
2. Inherit from `torch.utils.data.Dataset`
3. Implement `__len__` and `__getitem__`
4. Return the standard dict format
### Modifying Training
1. Main training loop is in `train.py`
2. Per-epoch logic is in `function.py` (`train_one`, `validation_one`)
3. Loss functions are configured based on `args.thd`
### Adding Metrics
1. Add to `src/oneprompt_seg/utils/metrics.py`
2. Update `__init__.py` exports if needed
## Troubleshooting
### Import Errors
- Ensure MONAI and other dependencies are installed
- The src package uses lazy imports to avoid import errors during installation
### CUDA Out of Memory
- Reduce batch size with `-b` argument
- Reduce image size with `-image_size` argument
- Enable gradient checkpointing in the model
### Path Issues
- Data paths are relative to the working directory
- Logs are saved to `./logs/` by default
- Checkpoints are saved to `./checkpoint/` by default
## Testing
```bash
# Run all tests
pytest tests/ -v
# Run specific test file
pytest tests/test_metrics.py -v
# Run with coverage
pytest tests/ --cov=src/oneprompt_seg --cov-report=html
```
## Notes for Development
1. The project maintains backward compatibility - original files (`train.py`, `val.py`, `utils.py`) still work
2. New code should be added to `src/oneprompt_seg/` when possible
3. Use `configs/` YAML files for configuration instead of hardcoding
4. Run `pre-commit` hooks before committing: `pre-commit run --all-files`

@ -0,0 +1,75 @@
# Makefile for One-Prompt Medical Image Segmentation
# Common commands for development, testing, and training
.PHONY: help install install-dev test lint format clean train eval docs
# Default target
help:
@echo "Available commands:"
@echo " install Install production dependencies"
@echo " install-dev Install development dependencies"
@echo " test Run unit tests"
@echo " lint Run code linting"
@echo " format Format code with black and isort"
@echo " clean Clean temporary files"
@echo " train Run training script"
@echo " eval Run evaluation script"
@echo " docs Build documentation"
# Installation
install:
pip install -r requirements.txt
pip install -e .
install-dev:
pip install -r requirements.txt
pip install -r requirements-dev.txt
pip install -e .
pre-commit install
# Testing
test:
pytest tests/ -v --cov=src/oneprompt_seg --cov-report=html
test-quick:
pytest tests/ -v -x
# Linting and formatting
lint:
flake8 src/ tests/
mypy src/
format:
black src/ tests/ scripts/
isort src/ tests/ scripts/
# Cleaning
clean:
find . -type f -name "*.pyc" -delete
find . -type d -name "__pycache__" -delete
find . -type d -name ".pytest_cache" -delete
find . -type d -name ".mypy_cache" -delete
rm -rf build/ dist/ *.egg-info/
rm -rf .coverage htmlcov/
# Training and evaluation
train:
python scripts/train.py --config configs/default.yaml
train-polyp:
python scripts/train.py \
-net oneprompt \
-mod one_adpt \
-exp_name polyp_training \
-dataset polyp \
-data_path ./data/polyp
eval:
python scripts/val.py --config configs/default.yaml
# Documentation
docs:
cd docs && make html
docs-clean:
cd docs && make clean

@ -1,31 +1,133 @@
# One-Prompt to Segment All Meical Image
# One-Prompt to Segment All Medical Images
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
[![PyTorch 1.10+](https://img.shields.io/badge/pytorch-1.10+-ee4c2c.svg)](https://pytorch.org/)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
One-Prompt to Segment All Medical Images, or say One-Prompt, combines the strengths of one-shot and interactive methods. In the inference stage, with just one prompted sample, it can adeptly handle the unseen task in a single forward pass.
This method is elaborated in the paper [One-Prompt to Segment All Medical Images](https://arxiv.org/abs/2305.10300).
This method is elaborated in the paper [One-Prompt to Segment All Medical Images](https://arxiv.org/abs/2305.10300) (CVPR 2024).
## A Quick Overview
## A Quick Overview
<img width="800" height="580" src="https://github.com/KidsWithTokens/one-prompt/blob/main/figs/oneprompt.png">
## Project Structure
```
One-Prompt-Medical-Image-Segmentation/
├── src/oneprompt_seg/ # 源代码包 (推荐使用)
│ ├── core/ # 核心模块 (网络构建、训练函数)
│ ├── data/ # 数据模块 (数据集、加载器)
│ └── utils/ # 工具模块 (指标、可视化、日志)
├── models/ # 模型实现
│ ├── oneprompt/ # One-Prompt 模型
│ ├── unet/ # UNet 模型
│ └── ...
├── tests/ # 单元测试
├── docs/ # 项目文档
├── configs/ # 配置文件
│ ├── default.yaml # 默认配置
│ ├── development.yaml # 开发环境配置
│ └── production.yaml # 生产环境配置
├── scripts/ # 脚本文件
│ ├── train.py # 训练脚本
│ └── evaluate.py # 评估脚本
├── train.py # 原始训练入口
├── val.py # 原始验证入口
├── pyproject.toml # 项目配置
├── requirements.txt # 依赖列表
└── Makefile # 常用命令
```
## Installation
### Method 1: Using Conda (Recommended)
```bash
# 创建环境
conda env create -f environment.yml
conda activate oneprompt
# 安装项目包 (可选,用于使用 src 布局)
pip install -e .
```
### Method 2: Using pip
```bash
# 创建虚拟环境
python -m venv venv
source venv/bin/activate # Linux/Mac
# 或 venv\Scripts\activate # Windows
# 安装依赖
pip install -r requirements.txt
# 安装项目包 (可选)
pip install -e .
```
### Development Installation
```bash
# 安装开发依赖
pip install -r requirements-dev.txt
# 安装 pre-commit hooks
pre-commit install
```
## Quick Start
### Using Makefile
```bash
# 安装依赖
make install
# 运行测试
make test
# 代码格式化
make format
# 代码检查
make lint
# 训练模型
make train
```
### Training
```bash
# 基本训练命令
python train.py -net oneprompt -mod one_adpt -exp_name basic_exp \
-b 64 -dataset oneprompt -data_path ../data -baseline 'unet'
# 或使用 scripts 目录下的脚本
python scripts/train.py -net oneprompt -mod one_adpt -exp_name basic_exp \
-dataset polyp -data_path ./data/polyp
```
## Requirement
### Evaluation
Install the environment:
```bash
python val.py -net oneprompt -mod one_adpt -exp_name One-ISIC \
-weights <weight_path> -b 1 -dataset isic -data_path ../dataset/isic \
-vis 10 -baseline 'unet'
```
``conda env create -f environment.yml``
## Dataset
``conda activate oneprompt``
### Download the Open-source Datasets
## Dataset
### Download the open-source datasets
We collected 78 **open-source** datasets for training and testing the model. The datasets and their download links are in [here](https://drive.google.com/file/d/1iXFm9M1ocrWNkEIthWUWnZYY2-1l-qya/view?usp=share_link).
### Download the prompts
The prompts corresponding to the datasets can be downloaded [here](https://drive.google.com/file/d/1cNv2WW_Cv2NYzpt90vvELaweM5ltIe8n/view?usp=share_link). Each prompt is saved a json message with the format ``{DATASET_NAME, SAMPLE_INDEX, PROMPT_TYPE, PROMPT_CONTENT}``
### Download the Prompts
## Train
run ``python train.py -net oneprompt -mod one_adpt -exp_name basic_exp -b 64 -dataset oneprompt -data_path *../data* -baseline 'unet'``
The prompts corresponding to the datasets can be downloaded [here](https://drive.google.com/file/d/1cNv2WW_Cv2NYzpt90vvELaweM5ltIe8n/view?usp=share_link). Each prompt is saved a json message with the format `{DATASET_NAME, SAMPLE_INDEX, PROMPT_TYPE, PROMPT_CONTENT}`
## Test Examples
@ -33,69 +135,135 @@ run ``python train.py -net oneprompt -mod one_adpt -exp_name basic_exp -b 64 -da
1. Download ISIC dataset part 1 from https://challenge.isic-archive.com/data/. Then put the csv files in "./data/isic" under your data path. Your dataset folder under "your_data_path" should be like:
```
ISIC/
├── ISBI2016_ISIC_Part1_Test_Data/...
├── ISBI2016_ISIC_Part1_Training_Data/...
├── ISBI2016_ISIC_Part1_Test_GroundTruth.csv
└── ISBI2016_ISIC_Part1_Training_GroundTruth.csv
```
ISBI2016_ISIC_Part1_Test_Data/...
ISBI2016_ISIC_Part1_Training_Data/...
ISBI2016_ISIC_Part1_Test_GroundTruth.csv
ISBI2016_ISIC_Part1_Training_GroundTruth.csv
2. run: ``python val.py -net oneprompt -mod one_adpt -exp_name One-ISIC -weights *weight_path* -b 1 -dataset isic -data_path ../dataset/isic -vis 10 -baseline 'unet'``
change "data_path" and "exp_name" for your own useage. you can change "exp_name" to anything you want.
2. Run:
```bash
python val.py -net oneprompt -mod one_adpt -exp_name One-ISIC \
-weights <weight_path> -b 1 -dataset isic -data_path ../dataset/isic \
-vis 10 -baseline 'unet'
```
You can descrease the ``image size`` or batch size ``b`` if out of memory.
Change `data_path` and `exp_name` for your own usage.
3. Evaluation: The code can automatically evaluate the model on the test set during traing, set "--val_freq" to control how many epoches you want to evaluate once. You can also run val.py for the independent evaluation.
You can decrease the `image size` or batch size `b` if out of memory.
4. Result Visualization: You can set "--vis" parameter to control how many epoches you want to see the results in the training or evaluation process.
3. **Evaluation**: The code can automatically evaluate the model on the test set during training, set `--val_freq` to control how many epochs you want to evaluate once. You can also run `val.py` for the independent evaluation.
In default, everything will be saved at `` ./logs/``
4. **Result Visualization**: You can set `--vis` parameter to control how many epochs you want to see the results in the training or evaluation process.
### REFUGE: Optic-disc Segmentation from Fundus Images (2D)
[REFUGE](https://refuge.grand-challenge.org/) dataset contains 1200 fundus images with optic disc/cup segmentations and clinical glaucoma labels.
In default, everything will be saved at `./logs/`
1. Dowaload the dataset manually from [here](https://huggingface.co/datasets/realslimman/REFUGE-MultiRater/tree/main), or using command lines:
### REFUGE: Optic-disc Segmentation from Fundus Images (2D)
``git lfs install``
[REFUGE](https://refuge.grand-challenge.org/) dataset contains 1200 fundus images with optic disc/cup segmentations and clinical glaucoma labels.
``git clone git@hf.co:datasets/realslimman/REFUGE-MultiRater``
1. Download the dataset manually from [here](https://huggingface.co/datasets/realslimman/REFUGE-MultiRater/tree/main), or using command lines:
unzip and put the dataset to the target folder
```bash
git lfs install
git clone git@hf.co:datasets/realslimman/REFUGE-MultiRater
unzip ./REFUGE-MultiRater.zip
mv REFUGE-MultiRater ./data
```
``unzip ./REFUGE-MultiRater.zip``
2. For training the adapter, run:
```bash
python val.py -net oneprompt -mod one_adpt -exp_name One-REFUGE \
-weights <weight_path> -b 1 -baseline 'unet' \
-dataset REFUGE -data_path ./data/REFUGE-MultiRater
```
``mv REFUGE-MultiRater ./data``
## Run on Your Own Dataset
2. For training the adapter, run: ``python val.py -net oneprompt -mod one_adpt -exp_name One-REFUGE -weights *weight_path* -b 1 -baseline 'unet' -dataset REFUGE -data_path ./data/REFUGE-MultiRater``
you can change "exp_name" to anything you want.
It is simple to run One-Prompt on other datasets. Just write another dataset class following which in `./dataset.py`. You only need to make sure you return a dict with:
You can descrease the ``image size`` or batch size ``b`` if out of memory.
```python
{
'image': A tensor saving images with size [C,H,W] for 2D image,
size [C, H, W, D] for 3D data.
D is the depth of 3D volume, C is the channel of a scan/frame,
which is commonly 1 for CT, MRI, US data.
If processing, say like a colorful surgical video, D could be
the number of time frames, and C will be 3 for a RGB frame.
## Run on your own dataset
It is simple to run omeprompt on the other datasets. Just write another dataset class following which in `` ./dataset.py``. You only need to make sure you return a dict with
'label': The target masks. Same size with the images except the
resolutions (H and W).
'p_label': The prompt label to decide positive/negative prompt.
To simplify, you can always set 1 if don't need the
negative prompt function.
{
'image': A tensor saving images with size [C,H,W] for 2D image, size [C, H, W, D] for 3D data.
D is the depth of 3D volume, C is the channel of a scan/frame, which is commonly 1 for CT, MRI, US data.
If processing, say like a colorful surgical video, D could the number of time frames, and C will be 3 for a RGB frame.
'pt': The prompt. e.g., a click prompt should be [x of click, y of click],
one click for each scan/frame if using 3d data.
'label': The target masks. Same size with the images except the resolutions (H and W).
'image_meta_dict': Optional. if you want save/visualize the result,
you should put the name of the image in it with
the key ['filename_or_obj'].
'p_label': The prompt label to decide positive/negative prompt. To simplify, you can always set 1 if don't need the negative prompt function.
...(others as you want)
}
```
'pt': The prompt. e.g., a click prompt should be [x of click, y of click], one click for each scan/frame if using 3d data.
## API Usage (New)
'image_meta_dict': Optional. if you want save/visulize the result, you should put the name of the image in it with the key ['filename_or_obj'].
After installing with `pip install -e .`, you can use the package programmatically:
...(others as you want)
}
```python
from oneprompt_seg import get_network, CombinedPolypDataset
from oneprompt_seg.utils import eval_seg, vis_image, create_logger
## Cite
# Build network
model = get_network(args, 'oneprompt')
# Create dataset
dataset = CombinedPolypDataset(args, data_path, transform=transform)
# Evaluate
iou, dice = eval_seg(predictions, targets, threshold=(0.5,))
```
## Configuration
The project supports YAML configuration files in `configs/`:
- `default.yaml` - Default training configuration
- `development.yaml` - Development/debugging configuration
- `production.yaml` - Production training configuration
## Development
### Running Tests
```bash
# Run all tests
pytest tests/ -v
# Run with coverage
pytest tests/ -v --cov=src/oneprompt_seg --cov-report=html
```
### Code Quality
```bash
# Format code
black src/ tests/ scripts/
isort src/ tests/ scripts/
# Lint code
flake8 src/ tests/
mypy src/
```
## Cite
```bibtex
@InProceedings{Wu_2024_CVPR,
author = {Wu, Junde and Xu, Min},
title = {One-Prompt to Segment All Medical Images},
@ -106,6 +274,6 @@ It is simple to run omeprompt on the other datasets. Just write another dataset
}
```
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

@ -0,0 +1,53 @@
# One-Prompt Medical Image Segmentation - Development Configuration
project:
name: "one-prompt-segmentation"
version: "1.0.0"
description: "开发环境配置"
mode: "development"
# 数据配置
data:
dataset: "polyp"
data_path: "./data/TestDataset"
train_ratio: 0.8
batch_size: 2
num_workers: 2
# 模型配置
model:
net: "oneprompt"
baseline: "unet"
mod: "one_adpt"
image_size: 256
out_size: 256
patch_size: 16
dim: 256
depth: 1
heads: 16
mlp_dim: 1024
# 训练配置 (开发环境使用较小的配置)
training:
epochs: 10
learning_rate: 0.0001
optimizer: "adam"
weight_decay: 0.0
scheduler:
name: "step"
step_size: 5
gamma: 0.5
gradient_clip: 1.0
# 验证配置
validation:
val_freq: 2
vis_freq: 10
# 日志配置
logging:
log_dir: "logs/dev"
tensorboard: true
save_best: true
checkpoint_freq: 5
log_level: "DEBUG"

@ -0,0 +1,59 @@
# One-Prompt Medical Image Segmentation - Production Configuration
project:
name: "one-prompt-segmentation"
version: "1.0.0"
description: "生产环境配置"
mode: "production"
# 数据配置
data:
dataset: "oneprompt"
data_path: "/data/medical_images"
train_ratio: 0.8
batch_size: 8
num_workers: 8
# 模型配置
model:
net: "oneprompt"
baseline: "unet"
mod: "one_adpt"
image_size: 1024
out_size: 256
patch_size: 16
dim: 256
depth: 1
heads: 16
mlp_dim: 1024
# 训练配置
training:
epochs: 30000
learning_rate: 0.0001
optimizer: "adam"
weight_decay: 0.0
scheduler:
name: "step"
step_size: 10
gamma: 0.5
early_stopping_patience: 50
gradient_clip: 1.0
# 验证配置
validation:
val_freq: 100
vis_freq: 50
# 日志配置
logging:
log_dir: "logs/production"
tensorboard: true
save_best: true
checkpoint_freq: 100
log_level: "INFO"
# 分布式训练
distributed:
enabled: false
gpu_ids: "0,1,2,3"

@ -0,0 +1,31 @@
# Documentation
This directory contains the project documentation.
## Structure
```
docs/
├── api/ # API reference documentation
├── guides/ # User guides and tutorials
├── examples/ # Example code and notebooks
└── index.rst # Main documentation index
```
## Building Documentation
To build the documentation locally:
```bash
cd docs
pip install sphinx sphinx-rtd-theme
make html
```
The built documentation will be available in `docs/_build/html/`.
## Documentation Style
- Use Google-style docstrings for all modules, classes, and functions
- Include type hints in function signatures
- Provide examples in docstrings where appropriate

@ -0,0 +1,72 @@
# Sphinx configuration file
import os
import sys
sys.path.insert(0, os.path.abspath('../src'))
# Project information
project = 'One-Prompt Medical Image Segmentation'
copyright = '2024, One-Prompt Team'
author = 'One-Prompt Team'
version = '1.0.0'
release = '1.0.0'
# Extensions
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.mathjax',
'sphinx.ext.intersphinx',
]
# Napoleon settings for Google-style docstrings
napoleon_google_docstring = True
napoleon_numpy_docstring = True
napoleon_include_init_with_doc = True
napoleon_include_private_with_doc = False
napoleon_include_special_with_doc = True
# Mock imports for documentation building
autodoc_mock_imports = [
'torch',
'torchvision',
'numpy',
'pandas',
'monai',
'PIL',
'cv2',
'skimage',
'einops',
'tensorboardX',
'sklearn',
'matplotlib',
'seaborn',
'tqdm',
'dateutil',
]
# Intersphinx mapping
intersphinx_mapping = {
'python': ('https://docs.python.org/3', None),
'torch': ('https://pytorch.org/docs/stable/', None),
'numpy': ('https://numpy.org/doc/stable/', None),
}
# HTML theme
html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']
# Source file settings
source_suffix = '.rst'
master_doc = 'index'
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# Autodoc settings
autodoc_default_options = {
'members': True,
'member-order': 'bysource',
'special-members': '__init__',
'undoc-members': True,
'exclude-members': '__weakref__'
}

@ -0,0 +1,49 @@
One-Prompt Medical Image Segmentation
=====================================
Welcome to the documentation for One-Prompt Medical Image Segmentation.
.. toctree::
:maxdepth: 2
:caption: Contents:
installation
quickstart
api/index
examples/index
Overview
--------
One-Prompt Medical Image Segmentation is a deep learning framework for
medical image segmentation that uses a single prompt to segment various
types of medical images.
Features
--------
* **One-Prompt Learning**: Segment medical images using a single example
* **Multi-Dataset Support**: Works with ISIC, REFUGE, Polyp datasets
* **Flexible Architecture**: Based on SAM with custom modifications
* **Easy Integration**: Simple API for training and inference
Quick Start
-----------
.. code-block:: python
from oneprompt_seg.core import get_network
from oneprompt_seg.data import CombinedPolypDataset
# Load model
model = get_network(args, 'oneprompt')
# Create dataset
dataset = CombinedPolypDataset(args, data_path)
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`

@ -0,0 +1,151 @@
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "oneprompt-seg"
version = "1.0.0"
description = "One-Prompt to Segment All Medical Images"
readme = "README.md"
license = {text = "MIT"}
authors = [
{name = "One-Prompt Team", email = "oneprompt@example.com"}
]
keywords = [
"medical imaging",
"image segmentation",
"deep learning",
"pytorch",
"computer vision"
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"Intended Audience :: Healthcare Industry",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Medical Science Apps.",
]
requires-python = ">=3.8"
dependencies = [
"torch>=1.10.0",
"torchvision>=0.11.0",
"numpy>=1.19.0",
"pandas>=1.3.0",
"Pillow>=8.0.0",
"scikit-image>=0.18.0",
"scikit-learn>=0.24.0",
"monai>=0.8.0",
"einops>=0.4.0",
"tensorboardX>=2.4.0",
"tqdm>=4.60.0",
"python-dateutil>=2.8.0",
"PyYAML>=5.4.0",
"matplotlib>=3.3.0",
"seaborn>=0.11.0",
"opencv-python>=4.5.0",
]
[project.optional-dependencies]
dev = [
"pytest>=6.0.0",
"pytest-cov>=2.12.0",
"black>=21.0.0",
"isort>=5.9.0",
"flake8>=3.9.0",
"mypy>=0.900",
"pre-commit>=2.13.0",
]
docs = [
"sphinx>=4.0.0",
"sphinx-rtd-theme>=0.5.0",
"sphinx-autodoc-typehints>=1.12.0",
]
[project.urls]
Homepage = "https://github.com/oneprompt/One-Prompt-Medical-Image-Segmentation"
Documentation = "https://oneprompt-seg.readthedocs.io"
Repository = "https://github.com/oneprompt/One-Prompt-Medical-Image-Segmentation"
Issues = "https://github.com/oneprompt/One-Prompt-Medical-Image-Segmentation/issues"
[project.scripts]
oneprompt-train = "scripts.train:main"
oneprompt-val = "scripts.val:main"
[tool.setuptools.packages.find]
where = ["src"]
[tool.setuptools.package-data]
"*" = ["*.yaml", "*.yml", "*.json"]
# Black configuration
[tool.black]
line-length = 88
target-version = ['py38', 'py39', 'py310', 'py311']
include = '\.pyi?$'
extend-exclude = '''
/(
\.git
| \.mypy_cache
| \.pytest_cache
| \.venv
| build
| dist
| __pycache__
)/
'''
# isort configuration
[tool.isort]
profile = "black"
line_length = 88
skip_gitignore = true
known_first_party = ["oneprompt_seg"]
known_third_party = ["torch", "torchvision", "numpy", "pandas", "monai"]
# pytest configuration
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
python_functions = ["test_*"]
addopts = "-v --tb=short"
filterwarnings = [
"ignore::DeprecationWarning",
"ignore::UserWarning",
]
# mypy configuration
[tool.mypy]
python_version = "3.8"
warn_return_any = true
warn_unused_configs = true
ignore_missing_imports = true
exclude = [
"build",
"dist",
"docs",
]
# Coverage configuration
[tool.coverage.run]
source = ["src/oneprompt_seg"]
branch = true
omit = [
"*/tests/*",
"*/__pycache__/*",
]
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"raise AssertionError",
"raise NotImplementedError",
"if __name__ == .__main__.:",
]

@ -0,0 +1,21 @@
# Development dependencies for testing and code quality
# Testing
pytest>=6.0.0
pytest-cov>=2.12.0
# Code formatting
black>=21.0.0
isort>=5.9.0
# Linting
flake8>=3.9.0
mypy>=0.900
# Pre-commit hooks
pre-commit>=2.13.0
# Documentation
sphinx>=4.0.0
sphinx-rtd-theme>=0.5.0
sphinx-autodoc-typehints>=1.12.0

@ -0,0 +1,29 @@
# One-Prompt Medical Image Segmentation Dependencies
# Core dependencies
torch>=1.10.0
torchvision>=0.11.0
numpy>=1.19.0
pandas>=1.3.0
Pillow>=8.0.0
# Image processing
scikit-image>=0.18.0
scikit-learn>=0.24.0
opencv-python>=4.5.0
# Medical imaging
monai>=0.8.0
# Deep learning utilities
einops>=0.4.0
tensorboardX>=2.4.0
tqdm>=4.60.0
# Utilities
python-dateutil>=2.8.0
PyYAML>=5.4.0
# Visualization
matplotlib>=3.3.0
seaborn>=0.11.0

@ -0,0 +1,6 @@
"""
Scripts package for One-Prompt Medical Image Segmentation.
This package contains executable scripts for training, evaluation,
and utility operations.
"""

@ -0,0 +1,195 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Validation/Evaluation script for One-Prompt Medical Image Segmentation.
This script provides evaluation functionality for trained models.
Usage:
python scripts/val.py -net oneprompt -mod one_adpt -exp_name eval_exp \\
-dataset polyp -data_path ./data/polyp -weights ./checkpoints/best.pth
"""
import os
import sys
# Add project root to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from collections import OrderedDict
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# Local imports
import cfg
from conf import settings
from dataset import ISIC2016, REFUGE, PolypDataset, CombinedPolypDataset
from utils import (
get_network,
get_decath_loader,
create_logger,
set_log_dir,
)
import function
def main():
"""Main evaluation function."""
# Parse arguments
args = cfg.parse_args()
# Setup device
gpu_device = torch.device('cuda', args.gpu_device)
# Build network
net = get_network(
args, args.net,
use_gpu=args.gpu,
gpu_device=gpu_device,
distribution=args.distributed
)
# Load pretrained model
assert args.weights != 0, "Please specify model weights with -weights"
print(f'=> resuming from {args.weights}')
assert os.path.exists(args.weights)
checkpoint_file = os.path.join(args.weights)
assert os.path.exists(checkpoint_file)
loc = f'cuda:{args.gpu_device}'
checkpoint = torch.load(checkpoint_file, map_location=loc)
start_epoch = checkpoint['epoch']
best_tol = checkpoint['best_tol']
state_dict = checkpoint['state_dict']
if args.distributed != 'none':
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = 'module.' + k
new_state_dict[name] = v
else:
new_state_dict = state_dict
net.load_state_dict(new_state_dict)
# Setup logging
args.path_helper = set_log_dir('logs', args.exp_name)
logger = create_logger(args.path_helper['log_path'])
logger.info(args)
# Setup data transforms
transform_train = transforms.Compose([
transforms.Resize((args.image_size, args.image_size)),
transforms.ToTensor(),
])
transform_train_seg = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((args.image_size, args.image_size)),
])
transform_test = transforms.Compose([
transforms.Resize((args.image_size, args.image_size)),
transforms.ToTensor(),
])
transform_test_seg = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((args.image_size, args.image_size)),
])
# Load data based on dataset type
if args.dataset == 'isic':
isic_train_dataset = ISIC2016(
args, args.data_path,
transform=transform_train,
transform_msk=transform_train_seg,
mode='Training'
)
isic_test_dataset = ISIC2016(
args, args.data_path,
transform=transform_test,
transform_msk=transform_test_seg,
mode='Test'
)
nice_train_loader = DataLoader(
isic_train_dataset,
batch_size=args.b,
shuffle=True,
num_workers=8,
pin_memory=True
)
nice_test_loader = DataLoader(
isic_test_dataset,
batch_size=args.b,
shuffle=False,
num_workers=8,
pin_memory=True
)
elif args.dataset == 'oneprompt':
nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list = get_decath_loader(args)
elif args.dataset == 'REFUGE':
refuge_train_dataset = REFUGE(
args, args.data_path,
transform=transform_train,
transform_msk=transform_train_seg,
mode='Training'
)
refuge_test_dataset = REFUGE(
args, args.data_path,
transform=transform_test,
transform_msk=transform_test_seg,
mode='Test'
)
nice_train_loader = DataLoader(
refuge_train_dataset,
batch_size=args.b,
shuffle=True,
num_workers=8,
pin_memory=True
)
nice_test_loader = DataLoader(
refuge_test_dataset,
batch_size=args.b,
shuffle=False,
num_workers=8,
pin_memory=True
)
elif args.dataset == 'polyp':
transform_test_seg = transforms.Compose([
transforms.Resize((args.out_size, args.out_size)),
transforms.ToTensor(),
])
polyp_test_dataset = CombinedPolypDataset(
args, args.data_path,
transform=transform_test,
transform_msk=transform_test_seg,
mode='Test'
)
nice_test_loader = DataLoader(
polyp_test_dataset,
batch_size=args.b,
shuffle=False,
num_workers=8,
pin_memory=True
)
# Run evaluation
if args.mod == 'sam_adpt' or args.mod == 'one_adpt':
net.eval()
tol, (eiou, edice) = function.validation_one(args, nice_test_loader, 0, net)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {start_epoch}.')
print(f'\nEvaluation Results:')
print(f' Total Score: {tol}')
print(f' IoU: {eiou}')
print(f' Dice: {edice}')
if __name__ == '__main__':
main()

@ -0,0 +1,206 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Training script for One-Prompt Medical Image Segmentation.
This script provides the main entry point for training the One-Prompt
segmentation model on various medical imaging datasets.
Usage:
python scripts/train.py -net oneprompt -mod one_adpt -exp_name experiment1 \\
-dataset polyp -data_path ./data/polyp
Example:
python scripts/train.py \\
-net oneprompt \\
-mod one_adpt \\
-exp_name polyp_training \\
-dataset polyp \\
-data_path /path/to/data
"""
import os
import sys
import time
# Add project root to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
# Local imports
import cfg
from conf import settings
from dataset import CombinedPolypDataset
from utils import (
get_network,
get_decath_loader,
create_logger,
set_log_dir,
save_checkpoint,
)
import function
def main():
"""Main training function."""
# Parse arguments
args = cfg.parse_args()
# Setup device
gpu_device = torch.device('cuda', args.gpu_device)
# Build network
net = get_network(
args, args.net,
use_gpu=args.gpu,
gpu_device=gpu_device,
distribution=args.distributed
)
# Setup optimizer and scheduler
optimizer = optim.Adam(
net.parameters(),
lr=args.lr,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0,
amsgrad=False
)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
# Load pretrained model if specified
start_epoch = 0
best_tol = 1e4
if args.weights != 0:
print(f'=> resuming from {args.weights}')
assert os.path.exists(args.weights)
checkpoint_file = os.path.join(args.weights)
assert os.path.exists(checkpoint_file)
loc = f'cuda:{args.gpu_device}'
checkpoint = torch.load(checkpoint_file, map_location=loc)
start_epoch = checkpoint['epoch']
best_tol = checkpoint['best_tol']
net.load_state_dict(checkpoint['state_dict'], strict=False)
args.path_helper = checkpoint['path_helper']
logger = create_logger(args.path_helper['log_path'])
print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
# Setup logging
args.path_helper = set_log_dir('logs', args.exp_name)
logger = create_logger(args.path_helper['log_path'])
logger.info(args)
# Load data
if args.dataset == 'oneprompt':
nice_train_loader, nice_test_loader, transform_train, transform_val, train_list, val_list = get_decath_loader(args)
elif args.dataset == 'polyp':
# Polyp dataset
transform_train = transforms.Compose([
transforms.Resize((args.image_size, args.image_size)),
transforms.ToTensor(),
])
transform_train_seg = transforms.Compose([
transforms.Resize((args.out_size, args.out_size)),
transforms.ToTensor(),
])
transform_test = transforms.Compose([
transforms.Resize((args.image_size, args.image_size)),
transforms.ToTensor(),
])
transform_test_seg = transforms.Compose([
transforms.Resize((args.out_size, args.out_size)),
transforms.ToTensor(),
])
train_dataset = CombinedPolypDataset(
args, args.data_path,
transform=transform_train,
transform_msk=transform_train_seg,
mode='Training'
)
test_dataset = CombinedPolypDataset(
args, args.data_path,
transform=transform_test,
transform_msk=transform_test_seg,
mode='Test'
)
nice_train_loader = DataLoader(
train_dataset,
batch_size=args.b,
shuffle=True,
num_workers=args.w,
pin_memory=True
)
nice_test_loader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=args.w,
pin_memory=True
)
# Setup checkpoint path and tensorboard
checkpoint_path = os.path.join(settings.CHECKPOINT_PATH, args.net, settings.TIME_NOW)
if not os.path.exists(settings.LOG_DIR):
os.mkdir(settings.LOG_DIR)
writer = SummaryWriter(
log_dir=os.path.join(settings.LOG_DIR, args.net, settings.TIME_NOW)
)
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)
checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth')
# Training loop
best_acc = 0.0
for epoch in range(settings.EPOCH):
net.train()
time_start = time.time()
loss = function.train_one(
args, net, optimizer, nice_train_loader, epoch, writer, vis=args.vis
)
logger.info(f'Train loss: {loss}|| @ epoch {epoch}.')
time_end = time.time()
print(f'time_for_training {time_end - time_start}')
net.eval()
if epoch and epoch % args.val_freq == 0 or epoch == settings.EPOCH - 1:
tol, (eiou, edice) = function.validation_one(
args, nice_test_loader, epoch, net, writer
)
logger.info(f'Total score: {tol}, IOU: {eiou}, DICE: {edice} || @ epoch {epoch}.')
if args.distributed != 'none':
sd = net.module.state_dict()
else:
sd = net.state_dict()
if tol < best_tol:
best_tol = tol
is_best = True
save_checkpoint({
'epoch': epoch + 1,
'model': args.net,
'state_dict': sd,
'optimizer': optimizer.state_dict(),
'best_tol': best_tol,
'path_helper': args.path_helper,
}, is_best, args.path_helper['ckpt_path'], filename="best_checkpoint")
else:
is_best = False
writer.close()
logger.info("Training completed!")
if __name__ == '__main__':
main()

@ -0,0 +1,52 @@
"""
One-Prompt Medical Image Segmentation
A deep learning framework for medical image segmentation using one-prompt learning.
"""
__version__ = "1.0.0"
__author__ = "One-Prompt Team"
# Lazy imports to avoid dependency issues during installation
def __getattr__(name):
"""Lazy import of submodules."""
if name == "get_network":
from .core.network import get_network
return get_network
elif name in ("train_one", "validation_one"):
from .core import function
return getattr(function, name)
elif name in ("ISIC2016", "REFUGE", "PolypDataset", "CombinedPolypDataset"):
from .data import datasets
return getattr(datasets, name)
elif name in ("create_logger", "set_log_dir", "save_checkpoint"):
from .utils import logging
return getattr(logging, name)
elif name in ("eval_seg", "vis_image", "DiceMetric"):
if name == "DiceMetric":
from .utils.metrics import DiceMetric
return DiceMetric
elif name == "eval_seg":
from .utils.metrics import eval_seg
return eval_seg
elif name == "vis_image":
from .utils.visualization import vis_image
return vis_image
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
__all__ = [
"get_network",
"train_one",
"validation_one",
"ISIC2016",
"REFUGE",
"PolypDataset",
"CombinedPolypDataset",
"create_logger",
"set_log_dir",
"save_checkpoint",
"eval_seg",
"vis_image",
"DiceMetric",
]

@ -0,0 +1,17 @@
"""
Core module for One-Prompt Medical Image Segmentation.
This module contains the core functionality including:
- Network builders
- Model definitions
- Training and validation functions
"""
from .network import get_network
from .function import train_one, validation_one
__all__ = [
"get_network",
"train_one",
"validation_one",
]

@ -0,0 +1,301 @@
"""
Training and validation functions.
This module provides the core training and validation logic
for medical image segmentation models.
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tqdm import tqdm
from einops import rearrange
from monai.losses import DiceCELoss
from ..utils.metrics import eval_seg, DiceMetric
from ..utils.visualization import vis_image
from ..utils.data_utils import generate_click_prompt
# Global variables for training
_args = None
_gpu_device = None
_criterion_g = None
_scaler = None
def _init_training_globals(args):
"""Initialize global variables for training."""
global _args, _gpu_device, _criterion_g, _scaler
_args = args
_gpu_device = torch.device('cuda', args.gpu_device)
pos_weight = torch.ones([1]).cuda(device=_gpu_device) * 2
_criterion_g = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
_scaler = torch.cuda.amp.GradScaler()
def train_one(args, net: nn.Module, optimizer, train_loader,
epoch: int, writer=None, schedulers=None, vis: int = 50) -> float:
"""Train the model for one epoch.
Args:
args: Training arguments.
net: The neural network model.
optimizer: The optimizer for training.
train_loader: DataLoader for training data.
epoch: Current epoch number.
writer: TensorBoard writer for logging.
schedulers: Learning rate schedulers.
vis: Visualization frequency (0 to disable).
Returns:
The final batch loss value.
"""
_init_training_globals(args)
epoch_loss = 0
ind = 0
net.train()
optimizer.zero_grad()
model = net.module if hasattr(net, 'module') else net
gpu_device = torch.device('cuda:' + str(args.gpu_device))
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = _criterion_g
with tqdm(total=len(train_loader), desc=f'Epoch {epoch}', unit='img') as pbar:
for pack in train_loader:
current_b = pack['image'].size(0)
if ind == 0:
tmp_img = pack['image'].to(dtype=torch.float32, device=gpu_device)[0, :, :, :].unsqueeze(0).repeat(current_b, 1, 1, 1)
tmp_mask = pack['label'].to(dtype=torch.float32, device=gpu_device)[0, :, :, :].unsqueeze(0).repeat(current_b, 1, 1, 1)
if 'pt' not in pack:
tmp_img, pt, tmp_mask = generate_click_prompt(tmp_img, tmp_mask)
else:
pt = pack['pt']
point_labels = pack['p_label']
if point_labels[0] != -1:
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=gpu_device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=gpu_device)
coords_torch = coords_torch[0:1, :].repeat(current_b, 1)
labels_torch = labels_torch[0:1].repeat(current_b)
coords_torch, labels_torch = coords_torch[:, None, :], labels_torch[:, None]
tmp_pt = (coords_torch, labels_torch)
else:
if tmp_img.size(0) != current_b:
tmp_img = tmp_img[0:1].repeat(current_b, 1, 1, 1)
tmp_mask = tmp_mask[0:1].repeat(current_b, 1, 1, 1)
if 'tmp_pt' in dir():
coords_torch = tmp_pt[0][0:1].repeat(current_b, 1, 1)
labels_torch = tmp_pt[1][0:1].repeat(current_b, 1)
tmp_pt = (coords_torch, labels_torch)
imgs = pack['image'].to(dtype=torch.float32, device=gpu_device)
masks = pack['label'].to(dtype=torch.float32, device=gpu_device)
name = pack['image_meta_dict']['filename_or_obj']
if 'pt' in pack:
pt = pack['pt']
point_labels = pack['p_label']
if point_labels[0] != -1:
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=gpu_device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=gpu_device)
coords_torch, labels_torch = coords_torch[:, None, :], labels_torch[:, None]
pt = (coords_torch, labels_torch)
if args.thd:
pt = rearrange(pt, 'b n d -> (b d) n')
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
imgs = imgs.repeat(1, 3, 1, 1)
point_labels = torch.ones(imgs.size(0))
imgs = torchvision.transforms.Resize((args.image_size, args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size, args.out_size))(masks)
showp = pt
mask_type = torch.float32
ind += 1
b_size, c, w, h = imgs.size()
longsize = w if w >= h else h
imgs = imgs.to(dtype=mask_type, device=gpu_device)
with torch.amp.autocast('cuda'):
with torch.no_grad():
imge, skips = model.image_encoder(imgs)
timge, tskips = model.image_encoder(tmp_img)
p1, p2, se, de = model.prompt_encoder(
points=pt,
boxes=None,
doodles=None,
masks=None,
)
torch.cuda.empty_cache()
pred, _ = model.mask_decoder(
skips_raw=skips,
skips_tmp=tskips,
raw_emb=imge,
tmp_emb=timge,
pt1=p1,
pt2=p2,
image_pe=model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
if pred.shape[-2:] != masks.shape[-2:]:
pred = F.interpolate(pred, size=masks.shape[-2:], mode='bilinear', align_corners=False)
loss = lossfunc(pred, masks)
if torch.isnan(loss) or torch.isinf(loss):
optimizer.zero_grad()
pbar.set_postfix(**{'loss (batch)': 'nan/inf skipped'})
pbar.update()
ind += 1
continue
pbar.set_postfix(**{'loss (batch)': loss.item()})
epoch_loss += loss.item()
_scaler.scale(loss).backward()
_scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
_scaler.step(optimizer)
_scaler.update()
optimizer.zero_grad()
if vis:
if ind % vis == 0:
namecat = 'Train'
for na in name:
namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
vis_image(imgs, pred, masks,
os.path.join(args.path_helper['sample_path'],
namecat + 'epoch+' + str(epoch) + '.jpg'),
reverse=False)
pbar.update()
return loss
def validation_one(args, val_loader, epoch: int, net: nn.Module,
writer=None, clean_dir: bool = True):
"""Validate the model.
Args:
args: Validation arguments.
val_loader: DataLoader for validation data.
epoch: Current epoch number.
net: The neural network model.
writer: TensorBoard writer for logging.
clean_dir: Whether to clean the output directory.
Returns:
Tuple of (total_loss, (iou, dice)).
"""
_init_training_globals(args)
net.eval()
model = net.module if hasattr(net, 'module') else net
mask_type = torch.float32
n_val = len(val_loader)
mix_res = (0, 0, 0, 0)
tot = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
gpu_device = torch.device('cuda:' + str(args.gpu_device))
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = _criterion_g
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
for ind, pack in enumerate(val_loader):
if ind == 0:
tmp_img = pack['image'].to(dtype=torch.float32, device=gpu_device)[0, :, :, :].unsqueeze(0).repeat(args.b, 1, 1, 1)
tmp_mask = pack['label'].to(dtype=torch.float32, device=gpu_device)[0, :, :, :].unsqueeze(0).repeat(args.b, 1, 1, 1)
if 'pt' not in pack:
tmp_img, pt, tmp_mask = generate_click_prompt(tmp_img, tmp_mask)
else:
pt = pack['pt']
point_labels = pack['p_label']
if point_labels[0] != -1:
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=gpu_device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=gpu_device)
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
pt = (coords_torch, labels_torch)
imgs = pack['image'].to(dtype=torch.float32, device=gpu_device)
masks = pack['label'].to(dtype=torch.float32, device=gpu_device)
name = pack['image_meta_dict']['filename_or_obj']
showp = pt
mask_type = torch.float32
ind += 1
b_size, c, w, h = imgs.size()
longsize = w if w >= h else h
imgs = imgs.to(dtype=mask_type, device=gpu_device)
with torch.no_grad():
imge, skips = model.image_encoder(imgs)
timge, tskips = model.image_encoder(tmp_img)
p1, p2, se, de = model.prompt_encoder(
points=pt,
boxes=None,
doodles=None,
masks=None,
)
pred, _ = model.mask_decoder(
skips_raw=skips,
skips_tmp=tskips,
raw_emb=imge,
tmp_emb=timge,
pt1=p1,
pt2=p2,
image_pe=model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
if pred.shape[-2:] != masks.shape[-2:]:
pred = F.interpolate(pred, size=masks.shape[-2:], mode='bilinear', align_corners=False)
tot += lossfunc(pred, masks)
if args.vis and ind % args.vis == 0:
namecat = 'Test'
for na in name:
img_name = na.split('/')[-1].split('.')[0]
namecat = namecat + img_name + '+'
vis_image(imgs, pred, masks,
os.path.join(args.path_helper['sample_path'],
namecat + 'epoch+' + str(epoch) + '.jpg'),
reverse=False)
temp = eval_seg(pred, masks, threshold)
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
pbar.update()
return tot / n_val, tuple([a / n_val for a in mix_res])

@ -0,0 +1,54 @@
"""
Network builder module.
This module provides functions to build and initialize neural networks
for medical image segmentation.
"""
import sys
import torch
import torch.nn as nn
def get_network(args, net: str, use_gpu: bool = True,
gpu_device: torch.device = None,
distribution: str = 'none') -> nn.Module:
"""Return the specified network architecture.
Args:
args: Arguments containing model configuration.
net: Network type identifier (e.g., 'oneprompt').
use_gpu: Whether to use GPU acceleration.
gpu_device: The GPU device to use.
distribution: Distributed training configuration.
- 'none': No distribution
- comma-separated GPU IDs for DataParallel
Returns:
The initialized neural network model.
Raises:
SystemExit: If the network type is not supported.
"""
if gpu_device is None:
gpu_device = torch.device('cuda', args.gpu_device)
if net == 'oneprompt':
from models.oneprompt import OnePredictor, one_model_registry
from models.oneprompt.utils.transforms import ResizeLongestSide
model = one_model_registry[args.baseline](args).to(gpu_device)
else:
print('the network name you have entered is not supported yet')
sys.exit()
if use_gpu:
if distribution != 'none':
model = torch.nn.DataParallel(
model,
device_ids=[int(id) for id in distribution.split(',')]
)
model = model.to(device=gpu_device)
else:
model = model.to(device=gpu_device)
return model

@ -0,0 +1,16 @@
"""
Data module for One-Prompt Medical Image Segmentation.
This module contains dataset classes for various medical imaging datasets.
"""
from .datasets import ISIC2016, REFUGE, PolypDataset, CombinedPolypDataset
from .loader import get_decath_loader
__all__ = [
"ISIC2016",
"REFUGE",
"PolypDataset",
"CombinedPolypDataset",
"get_decath_loader",
]

@ -0,0 +1,357 @@
"""
Dataset classes for medical image segmentation.
This module provides dataset implementations for various medical imaging datasets:
- ISIC2016: Skin lesion segmentation
- REFUGE: Optic disc/cup segmentation
- PolypDataset: Polyp segmentation
- CombinedPolypDataset: Combined polyp datasets
"""
import os
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torch.nn.functional as F
import pandas as pd
from ..utils.data_utils import random_click
class ISIC2016(Dataset):
"""ISIC 2016 skin lesion segmentation dataset.
Args:
args: Arguments containing image_size and other configurations.
data_path: Path to the dataset directory.
transform: Transform to apply to images.
transform_msk: Transform to apply to masks.
mode: 'Training' or 'Test'.
prompt: Prompt type ('click').
plane: Whether to use plane mode.
"""
def __init__(self, args, data_path: str, transform=None,
transform_msk=None, mode: str = 'Training',
prompt: str = 'click', plane: bool = False):
df = pd.read_csv(
os.path.join(data_path, 'ISBI2016_ISIC_Part3B_' + mode + '_GroundTruth.csv'),
encoding='gbk'
)
self.name_list = df.iloc[:, 1].tolist()
self.label_list = df.iloc[:, 2].tolist()
self.data_path = data_path
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self) -> int:
return len(self.name_list)
def __getitem__(self, index: int) -> dict:
inout = 1
point_label = 1
name = self.name_list[index]
img_path = os.path.join(self.data_path, name)
mask_name = self.label_list[index]
msk_path = os.path.join(self.data_path, mask_name)
img = Image.open(img_path).convert('RGB')
mask = Image.open(msk_path).convert('L')
newsize = (self.img_size, self.img_size)
mask = mask.resize(newsize)
if self.prompt == 'click':
pt = random_click(np.array(mask) / 255, point_label, inout)
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
name = name.split('/')[-1].split(".jpg")[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'label': mask,
'p_label': point_label,
'pt': pt,
'image_meta_dict': image_meta_dict,
}
class PolypDataset(Dataset):
"""Polyp segmentation dataset.
Supports CVC-300, CVC-ClinicDB, CVC-ColonDB, ETIS-LaribPolypDB, Kvasir datasets.
Directory structure:
data_path/
images/
xxx.png
masks/
xxx.png
Args:
args: Arguments containing image_size and out_size.
data_path: Path to the dataset directory.
transform: Transform to apply to images.
transform_msk: Transform to apply to masks.
mode: 'Training' or 'Test'.
prompt: Prompt type ('click').
plane: Whether to use plane mode.
"""
def __init__(self, args, data_path: str, transform=None,
transform_msk=None, mode: str = 'Training',
prompt: str = 'click', plane: bool = False):
self.data_path = data_path
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.out_size = args.out_size
self.transform = transform
self.transform_msk = transform_msk
img_dir = os.path.join(data_path, 'images')
self.img_list = sorted([
f for f in os.listdir(img_dir)
if f.endswith(('.png', '.jpg', '.jpeg'))
])
split_idx = int(len(self.img_list) * 0.8)
if mode == 'Training':
self.img_list = self.img_list[:split_idx]
else:
self.img_list = self.img_list[split_idx:]
def __len__(self) -> int:
return len(self.img_list)
def __getitem__(self, index: int) -> dict:
point_label = 1
inout = 1
img_name = self.img_list[index]
img_path = os.path.join(self.data_path, 'images', img_name)
mask_path = os.path.join(self.data_path, 'masks', img_name)
img = Image.open(img_path).convert('RGB')
mask = Image.open(mask_path).convert('L')
newsize = (self.img_size, self.img_size)
mask_resized = mask.resize(newsize)
if self.prompt == 'click':
pt = random_click(np.array(mask_resized) / 255, point_label, inout)
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
if self.transform_msk:
mask = self.transform_msk(mask)
name = img_name.split('.')[0]
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'label': mask,
'p_label': point_label,
'pt': pt,
'image_meta_dict': image_meta_dict,
}
class CombinedPolypDataset(Dataset):
"""Combined polyp dataset for training on multiple datasets.
Args:
args: Arguments containing image_size and out_size.
data_path: Path to the root directory containing dataset folders.
transform: Transform to apply to images.
transform_msk: Transform to apply to masks.
mode: 'Training' or 'Test'.
prompt: Prompt type ('click').
plane: Whether to use plane mode.
"""
def __init__(self, args, data_path: str, transform=None,
transform_msk=None, mode: str = 'Training',
prompt: str = 'click', plane: bool = False):
self.datasets = []
dataset_dirs = [
'CVC-300', 'CVC-ClinicDB', 'CVC-ColonDB',
'ETIS-LaribPolypDB', 'Kvasir'
]
for dataset_dir in dataset_dirs:
full_path = os.path.join(data_path, dataset_dir)
if os.path.exists(full_path):
ds = PolypDataset(
args, full_path, transform, transform_msk,
mode, prompt, plane
)
if len(ds) > 0:
self.datasets.append(ds)
print(f"Loaded {dataset_dir}: {len(ds)} samples ({mode})")
self.cumulative_sizes = []
total = 0
for ds in self.datasets:
total += len(ds)
self.cumulative_sizes.append(total)
print(f"Total {mode} samples: {total}")
def __len__(self) -> int:
return self.cumulative_sizes[-1] if self.cumulative_sizes else 0
def __getitem__(self, index: int) -> dict:
for i, cumsize in enumerate(self.cumulative_sizes):
if index < cumsize:
if i == 0:
return self.datasets[i][index]
else:
return self.datasets[i][index - self.cumulative_sizes[i - 1]]
raise IndexError("Index out of range")
class REFUGE(Dataset):
"""REFUGE optic disc/cup segmentation dataset with multi-rater annotations.
Args:
args: Arguments containing image_size and out_size.
data_path: Path to the dataset directory.
transform: Transform to apply to images.
transform_msk: Transform to apply to masks.
mode: 'Training' or 'Test'.
prompt: Prompt type ('click').
plane: Whether to use plane mode.
"""
def __init__(self, args, data_path: str, transform=None,
transform_msk=None, mode: str = 'Training',
prompt: str = 'click', plane: bool = False):
self.data_path = data_path
self.subfolders = [
f.path for f in os.scandir(os.path.join(data_path, mode + '-400'))
if f.is_dir()
]
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.mask_size = args.out_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self) -> int:
return len(self.subfolders)
def __getitem__(self, index: int) -> dict:
inout = 1
point_label = 1
subfolder = self.subfolders[index]
name = subfolder.split('/')[-1]
img_path = os.path.join(subfolder, name + '.jpg')
multi_rater_cup_path = [
os.path.join(subfolder, name + '_seg_cup_' + str(i) + '.png')
for i in range(1, 8)
]
multi_rater_disc_path = [
os.path.join(subfolder, name + '_seg_disc_' + str(i) + '.png')
for i in range(1, 8)
]
img = Image.open(img_path).convert('RGB')
multi_rater_cup = [Image.open(path).convert('L') for path in multi_rater_cup_path]
multi_rater_disc = [Image.open(path).convert('L') for path in multi_rater_disc_path]
newsize = (self.img_size, self.img_size)
multi_rater_cup_np = [
np.array(single_rater.resize(newsize))
for single_rater in multi_rater_cup
]
multi_rater_disc_np = [
np.array(single_rater.resize(newsize))
for single_rater in multi_rater_disc
]
if self.prompt == 'click':
pt_cup = random_click(
np.array(np.mean(np.stack(multi_rater_cup_np), axis=0)) / 255,
point_label, inout
)
pt_disc = random_click(
np.array(np.mean(np.stack(multi_rater_disc_np), axis=0)) / 255,
point_label, inout
)
if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
multi_rater_cup = [
torch.as_tensor(
(self.transform(single_rater) > 0.5).float(),
dtype=torch.float32
)
for single_rater in multi_rater_cup
]
multi_rater_cup = torch.stack(multi_rater_cup, dim=0)
mask_cup = F.interpolate(
multi_rater_cup,
size=(self.mask_size, self.mask_size),
mode='bilinear',
align_corners=False
).mean(dim=0)
multi_rater_disc = [
torch.as_tensor(
(self.transform(single_rater) > 0.5).float(),
dtype=torch.float32
)
for single_rater in multi_rater_disc
]
multi_rater_disc = torch.stack(multi_rater_disc, dim=0)
mask_disc = F.interpolate(
multi_rater_disc,
size=(self.mask_size, self.mask_size),
mode='bilinear',
align_corners=False
).mean(dim=0)
torch.set_rng_state(state)
image_meta_dict = {'filename_or_obj': name}
return {
'image': img,
'multi_rater_cup': multi_rater_cup,
'multi_rater_disc': multi_rater_disc,
'mask_cup': mask_cup,
'mask_disc': mask_disc,
'label': mask_disc,
'p_label': point_label,
'pt_cup': pt_cup,
'pt_disc': pt_disc,
'pt': pt_disc,
'selected_rater': torch.tensor(np.arange(7)),
'image_meta_dict': image_meta_dict,
}

@ -0,0 +1,161 @@
"""
Data loader utilities for medical imaging datasets.
This module provides data loading functions for various medical imaging formats,
particularly for MONAI-based volumetric data.
"""
import os
import torch
from monai.transforms import (
Compose,
CropForegroundd,
LoadImaged,
Orientationd,
RandFlipd,
RandCropByPosNegLabeld,
RandShiftIntensityd,
ScaleIntensityRanged,
Spacingd,
RandRotate90d,
EnsureTyped,
)
from monai.data import (
ThreadDataLoader,
CacheDataset,
load_decathlon_datalist,
set_track_meta,
)
def get_decath_loader(args):
"""Get data loaders for decathlon-style datasets.
Creates training and validation data loaders with appropriate transforms
for medical imaging data.
Args:
args: Arguments containing:
- data_path: Path to the dataset
- gpu_device: GPU device ID
- roi_size: Region of interest size
- chunk: Volume chunk size
- num_sample: Number of positive/negative samples
- b: Batch size
Returns:
Tuple containing:
- train_loader: DataLoader for training
- val_loader: DataLoader for validation
- train_transforms: Training transforms
- val_transforms: Validation transforms
- datalist: Training data list
- val_files: Validation data list
"""
device = torch.device('cuda', args.gpu_device)
train_transforms = Compose([
LoadImaged(keys=["image", "label"], ensure_channel_first=True),
ScaleIntensityRanged(
keys=["image"],
a_min=-175,
a_max=250,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
EnsureTyped(keys=["image", "label"], device=device, track_meta=False),
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(args.roi_size, args.roi_size, args.chunk),
pos=1,
neg=1,
num_samples=args.num_sample,
image_key="image",
image_threshold=0,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[0],
prob=0.10,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[1],
prob=0.10,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[2],
prob=0.10,
),
RandRotate90d(
keys=["image", "label"],
prob=0.10,
max_k=3,
),
RandShiftIntensityd(
keys=["image"],
offsets=0.10,
prob=0.50,
),
])
val_transforms = Compose([
LoadImaged(keys=["image", "label"], ensure_channel_first=True),
ScaleIntensityRanged(
keys=["image"],
a_min=-175,
a_max=250,
b_min=0.0,
b_max=1.0,
clip=True
),
CropForegroundd(keys=["image", "label"], source_key="image"),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
EnsureTyped(keys=["image", "label"], device=device, track_meta=True),
])
data_dir = args.data_path
split_json = "dataset_0.json"
datasets = os.path.join(data_dir, split_json)
datalist = load_decathlon_datalist(datasets, True, "training")
val_files = load_decathlon_datalist(datasets, True, "validation")
train_ds = CacheDataset(
data=datalist,
transform=train_transforms,
cache_num=24,
cache_rate=1.0,
num_workers=8,
)
train_loader = ThreadDataLoader(
train_ds, num_workers=0, batch_size=args.b, shuffle=True
)
val_ds = CacheDataset(
data=val_files,
transform=val_transforms,
cache_num=2,
cache_rate=1.0,
num_workers=0
)
val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)
set_track_meta(False)
return train_loader, val_loader, train_transforms, val_transforms, datalist, val_files

@ -0,0 +1,52 @@
"""
Utility modules for One-Prompt Medical Image Segmentation.
This module provides various utility functions including:
- Metrics (Dice, IoU)
- Visualization tools
- Data utilities
- Logging utilities
"""
# Use lazy imports to avoid dependency issues
__all__ = [
# Metrics
"DiceMetric",
"eval_seg",
"iou",
"dice_coeff",
"DiceCoeff",
# Visualization
"vis_image",
"save_image",
"make_grid",
# Data utilities
"random_click",
"generate_click_prompt",
# Logging
"create_logger",
"set_log_dir",
"save_checkpoint",
# Scheduler
"WarmUpLR",
]
def __getattr__(name):
"""Lazy import of submodules."""
if name in ("DiceMetric", "eval_seg", "iou", "dice_coeff", "DiceCoeff"):
from . import metrics
return getattr(metrics, name)
elif name in ("vis_image", "save_image", "make_grid"):
from . import visualization
return getattr(visualization, name)
elif name in ("random_click", "generate_click_prompt"):
from . import data_utils
return getattr(data_utils, name)
elif name in ("create_logger", "set_log_dir", "save_checkpoint"):
from . import logging
return getattr(logging, name)
elif name == "WarmUpLR":
from .scheduler import WarmUpLR
return WarmUpLR
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

@ -0,0 +1,78 @@
"""
Data utility functions for medical image segmentation.
This module provides utility functions for data processing,
including click prompt generation and random sampling.
"""
import random
import numpy as np
import torch
def random_click(mask: np.ndarray, point_labels: int = 1,
inout: int = 1) -> np.ndarray:
"""Generate a random click point on the mask.
Args:
mask: Binary mask array.
point_labels: Label for the point (1 for foreground).
inout: Value to match in the mask (1 for foreground).
Returns:
Coordinates of the random click point [y, x].
"""
indices = np.argwhere(mask == inout)
if len(indices) == 0:
# Return center if no valid points
h, w = mask.shape[:2]
return np.array([h // 2, w // 2])
return indices[np.random.randint(len(indices))]
def generate_click_prompt(img: torch.Tensor, msk: torch.Tensor,
pt_label: int = 1) -> tuple:
"""Generate click prompts from image and mask.
Creates point prompts by randomly selecting positions from the mask
for each slice in a 3D volume.
Args:
img: Input image tensor of shape (B, C, H, W, D).
msk: Mask tensor of shape (B, C, H, W, D).
pt_label: Point label value.
Returns:
Tuple of (image, point_coordinates, processed_mask).
"""
pt_list = []
msk_list = []
b, c, h, w, d = msk.size()
msk = msk[:, 0, :, :, :]
for i in range(d):
pt_list_s = []
msk_list_s = []
for j in range(b):
msk_s = msk[j, :, :, i]
indices = torch.nonzero(msk_s)
if indices.size(0) == 0:
random_index = torch.randint(0, h, (2,)).to(device=msk.device)
new_s = msk_s
else:
random_index = random.choice(indices)
label = msk_s[random_index[0], random_index[1]]
new_s = torch.zeros_like(msk_s)
new_s = (msk_s == label).to(dtype=torch.float)
pt_list_s.append(random_index)
msk_list_s.append(new_s)
pts = torch.stack(pt_list_s, dim=0)
msks = torch.stack(msk_list_s, dim=0)
pt_list.append(pts)
msk_list.append(msks)
pt = torch.stack(pt_list, dim=-1)
msk = torch.stack(msk_list, dim=-1)
msk = msk.unsqueeze(1)
return img, pt, msk

@ -0,0 +1,103 @@
"""
Logging utilities for training and experiment tracking.
This module provides functions for creating loggers, managing log directories,
and saving checkpoints.
"""
import os
import time
import logging
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, Optional
import dateutil.tz
import torch
def create_logger(log_dir: str, phase: str = 'train') -> logging.Logger:
"""Create a logger for training.
Args:
log_dir: Directory to save log files.
phase: Training phase name ('train', 'val', etc.).
Returns:
Configured logger instance.
"""
time_str = time.strftime('%Y-%m-%d-%H-%M')
log_file = '{}_{}.log'.format(time_str, phase)
final_log_file = os.path.join(log_dir, log_file)
head = '%(asctime)-15s %(message)s'
logging.basicConfig(filename=str(final_log_file), format=head)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
console = logging.StreamHandler()
logging.getLogger('').addHandler(console)
return logger
def set_log_dir(root_dir: str, exp_name: str) -> Dict[str, str]:
"""Set up logging directories for an experiment.
Creates the following directory structure:
root_dir/
exp_name_timestamp/
Model/ - for checkpoints
Log/ - for log files
Samples/ - for sample images
Args:
root_dir: Root directory for all experiments.
exp_name: Name of the experiment.
Returns:
Dictionary containing paths:
- 'prefix': Base experiment directory
- 'ckpt_path': Checkpoint directory
- 'log_path': Log file directory
- 'sample_path': Sample image directory
"""
path_dict = {}
os.makedirs(root_dir, exist_ok=True)
# Set log path
exp_path = os.path.join(root_dir, exp_name)
now = datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
prefix = exp_path + '_' + timestamp
os.makedirs(prefix)
path_dict['prefix'] = prefix
# Set checkpoint path
ckpt_path = os.path.join(prefix, 'Model')
os.makedirs(ckpt_path)
path_dict['ckpt_path'] = ckpt_path
log_path = os.path.join(prefix, 'Log')
os.makedirs(log_path)
path_dict['log_path'] = log_path
# Set sample image path
sample_path = os.path.join(prefix, 'Samples')
os.makedirs(sample_path)
path_dict['sample_path'] = sample_path
return path_dict
def save_checkpoint(states: Dict[str, Any], is_best: bool,
output_dir: str, filename: str = 'checkpoint.pth') -> None:
"""Save model checkpoint.
Args:
states: Dictionary containing model state and metadata.
is_best: Whether this is the best model so far.
output_dir: Directory to save the checkpoint.
filename: Name of the checkpoint file.
"""
torch.save(states, os.path.join(output_dir, filename))
if is_best:
torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth'))

@ -0,0 +1,157 @@
"""
Evaluation metrics for medical image segmentation.
This module provides various metrics for evaluating segmentation quality:
- Dice coefficient
- IoU (Intersection over Union)
- DiceMetric class for batch evaluation
"""
import numpy as np
import torch
from torch.autograd import Function
# Optional import for MONAI DiceMetric
try:
from monai.metrics import DiceMetric
except ImportError:
DiceMetric = None
def iou(outputs: np.ndarray, labels: np.ndarray) -> float:
"""Compute Intersection over Union (IoU) metric.
Args:
outputs: Predicted binary masks of shape (N, H, W).
labels: Ground truth binary masks of shape (N, H, W).
Returns:
Mean IoU score across all samples.
"""
SMOOTH = 1e-6
intersection = (outputs & labels).sum((1, 2))
union = (outputs | labels).sum((1, 2))
iou_score = (intersection + SMOOTH) / (union + SMOOTH)
return iou_score.mean()
class DiceCoeff(Function):
"""Dice coefficient for individual examples.
This is a custom autograd function for computing the Dice coefficient
with gradient support.
"""
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Compute Dice coefficient.
Args:
input: Predicted mask tensor.
target: Ground truth mask tensor.
Returns:
Dice coefficient value.
"""
self.save_for_backward(input, target)
eps = 0.0001
self.inter = torch.dot(input.view(-1), target.view(-1))
self.union = torch.sum(input) + torch.sum(target) + eps
t = (2 * self.inter.float() + eps) / self.union.float()
return t
def backward(self, grad_output):
"""Compute gradients for backpropagation.
Args:
grad_output: Gradient of the loss with respect to the output.
Returns:
Tuple of gradients for input and target.
"""
input, target = self.saved_variables
grad_input = grad_target = None
if self.needs_input_grad[0]:
grad_input = grad_output * 2 * (target * self.union - self.inter) \
/ (self.union * self.union)
if self.needs_input_grad[1]:
grad_target = None
return grad_input, grad_target
def dice_coeff(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Compute Dice coefficient for batches.
Args:
input: Predicted mask tensor of shape (N, H, W).
target: Ground truth mask tensor of shape (N, H, W).
Returns:
Mean Dice coefficient across the batch.
"""
if input.is_cuda:
s = torch.FloatTensor(1).to(device=input.device).zero_()
else:
s = torch.FloatTensor(1).zero_()
for i, c in enumerate(zip(input, target)):
s = s + DiceCoeff().forward(c[0], c[1])
return s / (i + 1)
def eval_seg(pred: torch.Tensor, true_mask_p: torch.Tensor,
threshold: tuple) -> tuple:
"""Evaluate segmentation predictions.
Args:
pred: Predicted masks of shape (B, C, H, W).
true_mask_p: Ground truth masks of shape (B, C, H, W).
threshold: Tuple of threshold values for binarization.
Returns:
Tuple of evaluation metrics (IoU and Dice scores).
"""
b, c, h, w = pred.size()
if c == 2:
# Two-channel output (disc and cup)
iou_d, iou_c, disc_dice, cup_dice = 0, 0, 0, 0
for th in threshold:
gt_vmask_p = (true_mask_p > th).float()
vpred = (pred > th).float()
vpred_cpu = vpred.cpu()
disc_pred = vpred_cpu[:, 0, :, :].numpy().astype('int32')
cup_pred = vpred_cpu[:, 1, :, :].numpy().astype('int32')
disc_mask = gt_vmask_p[:, 0, :, :].squeeze(1).cpu().numpy().astype('int32')
cup_mask = gt_vmask_p[:, 1, :, :].squeeze(1).cpu().numpy().astype('int32')
iou_d += iou(disc_pred, disc_mask)
iou_c += iou(cup_pred, cup_mask)
disc_dice += dice_coeff(vpred[:, 0, :, :], gt_vmask_p[:, 0, :, :]).item()
cup_dice += dice_coeff(vpred[:, 1, :, :], gt_vmask_p[:, 1, :, :]).item()
return (
iou_d / len(threshold),
iou_c / len(threshold),
disc_dice / len(threshold),
cup_dice / len(threshold)
)
else:
# Single-channel output
eiou, edice = 0, 0
for th in threshold:
gt_vmask_p = (true_mask_p > th).float()
vpred = (pred > th).float()
vpred_cpu = vpred.cpu()
disc_pred = vpred_cpu[:, 0, :, :].numpy().astype('int32')
disc_mask = gt_vmask_p[:, 0, :, :].squeeze(1).cpu().numpy().astype('int32')
eiou += iou(disc_pred, disc_mask)
edice += dice_coeff(vpred[:, 0, :, :], gt_vmask_p[:, 0, :, :]).item()
return eiou / len(threshold), edice / len(threshold)

@ -0,0 +1,38 @@
"""
Learning rate schedulers for training.
This module provides custom learning rate scheduler implementations.
"""
from torch.optim.lr_scheduler import _LRScheduler
class WarmUpLR(_LRScheduler):
"""Warmup learning rate scheduler.
Gradually increases the learning rate from 0 to the initial learning rate
during the warmup phase.
Args:
optimizer: Wrapped optimizer.
total_iters: Total number of warmup iterations.
last_epoch: The index of last epoch.
"""
def __init__(self, optimizer, total_iters: int, last_epoch: int = -1):
self.total_iters = total_iters
super().__init__(optimizer, last_epoch)
def get_lr(self):
"""Calculate learning rate for the current step.
During warmup, the learning rate is linearly scaled:
lr = base_lr * current_step / total_iters
Returns:
List of learning rates for each parameter group.
"""
return [
base_lr * self.last_epoch / (self.total_iters + 1e-8)
for base_lr in self.base_lrs
]

@ -0,0 +1,204 @@
"""
Visualization utilities for medical image segmentation.
This module provides functions for visualizing segmentation results,
creating image grids, and saving output images.
"""
import math
import warnings
import pathlib
from typing import Union, Optional, List, Tuple, Text, BinaryIO
import numpy as np
import torch
import torchvision.utils as vutils
from PIL import Image
@torch.no_grad()
def make_grid(
tensor: Union[torch.Tensor, List[torch.Tensor]],
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
value_range: Optional[Tuple[int, int]] = None,
scale_each: bool = False,
pad_value: int = 0,
**kwargs
) -> torch.Tensor:
"""Make a grid of images.
Args:
tensor: 4D mini-batch Tensor of shape (B, C, H, W)
or a list of images all of the same size.
nrow: Number of images displayed in each row of the grid.
padding: Amount of padding.
normalize: If True, shift the image to the range (0, 1).
value_range: Tuple (min, max) for normalization.
scale_each: If True, scale each image in the batch independently.
pad_value: Value for the padded pixels.
Returns:
Grid tensor of shape (C, H, W).
"""
if not (torch.is_tensor(tensor) or
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
if "range" in kwargs.keys():
warning = "range will be deprecated, please use value_range instead."
warnings.warn(warning)
value_range = kwargs["range"]
if isinstance(tensor, list):
tensor = torch.stack(tensor, dim=0)
if tensor.dim() == 2:
tensor = tensor.unsqueeze(0)
if tensor.dim() == 3:
if tensor.size(0) == 1:
tensor = torch.cat((tensor, tensor, tensor), 0)
tensor = tensor.unsqueeze(0)
if tensor.dim() == 4 and tensor.size(1) == 1:
tensor = torch.cat((tensor, tensor, tensor), 1)
if normalize is True:
tensor = tensor.clone()
if value_range is not None:
assert isinstance(value_range, tuple), \
"value_range has to be a tuple (min, max) if specified."
def norm_ip(img, low, high):
img.clamp(min=low, max=high)
img.sub_(low).div_(max(high - low, 1e-5))
def norm_range(t, value_range):
if value_range is not None:
norm_ip(t, value_range[0], value_range[1])
else:
norm_ip(t, float(t.min()), float(t.max()))
if scale_each is True:
for t in tensor:
norm_range(t, value_range)
else:
norm_range(tensor, value_range)
if tensor.size(0) == 1:
return tensor.squeeze(0)
nmaps = tensor.size(0)
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
num_channels = tensor.size(1)
grid = tensor.new_full(
(num_channels, height * ymaps + padding, width * xmaps + padding),
pad_value
)
k = 0
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
grid.narrow(1, y * height + padding, height - padding).narrow(
2, x * width + padding, width - padding
).copy_(tensor[k])
k = k + 1
return grid
@torch.no_grad()
def save_image(
tensor: Union[torch.Tensor, List[torch.Tensor]],
fp: Union[Text, pathlib.Path, BinaryIO],
format: Optional[str] = None,
**kwargs
) -> None:
"""Save a tensor as an image file.
Args:
tensor: Image tensor to be saved.
fp: Filename or file object.
format: Image format. If omitted, determined from filename extension.
**kwargs: Other arguments passed to make_grid.
"""
grid = make_grid(tensor, **kwargs)
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(fp, format=format)
def vis_image(imgs: torch.Tensor, pred_masks: torch.Tensor,
gt_masks: torch.Tensor, save_path: str,
reverse: bool = False, points=None) -> None:
"""Visualize segmentation results.
Creates a grid visualization with input images, predicted masks,
and ground truth masks.
Args:
imgs: Input images tensor of shape (B, C, H, W).
pred_masks: Predicted masks tensor of shape (B, C, H, W).
gt_masks: Ground truth masks tensor of shape (B, C, H, W).
save_path: Path to save the visualization.
reverse: If True, invert the mask colors.
points: Optional point annotations to overlay.
"""
import torchvision
b, c, h, w = pred_masks.size()
dev = pred_masks.get_device()
row_num = min(b, 4)
if torch.max(pred_masks) > 1 or torch.min(pred_masks) < 0:
pred_masks = torch.sigmoid(pred_masks)
if reverse:
pred_masks = 1 - pred_masks
gt_masks = 1 - gt_masks
if c == 2:
# Two-channel (disc and cup)
pred_disc = pred_masks[:, 0, :, :].unsqueeze(1).expand(b, 3, h, w)
pred_cup = pred_masks[:, 1, :, :].unsqueeze(1).expand(b, 3, h, w)
gt_disc = gt_masks[:, 0, :, :].unsqueeze(1).expand(b, 3, h, w)
gt_cup = gt_masks[:, 1, :, :].unsqueeze(1).expand(b, 3, h, w)
compose = torch.cat((
pred_disc[:row_num, :, :, :],
pred_cup[:row_num, :, :, :],
gt_disc[:row_num, :, :, :],
gt_cup[:row_num, :, :, :]
), 0)
vutils.save_image(compose, fp=save_path, nrow=row_num, padding=10)
else:
# Single channel
imgs = torchvision.transforms.Resize((h, w))(imgs)
if imgs.size(1) == 1:
imgs = imgs[:, 0, :, :].unsqueeze(1).expand(b, 3, h, w)
pred_masks = pred_masks[:, 0, :, :].unsqueeze(1).expand(b, 3, h, w)
gt_masks = gt_masks[:, 0, :, :].unsqueeze(1).expand(b, 3, h, w)
if points is not None:
# Import args for point visualization
import cfg
args = cfg.parse_args()
for i in range(b):
if args.thd:
p = np.round(points.cpu() / args.roi_size * args.out_size).to(dtype=torch.int)
else:
p = np.round(points.cpu() / args.image_size * args.out_size).to(dtype=torch.int)
gt_masks[i, 0, p[i, 0]-5:p[i, 0]+5, p[i, 1]-5:p[i, 1]+5] = 0.5
gt_masks[i, 1, p[i, 0]-5:p[i, 0]+5, p[i, 1]-5:p[i, 1]+5] = 0.1
gt_masks[i, 2, p[i, 0]-5:p[i, 0]+5, p[i, 1]-5:p[i, 1]+5] = 0.4
tup = (
imgs[:row_num, :, :, :],
pred_masks[:row_num, :, :, :],
gt_masks[:row_num, :, :, :]
)
compose = torch.cat(tup, 0)
vutils.save_image(compose, fp=save_path, nrow=row_num, padding=10)

@ -0,0 +1,5 @@
"""
Test package for One-Prompt Medical Image Segmentation.
This package contains unit tests and integration tests for the project.
"""

@ -0,0 +1,56 @@
"""
Unit tests for data utilities.
"""
import pytest
import numpy as np
import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from oneprompt_seg.utils.data_utils import random_click, generate_click_prompt
class TestRandomClick:
"""Tests for random_click function."""
def test_click_on_foreground(self):
"""Test that click is generated on foreground."""
mask = np.zeros((64, 64))
mask[20:40, 20:40] = 1
pt = random_click(mask, point_labels=1, inout=1)
assert 20 <= pt[0] < 40
assert 20 <= pt[1] < 40
def test_empty_mask(self):
"""Test click generation with empty mask."""
mask = np.zeros((64, 64))
pt = random_click(mask, point_labels=1, inout=1)
# Should return center when no valid points
assert pt[0] == 32
assert pt[1] == 32
def test_click_coordinates_shape(self):
"""Test that click returns correct shape."""
mask = np.ones((64, 64))
pt = random_click(mask, point_labels=1, inout=1)
assert len(pt) == 2
class TestGenerateClickPrompt:
"""Tests for generate_click_prompt function."""
def test_output_shapes(self):
"""Test output tensor shapes."""
img = torch.rand(2, 3, 64, 64, 4)
msk = torch.rand(2, 1, 64, 64, 4)
msk = (msk > 0.5).float()
out_img, pt, out_msk = generate_click_prompt(img, msk)
assert out_img.shape == img.shape
assert pt.shape[0] == 2 # batch size
assert pt.shape[-1] == 4 # depth
assert out_msk.shape[0] == 2 # batch size

@ -0,0 +1,78 @@
"""
Unit tests for metrics module.
"""
import pytest
import numpy as np
import torch
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from oneprompt_seg.utils.metrics import iou, dice_coeff, eval_seg
class TestIoU:
"""Tests for IoU metric."""
def test_perfect_overlap(self):
"""Test IoU with perfect overlap."""
pred = np.ones((1, 64, 64), dtype=np.int32)
target = np.ones((1, 64, 64), dtype=np.int32)
result = iou(pred, target)
assert result == pytest.approx(1.0, abs=1e-5)
def test_no_overlap(self):
"""Test IoU with no overlap."""
pred = np.ones((1, 64, 64), dtype=np.int32)
target = np.zeros((1, 64, 64), dtype=np.int32)
result = iou(pred, target)
assert result == pytest.approx(0.0, abs=1e-5)
def test_partial_overlap(self):
"""Test IoU with partial overlap."""
pred = np.zeros((1, 64, 64), dtype=np.int32)
target = np.zeros((1, 64, 64), dtype=np.int32)
pred[0, :32, :] = 1
target[0, 16:48, :] = 1
result = iou(pred, target)
assert 0 < result < 1
class TestDiceCoeff:
"""Tests for Dice coefficient."""
def test_perfect_overlap(self):
"""Test Dice with perfect overlap."""
pred = torch.ones(1, 64, 64)
target = torch.ones(1, 64, 64)
result = dice_coeff(pred, target)
assert result.item() == pytest.approx(1.0, abs=1e-3)
def test_no_overlap(self):
"""Test Dice with no overlap."""
pred = torch.ones(1, 64, 64)
target = torch.zeros(1, 64, 64)
result = dice_coeff(pred, target)
assert result.item() == pytest.approx(0.0, abs=1e-3)
class TestEvalSeg:
"""Tests for eval_seg function."""
def test_single_channel(self):
"""Test evaluation with single channel output."""
pred = torch.rand(2, 1, 64, 64)
target = torch.rand(2, 1, 64, 64)
threshold = (0.5,)
result = eval_seg(pred, target, threshold)
assert len(result) == 2 # IoU and Dice
def test_two_channel(self):
"""Test evaluation with two channel output."""
pred = torch.rand(2, 2, 64, 64)
target = torch.rand(2, 2, 64, 64)
threshold = (0.5,)
result = eval_seg(pred, target, threshold)
assert len(result) == 4 # IoU_d, IoU_c, Dice_d, Dice_c
Loading…
Cancel
Save