You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

8.6 KiB

One-Prompt to Segment All Medical Images

Python 3.8+ PyTorch 1.10+ License: MIT

One-Prompt to Segment All Medical Images, or say One-Prompt, combines the strengths of one-shot and interactive methods. In the inference stage, with just one prompted sample, it can adeptly handle the unseen task in a single forward pass.

This method is elaborated in the paper One-Prompt to Segment All Medical Images (CVPR 2024).

A Quick Overview

Project Structure

One-Prompt-Medical-Image-Segmentation/
├── src/oneprompt_seg/           # 源代码包 (推荐使用)
│   ├── core/                    # 核心模块 (网络构建、训练函数)
│   ├── data/                    # 数据模块 (数据集、加载器)
│   └── utils/                   # 工具模块 (指标、可视化、日志)
├── models/                      # 模型实现
│   ├── oneprompt/               # One-Prompt 模型
│   ├── unet/                    # UNet 模型
│   └── ...
├── tests/                       # 单元测试
├── docs/                        # 项目文档
├── configs/                     # 配置文件
│   ├── default.yaml             # 默认配置
│   ├── development.yaml         # 开发环境配置
│   └── production.yaml          # 生产环境配置
├── scripts/                     # 脚本文件
│   ├── train.py                 # 训练脚本
│   └── evaluate.py              # 评估脚本
├── train.py                     # 原始训练入口
├── val.py                       # 原始验证入口
├── pyproject.toml               # 项目配置
├── requirements.txt             # 依赖列表
└── Makefile                     # 常用命令

Installation

# 创建环境
conda env create -f environment.yml
conda activate oneprompt

# 安装项目包 (可选,用于使用 src 布局)
pip install -e .

Method 2: Using pip

# 创建虚拟环境
python -m venv venv
source venv/bin/activate  # Linux/Mac
# 或 venv\Scripts\activate  # Windows

# 安装依赖
pip install -r requirements.txt

# 安装项目包 (可选)
pip install -e .

Development Installation

# 安装开发依赖
pip install -r requirements-dev.txt

# 安装 pre-commit hooks
pre-commit install

Quick Start

Using Makefile

# 安装依赖
make install

# 运行测试
make test

# 代码格式化
make format

# 代码检查
make lint

# 训练模型
make train

Training

# 基本训练命令
python train.py -net oneprompt -mod one_adpt -exp_name basic_exp \
    -b 64 -dataset oneprompt -data_path ../data -baseline 'unet'

# 或使用 scripts 目录下的脚本
python scripts/train.py -net oneprompt -mod one_adpt -exp_name basic_exp \
    -dataset polyp -data_path ./data/polyp

Evaluation

python val.py -net oneprompt -mod one_adpt -exp_name One-ISIC \
    -weights <weight_path> -b 1 -dataset isic -data_path ../dataset/isic \
    -vis 10 -baseline 'unet'

Dataset

Download the Open-source Datasets

We collected 78 open-source datasets for training and testing the model. The datasets and their download links are in here.

Download the Prompts

The prompts corresponding to the datasets can be downloaded here. Each prompt is saved a json message with the format {DATASET_NAME, SAMPLE_INDEX, PROMPT_TYPE, PROMPT_CONTENT}

Test Examples

Melanoma Segmentation from Skin Images (2D)

  1. Download ISIC dataset part 1 from https://challenge.isic-archive.com/data/. Then put the csv files in "./data/isic" under your data path. Your dataset folder under "your_data_path" should be like:
ISIC/
├── ISBI2016_ISIC_Part1_Test_Data/...
├── ISBI2016_ISIC_Part1_Training_Data/...
├── ISBI2016_ISIC_Part1_Test_GroundTruth.csv
└── ISBI2016_ISIC_Part1_Training_GroundTruth.csv
  1. Run:
python val.py -net oneprompt -mod one_adpt -exp_name One-ISIC \
    -weights <weight_path> -b 1 -dataset isic -data_path ../dataset/isic \
    -vis 10 -baseline 'unet'

Change data_path and exp_name for your own usage.

You can decrease the image size or batch size b if out of memory.

  1. Evaluation: The code can automatically evaluate the model on the test set during training, set --val_freq to control how many epochs you want to evaluate once. You can also run val.py for the independent evaluation.

  2. Result Visualization: You can set --vis parameter to control how many epochs you want to see the results in the training or evaluation process.

In default, everything will be saved at ./logs/

REFUGE: Optic-disc Segmentation from Fundus Images (2D)

REFUGE dataset contains 1200 fundus images with optic disc/cup segmentations and clinical glaucoma labels.

  1. Download the dataset manually from here, or using command lines:
git lfs install
git clone git@hf.co:datasets/realslimman/REFUGE-MultiRater
unzip ./REFUGE-MultiRater.zip
mv REFUGE-MultiRater ./data
  1. For training the adapter, run:
python val.py -net oneprompt -mod one_adpt -exp_name One-REFUGE \
    -weights <weight_path> -b 1 -baseline 'unet' \
    -dataset REFUGE -data_path ./data/REFUGE-MultiRater

Run on Your Own Dataset

It is simple to run One-Prompt on other datasets. Just write another dataset class following which in ./dataset.py. You only need to make sure you return a dict with:

{
    'image': A tensor saving images with size [C,H,W] for 2D image,
             size [C, H, W, D] for 3D data.
             D is the depth of 3D volume, C is the channel of a scan/frame,
             which is commonly 1 for CT, MRI, US data.
             If processing, say like a colorful surgical video, D could be
             the number of time frames, and C will be 3 for a RGB frame.

    'label': The target masks. Same size with the images except the
             resolutions (H and W).

    'p_label': The prompt label to decide positive/negative prompt.
               To simplify, you can always set 1 if don't need the
               negative prompt function.

    'pt': The prompt. e.g., a click prompt should be [x of click, y of click],
          one click for each scan/frame if using 3d data.

    'image_meta_dict': Optional. if you want save/visualize the result,
                       you should put the name of the image in it with
                       the key ['filename_or_obj'].

    ...(others as you want)
}

API Usage (New)

After installing with pip install -e ., you can use the package programmatically:

from oneprompt_seg import get_network, CombinedPolypDataset
from oneprompt_seg.utils import eval_seg, vis_image, create_logger

# Build network
model = get_network(args, 'oneprompt')

# Create dataset
dataset = CombinedPolypDataset(args, data_path, transform=transform)

# Evaluate
iou, dice = eval_seg(predictions, targets, threshold=(0.5,))

Configuration

The project supports YAML configuration files in configs/:

  • default.yaml - Default training configuration
  • development.yaml - Development/debugging configuration
  • production.yaml - Production training configuration

Development

Running Tests

# Run all tests
pytest tests/ -v

# Run with coverage
pytest tests/ -v --cov=src/oneprompt_seg --cov-report=html

Code Quality

# Format code
black src/ tests/ scripts/
isort src/ tests/ scripts/

# Lint code
flake8 src/ tests/
mypy src/

Cite

@InProceedings{Wu_2024_CVPR,
    author    = {Wu, Junde and Xu, Min},
    title     = {One-Prompt to Segment All Medical Images},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2024},
    pages     = {11302-11312}
}

License

This project is licensed under the MIT License - see the LICENSE file for details.