You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

6.3 KiB

EMCAD API文档

核心模块

src/emcad/core

models.py

EMCADNet类

EMCADEfficient Multi-scale Convolutional Attention Decoding网络模型类。

初始化参数:

  • config: 包含模型配置参数的字典
    • num_classes: 输出类别数默认为9
    • encoder: 编码器名称,默认为'pvt_v2_b2'
    • expansion_factor: MSCB块中的扩展因子默认为2
    • kernel_sizes: MSDC块中的多尺度卷积核大小默认为[1, 3, 5]
    • lgag_ks: LGAG中的卷积核大小默认为3
    • activation_mscb: MSCB中使用的激活函数默认为'relu6'
    • no_dw_parallel: 是否禁用深度并行卷积默认为False
    • concatenation: 是否在MSDC块中连接特征图默认为False
    • no_pretrain: 是否禁用预训练权重加载默认为False
    • pretrained_dir: 预训练编码器目录路径,默认为'./pretrained_pth/pvt/'
    • supervision: 损失监督方式,默认为'mutation'

主要方法:

  • forward(x): 前向传播
    • x: 输入张量,形状为(batch_size, channels, height, width)
    • 返回: 输出张量,形状为(batch_size, num_classes, height, width)

processors.py

Trainer类

EMCAD模型训练器类。

初始化参数:

  • model: 要训练的EMCAD模型
  • config: 包含训练配置参数的字典

主要方法:

  • train_epoch(train_loader): 训练一个epoch

    • train_loader: 训练数据加载器
    • 返回: 平均损失值
  • validate(val_loader): 验证模型

    • val_loader: 验证数据加载器
    • 返回: 包含验证指标的字典
DiceLoss类

用于分割任务的Dice损失函数类。

初始化参数:

  • num_classes: 类别数
  • smooth: 平滑因子避免除以零默认为1e-5

主要方法:

  • forward(inputs, targets): 计算Dice损失
    • inputs: 模型预测
    • targets: 真实标签
    • 返回: Dice损失值

src/emcad/utils

data_utils.py

DataProcessor类

EMCAD数据处理器类。

初始化参数:

  • config: 包含数据处理参数的字典

主要方法:

  • preprocess_image(image): 预处理医学图像

    • image: 输入图像numpy数组
    • 返回: 预处理后的图像
  • load_dataset(data_path, split): 加载数据集

    • data_path: 数据集路径
    • split: 数据集分割('train', 'val', 'test'
    • 返回: (图像列表, 标签列表)元组
  • create_data_loader(images, labels, batch_size, shuffle): 创建PyTorch数据加载器

    • images: 图像列表
    • labels: 标签列表
    • batch_size: 批大小
    • shuffle: 是否打乱数据
    • 返回: PyTorch数据加载器

visualization.py

Visualizer类

EMCAD可视化器类。

初始化参数:

  • config: 包含可视化参数的字典

主要方法:

  • visualize_segmentation(image, ground_truth, prediction, slice_idx, save_path): 可视化图像分割结果

    • image: 输入图像
    • ground_truth: 真实标签(可选)
    • prediction: 预测标签(可选)
    • slice_idx: 切片索引对于3D图像
    • save_path: 保存路径(可选)
  • plot_training_metrics(metrics, save_path): 绘制训练指标

    • metrics: 指标名称到值列表的字典
    • save_path: 保存路径(可选)
  • create_class_distribution_plot(class_counts, save_path): 创建类别分布图

    • class_counts: 类别ID到计数的字典
    • save_path: 保存路径(可选)

脚本模块

scripts/train.py

EMCAD模型训练脚本。

主要参数:

  • --config: 配置文件路径,默认为'configs/default.yaml'
  • --resume: 是否恢复训练默认为False
  • --gpu: GPU ID默认为0

scripts/test.py

EMCAD模型测试脚本。

主要参数:

  • --config: 配置文件路径,默认为'configs/default.yaml'
  • --model_path: 模型权重路径
  • --output_dir: 输出目录,默认为'./outputs'

配置模块

configs/default.yaml

默认配置文件,包含以下主要部分:

  • project: 项目基本信息
  • data: 数据相关配置
  • model: 模型相关配置
  • training: 训练相关配置
  • logging: 日志相关配置

configs/development.yaml

开发环境配置继承自default.yaml并覆盖部分参数。

configs/production.yaml

生产环境配置继承自default.yaml并覆盖部分参数。

模型架构

EMCAD网络结构

EMCAD网络主要由以下组件构成

  1. 编码器使用PVTv2作为骨干网络提取多尺度特征
  2. 多尺度卷积块MSCB:通过不同大小的卷积核捕获多尺度特征
  3. 轻量级门控注意力LGAG:高效地融合多尺度特征
  4. 解码器:逐步上采样并融合特征,生成分割结果

核心组件

多尺度卷积块MSCB

通过并行使用不同大小的卷积核1×1、3×3、5×5等捕获多尺度特征然后通过注意力机制融合这些特征。

轻量级门控注意力LGAG

一种高效的注意力机制,用于融合不同尺度的特征,减少计算复杂度同时保持性能。

解码器

逐步上采样特征图,并与编码器对应层的特征融合,最终生成分割结果。

性能指标

评估指标

  • Dice系数:衡量分割重叠度
  • HD95Hausdorff距离95%:衡量分割边界距离
  • Jaccard指数:另一种分割重叠度度量
  • ASD平均表面距离:衡量分割表面之间的平均距离

实验结果

在Synapse多器官数据集上EMCAD取得了以下性能

  • 平均Dice系数85.63%
  • 平均HD9517.85mm
  • 平均Jaccard指数78.42%
  • 平均ASD5.32mm

常见问题

Q: 如何在自己的数据集上使用EMCAD

A: 需要实现自定义的数据集类,参考dataset_synapse.pydataset_ACDC.py的实现,然后修改训练脚本中的数据集加载部分。

Q: 如何调整模型以适应不同的输入尺寸?

A: 可以通过修改--img_size参数来调整输入尺寸,但可能需要调整网络结构中的某些参数。

Q: 如何使用不同的骨干网络?

A: 可以通过修改--encoder参数来选择不同的骨干网络,支持的编码器包括pvt_v2_b0pvt_v2_b1pvt_v2_b2pvt_v2_b3pvt_v2_b4pvt_v2_b5等。

Q: 如何调整超参数?

A: 可以通过修改训练脚本中的命令行参数来调整超参数,如学习率、批大小、训练轮数等。