You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

13 KiB

本项目用于论文复现https://arxiv.org/abs/2405.06880

EMCAD高效多尺度卷积注意力解码器用于医学图像分割

本项目是论文《EMCAD用于医学图像分割的高效多尺度卷积注意力解码器》的官方PyTorch实现该论文发表于CVPR 2024。项目作者包括Md Mostafijur Rahman、Mustafa Munir和Radu Marculescu均来自德克萨斯大学奥斯汀分校。

项目概述

EMCAD是一种创新的医学图像分割框架它通过引入高效多尺度卷积注意力解码器来提升分割精度和效率。该方法在多个医学图像数据集上展现了卓越的性能特别是在Synapse多器官分割任务中取得了优异的成绩。

本项目的核心贡献在于提出了一种新型的解码器架构该架构能够有效地融合多尺度特征信息并通过注意力机制增强特征表达能力。与传统的解码器相比EMCAD在保持较低计算开销的同时显著提升了分割精度这使得它特别适合于资源受限的医疗应用场景。

主要特性

EMCAD具有以下核心特性首先它采用了高效的多尺度卷积注意力机制能够同时捕获不同尺度的特征信息并通过注意力机制进行特征增强。其次EMCAD支持多种骨干网络包括PVTv2系列和ResNet系列用户可以根据实际需求选择合适的骨干网络进行特征提取。第三EMCAD提供了灵活的配置选项支持自定义卷积核大小、扩展因子、注意力机制类型等参数以适应不同的应用场景。

此外EMCAD还支持深度监督和变异监督两种训练策略用户可以根据具体任务选择合适的训练方式。模型还支持并行深度卷积和特征级联两种融合模式提供了丰富的可调节参数。

性能指标

EMCAD在Synapse数据集上展现了优异的分割性能平均Dice系数达到了85%以上同时保持了较低的计算复杂度FLOPs和参数量。与现有的主流分割方法相比EMCAD在精度和效率之间取得了更好的平衡。

项目结构

EMCAD/
├── README.md                    # 项目说明文档
├── setup.py                     # 安装配置脚本
├── requirements.txt             # 依赖包列表
├── .gitignore                   # Git忽略文件配置
│
├── src/                         # 源代码目录
│   ├── core/                    # 核心模型定义
│   │   ├── networks.py          # EMCADNet主网络结构
│   │   ├── decoders.py          # EMCAD解码器模块
│   │   ├── pvtv2.py             # PVTv2骨干网络
│   │   └── resnet.py            # ResNet系列骨干网络
│   │
│   └── utils/                   # 工具模块
│       ├── dataset_synapse.py   # Synapse数据集加载器
│       ├── dataset_ACDC.py      # ACDC数据集加载器
│       ├── dataloader.py        # 通用数据加载器
│       ├── trainer.py           # 训练核心逻辑
│       ├── utils.py             # 实用工具函数
│       ├── transforms.py        # 数据变换操作
│       ├── joint_transforms.py  # 联合变换操作
│       ├── format_conversion.py # 格式转换工具
│       ├── preprocess_synapse_data.py     # Synapse数据预处理
│       ├── preprocess_synapse_data_3d.py  # 3D数据预处理
│       └── misc.py              # 其他工具函数
│
├── scripts/                     # 脚本目录
│   └── train_synapse.py         # Synapse数据集训练入口
│
├── tests/                       # 测试代码目录
│   └── test_synapse.py          # 模型测试脚本
│
├── data/                        # 数据目录
│   ├── lists/                   # 数据列表文件
│   │   └── lists_Synapse/       # Synapse数据划分
│   │       ├── train.txt
│   │       ├── test_vol.txt
│   │       └── .ipynb_checkpoints/
│   │
│   └── ACDC/                    # ACDC数据集
│       └── ACDC/
│           ├── lists_ACDC/      # ACDC数据划分
│           │   ├── train.txt
│           │   ├── valid.txt
│           │   └── test.txt
│           ├── train/           # 训练数据
│           ├── valid/           # 验证数据
│           └── test/            # 测试数据
│
├── model_pth/                   # 预训练模型权重目录
├── pretrained_pth/              # 预训练权重目录
│   └── pvt/                     # PVTv2预训练权重
│
└── SimpleITK.whl                # SimpleITK安装包

