# One-Prompt to Segment All Medical Images [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/) [![PyTorch 1.10+](https://img.shields.io/badge/pytorch-1.10+-ee4c2c.svg)](https://pytorch.org/) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) One-Prompt to Segment All Medical Images, or say One-Prompt, combines the strengths of one-shot and interactive methods. In the inference stage, with just one prompted sample, it can adeptly handle the unseen task in a single forward pass. This method is elaborated in the paper [One-Prompt to Segment All Medical Images](https://arxiv.org/abs/2305.10300) (CVPR 2024). ## A Quick Overview ## Project Structure ``` One-Prompt-Medical-Image-Segmentation/ ├── src/oneprompt_seg/ # 源代码包 (推荐使用) │ ├── core/ # 核心模块 (网络构建、训练函数) │ ├── data/ # 数据模块 (数据集、加载器) │ └── utils/ # 工具模块 (指标、可视化、日志) ├── models/ # 模型实现 │ ├── oneprompt/ # One-Prompt 模型 │ ├── unet/ # UNet 模型 │ └── ... ├── tests/ # 单元测试 ├── docs/ # 项目文档 ├── configs/ # 配置文件 │ ├── default.yaml # 默认配置 │ ├── development.yaml # 开发环境配置 │ └── production.yaml # 生产环境配置 ├── scripts/ # 脚本文件 │ ├── train.py # 训练脚本 │ └── evaluate.py # 评估脚本 ├── train.py # 原始训练入口 ├── val.py # 原始验证入口 ├── pyproject.toml # 项目配置 ├── requirements.txt # 依赖列表 └── Makefile # 常用命令 ``` ## Installation ### Method 1: Using Conda (Recommended) ```bash # 创建环境 conda env create -f environment.yml conda activate oneprompt # 安装项目包 (可选,用于使用 src 布局) pip install -e . ``` ### Method 2: Using pip ```bash # 创建虚拟环境 python -m venv venv source venv/bin/activate # Linux/Mac # 或 venv\Scripts\activate # Windows # 安装依赖 pip install -r requirements.txt # 安装项目包 (可选) pip install -e . ``` ### Development Installation ```bash # 安装开发依赖 pip install -r requirements-dev.txt # 安装 pre-commit hooks pre-commit install ``` ## Quick Start ### Using Makefile ```bash # 安装依赖 make install # 运行测试 make test # 代码格式化 make format # 代码检查 make lint # 训练模型 make train ``` ### Training ```bash # 基本训练命令 python train.py -net oneprompt -mod one_adpt -exp_name basic_exp \ -b 64 -dataset oneprompt -data_path ../data -baseline 'unet' # 或使用 scripts 目录下的脚本 python scripts/train.py -net oneprompt -mod one_adpt -exp_name basic_exp \ -dataset polyp -data_path ./data/polyp ``` ### Evaluation ```bash python val.py -net oneprompt -mod one_adpt -exp_name One-ISIC \ -weights -b 1 -dataset isic -data_path ../dataset/isic \ -vis 10 -baseline 'unet' ``` ## Dataset ### Download the Open-source Datasets We collected 78 **open-source** datasets for training and testing the model. The datasets and their download links are in [here](https://drive.google.com/file/d/1iXFm9M1ocrWNkEIthWUWnZYY2-1l-qya/view?usp=share_link). ### Download the Prompts The prompts corresponding to the datasets can be downloaded [here](https://drive.google.com/file/d/1cNv2WW_Cv2NYzpt90vvELaweM5ltIe8n/view?usp=share_link). Each prompt is saved a json message with the format `{DATASET_NAME, SAMPLE_INDEX, PROMPT_TYPE, PROMPT_CONTENT}` ## Test Examples ### Melanoma Segmentation from Skin Images (2D) 1. Download ISIC dataset part 1 from https://challenge.isic-archive.com/data/. Then put the csv files in "./data/isic" under your data path. Your dataset folder under "your_data_path" should be like: ``` ISIC/ ├── ISBI2016_ISIC_Part1_Test_Data/... ├── ISBI2016_ISIC_Part1_Training_Data/... ├── ISBI2016_ISIC_Part1_Test_GroundTruth.csv └── ISBI2016_ISIC_Part1_Training_GroundTruth.csv ``` 2. Run: ```bash python val.py -net oneprompt -mod one_adpt -exp_name One-ISIC \ -weights -b 1 -dataset isic -data_path ../dataset/isic \ -vis 10 -baseline 'unet' ``` Change `data_path` and `exp_name` for your own usage. You can decrease the `image size` or batch size `b` if out of memory. 3. **Evaluation**: The code can automatically evaluate the model on the test set during training, set `--val_freq` to control how many epochs you want to evaluate once. You can also run `val.py` for the independent evaluation. 4. **Result Visualization**: You can set `--vis` parameter to control how many epochs you want to see the results in the training or evaluation process. In default, everything will be saved at `./logs/` ### REFUGE: Optic-disc Segmentation from Fundus Images (2D) [REFUGE](https://refuge.grand-challenge.org/) dataset contains 1200 fundus images with optic disc/cup segmentations and clinical glaucoma labels. 1. Download the dataset manually from [here](https://huggingface.co/datasets/realslimman/REFUGE-MultiRater/tree/main), or using command lines: ```bash git lfs install git clone git@hf.co:datasets/realslimman/REFUGE-MultiRater unzip ./REFUGE-MultiRater.zip mv REFUGE-MultiRater ./data ``` 2. For training the adapter, run: ```bash python val.py -net oneprompt -mod one_adpt -exp_name One-REFUGE \ -weights -b 1 -baseline 'unet' \ -dataset REFUGE -data_path ./data/REFUGE-MultiRater ``` ## Run on Your Own Dataset It is simple to run One-Prompt on other datasets. Just write another dataset class following which in `./dataset.py`. You only need to make sure you return a dict with: ```python { 'image': A tensor saving images with size [C,H,W] for 2D image, size [C, H, W, D] for 3D data. D is the depth of 3D volume, C is the channel of a scan/frame, which is commonly 1 for CT, MRI, US data. If processing, say like a colorful surgical video, D could be the number of time frames, and C will be 3 for a RGB frame. 'label': The target masks. Same size with the images except the resolutions (H and W). 'p_label': The prompt label to decide positive/negative prompt. To simplify, you can always set 1 if don't need the negative prompt function. 'pt': The prompt. e.g., a click prompt should be [x of click, y of click], one click for each scan/frame if using 3d data. 'image_meta_dict': Optional. if you want save/visualize the result, you should put the name of the image in it with the key ['filename_or_obj']. ...(others as you want) } ``` ## API Usage (New) After installing with `pip install -e .`, you can use the package programmatically: ```python from oneprompt_seg import get_network, CombinedPolypDataset from oneprompt_seg.utils import eval_seg, vis_image, create_logger # Build network model = get_network(args, 'oneprompt') # Create dataset dataset = CombinedPolypDataset(args, data_path, transform=transform) # Evaluate iou, dice = eval_seg(predictions, targets, threshold=(0.5,)) ``` ## Configuration The project supports YAML configuration files in `configs/`: - `default.yaml` - Default training configuration - `development.yaml` - Development/debugging configuration - `production.yaml` - Production training configuration ## Development ### Running Tests ```bash # Run all tests pytest tests/ -v # Run with coverage pytest tests/ -v --cov=src/oneprompt_seg --cov-report=html ``` ### Code Quality ```bash # Format code black src/ tests/ scripts/ isort src/ tests/ scripts/ # Lint code flake8 src/ tests/ mypy src/ ``` ## Cite ```bibtex @InProceedings{Wu_2024_CVPR, author = {Wu, Junde and Xu, Min}, title = {One-Prompt to Segment All Medical Images}, booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, month = {June}, year = {2024}, pages = {11302-11312} } ``` ## License This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.