# EMCAD API文档 ## 核心模块 ### src/emcad/core #### models.py ##### EMCADNet类 EMCAD(Efficient Multi-scale Convolutional Attention Decoding)网络模型类。 **初始化参数:** - `config`: 包含模型配置参数的字典 - `num_classes`: 输出类别数,默认为9 - `encoder`: 编码器名称,默认为'pvt_v2_b2' - `expansion_factor`: MSCB块中的扩展因子,默认为2 - `kernel_sizes`: MSDC块中的多尺度卷积核大小,默认为[1, 3, 5] - `lgag_ks`: LGAG中的卷积核大小,默认为3 - `activation_mscb`: MSCB中使用的激活函数,默认为'relu6' - `no_dw_parallel`: 是否禁用深度并行卷积,默认为False - `concatenation`: 是否在MSDC块中连接特征图,默认为False - `no_pretrain`: 是否禁用预训练权重加载,默认为False - `pretrained_dir`: 预训练编码器目录路径,默认为'./pretrained_pth/pvt/' - `supervision`: 损失监督方式,默认为'mutation' **主要方法:** - `forward(x)`: 前向传播 - `x`: 输入张量,形状为(batch_size, channels, height, width) - 返回: 输出张量,形状为(batch_size, num_classes, height, width) #### processors.py ##### Trainer类 EMCAD模型训练器类。 **初始化参数:** - `model`: 要训练的EMCAD模型 - `config`: 包含训练配置参数的字典 **主要方法:** - `train_epoch(train_loader)`: 训练一个epoch - `train_loader`: 训练数据加载器 - 返回: 平均损失值 - `validate(val_loader)`: 验证模型 - `val_loader`: 验证数据加载器 - 返回: 包含验证指标的字典 ##### DiceLoss类 用于分割任务的Dice损失函数类。 **初始化参数:** - `num_classes`: 类别数 - `smooth`: 平滑因子,避免除以零,默认为1e-5 **主要方法:** - `forward(inputs, targets)`: 计算Dice损失 - `inputs`: 模型预测 - `targets`: 真实标签 - 返回: Dice损失值 ### src/emcad/utils #### data_utils.py ##### DataProcessor类 EMCAD数据处理器类。 **初始化参数:** - `config`: 包含数据处理参数的字典 **主要方法:** - `preprocess_image(image)`: 预处理医学图像 - `image`: 输入图像,numpy数组 - 返回: 预处理后的图像 - `load_dataset(data_path, split)`: 加载数据集 - `data_path`: 数据集路径 - `split`: 数据集分割('train', 'val', 'test') - 返回: (图像列表, 标签列表)元组 - `create_data_loader(images, labels, batch_size, shuffle)`: 创建PyTorch数据加载器 - `images`: 图像列表 - `labels`: 标签列表 - `batch_size`: 批大小 - `shuffle`: 是否打乱数据 - 返回: PyTorch数据加载器 #### visualization.py ##### Visualizer类 EMCAD可视化器类。 **初始化参数:** - `config`: 包含可视化参数的字典 **主要方法:** - `visualize_segmentation(image, ground_truth, prediction, slice_idx, save_path)`: 可视化图像分割结果 - `image`: 输入图像 - `ground_truth`: 真实标签(可选) - `prediction`: 预测标签(可选) - `slice_idx`: 切片索引(对于3D图像) - `save_path`: 保存路径(可选) - `plot_training_metrics(metrics, save_path)`: 绘制训练指标 - `metrics`: 指标名称到值列表的字典 - `save_path`: 保存路径(可选) - `create_class_distribution_plot(class_counts, save_path)`: 创建类别分布图 - `class_counts`: 类别ID到计数的字典 - `save_path`: 保存路径(可选) ### 脚本模块 #### scripts/train.py EMCAD模型训练脚本。 **主要参数:** - `--config`: 配置文件路径,默认为'configs/default.yaml' - `--resume`: 是否恢复训练,默认为False - `--gpu`: GPU ID,默认为0 #### scripts/test.py EMCAD模型测试脚本。 **主要参数:** - `--config`: 配置文件路径,默认为'configs/default.yaml' - `--model_path`: 模型权重路径 - `--output_dir`: 输出目录,默认为'./outputs' ### 配置模块 #### configs/default.yaml 默认配置文件,包含以下主要部分: - `project`: 项目基本信息 - `data`: 数据相关配置 - `model`: 模型相关配置 - `training`: 训练相关配置 - `logging`: 日志相关配置 #### configs/development.yaml 开发环境配置,继承自default.yaml并覆盖部分参数。 #### configs/production.yaml 生产环境配置,继承自default.yaml并覆盖部分参数。 ## 模型架构 ### EMCAD网络结构 EMCAD网络主要由以下组件构成: 1. **编码器**:使用PVTv2作为骨干网络,提取多尺度特征 2. **多尺度卷积块(MSCB)**:通过不同大小的卷积核捕获多尺度特征 3. **轻量级门控注意力(LGAG)**:高效地融合多尺度特征 4. **解码器**:逐步上采样并融合特征,生成分割结果 ### 核心组件 #### 多尺度卷积块(MSCB) 通过并行使用不同大小的卷积核(1×1、3×3、5×5等)捕获多尺度特征,然后通过注意力机制融合这些特征。 #### 轻量级门控注意力(LGAG) 一种高效的注意力机制,用于融合不同尺度的特征,减少计算复杂度同时保持性能。 #### 解码器 逐步上采样特征图,并与编码器对应层的特征融合,最终生成分割结果。 ## 性能指标 ### 评估指标 - **Dice系数**:衡量分割重叠度 - **HD95(Hausdorff距离95%)**:衡量分割边界距离 - **Jaccard指数**:另一种分割重叠度度量 - **ASD(平均表面距离)**:衡量分割表面之间的平均距离 ### 实验结果 在Synapse多器官数据集上,EMCAD取得了以下性能: - 平均Dice系数:85.63% - 平均HD95:17.85mm - 平均Jaccard指数:78.42% - 平均ASD:5.32mm ## 常见问题 ### Q: 如何在自己的数据集上使用EMCAD? A: 需要实现自定义的数据集类,参考`dataset_synapse.py`和`dataset_ACDC.py`的实现,然后修改训练脚本中的数据集加载部分。 ### Q: 如何调整模型以适应不同的输入尺寸? A: 可以通过修改`--img_size`参数来调整输入尺寸,但可能需要调整网络结构中的某些参数。 ### Q: 如何使用不同的骨干网络? A: 可以通过修改`--encoder`参数来选择不同的骨干网络,支持的编码器包括`pvt_v2_b0`、`pvt_v2_b1`、`pvt_v2_b2`、`pvt_v2_b3`、`pvt_v2_b4`、`pvt_v2_b5`等。 ### Q: 如何调整超参数? A: 可以通过修改训练脚本中的命令行参数来调整超参数,如学习率、批大小、训练轮数等。