|
|
2 months ago | |
|---|---|---|
| configs | 2 months ago | |
| data | 2 months ago | |
| docs | 2 months ago | |
| scripts | 2 months ago | |
| src | 2 months ago | |
| tests | 2 months ago | |
| .gitignore | 2 months ago | |
| .pre-commit-config.yaml | 2 months ago | |
| EMCAD:面向心脏图像分割的高效多尺度卷积注意力解码器.pdf | 2 months ago | |
| Makefile | 2 months ago | |
| README.md | 2 months ago | |
| requirements.txt | 2 months ago | |
| setup.py | 2 months ago | |
README.md
本项目用于论文复现https://arxiv.org/abs/2405.06880
EMCAD:高效多尺度卷积注意力解码器用于医学图像分割
本项目是论文《EMCAD:用于医学图像分割的高效多尺度卷积注意力解码器》的官方PyTorch实现,该论文发表于CVPR 2024。项目作者包括Md Mostafijur Rahman、Mustafa Munir和Radu Marculescu,均来自德克萨斯大学奥斯汀分校。
项目概述
EMCAD是一种创新的医学图像分割框架,它通过引入高效多尺度卷积注意力解码器来提升分割精度和效率。该方法在多个医学图像数据集上展现了卓越的性能,特别是在Synapse多器官分割任务中取得了优异的成绩。
本项目的核心贡献在于提出了一种新型的解码器架构,该架构能够有效地融合多尺度特征信息,并通过注意力机制增强特征表达能力。与传统的解码器相比,EMCAD在保持较低计算开销的同时,显著提升了分割精度,这使得它特别适合于资源受限的医疗应用场景。
主要特性
EMCAD具有以下核心特性:首先,它采用了高效的多尺度卷积注意力机制,能够同时捕获不同尺度的特征信息,并通过注意力机制进行特征增强。其次,EMCAD支持多种骨干网络,包括PVTv2系列和ResNet系列,用户可以根据实际需求选择合适的骨干网络进行特征提取。第三,EMCAD提供了灵活的配置选项,支持自定义卷积核大小、扩展因子、注意力机制类型等参数,以适应不同的应用场景。
此外,EMCAD还支持深度监督和变异监督两种训练策略,用户可以根据具体任务选择合适的训练方式。模型还支持并行深度卷积和特征级联两种融合模式,提供了丰富的可调节参数。
性能指标
EMCAD在Synapse数据集上展现了优异的分割性能,平均Dice系数达到了85%以上,同时保持了较低的计算复杂度(FLOPs)和参数量。与现有的主流分割方法相比,EMCAD在精度和效率之间取得了更好的平衡。
项目结构
EMCAD/
├── README.md # 项目说明文档
├── setup.py # 安装配置脚本
├── requirements.txt # 依赖包列表
├── .gitignore # Git忽略文件配置
│
├── src/ # 源代码目录
│ ├── core/ # 核心模型定义
│ │ ├── networks.py # EMCADNet主网络结构
│ │ ├── decoders.py # EMCAD解码器模块
│ │ ├── pvtv2.py # PVTv2骨干网络
│ │ └── resnet.py # ResNet系列骨干网络
│ │
│ └── utils/ # 工具模块
│ ├── dataset_synapse.py # Synapse数据集加载器
│ ├── dataset_ACDC.py # ACDC数据集加载器
│ ├── dataloader.py # 通用数据加载器
│ ├── trainer.py # 训练核心逻辑
│ ├── utils.py # 实用工具函数
│ ├── transforms.py # 数据变换操作
│ ├── joint_transforms.py # 联合变换操作
│ ├── format_conversion.py # 格式转换工具
│ ├── preprocess_synapse_data.py # Synapse数据预处理
│ ├── preprocess_synapse_data_3d.py # 3D数据预处理
│ └── misc.py # 其他工具函数
│
├── scripts/ # 脚本目录
│ └── train_synapse.py # Synapse数据集训练入口
│
├── tests/ # 测试代码目录
│ └── test_synapse.py # 模型测试脚本
│
├── data/ # 数据目录
│ ├── lists/ # 数据列表文件
│ │ └── lists_Synapse/ # Synapse数据划分
│ │ ├── train.txt
│ │ ├── test_vol.txt
│ │ └── .ipynb_checkpoints/
│ │
│ └── ACDC/ # ACDC数据集
│ └── ACDC/
│ ├── lists_ACDC/ # ACDC数据划分
│ │ ├── train.txt
│ │ ├── valid.txt
│ │ └── test.txt
│ ├── train/ # 训练数据
│ ├── valid/ # 验证数据
│ └── test/ # 测试数据
│
├── model_pth/ # 预训练模型权重目录
├── pretrained_pth/ # 预训练权重目录
│ └── pvt/ # PVTv2预训练权重
│
└── SimpleITK.whl # SimpleITK安装包
环境配置
系统要求
本项目推荐在Linux操作系统上运行,也支持Windows和macOS系统。建议使用Python 3.8或更高版本,强烈建议使用Anaconda或Miniconda创建独立的虚拟环境,以避免与系统其他Python项目产生依赖冲突。
依赖安装
首先创建并激活虚拟环境:
conda create -n emcadenv python=3.8
conda activate emcadenv
然后安装PyTorch和相关依赖。对于CUDA 11.3用户,建议使用以下命令安装兼容的PyTorch版本:
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
如果使用其他版本的CUDA,请访问PyTorch官方网站获取对应的安装命令。接下来安装其他必要的依赖包:
pip install -r requirements.txt
对于Windows系统,可能需要手动安装SimpleITK:
pip install SimpleITK.whl
依赖包说明
requirements.txt中包含了项目运行所需的所有依赖包,主要包括:
- numpy:数值计算
- torch和torchvision:深度学习框架
- h5py:处理HDF5格式数据
- scipy:科学计算
- matplotlib:可视化
- tqdm:显示进度条
- tensorboardX:可视化训练过程
- nibabel:处理医学图像格式
- medpy:计算医学图像评价指标
- ptflops和thop:计算模型复杂度
- segmentation-mask-overlay:叠加分割掩膜可视化
- timm:提供丰富的预训练模型
数据准备
Synapse多器官数据集
Synapse数据集是医学图像分割领域常用的基准数据集,包含了30个腹部CT扫描案例,涉及脾脏、肾脏、肝脏、胃、主动脉等多个器官的分割标注。首先需要注册并从Synapse官方网站下载数据集,然后按照TransUNet项目提供的数据划分方式,将RawData文件夹划分为包含18个扫描的训练集和包含12个扫描的测试集,存放在./data/synapse/Abdomen/RawData/目录下。
数据预处理有两种方式:
方式一:运行预处理脚本
python src/utils/preprocess_synapse_data.py
方式二:直接下载预处理的Synapse数据
从项目 releases 页面下载预处理的Synapse数据,保存到./data/目录下。
ACDC数据集
ACDC数据集是心脏MRI分割的标准数据集。推荐从MT-UNet项目的Google Drive下载预处理的ACDC数据集,然后解压到./data/ACDC/目录即可使用。
预训练模型
EMCAD使用PVTv2作为默认的骨干网络,需要下载预训练的PVTv2模型权重。预训练模型可以从以下来源下载:
- Google Drive:项目提供的预训练权重
- PVT GitHub releases页面:官方PVTv2预训练权重
下载后将模型文件放入./model_pth/或./pretrained_pth/pvt/目录。默认的预训练路径可以通过--pretrained_dir参数进行配置。
快速开始
训练模型
在Synapse数据集上训练EMCAD模型的基本命令如下:
python scripts/train_synapse.py --root_path /path/to/train/data --volume_path /path/to/test/data --encoder pvt_v2_b2
常用参数说明:
| 参数 | 说明 | 默认值 |
|---|---|---|
| --root_path | 训练数据的根目录 | /data/ACDC/train |
| --volume_path | 测试数据的根目录 | /data/ACDC/test |
| --dataset | 数据集名称 | ADC |
| --list_dir | 数据列表目录 | /data/ACDC/lists_ACDC |
| --num_classes | 分割类别数 | 4(ACDC)/9(Synapse) |
| --encoder | 骨干网络类型 | pvt_v2_b2 |
| --batch_size | 每批训练的样本数 | 6 |
| --base_lr | 基础学习率 | 0.01 |
| --img_size | 输入图像尺寸 | 224 |
| --max_epochs | 训练轮数 | 300 |
| --expansion_factor | MSCB块扩展因子 | 2 |
| --kernel_sizes | 多尺度卷积核大小 | [1, 3, 5] |
模型推理
训练完成后,可以使用以下命令进行模型推理:
python tests/test_synapse.py
推理结果默认保存在predictions目录下,可以选择保存为nii格式的分割结果文件。
使用ACDC数据集
在ACDC数据集上训练时,需要将--dataset参数设置为ACDC,并调整相应的数据路径和类别数:
python scripts/train_synapse.py \
--root_path ./data/ACDC/train \
--volume_path ./data/ACDC/test \
--list_dir ./data/ACDC/ACDC/lists_ACDC \
--dataset ACDC \
--num_classes 4 \
--encoder pvt_v2_b2
API参考
EMCADNet类
EMCADNet是项目的主模型类,用于构建完整的分割网络。
构造函数参数:
| 参数 | 类型 | 说明 | 默认值 |
|---|---|---|---|
| num_classes | int | 输出类别数 | 1 |
| kernel_sizes | list | 多尺度卷积核大小列表 | [1, 3, 5] |
| expansion_factor | int | 扩展因子 | 2 |
| dw_parallel | bool | 是否使用并行深度卷积 | True |
| add | bool | 特征融合方式 | True |
| lgag_ks | int | LGAG模块卷积核大小 | 3 |
| activation | str | 激活函数类型 | relu |
| encoder | str | 骨干网络名称 | pvt_v2_b2 |
| pretrain | bool | 是否加载预训练权重 | True |
| pretrained_dir | str | 预训练权重路径 | ./pretrained_pth/pvt/ |
数据集类
Synapse_dataset类
用于加载Synapse数据集,支持训练和测试两种模式。
from src.utils.dataset_synapse import Synapse_dataset, RandomGenerator
db_train = Synapse_dataset(
base_dir="./data/synapse",
list_dir="./data/lists/lists_Synapse",
split="train",
nclass=9,
transform=transforms.Compose([RandomGenerator(output_size=[224, 224])])
)
ACDCdataset类
用于加载ACDC数据集。
from src.utils.dataset_ACDC import ACDCdataset
db_train = ACDCdataset(
base_dir="./data/ACDC",
list_dir="./data/ACDC/ACDC/lists_ACDC",
split="train",
transform=transform
)
工具函数
项目提供了丰富的工具函数:
| 函数 | 说明 |
|---|---|
| DiceLoss | 计算Dice损失 |
| powerset | 生成子集 |
| val_single_volume | 单个体数据的验证推理 |
| test_single_volume | 单个体数据的完整推理和结果保存 |
扩展与定制
自定义骨干网络
如果需要使用其他骨干网络,可以在src/core/networks.py文件中参考现有实现,添加新的骨干网络支持代码。新的骨干网络需要实现特征提取功能,并返回四个尺度的特征图。
添加新的数据集
要支持新的数据集,可以在src/utils/目录下创建新的数据集类,参考Synapse_dataset的实现方式。新数据集类需要继承Dataset类,并实现__len__和__getitem__方法。
调整模型配置
模型的各个组件都可以通过参数进行调整,包括卷积核大小、扩展因子、注意力机制类型等。建议从默认配置开始,逐步调整以找到最适合特定任务和数据的配置。
运行测试
项目提供了测试脚本来验证模型是否正确:
python tests/test_synapse.py --dataset Synapse --num_classes 9
常见问题
1. 导入模块时提示ModuleNotFoundError
请确保已正确安装项目:
pip install -e .
或者设置PYTHONPATH环境变量:
export PYTHONPATH=$PYTHONPATH:/path/to/EMCAD
2. 预训练权重加载失败
请检查预训练权重路径是否正确,确保PVTv2权重文件已下载并放置在正确的目录。
3. 内存不足错误
尝试减小batch_size或img_size参数的值。
引用
如果您在研究中使用了本项目,请引用以下论文:
@article{rahman2024emcad,
title={EMCAD: Efficient Multi-scale Convolutional Attention Decoding for Medical Image Segmentation},
author={Rahman, Md Mostafijur and Munir, Mustafa and Marculescu, Radu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={11769--11779},
year={2024}
}
致谢
本项目的实现参考了以下优秀的开源项目:
在此向这些项目的作者表示感谢。
许可证
本项目遵循MIT许可证开源。