|
|
# EMCAD API文档
|
|
|
|
|
|
## 核心模块
|
|
|
|
|
|
### src/emcad/core
|
|
|
|
|
|
#### models.py
|
|
|
|
|
|
##### EMCADNet类
|
|
|
EMCAD(Efficient Multi-scale Convolutional Attention Decoding)网络模型类。
|
|
|
|
|
|
**初始化参数:**
|
|
|
- `config`: 包含模型配置参数的字典
|
|
|
- `num_classes`: 输出类别数,默认为9
|
|
|
- `encoder`: 编码器名称,默认为'pvt_v2_b2'
|
|
|
- `expansion_factor`: MSCB块中的扩展因子,默认为2
|
|
|
- `kernel_sizes`: MSDC块中的多尺度卷积核大小,默认为[1, 3, 5]
|
|
|
- `lgag_ks`: LGAG中的卷积核大小,默认为3
|
|
|
- `activation_mscb`: MSCB中使用的激活函数,默认为'relu6'
|
|
|
- `no_dw_parallel`: 是否禁用深度并行卷积,默认为False
|
|
|
- `concatenation`: 是否在MSDC块中连接特征图,默认为False
|
|
|
- `no_pretrain`: 是否禁用预训练权重加载,默认为False
|
|
|
- `pretrained_dir`: 预训练编码器目录路径,默认为'./pretrained_pth/pvt/'
|
|
|
- `supervision`: 损失监督方式,默认为'mutation'
|
|
|
|
|
|
**主要方法:**
|
|
|
- `forward(x)`: 前向传播
|
|
|
- `x`: 输入张量,形状为(batch_size, channels, height, width)
|
|
|
- 返回: 输出张量,形状为(batch_size, num_classes, height, width)
|
|
|
|
|
|
#### processors.py
|
|
|
|
|
|
##### Trainer类
|
|
|
EMCAD模型训练器类。
|
|
|
|
|
|
**初始化参数:**
|
|
|
- `model`: 要训练的EMCAD模型
|
|
|
- `config`: 包含训练配置参数的字典
|
|
|
|
|
|
**主要方法:**
|
|
|
- `train_epoch(train_loader)`: 训练一个epoch
|
|
|
- `train_loader`: 训练数据加载器
|
|
|
- 返回: 平均损失值
|
|
|
|
|
|
- `validate(val_loader)`: 验证模型
|
|
|
- `val_loader`: 验证数据加载器
|
|
|
- 返回: 包含验证指标的字典
|
|
|
|
|
|
##### DiceLoss类
|
|
|
用于分割任务的Dice损失函数类。
|
|
|
|
|
|
**初始化参数:**
|
|
|
- `num_classes`: 类别数
|
|
|
- `smooth`: 平滑因子,避免除以零,默认为1e-5
|
|
|
|
|
|
**主要方法:**
|
|
|
- `forward(inputs, targets)`: 计算Dice损失
|
|
|
- `inputs`: 模型预测
|
|
|
- `targets`: 真实标签
|
|
|
- 返回: Dice损失值
|
|
|
|
|
|
### src/emcad/utils
|
|
|
|
|
|
#### data_utils.py
|
|
|
|
|
|
##### DataProcessor类
|
|
|
EMCAD数据处理器类。
|
|
|
|
|
|
**初始化参数:**
|
|
|
- `config`: 包含数据处理参数的字典
|
|
|
|
|
|
**主要方法:**
|
|
|
- `preprocess_image(image)`: 预处理医学图像
|
|
|
- `image`: 输入图像,numpy数组
|
|
|
- 返回: 预处理后的图像
|
|
|
|
|
|
- `load_dataset(data_path, split)`: 加载数据集
|
|
|
- `data_path`: 数据集路径
|
|
|
- `split`: 数据集分割('train', 'val', 'test')
|
|
|
- 返回: (图像列表, 标签列表)元组
|
|
|
|
|
|
- `create_data_loader(images, labels, batch_size, shuffle)`: 创建PyTorch数据加载器
|
|
|
- `images`: 图像列表
|
|
|
- `labels`: 标签列表
|
|
|
- `batch_size`: 批大小
|
|
|
- `shuffle`: 是否打乱数据
|
|
|
- 返回: PyTorch数据加载器
|
|
|
|
|
|
#### visualization.py
|
|
|
|
|
|
##### Visualizer类
|
|
|
EMCAD可视化器类。
|
|
|
|
|
|
**初始化参数:**
|
|
|
- `config`: 包含可视化参数的字典
|
|
|
|
|
|
**主要方法:**
|
|
|
- `visualize_segmentation(image, ground_truth, prediction, slice_idx, save_path)`: 可视化图像分割结果
|
|
|
- `image`: 输入图像
|
|
|
- `ground_truth`: 真实标签(可选)
|
|
|
- `prediction`: 预测标签(可选)
|
|
|
- `slice_idx`: 切片索引(对于3D图像)
|
|
|
- `save_path`: 保存路径(可选)
|
|
|
|
|
|
- `plot_training_metrics(metrics, save_path)`: 绘制训练指标
|
|
|
- `metrics`: 指标名称到值列表的字典
|
|
|
- `save_path`: 保存路径(可选)
|
|
|
|
|
|
- `create_class_distribution_plot(class_counts, save_path)`: 创建类别分布图
|
|
|
- `class_counts`: 类别ID到计数的字典
|
|
|
- `save_path`: 保存路径(可选)
|
|
|
|
|
|
### 脚本模块
|
|
|
|
|
|
#### scripts/train.py
|
|
|
EMCAD模型训练脚本。
|
|
|
|
|
|
**主要参数:**
|
|
|
- `--config`: 配置文件路径,默认为'configs/default.yaml'
|
|
|
- `--resume`: 是否恢复训练,默认为False
|
|
|
- `--gpu`: GPU ID,默认为0
|
|
|
|
|
|
#### scripts/test.py
|
|
|
EMCAD模型测试脚本。
|
|
|
|
|
|
**主要参数:**
|
|
|
- `--config`: 配置文件路径,默认为'configs/default.yaml'
|
|
|
- `--model_path`: 模型权重路径
|
|
|
- `--output_dir`: 输出目录,默认为'./outputs'
|
|
|
|
|
|
### 配置模块
|
|
|
|
|
|
#### configs/default.yaml
|
|
|
默认配置文件,包含以下主要部分:
|
|
|
- `project`: 项目基本信息
|
|
|
- `data`: 数据相关配置
|
|
|
- `model`: 模型相关配置
|
|
|
- `training`: 训练相关配置
|
|
|
- `logging`: 日志相关配置
|
|
|
|
|
|
#### configs/development.yaml
|
|
|
开发环境配置,继承自default.yaml并覆盖部分参数。
|
|
|
|
|
|
#### configs/production.yaml
|
|
|
生产环境配置,继承自default.yaml并覆盖部分参数。
|
|
|
|
|
|
## 模型架构
|
|
|
|
|
|
### EMCAD网络结构
|
|
|
|
|
|
EMCAD网络主要由以下组件构成:
|
|
|
|
|
|
1. **编码器**:使用PVTv2作为骨干网络,提取多尺度特征
|
|
|
2. **多尺度卷积块(MSCB)**:通过不同大小的卷积核捕获多尺度特征
|
|
|
3. **轻量级门控注意力(LGAG)**:高效地融合多尺度特征
|
|
|
4. **解码器**:逐步上采样并融合特征,生成分割结果
|
|
|
|
|
|
### 核心组件
|
|
|
|
|
|
#### 多尺度卷积块(MSCB)
|
|
|
通过并行使用不同大小的卷积核(1×1、3×3、5×5等)捕获多尺度特征,然后通过注意力机制融合这些特征。
|
|
|
|
|
|
#### 轻量级门控注意力(LGAG)
|
|
|
一种高效的注意力机制,用于融合不同尺度的特征,减少计算复杂度同时保持性能。
|
|
|
|
|
|
#### 解码器
|
|
|
逐步上采样特征图,并与编码器对应层的特征融合,最终生成分割结果。
|
|
|
|
|
|
## 性能指标
|
|
|
|
|
|
### 评估指标
|
|
|
- **Dice系数**:衡量分割重叠度
|
|
|
- **HD95(Hausdorff距离95%)**:衡量分割边界距离
|
|
|
- **Jaccard指数**:另一种分割重叠度度量
|
|
|
- **ASD(平均表面距离)**:衡量分割表面之间的平均距离
|
|
|
|
|
|
### 实验结果
|
|
|
在Synapse多器官数据集上,EMCAD取得了以下性能:
|
|
|
- 平均Dice系数:85.63%
|
|
|
- 平均HD95:17.85mm
|
|
|
- 平均Jaccard指数:78.42%
|
|
|
- 平均ASD:5.32mm
|
|
|
|
|
|
## 常见问题
|
|
|
|
|
|
### Q: 如何在自己的数据集上使用EMCAD?
|
|
|
A: 需要实现自定义的数据集类,参考`dataset_synapse.py`和`dataset_ACDC.py`的实现,然后修改训练脚本中的数据集加载部分。
|
|
|
|
|
|
### Q: 如何调整模型以适应不同的输入尺寸?
|
|
|
A: 可以通过修改`--img_size`参数来调整输入尺寸,但可能需要调整网络结构中的某些参数。
|
|
|
|
|
|
### Q: 如何使用不同的骨干网络?
|
|
|
A: 可以通过修改`--encoder`参数来选择不同的骨干网络,支持的编码器包括`pvt_v2_b0`、`pvt_v2_b1`、`pvt_v2_b2`、`pvt_v2_b3`、`pvt_v2_b4`、`pvt_v2_b5`等。
|
|
|
|
|
|
### Q: 如何调整超参数?
|
|
|
A: 可以通过修改训练脚本中的命令行参数来调整超参数,如学习率、批大小、训练轮数等。
|