# 本项目用于论文复现https://arxiv.org/abs/2405.06880 # EMCAD:高效多尺度卷积注意力解码器用于医学图像分割 本项目是论文《EMCAD:用于医学图像分割的高效多尺度卷积注意力解码器》的官方PyTorch实现,该论文发表于CVPR 2024。项目作者包括Md Mostafijur Rahman、Mustafa Munir和Radu Marculescu,均来自德克萨斯大学奥斯汀分校。 ## 项目概述 EMCAD是一种创新的医学图像分割框架,它通过引入高效多尺度卷积注意力解码器来提升分割精度和效率。该方法在多个医学图像数据集上展现了卓越的性能,特别是在Synapse多器官分割任务中取得了优异的成绩。 本项目的核心贡献在于提出了一种新型的解码器架构,该架构能够有效地融合多尺度特征信息,并通过注意力机制增强特征表达能力。与传统的解码器相比,EMCAD在保持较低计算开销的同时,显著提升了分割精度,这使得它特别适合于资源受限的医疗应用场景。 ### 主要特性 EMCAD具有以下核心特性:首先,它采用了高效的多尺度卷积注意力机制,能够同时捕获不同尺度的特征信息,并通过注意力机制进行特征增强。其次,EMCAD支持多种骨干网络,包括PVTv2系列和ResNet系列,用户可以根据实际需求选择合适的骨干网络进行特征提取。第三,EMCAD提供了灵活的配置选项,支持自定义卷积核大小、扩展因子、注意力机制类型等参数,以适应不同的应用场景。 此外,EMCAD还支持深度监督和变异监督两种训练策略,用户可以根据具体任务选择合适的训练方式。模型还支持并行深度卷积和特征级联两种融合模式,提供了丰富的可调节参数。 ### 性能指标 EMCAD在Synapse数据集上展现了优异的分割性能,平均Dice系数达到了85%以上,同时保持了较低的计算复杂度(FLOPs)和参数量。与现有的主流分割方法相比,EMCAD在精度和效率之间取得了更好的平衡。 ## 项目结构 ``` EMCAD/ ├── README.md # 项目说明文档 ├── setup.py # 安装配置脚本 ├── requirements.txt # 依赖包列表 ├── .gitignore # Git忽略文件配置 │ ├── src/ # 源代码目录 │ ├── core/ # 核心模型定义 │ │ ├── networks.py # EMCADNet主网络结构 │ │ ├── decoders.py # EMCAD解码器模块 │ │ ├── pvtv2.py # PVTv2骨干网络 │ │ └── resnet.py # ResNet系列骨干网络 │ │ │ └── utils/ # 工具模块 │ ├── dataset_synapse.py # Synapse数据集加载器 │ ├── dataset_ACDC.py # ACDC数据集加载器 │ ├── dataloader.py # 通用数据加载器 │ ├── trainer.py # 训练核心逻辑 │ ├── utils.py # 实用工具函数 │ ├── transforms.py # 数据变换操作 │ ├── joint_transforms.py # 联合变换操作 │ ├── format_conversion.py # 格式转换工具 │ ├── preprocess_synapse_data.py # Synapse数据预处理 │ ├── preprocess_synapse_data_3d.py # 3D数据预处理 │ └── misc.py # 其他工具函数 │ ├── scripts/ # 脚本目录 │ └── train_synapse.py # Synapse数据集训练入口 │ ├── tests/ # 测试代码目录 │ └── test_synapse.py # 模型测试脚本 │ ├── data/ # 数据目录 │ ├── lists/ # 数据列表文件 │ │ └── lists_Synapse/ # Synapse数据划分 │ │ ├── train.txt │ │ ├── test_vol.txt │ │ └── .ipynb_checkpoints/ │ │ │ └── ACDC/ # ACDC数据集 │ └── ACDC/ │ ├── lists_ACDC/ # ACDC数据划分 │ │ ├── train.txt │ │ ├── valid.txt │ │ └── test.txt │ ├── train/ # 训练数据 │ ├── valid/ # 验证数据 │ └── test/ # 测试数据 │ ├── model_pth/ # 预训练模型权重目录 ├── pretrained_pth/ # 预训练权重目录 │ └── pvt/ # PVTv2预训练权重 │ └── SimpleITK.whl # SimpleITK安装包 ``` ## 环境配置 ### 系统要求 本项目推荐在Linux操作系统上运行,也支持Windows和macOS系统。建议使用Python 3.8或更高版本,强烈建议使用Anaconda或Miniconda创建独立的虚拟环境,以避免与系统其他Python项目产生依赖冲突。 ### 依赖安装 首先创建并激活虚拟环境: ```bash conda create -n emcadenv python=3.8 conda activate emcadenv ``` 然后安装PyTorch和相关依赖。对于CUDA 11.3用户,建议使用以下命令安装兼容的PyTorch版本: ```bash pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 ``` 如果使用其他版本的CUDA,请访问PyTorch官方网站获取对应的安装命令。接下来安装其他必要的依赖包: ```bash pip install -r requirements.txt ``` 对于Windows系统,可能需要手动安装SimpleITK: ```bash pip install SimpleITK.whl ``` ### 依赖包说明 requirements.txt中包含了项目运行所需的所有依赖包,主要包括: - numpy:数值计算 - torch和torchvision:深度学习框架 - h5py:处理HDF5格式数据 - scipy:科学计算 - matplotlib:可视化 - tqdm:显示进度条 - tensorboardX:可视化训练过程 - nibabel:处理医学图像格式 - medpy:计算医学图像评价指标 - ptflops和thop:计算模型复杂度 - segmentation-mask-overlay:叠加分割掩膜可视化 - timm:提供丰富的预训练模型 ## 数据准备 ### Synapse多器官数据集 Synapse数据集是医学图像分割领域常用的基准数据集,包含了30个腹部CT扫描案例,涉及脾脏、肾脏、肝脏、胃、主动脉等多个器官的分割标注。首先需要注册并从Synapse官方网站下载数据集,然后按照TransUNet项目提供的数据划分方式,将RawData文件夹划分为包含18个扫描的训练集和包含12个扫描的测试集,存放在./data/synapse/Abdomen/RawData/目录下。 数据预处理有两种方式: **方式一:运行预处理脚本** ```bash python src/utils/preprocess_synapse_data.py ``` **方式二:直接下载预处理的Synapse数据** 从项目 releases 页面下载预处理的Synapse数据,保存到./data/目录下。 ### ACDC数据集 ACDC数据集是心脏MRI分割的标准数据集。推荐从MT-UNet项目的Google Drive下载预处理的ACDC数据集,然后解压到./data/ACDC/目录即可使用。 ## 预训练模型 EMCAD使用PVTv2作为默认的骨干网络,需要下载预训练的PVTv2模型权重。预训练模型可以从以下来源下载: - Google Drive:项目提供的预训练权重 - PVT GitHub releases页面:官方PVTv2预训练权重 下载后将模型文件放入./model_pth/或./pretrained_pth/pvt/目录。默认的预训练路径可以通过--pretrained_dir参数进行配置。 ## 快速开始 ### 训练模型 在Synapse数据集上训练EMCAD模型的基本命令如下: ```bash python scripts/train_synapse.py --root_path /path/to/train/data --volume_path /path/to/test/data --encoder pvt_v2_b2 ``` **常用参数说明:** | 参数 | 说明 | 默认值 | |------|------|--------| | --root_path | 训练数据的根目录 | /data/ACDC/train | | --volume_path | 测试数据的根目录 | /data/ACDC/test | | --dataset | 数据集名称 | ADC | | --list_dir | 数据列表目录 | /data/ACDC/lists_ACDC | | --num_classes | 分割类别数 | 4(ACDC)/9(Synapse) | | --encoder | 骨干网络类型 | pvt_v2_b2 | | --batch_size | 每批训练的样本数 | 6 | | --base_lr | 基础学习率 | 0.01 | | --img_size | 输入图像尺寸 | 224 | | --max_epochs | 训练轮数 | 300 | | --expansion_factor | MSCB块扩展因子 | 2 | | --kernel_sizes | 多尺度卷积核大小 | [1, 3, 5] | ### 模型推理 训练完成后,可以使用以下命令进行模型推理: ```bash python tests/test_synapse.py ``` 推理结果默认保存在predictions目录下,可以选择保存为nii格式的分割结果文件。 ### 使用ACDC数据集 在ACDC数据集上训练时,需要将--dataset参数设置为ACDC,并调整相应的数据路径和类别数: ```bash python scripts/train_synapse.py \ --root_path ./data/ACDC/train \ --volume_path ./data/ACDC/test \ --list_dir ./data/ACDC/ACDC/lists_ACDC \ --dataset ACDC \ --num_classes 4 \ --encoder pvt_v2_b2 ``` ## API参考 ### EMCADNet类 EMCADNet是项目的主模型类,用于构建完整的分割网络。 **构造函数参数:** | 参数 | 类型 | 说明 | 默认值 | |------|------|------|--------| | num_classes | int | 输出类别数 | 1 | | kernel_sizes | list | 多尺度卷积核大小列表 | [1, 3, 5] | | expansion_factor | int | 扩展因子 | 2 | | dw_parallel | bool | 是否使用并行深度卷积 | True | | add | bool | 特征融合方式 | True | | lgag_ks | int | LGAG模块卷积核大小 | 3 | | activation | str | 激活函数类型 | relu | | encoder | str | 骨干网络名称 | pvt_v2_b2 | | pretrain | bool | 是否加载预训练权重 | True | | pretrained_dir | str | 预训练权重路径 | ./pretrained_pth/pvt/ | ### 数据集类 **Synapse_dataset类** 用于加载Synapse数据集,支持训练和测试两种模式。 ```python from src.utils.dataset_synapse import Synapse_dataset, RandomGenerator db_train = Synapse_dataset( base_dir="./data/synapse", list_dir="./data/lists/lists_Synapse", split="train", nclass=9, transform=transforms.Compose([RandomGenerator(output_size=[224, 224])]) ) ``` **ACDCdataset类** 用于加载ACDC数据集。 ```python from src.utils.dataset_ACDC import ACDCdataset db_train = ACDCdataset( base_dir="./data/ACDC", list_dir="./data/ACDC/ACDC/lists_ACDC", split="train", transform=transform ) ``` ### 工具函数 项目提供了丰富的工具函数: | 函数 | 说明 | |------|------| | DiceLoss | 计算Dice损失 | | powerset | 生成子集 | | val_single_volume | 单个体数据的验证推理 | | test_single_volume | 单个体数据的完整推理和结果保存 | ## 扩展与定制 ### 自定义骨干网络 如果需要使用其他骨干网络,可以在src/core/networks.py文件中参考现有实现,添加新的骨干网络支持代码。新的骨干网络需要实现特征提取功能,并返回四个尺度的特征图。 ### 添加新的数据集 要支持新的数据集,可以在src/utils/目录下创建新的数据集类,参考Synapse_dataset的实现方式。新数据集类需要继承Dataset类,并实现__len__和__getitem__方法。 ### 调整模型配置 模型的各个组件都可以通过参数进行调整,包括卷积核大小、扩展因子、注意力机制类型等。建议从默认配置开始,逐步调整以找到最适合特定任务和数据的配置。 ### 运行测试 项目提供了测试脚本来验证模型是否正确: ```bash python tests/test_synapse.py --dataset Synapse --num_classes 9 ``` ## 常见问题 ### 1. 导入模块时提示ModuleNotFoundError 请确保已正确安装项目: ```bash pip install -e . ``` 或者设置PYTHONPATH环境变量: ```bash export PYTHONPATH=$PYTHONPATH:/path/to/EMCAD ``` ### 2. 预训练权重加载失败 请检查预训练权重路径是否正确,确保PVTv2权重文件已下载并放置在正确的目录。 ### 3. 内存不足错误 尝试减小batch_size或img_size参数的值。 ## 引用 如果您在研究中使用了本项目,请引用以下论文: ```bibtex @article{rahman2024emcad, title={EMCAD: Efficient Multi-scale Convolutional Attention Decoding for Medical Image Segmentation}, author={Rahman, Md Mostafijur and Munir, Mustafa and Marculescu, Radu}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, pages={11769--11779}, year={2024} } ``` ## 致谢 本项目的实现参考了以下优秀的开源项目: - [timm](https://github.com/rwightman/pytorch-image-models):提供丰富的预训练模型 - [TransUNet](https://github.com/Beckschen/TransUNet):医学图像分割框架参考 - [CASCADE](https://github.com/GuoHuang19/CASCADE)、[MERIT](https://github.com/JunrenT/Class-Balanced-Loss)、[G-CASCADE](https://github.com/linzhlgithub/G-CASCADE):分割方法参考 在此向这些项目的作者表示感谢。 ## 许可证 本项目遵循MIT许可证开源。