You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

349 lines
13 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# 本项目用于论文复现https://arxiv.org/abs/2405.06880
# EMCAD高效多尺度卷积注意力解码器用于医学图像分割
本项目是论文《EMCAD用于医学图像分割的高效多尺度卷积注意力解码器》的官方PyTorch实现该论文发表于CVPR 2024。项目作者包括Md Mostafijur Rahman、Mustafa Munir和Radu Marculescu均来自德克萨斯大学奥斯汀分校。
## 项目概述
EMCAD是一种创新的医学图像分割框架它通过引入高效多尺度卷积注意力解码器来提升分割精度和效率。该方法在多个医学图像数据集上展现了卓越的性能特别是在Synapse多器官分割任务中取得了优异的成绩。
本项目的核心贡献在于提出了一种新型的解码器架构该架构能够有效地融合多尺度特征信息并通过注意力机制增强特征表达能力。与传统的解码器相比EMCAD在保持较低计算开销的同时显著提升了分割精度这使得它特别适合于资源受限的医疗应用场景。
### 主要特性
EMCAD具有以下核心特性首先它采用了高效的多尺度卷积注意力机制能够同时捕获不同尺度的特征信息并通过注意力机制进行特征增强。其次EMCAD支持多种骨干网络包括PVTv2系列和ResNet系列用户可以根据实际需求选择合适的骨干网络进行特征提取。第三EMCAD提供了灵活的配置选项支持自定义卷积核大小、扩展因子、注意力机制类型等参数以适应不同的应用场景。
此外EMCAD还支持深度监督和变异监督两种训练策略用户可以根据具体任务选择合适的训练方式。模型还支持并行深度卷积和特征级联两种融合模式提供了丰富的可调节参数。
### 性能指标
EMCAD在Synapse数据集上展现了优异的分割性能平均Dice系数达到了85%以上同时保持了较低的计算复杂度FLOPs和参数量。与现有的主流分割方法相比EMCAD在精度和效率之间取得了更好的平衡。
## 项目结构
```
EMCAD/
├── README.md # 项目说明文档
├── setup.py # 安装配置脚本
├── requirements.txt # 依赖包列表
├── .gitignore # Git忽略文件配置
├── src/ # 源代码目录
│ ├── core/ # 核心模型定义
│ │ ├── networks.py # EMCADNet主网络结构
│ │ ├── decoders.py # EMCAD解码器模块
│ │ ├── pvtv2.py # PVTv2骨干网络
│ │ └── resnet.py # ResNet系列骨干网络
│ │
│ └── utils/ # 工具模块
│ ├── dataset_synapse.py # Synapse数据集加载器
│ ├── dataset_ACDC.py # ACDC数据集加载器
│ ├── dataloader.py # 通用数据加载器
│ ├── trainer.py # 训练核心逻辑
│ ├── utils.py # 实用工具函数
│ ├── transforms.py # 数据变换操作
│ ├── joint_transforms.py # 联合变换操作
│ ├── format_conversion.py # 格式转换工具
│ ├── preprocess_synapse_data.py # Synapse数据预处理
│ ├── preprocess_synapse_data_3d.py # 3D数据预处理
│ └── misc.py # 其他工具函数
├── scripts/ # 脚本目录
│ └── train_synapse.py # Synapse数据集训练入口
├── tests/ # 测试代码目录
│ └── test_synapse.py # 模型测试脚本
├── data/ # 数据目录
│ ├── lists/ # 数据列表文件
│ │ └── lists_Synapse/ # Synapse数据划分
│ │ ├── train.txt
│ │ ├── test_vol.txt
│ │ └── .ipynb_checkpoints/
│ │
│ └── ACDC/ # ACDC数据集
│ └── ACDC/
│ ├── lists_ACDC/ # ACDC数据划分
│ │ ├── train.txt
│ │ ├── valid.txt
│ │ └── test.txt
│ ├── train/ # 训练数据
│ ├── valid/ # 验证数据
│ └── test/ # 测试数据
├── model_pth/ # 预训练模型权重目录
├── pretrained_pth/ # 预训练权重目录
│ └── pvt/ # PVTv2预训练权重
└── SimpleITK.whl # SimpleITK安装包
```
## 环境配置
### 系统要求
本项目推荐在Linux操作系统上运行也支持Windows和macOS系统。建议使用Python 3.8或更高版本强烈建议使用Anaconda或Miniconda创建独立的虚拟环境以避免与系统其他Python项目产生依赖冲突。
### 依赖安装
首先创建并激活虚拟环境:
```bash
conda create -n emcadenv python=3.8
conda activate emcadenv
```
然后安装PyTorch和相关依赖。对于CUDA 11.3用户建议使用以下命令安装兼容的PyTorch版本
```bash
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
```
如果使用其他版本的CUDA请访问PyTorch官方网站获取对应的安装命令。接下来安装其他必要的依赖包
```bash
pip install -r requirements.txt
```
对于Windows系统可能需要手动安装SimpleITK
```bash
pip install SimpleITK.whl
```
### 依赖包说明
requirements.txt中包含了项目运行所需的所有依赖包主要包括
- numpy数值计算
- torch和torchvision深度学习框架
- h5py处理HDF5格式数据
- scipy科学计算
- matplotlib可视化
- tqdm显示进度条
- tensorboardX可视化训练过程
- nibabel处理医学图像格式
- medpy计算医学图像评价指标
- ptflops和thop计算模型复杂度
- segmentation-mask-overlay叠加分割掩膜可视化
- timm提供丰富的预训练模型
## 数据准备
### Synapse多器官数据集
Synapse数据集是医学图像分割领域常用的基准数据集包含了30个腹部CT扫描案例涉及脾脏、肾脏、肝脏、胃、主动脉等多个器官的分割标注。首先需要注册并从Synapse官方网站下载数据集然后按照TransUNet项目提供的数据划分方式将RawData文件夹划分为包含18个扫描的训练集和包含12个扫描的测试集存放在./data/synapse/Abdomen/RawData/目录下。
数据预处理有两种方式:
**方式一:运行预处理脚本**
```bash
python src/utils/preprocess_synapse_data.py
```
**方式二直接下载预处理的Synapse数据**
从项目 releases 页面下载预处理的Synapse数据保存到./data/目录下。
### ACDC数据集
ACDC数据集是心脏MRI分割的标准数据集。推荐从MT-UNet项目的Google Drive下载预处理的ACDC数据集然后解压到./data/ACDC/目录即可使用。
## 预训练模型
EMCAD使用PVTv2作为默认的骨干网络需要下载预训练的PVTv2模型权重。预训练模型可以从以下来源下载
- Google Drive项目提供的预训练权重
- PVT GitHub releases页面官方PVTv2预训练权重
下载后将模型文件放入./model_pth/或./pretrained_pth/pvt/目录。默认的预训练路径可以通过--pretrained_dir参数进行配置。
## 快速开始
### 训练模型
在Synapse数据集上训练EMCAD模型的基本命令如下
```bash
python scripts/train_synapse.py --root_path /path/to/train/data --volume_path /path/to/test/data --encoder pvt_v2_b2
```
**常用参数说明:**
| 参数 | 说明 | 默认值 |
|------|------|--------|
| --root_path | 训练数据的根目录 | /data/ACDC/train |
| --volume_path | 测试数据的根目录 | /data/ACDC/test |
| --dataset | 数据集名称 | ADC |
| --list_dir | 数据列表目录 | /data/ACDC/lists_ACDC |
| --num_classes | 分割类别数 | 4ACDC/9Synapse |
| --encoder | 骨干网络类型 | pvt_v2_b2 |
| --batch_size | 每批训练的样本数 | 6 |
| --base_lr | 基础学习率 | 0.01 |
| --img_size | 输入图像尺寸 | 224 |
| --max_epochs | 训练轮数 | 300 |
| --expansion_factor | MSCB块扩展因子 | 2 |
| --kernel_sizes | 多尺度卷积核大小 | [1, 3, 5] |
### 模型推理
训练完成后,可以使用以下命令进行模型推理:
```bash
python tests/test_synapse.py
```
推理结果默认保存在predictions目录下可以选择保存为nii格式的分割结果文件。
### 使用ACDC数据集
在ACDC数据集上训练时需要将--dataset参数设置为ACDC并调整相应的数据路径和类别数
```bash
python scripts/train_synapse.py \
--root_path ./data/ACDC/train \
--volume_path ./data/ACDC/test \
--list_dir ./data/ACDC/ACDC/lists_ACDC \
--dataset ACDC \
--num_classes 4 \
--encoder pvt_v2_b2
```
## API参考
### EMCADNet类
EMCADNet是项目的主模型类用于构建完整的分割网络。
**构造函数参数:**
| 参数 | 类型 | 说明 | 默认值 |
|------|------|------|--------|
| num_classes | int | 输出类别数 | 1 |
| kernel_sizes | list | 多尺度卷积核大小列表 | [1, 3, 5] |
| expansion_factor | int | 扩展因子 | 2 |
| dw_parallel | bool | 是否使用并行深度卷积 | True |
| add | bool | 特征融合方式 | True |
| lgag_ks | int | LGAG模块卷积核大小 | 3 |
| activation | str | 激活函数类型 | relu |
| encoder | str | 骨干网络名称 | pvt_v2_b2 |
| pretrain | bool | 是否加载预训练权重 | True |
| pretrained_dir | str | 预训练权重路径 | ./pretrained_pth/pvt/ |
### 数据集类
**Synapse_dataset类**
用于加载Synapse数据集支持训练和测试两种模式。
```python
from src.utils.dataset_synapse import Synapse_dataset, RandomGenerator
db_train = Synapse_dataset(
base_dir="./data/synapse",
list_dir="./data/lists/lists_Synapse",
split="train",
nclass=9,
transform=transforms.Compose([RandomGenerator(output_size=[224, 224])])
)
```
**ACDCdataset类**
用于加载ACDC数据集。
```python
from src.utils.dataset_ACDC import ACDCdataset
db_train = ACDCdataset(
base_dir="./data/ACDC",
list_dir="./data/ACDC/ACDC/lists_ACDC",
split="train",
transform=transform
)
```
### 工具函数
项目提供了丰富的工具函数:
| 函数 | 说明 |
|------|------|
| DiceLoss | 计算Dice损失 |
| powerset | 生成子集 |
| val_single_volume | 单个体数据的验证推理 |
| test_single_volume | 单个体数据的完整推理和结果保存 |
## 扩展与定制
### 自定义骨干网络
如果需要使用其他骨干网络可以在src/core/networks.py文件中参考现有实现添加新的骨干网络支持代码。新的骨干网络需要实现特征提取功能并返回四个尺度的特征图。
### 添加新的数据集
要支持新的数据集可以在src/utils/目录下创建新的数据集类参考Synapse_dataset的实现方式。新数据集类需要继承Dataset类并实现__len__和__getitem__方法。
### 调整模型配置
模型的各个组件都可以通过参数进行调整,包括卷积核大小、扩展因子、注意力机制类型等。建议从默认配置开始,逐步调整以找到最适合特定任务和数据的配置。
### 运行测试
项目提供了测试脚本来验证模型是否正确:
```bash
python tests/test_synapse.py --dataset Synapse --num_classes 9
```
## 常见问题
### 1. 导入模块时提示ModuleNotFoundError
请确保已正确安装项目:
```bash
pip install -e .
```
或者设置PYTHONPATH环境变量
```bash
export PYTHONPATH=$PYTHONPATH:/path/to/EMCAD
```
### 2. 预训练权重加载失败
请检查预训练权重路径是否正确确保PVTv2权重文件已下载并放置在正确的目录。
### 3. 内存不足错误
尝试减小batch_size或img_size参数的值。
## 引用
如果您在研究中使用了本项目,请引用以下论文:
```bibtex
@article{rahman2024emcad,
title={EMCAD: Efficient Multi-scale Convolutional Attention Decoding for Medical Image Segmentation},
author={Rahman, Md Mostafijur and Munir, Mustafa and Marculescu, Radu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={11769--11779},
year={2024}
}
```
## 致谢
本项目的实现参考了以下优秀的开源项目:
- [timm](https://github.com/rwightman/pytorch-image-models):提供丰富的预训练模型
- [TransUNet](https://github.com/Beckschen/TransUNet):医学图像分割框架参考
- [CASCADE](https://github.com/GuoHuang19/CASCADE)、[MERIT](https://github.com/JunrenT/Class-Balanced-Loss)、[G-CASCADE](https://github.com/linzhlgithub/G-CASCADE):分割方法参考
在此向这些项目的作者表示感谢。
## 许可证
本项目遵循MIT许可证开源。