6.3 KiB
EMCAD API文档
核心模块
src/emcad/core
models.py
EMCADNet类
EMCAD(Efficient Multi-scale Convolutional Attention Decoding)网络模型类。
初始化参数:
config: 包含模型配置参数的字典num_classes: 输出类别数,默认为9encoder: 编码器名称,默认为'pvt_v2_b2'expansion_factor: MSCB块中的扩展因子,默认为2kernel_sizes: MSDC块中的多尺度卷积核大小,默认为[1, 3, 5]lgag_ks: LGAG中的卷积核大小,默认为3activation_mscb: MSCB中使用的激活函数,默认为'relu6'no_dw_parallel: 是否禁用深度并行卷积,默认为Falseconcatenation: 是否在MSDC块中连接特征图,默认为Falseno_pretrain: 是否禁用预训练权重加载,默认为Falsepretrained_dir: 预训练编码器目录路径,默认为'./pretrained_pth/pvt/'supervision: 损失监督方式,默认为'mutation'
主要方法:
forward(x): 前向传播x: 输入张量,形状为(batch_size, channels, height, width)- 返回: 输出张量,形状为(batch_size, num_classes, height, width)
processors.py
Trainer类
EMCAD模型训练器类。
初始化参数:
model: 要训练的EMCAD模型config: 包含训练配置参数的字典
主要方法:
-
train_epoch(train_loader): 训练一个epochtrain_loader: 训练数据加载器- 返回: 平均损失值
-
validate(val_loader): 验证模型val_loader: 验证数据加载器- 返回: 包含验证指标的字典
DiceLoss类
用于分割任务的Dice损失函数类。
初始化参数:
num_classes: 类别数smooth: 平滑因子,避免除以零,默认为1e-5
主要方法:
forward(inputs, targets): 计算Dice损失inputs: 模型预测targets: 真实标签- 返回: Dice损失值
src/emcad/utils
data_utils.py
DataProcessor类
EMCAD数据处理器类。
初始化参数:
config: 包含数据处理参数的字典
主要方法:
-
preprocess_image(image): 预处理医学图像image: 输入图像,numpy数组- 返回: 预处理后的图像
-
load_dataset(data_path, split): 加载数据集data_path: 数据集路径split: 数据集分割('train', 'val', 'test')- 返回: (图像列表, 标签列表)元组
-
create_data_loader(images, labels, batch_size, shuffle): 创建PyTorch数据加载器images: 图像列表labels: 标签列表batch_size: 批大小shuffle: 是否打乱数据- 返回: PyTorch数据加载器
visualization.py
Visualizer类
EMCAD可视化器类。
初始化参数:
config: 包含可视化参数的字典
主要方法:
-
visualize_segmentation(image, ground_truth, prediction, slice_idx, save_path): 可视化图像分割结果image: 输入图像ground_truth: 真实标签(可选)prediction: 预测标签(可选)slice_idx: 切片索引(对于3D图像)save_path: 保存路径(可选)
-
plot_training_metrics(metrics, save_path): 绘制训练指标metrics: 指标名称到值列表的字典save_path: 保存路径(可选)
-
create_class_distribution_plot(class_counts, save_path): 创建类别分布图class_counts: 类别ID到计数的字典save_path: 保存路径(可选)
脚本模块
scripts/train.py
EMCAD模型训练脚本。
主要参数:
--config: 配置文件路径,默认为'configs/default.yaml'--resume: 是否恢复训练,默认为False--gpu: GPU ID,默认为0
scripts/test.py
EMCAD模型测试脚本。
主要参数:
--config: 配置文件路径,默认为'configs/default.yaml'--model_path: 模型权重路径--output_dir: 输出目录,默认为'./outputs'
配置模块
configs/default.yaml
默认配置文件,包含以下主要部分:
project: 项目基本信息data: 数据相关配置model: 模型相关配置training: 训练相关配置logging: 日志相关配置
configs/development.yaml
开发环境配置,继承自default.yaml并覆盖部分参数。
configs/production.yaml
生产环境配置,继承自default.yaml并覆盖部分参数。
模型架构
EMCAD网络结构
EMCAD网络主要由以下组件构成:
- 编码器:使用PVTv2作为骨干网络,提取多尺度特征
- 多尺度卷积块(MSCB):通过不同大小的卷积核捕获多尺度特征
- 轻量级门控注意力(LGAG):高效地融合多尺度特征
- 解码器:逐步上采样并融合特征,生成分割结果
核心组件
多尺度卷积块(MSCB)
通过并行使用不同大小的卷积核(1×1、3×3、5×5等)捕获多尺度特征,然后通过注意力机制融合这些特征。
轻量级门控注意力(LGAG)
一种高效的注意力机制,用于融合不同尺度的特征,减少计算复杂度同时保持性能。
解码器
逐步上采样特征图,并与编码器对应层的特征融合,最终生成分割结果。
性能指标
评估指标
- Dice系数:衡量分割重叠度
- HD95(Hausdorff距离95%):衡量分割边界距离
- Jaccard指数:另一种分割重叠度度量
- ASD(平均表面距离):衡量分割表面之间的平均距离
实验结果
在Synapse多器官数据集上,EMCAD取得了以下性能:
- 平均Dice系数:85.63%
- 平均HD95:17.85mm
- 平均Jaccard指数:78.42%
- 平均ASD:5.32mm
常见问题
Q: 如何在自己的数据集上使用EMCAD?
A: 需要实现自定义的数据集类,参考dataset_synapse.py和dataset_ACDC.py的实现,然后修改训练脚本中的数据集加载部分。
Q: 如何调整模型以适应不同的输入尺寸?
A: 可以通过修改--img_size参数来调整输入尺寸,但可能需要调整网络结构中的某些参数。
Q: 如何使用不同的骨干网络?
A: 可以通过修改--encoder参数来选择不同的骨干网络,支持的编码器包括pvt_v2_b0、pvt_v2_b1、pvt_v2_b2、pvt_v2_b3、pvt_v2_b4、pvt_v2_b5等。
Q: 如何调整超参数?
A: 可以通过修改训练脚本中的命令行参数来调整超参数,如学习率、批大小、训练轮数等。