环境配置

系统要求

本项目推荐在Linux操作系统上运行也支持Windows和macOS系统。建议使用Python 3.8或更高版本强烈建议使用Anaconda或Miniconda创建独立的虚拟环境以避免与系统其他Python项目产生依赖冲突。

依赖安装

首先创建并激活虚拟环境:

conda create -n emcadenv python=3.8
conda activate emcadenv

然后安装PyTorch和相关依赖。对于CUDA 11.3用户建议使用以下命令安装兼容的PyTorch版本

pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113

如果使用其他版本的CUDA请访问PyTorch官方网站获取对应的安装命令。接下来安装其他必要的依赖包

pip install -r requirements.txt

对于Windows系统可能需要手动安装SimpleITK

pip install SimpleITK.whl

依赖包说明

requirements.txt中包含了项目运行所需的所有依赖包主要包括

  • numpy数值计算
  • torch和torchvision深度学习框架
  • h5py处理HDF5格式数据
  • scipy科学计算
  • matplotlib可视化
  • tqdm显示进度条
  • tensorboardX可视化训练过程
  • nibabel处理医学图像格式
  • medpy计算医学图像评价指标
  • ptflops和thop计算模型复杂度
  • segmentation-mask-overlay叠加分割掩膜可视化
  • timm提供丰富的预训练模型

数据准备

Synapse多器官数据集

Synapse数据集是医学图像分割领域常用的基准数据集包含了30个腹部CT扫描案例涉及脾脏、肾脏、肝脏、胃、主动脉等多个器官的分割标注。首先需要注册并从Synapse官方网站下载数据集然后按照TransUNet项目提供的数据划分方式将RawData文件夹划分为包含18个扫描的训练集和包含12个扫描的测试集存放在./data/synapse/Abdomen/RawData/目录下。

数据预处理有两种方式:

方式一:运行预处理脚本

python src/utils/preprocess_synapse_data.py

方式二直接下载预处理的Synapse数据

从项目 releases 页面下载预处理的Synapse数据保存到./data/目录下。

ACDC数据集

ACDC数据集是心脏MRI分割的标准数据集。推荐从MT-UNet项目的Google Drive下载预处理的ACDC数据集然后解压到./data/ACDC/目录即可使用。

预训练模型

EMCAD使用PVTv2作为默认的骨干网络需要下载预训练的PVTv2模型权重。预训练模型可以从以下来源下载

  • Google Drive项目提供的预训练权重
  • PVT GitHub releases页面官方PVTv2预训练权重

下载后将模型文件放入./model_pth/或./pretrained_pth/pvt/目录。默认的预训练路径可以通过--pretrained_dir参数进行配置。

快速开始

训练模型

在Synapse数据集上训练EMCAD模型的基本命令如下

python scripts/train_synapse.py --root_path /path/to/train/data --volume_path /path/to/test/data --encoder pvt_v2_b2

常用参数说明:

参数 说明 默认值
--root_path 训练数据的根目录 /data/ACDC/train
--volume_path 测试数据的根目录 /data/ACDC/test
--dataset 数据集名称 ADC
--list_dir 数据列表目录 /data/ACDC/lists_ACDC
--num_classes 分割类别数 4ACDC/9Synapse
--encoder 骨干网络类型 pvt_v2_b2
--batch_size 每批训练的样本数 6
--base_lr 基础学习率 0.01
--img_size 输入图像尺寸 224
--max_epochs 训练轮数 300
--expansion_factor MSCB块扩展因子 2
--kernel_sizes 多尺度卷积核大小 [1, 3, 5]

模型推理

训练完成后,可以使用以下命令进行模型推理:

python tests/test_synapse.py

推理结果默认保存在predictions目录下可以选择保存为nii格式的分割结果文件。

使用ACDC数据集

在ACDC数据集上训练时需要将--dataset参数设置为ACDC并调整相应的数据路径和类别数

python scripts/train_synapse.py \
    --root_path ./data/ACDC/train \
    --volume_path ./data/ACDC/test \
    --list_dir ./data/ACDC/ACDC/lists_ACDC \
    --dataset ACDC \
    --num_classes 4 \
    --encoder pvt_v2_b2

API参考

EMCADNet类

EMCADNet是项目的主模型类用于构建完整的分割网络。

构造函数参数:

参数 类型 说明 默认值
num_classes int 输出类别数 1
kernel_sizes list 多尺度卷积核大小列表 [1, 3, 5]
expansion_factor int 扩展因子 2
dw_parallel bool 是否使用并行深度卷积 True
add bool 特征融合方式 True
lgag_ks int LGAG模块卷积核大小 3
activation str 激活函数类型 relu
encoder str 骨干网络名称 pvt_v2_b2
pretrain bool 是否加载预训练权重 True
pretrained_dir str 预训练权重路径 ./pretrained_pth/pvt/

数据集类

Synapse_dataset类

用于加载Synapse数据集支持训练和测试两种模式。

from src.utils.dataset_synapse import Synapse_dataset, RandomGenerator

db_train = Synapse_dataset(
    base_dir="./data/synapse",
    list_dir="./data/lists/lists_Synapse",
    split="train",
    nclass=9,
    transform=transforms.Compose([RandomGenerator(output_size=[224, 224])])
)

ACDCdataset类

用于加载ACDC数据集。

from src.utils.dataset_ACDC import ACDCdataset

db_train = ACDCdataset(
    base_dir="./data/ACDC",
    list_dir="./data/ACDC/ACDC/lists_ACDC",
    split="train",
    transform=transform
)

工具函数

项目提供了丰富的工具函数:

函数 说明
DiceLoss 计算Dice损失
powerset 生成子集
val_single_volume 单个体数据的验证推理
test_single_volume 单个体数据的完整推理和结果保存

扩展与定制

自定义骨干网络

如果需要使用其他骨干网络可以在src/core/networks.py文件中参考现有实现添加新的骨干网络支持代码。新的骨干网络需要实现特征提取功能并返回四个尺度的特征图。

添加新的数据集

要支持新的数据集可以在src/utils/目录下创建新的数据集类参考Synapse_dataset的实现方式。新数据集类需要继承Dataset类并实现__len__和__getitem__方法。

调整模型配置

模型的各个组件都可以通过参数进行调整,包括卷积核大小、扩展因子、注意力机制类型等。建议从默认配置开始,逐步调整以找到最适合特定任务和数据的配置。

运行测试

项目提供了测试脚本来验证模型是否正确:

python tests/test_synapse.py --dataset Synapse --num_classes 9

常见问题

1. 导入模块时提示ModuleNotFoundError

请确保已正确安装项目:

pip install -e .

或者设置PYTHONPATH环境变量

export PYTHONPATH=$PYTHONPATH:/path/to/EMCAD

2. 预训练权重加载失败

请检查预训练权重路径是否正确确保PVTv2权重文件已下载并放置在正确的目录。

3. 内存不足错误

尝试减小batch_size或img_size参数的值。

引用

如果您在研究中使用了本项目,请引用以下论文:

@article{rahman2024emcad,
  title={EMCAD: Efficient Multi-scale Convolutional Attention Decoding for Medical Image Segmentation},
  author={Rahman, Md Mostafijur and Munir, Mustafa and Marculescu, Radu},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={11769--11779},
  year={2024}
}

致谢

本项目的实现参考了以下优秀的开源项目:

在此向这些项目的作者表示感谢。

许可证

本项目遵循MIT许可证开源。