first commit

main
learnljs 4 months ago
parent d64d987ebe
commit f6e9285936

211
.gitignore vendored

@ -0,0 +1,211 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
Pipfile.lock
# poetry
poetry.lock
# pdm
.pdm.toml
# PEP 582
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
.idea/
*.iml
*.ipr
*.iws
.vscode/
!.vscode/extensions.json
!.vscode/settings.json
# macOS
.DS_Store
*.DS_Store
# Windows
Thumbs.db
ehthumbs.db
Desktop.ini
# Project specific
model_pth/*
pretrained_pth/*
!model_pth/.gitkeep
!pretrained_pth/.gitkeep
# Data directories
data/synapse/*
data/ACDC/ACDC/train/*
data/ACDC/ACDC/test/*
data/ACDC/ACDC/valid/*
!data/synapse/.gitkeep
!data/ACDC/ACDC/train/.gitkeep
!data/ACDC/ACDC/test/.gitkeep
!data/ACDC/ACDC/valid/.gitkeep
# Training outputs
snapshots/
runs/
experiments/
predictions/
*.log
# Checkpoint files
*.pth
*.pt
# numpy arrays data files
*.npy
*.npz
# Medical image formats
*.nii
*.nii.gz
*.mha
*.mhd
# Checkpoints
checkpoint-*
*.ckpt
# Virtual environment
EMCADenv/
emcadenv/
# SimpleITK wheel
SimpleITK.whl
# TensorBoard
events.out.tfevents.*
# Temporary files
*.tmp
*.bak
*~
\#*\#

@ -0,0 +1,347 @@
# EMCAD高效多尺度卷积注意力解码器用于医学图像分割
本项目是论文《EMCAD用于医学图像分割的高效多尺度卷积注意力解码器》的官方PyTorch实现该论文发表于CVPR 2024。项目作者包括Md Mostafijur Rahman、Mustafa Munir和Radu Marculescu均来自德克萨斯大学奥斯汀分校。
## 项目概述
EMCAD是一种创新的医学图像分割框架它通过引入高效多尺度卷积注意力解码器来提升分割精度和效率。该方法在多个医学图像数据集上展现了卓越的性能特别是在Synapse多器官分割任务中取得了优异的成绩。
本项目的核心贡献在于提出了一种新型的解码器架构该架构能够有效地融合多尺度特征信息并通过注意力机制增强特征表达能力。与传统的解码器相比EMCAD在保持较低计算开销的同时显著提升了分割精度这使得它特别适合于资源受限的医疗应用场景。
### 主要特性
EMCAD具有以下核心特性首先它采用了高效的多尺度卷积注意力机制能够同时捕获不同尺度的特征信息并通过注意力机制进行特征增强。其次EMCAD支持多种骨干网络包括PVTv2系列和ResNet系列用户可以根据实际需求选择合适的骨干网络进行特征提取。第三EMCAD提供了灵活的配置选项支持自定义卷积核大小、扩展因子、注意力机制类型等参数以适应不同的应用场景。
此外EMCAD还支持深度监督和变异监督两种训练策略用户可以根据具体任务选择合适的训练方式。模型还支持并行深度卷积和特征级联两种融合模式提供了丰富的可调节参数。
### 性能指标
EMCAD在Synapse数据集上展现了优异的分割性能平均Dice系数达到了85%以上同时保持了较低的计算复杂度FLOPs和参数量。与现有的主流分割方法相比EMCAD在精度和效率之间取得了更好的平衡。
## 项目结构
```
EMCAD/
├── README.md # 项目说明文档
├── setup.py # 安装配置脚本
├── requirements.txt # 依赖包列表
├── .gitignore # Git忽略文件配置
├── src/ # 源代码目录
│ ├── core/ # 核心模型定义
│ │ ├── networks.py # EMCADNet主网络结构
│ │ ├── decoders.py # EMCAD解码器模块
│ │ ├── pvtv2.py # PVTv2骨干网络
│ │ └── resnet.py # ResNet系列骨干网络
│ │
│ └── utils/ # 工具模块
│ ├── dataset_synapse.py # Synapse数据集加载器
│ ├── dataset_ACDC.py # ACDC数据集加载器
│ ├── dataloader.py # 通用数据加载器
│ ├── trainer.py # 训练核心逻辑
│ ├── utils.py # 实用工具函数
│ ├── transforms.py # 数据变换操作
│ ├── joint_transforms.py # 联合变换操作
│ ├── format_conversion.py # 格式转换工具
│ ├── preprocess_synapse_data.py # Synapse数据预处理
│ ├── preprocess_synapse_data_3d.py # 3D数据预处理
│ └── misc.py # 其他工具函数
├── scripts/ # 脚本目录
│ └── train_synapse.py # Synapse数据集训练入口
├── tests/ # 测试代码目录
│ └── test_synapse.py # 模型测试脚本
├── data/ # 数据目录
│ ├── lists/ # 数据列表文件
│ │ └── lists_Synapse/ # Synapse数据划分
│ │ ├── train.txt
│ │ ├── test_vol.txt
│ │ └── .ipynb_checkpoints/
│ │
│ └── ACDC/ # ACDC数据集
│ └── ACDC/
│ ├── lists_ACDC/ # ACDC数据划分
│ │ ├── train.txt
│ │ ├── valid.txt
│ │ └── test.txt
│ ├── train/ # 训练数据
│ ├── valid/ # 验证数据
│ └── test/ # 测试数据
├── model_pth/ # 预训练模型权重目录
├── pretrained_pth/ # 预训练权重目录
│ └── pvt/ # PVTv2预训练权重
└── SimpleITK.whl # SimpleITK安装包
```
## 环境配置
### 系统要求
本项目推荐在Linux操作系统上运行也支持Windows和macOS系统。建议使用Python 3.8或更高版本强烈建议使用Anaconda或Miniconda创建独立的虚拟环境以避免与系统其他Python项目产生依赖冲突。
### 依赖安装
首先创建并激活虚拟环境:
```bash
conda create -n emcadenv python=3.8
conda activate emcadenv
```
然后安装PyTorch和相关依赖。对于CUDA 11.3用户建议使用以下命令安装兼容的PyTorch版本
```bash
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
```
如果使用其他版本的CUDA请访问PyTorch官方网站获取对应的安装命令。接下来安装其他必要的依赖包
```bash
pip install -r requirements.txt
```
对于Windows系统可能需要手动安装SimpleITK
```bash
pip install SimpleITK.whl
```
### 依赖包说明
requirements.txt中包含了项目运行所需的所有依赖包主要包括
- numpy数值计算
- torch和torchvision深度学习框架
- h5py处理HDF5格式数据
- scipy科学计算
- matplotlib可视化
- tqdm显示进度条
- tensorboardX可视化训练过程
- nibabel处理医学图像格式
- medpy计算医学图像评价指标
- ptflops和thop计算模型复杂度
- segmentation-mask-overlay叠加分割掩膜可视化
- timm提供丰富的预训练模型
## 数据准备
### Synapse多器官数据集
Synapse数据集是医学图像分割领域常用的基准数据集包含了30个腹部CT扫描案例涉及脾脏、肾脏、肝脏、胃、主动脉等多个器官的分割标注。首先需要注册并从Synapse官方网站下载数据集然后按照TransUNet项目提供的数据划分方式将RawData文件夹划分为包含18个扫描的训练集和包含12个扫描的测试集存放在./data/synapse/Abdomen/RawData/目录下。
数据预处理有两种方式:
**方式一:运行预处理脚本**
```bash
python src/utils/preprocess_synapse_data.py
```
**方式二直接下载预处理的Synapse数据**
从项目 releases 页面下载预处理的Synapse数据保存到./data/目录下。
### ACDC数据集
ACDC数据集是心脏MRI分割的标准数据集。推荐从MT-UNet项目的Google Drive下载预处理的ACDC数据集然后解压到./data/ACDC/目录即可使用。
## 预训练模型
EMCAD使用PVTv2作为默认的骨干网络需要下载预训练的PVTv2模型权重。预训练模型可以从以下来源下载
- Google Drive项目提供的预训练权重
- PVT GitHub releases页面官方PVTv2预训练权重
下载后将模型文件放入./model_pth/或./pretrained_pth/pvt/目录。默认的预训练路径可以通过--pretrained_dir参数进行配置。
## 快速开始
### 训练模型
在Synapse数据集上训练EMCAD模型的基本命令如下
```bash
python scripts/train_synapse.py --root_path /path/to/train/data --volume_path /path/to/test/data --encoder pvt_v2_b2
```
**常用参数说明:**
| 参数 | 说明 | 默认值 |
|------|------|--------|
| --root_path | 训练数据的根目录 | /data/ACDC/train |
| --volume_path | 测试数据的根目录 | /data/ACDC/test |
| --dataset | 数据集名称 | ADC |
| --list_dir | 数据列表目录 | /data/ACDC/lists_ACDC |
| --num_classes | 分割类别数 | 4ACDC/9Synapse |
| --encoder | 骨干网络类型 | pvt_v2_b2 |
| --batch_size | 每批训练的样本数 | 6 |
| --base_lr | 基础学习率 | 0.01 |
| --img_size | 输入图像尺寸 | 224 |
| --max_epochs | 训练轮数 | 300 |
| --expansion_factor | MSCB块扩展因子 | 2 |
| --kernel_sizes | 多尺度卷积核大小 | [1, 3, 5] |
### 模型推理
训练完成后,可以使用以下命令进行模型推理:
```bash
python tests/test_synapse.py
```
推理结果默认保存在predictions目录下可以选择保存为nii格式的分割结果文件。
### 使用ACDC数据集
在ACDC数据集上训练时需要将--dataset参数设置为ACDC并调整相应的数据路径和类别数
```bash
python scripts/train_synapse.py \
--root_path ./data/ACDC/train \
--volume_path ./data/ACDC/test \
--list_dir ./data/ACDC/ACDC/lists_ACDC \
--dataset ACDC \
--num_classes 4 \
--encoder pvt_v2_b2
```
## API参考
### EMCADNet类
EMCADNet是项目的主模型类用于构建完整的分割网络。
**构造函数参数:**
| 参数 | 类型 | 说明 | 默认值 |
|------|------|------|--------|
| num_classes | int | 输出类别数 | 1 |
| kernel_sizes | list | 多尺度卷积核大小列表 | [1, 3, 5] |
| expansion_factor | int | 扩展因子 | 2 |
| dw_parallel | bool | 是否使用并行深度卷积 | True |
| add | bool | 特征融合方式 | True |
| lgag_ks | int | LGAG模块卷积核大小 | 3 |
| activation | str | 激活函数类型 | relu |
| encoder | str | 骨干网络名称 | pvt_v2_b2 |
| pretrain | bool | 是否加载预训练权重 | True |
| pretrained_dir | str | 预训练权重路径 | ./pretrained_pth/pvt/ |
### 数据集类
**Synapse_dataset类**
用于加载Synapse数据集支持训练和测试两种模式。
```python
from src.utils.dataset_synapse import Synapse_dataset, RandomGenerator
db_train = Synapse_dataset(
base_dir="./data/synapse",
list_dir="./data/lists/lists_Synapse",
split="train",
nclass=9,
transform=transforms.Compose([RandomGenerator(output_size=[224, 224])])
)
```
**ACDCdataset类**
用于加载ACDC数据集。
```python
from src.utils.dataset_ACDC import ACDCdataset
db_train = ACDCdataset(
base_dir="./data/ACDC",
list_dir="./data/ACDC/ACDC/lists_ACDC",
split="train",
transform=transform
)
```
### 工具函数
项目提供了丰富的工具函数:
| 函数 | 说明 |
|------|------|
| DiceLoss | 计算Dice损失 |
| powerset | 生成子集 |
| val_single_volume | 单个体数据的验证推理 |
| test_single_volume | 单个体数据的完整推理和结果保存 |
## 扩展与定制
### 自定义骨干网络
如果需要使用其他骨干网络可以在src/core/networks.py文件中参考现有实现添加新的骨干网络支持代码。新的骨干网络需要实现特征提取功能并返回四个尺度的特征图。
### 添加新的数据集
要支持新的数据集可以在src/utils/目录下创建新的数据集类参考Synapse_dataset的实现方式。新数据集类需要继承Dataset类并实现__len__和__getitem__方法。
### 调整模型配置
模型的各个组件都可以通过参数进行调整,包括卷积核大小、扩展因子、注意力机制类型等。建议从默认配置开始,逐步调整以找到最适合特定任务和数据的配置。
### 运行测试
项目提供了测试脚本来验证模型是否正确:
```bash
python tests/test_synapse.py --dataset Synapse --num_classes 9
```
## 常见问题
### 1. 导入模块时提示ModuleNotFoundError
请确保已正确安装项目:
```bash
pip install -e .
```
或者设置PYTHONPATH环境变量
```bash
export PYTHONPATH=$PYTHONPATH:/path/to/EMCAD
```
### 2. 预训练权重加载失败
请检查预训练权重路径是否正确确保PVTv2权重文件已下载并放置在正确的目录。
### 3. 内存不足错误
尝试减小batch_size或img_size参数的值。
## 引用
如果您在研究中使用了本项目,请引用以下论文:
```bibtex
@article{rahman2024emcad,
title={EMCAD: Efficient Multi-scale Convolutional Attention Decoding for Medical Image Segmentation},
author={Rahman, Md Mostafijur and Munir, Mustafa and Marculescu, Radu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={11769--11779},
year={2024}
}
```
## 致谢
本项目的实现参考了以下优秀的开源项目:
- [timm](https://github.com/rwightman/pytorch-image-models):提供丰富的预训练模型
- [TransUNet](https://github.com/Beckschen/TransUNet):医学图像分割框架参考
- [CASCADE](https://github.com/GuoHuang19/CASCADE)、[MERIT](https://github.com/JunrenT/Class-Balanced-Loss)、[G-CASCADE](https://github.com/linzhlgithub/G-CASCADE):分割方法参考
在此向这些项目的作者表示感谢。
## 许可证
本项目遵循MIT许可证开源。

@ -0,0 +1,40 @@
case_002_volume_ED.npz
case_002_volume_ES.npz
case_003_volume_ED.npz
case_003_volume_ES.npz
case_008_volume_ED.npz
case_008_volume_ES.npz
case_009_volume_ED.npz
case_009_volume_ES.npz
case_012_volume_ED.npz
case_012_volume_ES.npz
case_014_volume_ED.npz
case_014_volume_ES.npz
case_017_volume_ED.npz
case_017_volume_ES.npz
case_024_volume_ED.npz
case_024_volume_ES.npz
case_042_volume_ED.npz
case_042_volume_ES.npz
case_048_volume_ED.npz
case_048_volume_ES.npz
case_049_volume_ED.npz
case_049_volume_ES.npz
case_053_volume_ED.npz
case_053_volume_ES.npz
case_055_volume_ED.npz
case_055_volume_ES.npz
case_064_volume_ED.npz
case_064_volume_ES.npz
case_067_volume_ED.npz
case_067_volume_ES.npz
case_079_volume_ED.npz
case_079_volume_ES.npz
case_081_volume_ED.npz
case_081_volume_ES.npz
case_088_volume_ED.npz
case_088_volume_ES.npz
case_092_volume_ED.npz
case_092_volume_ES.npz
case_095_volume_ED.npz
case_095_volume_ES.npz

File diff suppressed because it is too large Load Diff

@ -0,0 +1,182 @@
case_019_sliceED_0.npz
case_019_sliceED_1.npz
case_019_sliceED_10.npz
case_019_sliceED_2.npz
case_019_sliceED_3.npz
case_019_sliceED_4.npz
case_019_sliceED_5.npz
case_019_sliceED_6.npz
case_019_sliceED_7.npz
case_019_sliceED_8.npz
case_019_sliceED_9.npz
case_019_sliceES_0.npz
case_019_sliceES_1.npz
case_019_sliceES_10.npz
case_019_sliceES_2.npz
case_019_sliceES_3.npz
case_019_sliceES_4.npz
case_019_sliceES_5.npz
case_019_sliceES_6.npz
case_019_sliceES_7.npz
case_019_sliceES_8.npz
case_019_sliceES_9.npz
case_021_sliceED_0.npz
case_021_sliceED_1.npz
case_021_sliceED_2.npz
case_021_sliceED_3.npz
case_021_sliceED_4.npz
case_021_sliceED_5.npz
case_021_sliceED_6.npz
case_021_sliceED_7.npz
case_021_sliceED_8.npz
case_021_sliceED_9.npz
case_021_sliceES_0.npz
case_021_sliceES_1.npz
case_021_sliceES_2.npz
case_021_sliceES_3.npz
case_021_sliceES_4.npz
case_021_sliceES_5.npz
case_021_sliceES_6.npz
case_021_sliceES_7.npz
case_021_sliceES_8.npz
case_021_sliceES_9.npz
case_029_sliceED_0.npz
case_029_sliceED_1.npz
case_029_sliceED_10.npz
case_029_sliceED_2.npz
case_029_sliceED_3.npz
case_029_sliceED_4.npz
case_029_sliceED_5.npz
case_029_sliceED_6.npz
case_029_sliceED_7.npz
case_029_sliceED_8.npz
case_029_sliceED_9.npz
case_029_sliceES_0.npz
case_029_sliceES_1.npz
case_029_sliceES_10.npz
case_029_sliceES_2.npz
case_029_sliceES_3.npz
case_029_sliceES_4.npz
case_029_sliceES_5.npz
case_029_sliceES_6.npz
case_029_sliceES_7.npz
case_029_sliceES_8.npz
case_029_sliceES_9.npz
case_033_sliceED_0.npz
case_033_sliceED_1.npz
case_033_sliceED_2.npz
case_033_sliceED_3.npz
case_033_sliceED_4.npz
case_033_sliceED_5.npz
case_033_sliceED_6.npz
case_033_sliceED_7.npz
case_033_sliceED_8.npz
case_033_sliceED_9.npz
case_033_sliceES_0.npz
case_033_sliceES_1.npz
case_033_sliceES_2.npz
case_033_sliceES_3.npz
case_033_sliceES_4.npz
case_033_sliceES_5.npz
case_033_sliceES_6.npz
case_033_sliceES_7.npz
case_033_sliceES_8.npz
case_033_sliceES_9.npz
case_041_sliceED_0.npz
case_041_sliceED_1.npz
case_041_sliceED_2.npz
case_041_sliceED_3.npz
case_041_sliceED_4.npz
case_041_sliceED_5.npz
case_041_sliceES_0.npz
case_041_sliceES_1.npz
case_041_sliceES_2.npz
case_041_sliceES_3.npz
case_041_sliceES_4.npz
case_041_sliceES_5.npz
case_050_sliceED_0.npz
case_050_sliceED_1.npz
case_050_sliceED_2.npz
case_050_sliceED_3.npz
case_050_sliceED_4.npz
case_050_sliceED_5.npz
case_050_sliceED_6.npz
case_050_sliceED_7.npz
case_050_sliceED_8.npz
case_050_sliceED_9.npz
case_050_sliceES_0.npz
case_050_sliceES_1.npz
case_050_sliceES_2.npz
case_050_sliceES_3.npz
case_050_sliceES_4.npz
case_050_sliceES_5.npz
case_050_sliceES_6.npz
case_050_sliceES_7.npz
case_050_sliceES_8.npz
case_050_sliceES_9.npz
case_061_sliceED_0.npz
case_061_sliceED_1.npz
case_061_sliceED_2.npz
case_061_sliceED_3.npz
case_061_sliceED_4.npz
case_061_sliceED_5.npz
case_061_sliceED_6.npz
case_061_sliceED_7.npz
case_061_sliceED_8.npz
case_061_sliceES_0.npz
case_061_sliceES_1.npz
case_061_sliceES_2.npz
case_061_sliceES_3.npz
case_061_sliceES_4.npz
case_061_sliceES_5.npz
case_061_sliceES_6.npz
case_061_sliceES_7.npz
case_061_sliceES_8.npz
case_071_sliceED_0.npz
case_071_sliceED_1.npz
case_071_sliceED_2.npz
case_071_sliceED_3.npz
case_071_sliceED_4.npz
case_071_sliceED_5.npz
case_071_sliceED_6.npz
case_071_sliceED_7.npz
case_071_sliceED_8.npz
case_071_sliceED_9.npz
case_071_sliceES_0.npz
case_071_sliceES_1.npz
case_071_sliceES_2.npz
case_071_sliceES_3.npz
case_071_sliceES_4.npz
case_071_sliceES_5.npz
case_071_sliceES_6.npz
case_071_sliceES_7.npz
case_071_sliceES_8.npz
case_071_sliceES_9.npz
case_076_sliceED_0.npz
case_076_sliceED_1.npz
case_076_sliceED_2.npz
case_076_sliceED_3.npz
case_076_sliceED_4.npz
case_076_sliceED_5.npz
case_076_sliceED_6.npz
case_076_sliceED_7.npz
case_076_sliceES_0.npz
case_076_sliceES_1.npz
case_076_sliceES_2.npz
case_076_sliceES_3.npz
case_076_sliceES_4.npz
case_076_sliceES_5.npz
case_076_sliceES_6.npz
case_076_sliceES_7.npz
case_080_sliceED_0.npz
case_080_sliceED_1.npz
case_080_sliceED_2.npz
case_080_sliceED_3.npz
case_080_sliceED_4.npz
case_080_sliceED_5.npz
case_080_sliceES_0.npz
case_080_sliceES_1.npz
case_080_sliceES_2.npz
case_080_sliceES_3.npz
case_080_sliceES_4.npz
case_080_sliceES_5.npz

@ -0,0 +1,30 @@
case0031.npy.h5
case0007.npy.h5
case0009.npy.h5
case0005.npy.h5
case0026.npy.h5
case0039.npy.h5
case0024.npy.h5
case0034.npy.h5
case0033.npy.h5
case0030.npy.h5
case0023.npy.h5
case0040.npy.h5
case0010.npy.h5
case0021.npy.h5
case0006.npy.h5
case0027.npy.h5
case0028.npy.h5
case0037.npy.h5
case0008.npy.h5
case0022.npy.h5
case0038.npy.h5
case0036.npy.h5
case0032.npy.h5
case0002.npy.h5
case0029.npy.h5
case0003.npy.h5
case0001.npy.h5
case0004.npy.h5
case0025.npy.h5
case0035.npy.h5

@ -0,0 +1,12 @@
case0008
case0022
case0038
case0036
case0032
case0002
case0029
case0003
case0001
case0004
case0025
case0035

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

@ -0,0 +1,53 @@
CASCADE, G-CASCADE, AND EMCAD
8251 MAR
UT AUSTIN RESEARCH LICENSE
(NONCONFIDENTIAL SOURCE CODE)
The University of Texas at Austin has developed certain software and documentation that it desires to make available without charge to anyone for academic, research, experimental or personal use. If you wish to distribute or make other use of the software, you may purchase a license to do so from The University of Texas at Austin (licensing@otc.utexas.edu).
The accompanying source code is made available to you under the terms of this UT Research License (this “UTRL”). By installing or using the code, you are consenting to be bound by this UTRL. If you do not agree to the terms and conditions of this license, do not install or use any part of the code.
The terms and conditions in this UTRL not only apply to the source code made available by Licensor, but also to any improvements to, or derivative works of, that source code made by you and to any object code compiled from such source code, improvements or derivative works.
1. DEFINITIONS.
1.1 “Commercial Use” shall mean use of Software or Documentation by Licensee for direct or indirect financial, commercial or strategic gain or advantage, including without limitation: (a) bundling or integrating the Software with any hardware product or another software product for transfer, sale or license to a third party (even if distributing the Software on separate media and not charging for the Software); (b) providing customers with a link to the Software or a copy of the Software for use with hardware or another software product purchased by that customer; or (c) use in connection with the performance of services for which Licensee is compensated.
1.2 “Derivative Products” means any improvements to, or other derivative works of, the Software made by Licensee, and any computer software programs, and accompanying documentation, developed by Licensee which are a modification of, enhancement to, derived from or based upon the Licensed Software or documentation provided by Licensor for the Licensed Software, and any object code compiled from such computer software programs.
1.3 “Documentation” shall mean all manuals, user documentation, and other related materials pertaining to the Software that are made available to Licensee in connection with the Software.
1.4 “Licensor” shall mean The University of Texas at Austin, on behalf of the Board of Regents of the University of Texas System, an agency of the State of Texas, whose address is 3925 W. Braker Lane, Suite 1.9A (R3500), Austin, Texas 78759.
1.5 “Licensee” or “you” shall mean the person or entity that has agreed to the terms hereof and is exercising rights granted hereunder.
1.6 “Software” shall mean the computer program(s) referred to as: “CASCADE,” “G-CASCADE,” or “EMCAD” (UT Tech ID 8251 MAR), which is made available under this UTRL in source code form, including any error corrections, bug fixes, patches, updates or other modifications that Licensor may in its sole discretion make available to Licensee from time to time, and any object code compiled from such source code.
2. GRANT OF RIGHTS.
Subject to the terms and conditions hereunder, Licensor hereby grants to Licensee a worldwide, non-transferable, non-exclusive license to (a) install, use and reproduce the Software for academic, research, experimental and personal use (but specifically excluding Commercial Use); (b) use and modify the Software to create Derivative Products, subject to Section 3.2; (c) use the Documentation, if any, solely in connection with Licensees authorized use of the Software; and (d) a non-exclusive, royalty-free license for academic, research, experimental and personal use (but specifically excluding Commercial Use) to those patents, of which Radu Marculescu is a named inventor, that are licensable by Licensee and that are necessarily infringed by such authorized use of the Software, and solely in connection with Licensees authorized use of the Software.
3. RESTRICTIONS; COVENANTS.
3.1 Licensee may not: (a) distribute, sub-license or otherwise transfer copies or rights to the Software (or any portion thereof) or the Documentation; (b) use the Software (or any portion thereof) or Documentation for Commercial Use, or for any other use except as described in Section 2; (c) copy the Software or Documentation other than for archival and backup purposes; or (d) remove any product identification, copyright, proprietary notices or labels from the Software and Documentation. This UTRL confers no rights upon Licensee except those expressly granted herein.
3.2 Derivative Products. Licensee hereby agrees that it will provide a copy of all Derivative Products to Licensor and that its use of the Derivative Products will be subject to all of the same terms, conditions, restrictions and limitations on use imposed on the Software under this UTRL. Licensee hereby grants Licensor a worldwide, non-exclusive, royalty-free license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute Derivative Products. Licensee also hereby grants Licensor a worldwide, non-exclusive, royalty-free patent license to make, have made, use, offer to sell, sell, import and otherwise transfer the Derivative Products under those patent claims, from patents of which Radu Marculescu is a named inventor, that licensable by Licensee that are necessarily infringed by the Derivative Products.
4. CONFIDENTIALITY; PROTECTION OF SOFTWARE.
4.1 Reserved.
4.2 Proprietary Notices. Licensee shall maintain and place on any copy of Software or Documentation that it reproduces for internal use all notices as are authorized and/or required hereunder. Licensee shall include a copy of this UTRL and the following notice, on each copy of the Software and Documentation. Such license and notice shall be embedded in each copy of the Software, in the video screen display, on the physical medium embodying the Software copy and on any Documentation:
Copyright © 2021, The University of Texas at Austin. All rights reserved.
UNIVERSITY EXPRESSLY DISCLAIMS ANY AND ALL WARRANTIES CONCERNING THIS SOFTWARE AND DOCUMENTATION, INCLUDING ANY WARRANTIES OF MERCHANTABILITY, FITNESS FOR ANY PARTICULAR PURPOSE, NON-INFRINGEMENT AND WARRANTIES OF PERFORMANCE, AND ANY WARRANTY THAT MIGHT OTHERWISE ARISE FROM COURSE OF DEALING OR USAGE OF TRADE. NO WARRANTY IS EITHER EXPRESS OR IMPLIED WITH RESPECT TO THE USE OF THE SOFTWARE OR DOCUMENTATION. Under no circumstances shall University be liable for incidental, special, indirect, direct or consequential damages or loss of profits, interruption of business, or related expenses which may arise from use of Software or Documentation, including but not limited to those resulting from defects in Software and/or Documentation, or loss or inaccuracy of data of any kind.
5. WARRANTIES.
5.1 Disclaimer of Warranties. TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE SOFTWARE AND DOCUMENTATION ARE BEING PROVIDED ON AN “AS IS” BASIS WITHOUT ANY WARRANTIES OF ANY KIND RESPECTING THE SOFTWARE OR DOCUMENTATION, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTY OF DESIGN, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT.
5.2 Limitation of Liability. UNDER NO CIRCUMSTANCES UNLESS REQUIRED BY APPLICABLE LAW SHALL LICENSOR BE LIABLE FOR INCIDENTAL, SPECIAL, INDIRECT, DIRECT OR CONSEQUENTIAL DAMAGES OR LOSS OF PROFITS, INTERRUPTION OF BUSINESS, OR RELATED EXPENSES WHICH MAY ARISE AS A RESULT OF THIS LICENSE OR OUT OF THE USE OR ATTEMPT OF USE OF SOFTWARE OR DOCUMENTATION INCLUDING BUT NOT LIMITED TO THOSE RESULTING FROM DEFECTS IN SOFTWARE AND/OR DOCUMENTATION, OR LOSS OR INACCURACY OF DATA OF ANY KIND. THE FOREGOING EXCLUSIONS AND LIMITATIONS WILL APPLY TO ALL CLAIMS AND ACTIONS OF ANY KIND, WHETHER BASED ON CONTRACT, TORT (INCLUDING, WITHOUT LIMITATION, NEGLIGENCE), OR ANY OTHER GROUNDS.
6. INDEMNIFICATION.
Licensee shall indemnify, defend and hold harmless Licensor, the University of Texas System, their Regents, and their officers, agents and employees from and against any claims, demands, or causes of action whatsoever caused by, or arising out of, or resulting from, the exercise or practice of the license granted hereunder by Licensee, its officers, employees, agents or representatives.
7. TERMINATION.
If Licensee breaches this UTRL, Licensees right to use the Software and Documentation will terminate immediately without notice, but all provisions of this UTRL except Section 2 will survive termination and continue in effect. Upon termination, Licensee must destroy all copies of the Software and Documentation.
8. GOVERNING LAW; JURISDICTION AND VENUE.
The validity, interpretation, construction and performance of this UTRL shall be governed by the laws of the State of Texas. The Texas state courts of Travis County, Texas (or, if there is exclusive federal jurisdiction, the United States District Court for the Western District of Texas) shall have exclusive jurisdiction and venue over any dispute arising out of this UTRL, and Licensee consents to the jurisdiction of such courts. Application of the United Nations Convention on Contracts for the International Sale of Goods is expressly excluded.
9. EXPORT CONTROLS.
This license is subject to all applicable export restrictions. Licensee must comply with all export and import laws and restrictions and regulations of any United States or foreign agency or authority relating to the Software and its use.
10. U.S. GOVERNMENT END-USERS.
The Software is a “commercial item,” as that term is defined in 48 C.F.R. 2.101, consisting of “commercial computer software” and “commercial computer software documentation,” as such terms are used in 48 C.F.R. 12.212 (Sept. 1995) and 48 C.F.R. 227.7202 (June 1995). Consistent with 48 C.F.R. 12.212, 48 C.F.R. 27.405(b)(2) (June 1998) and 48 C.F.R. 227.7202, all U.S. Government End Users acquire the Software with only those rights as set forth herein.
11. MISCELLANEOUS
If any provision hereof shall be held illegal, invalid or unenforceable, in whole or in part, such provision shall be modified to the minimum extent necessary to make it legal, valid and enforceable, and the legality, validity and enforceability of all other provisions of this UTRL shall not be affected thereby. Licensee may not assign this UTRL in whole or in part, without Licensors prior written consent. Any attempt to assign this UTRL without such consent will be null and void. This UTRL is the complete and exclusive statement between Licensee and Licensor relating to the subject matter hereof and supersedes all prior oral and written and all contemporaneous oral negotiations, commitments and understandings of the parties, if any. Any waiver by either party of any default or breach hereunder shall not constitute a waiver of any provision of this UTRL or of any subsequent default or breach of the same or a different kind.
END OF LICENSE

Binary file not shown.

After

Width:  |  Height:  |  Size: 120 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 92 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 160 KiB

@ -0,0 +1,33 @@
numpy==1.22.4
loguru
tqdm
pyyaml
pandas
matplotlib
scikit-learn
scikit-image
scipy
opencv-python
seaborn
albumentations==1.1.0
tabulate
warmup-scheduler
transformers==4.21.3
torchprofile
torchmetrics
einops
ptflops
torchsummary
torchsummaryx
segmentation-mask-overlay==0.3.4
timm==0.6.12
tifffile
pillow
thop
simpleitk
nibabel
h5py
huggingface-hub==0.11.0
ml_collections
tensorboardx
medpy

@ -0,0 +1,289 @@
"""Synapse 与 ACDC 训练入口。"""
import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from src.core.networks import EMCADNet
from src.utils.trainer import trainer_ACDC, trainer_synapse
def build_parser():
"""构建训练参数解析器。"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--root_path",
type=str,
default="/data/ACDC/train",
help="root dir for training data (ACDC: /data/ACDC/train)",
)
parser.add_argument(
"--volume_path",
type=str,
default="/data/ACDC/test",
help="root dir for validation/test volume data",
)
parser.add_argument(
"--dataset",
type=str,
default="ACDC",
choices=["Synapse", "ACDC"],
help="experiment name",
)
parser.add_argument(
"--list_dir",
type=str,
default="/data/ACDC/lists_ACDC",
help="list dir (ACDC: /data/ACDC/lists_ACDC)",
)
parser.add_argument(
"--num_classes",
type=int,
default=4,
help="output channel of network (ACDC = 4)",
)
parser.add_argument(
"--encoder",
type=str,
default="pvt_v2_b2",
help="Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...",
)
parser.add_argument(
"--expansion_factor",
type=int,
default=2,
help="expansion factor in MSCB block",
)
parser.add_argument(
"--kernel_sizes",
type=int,
nargs="+",
default=[1, 3, 5],
help="multi-scale kernel sizes in MSDC block",
)
parser.add_argument(
"--lgag_ks", type=int, default=3, help="Kernel size in LGAG block"
)
parser.add_argument(
"--activation_mscb",
type=str,
default="relu6",
help="activation used in MSCB: relu6 or relu",
)
parser.add_argument(
"--no_dw_parallel",
action="store_true",
default=False,
help="use this flag to disable depth-wise parallel convolutions",
)
parser.add_argument(
"--concatenation",
action="store_true",
default=False,
help="use this flag to concatenate feature maps in MSDC block",
)
parser.add_argument(
"--no_pretrain",
action="store_true",
default=False,
help="use this flag to turn off loading pretrained encoder weights",
)
parser.add_argument(
"--pretrained_dir",
type=str,
default="./model_pth/",
help="path to pretrained encoder dir, e.g. ./model_pth/",
)
parser.add_argument(
"--supervision",
type=str,
default="mutation",
help="loss supervision: mutation, deep_supervision or last_layer",
)
parser.add_argument(
"--max_iterations", type=int, default=50000, help="maximum total iterations"
)
parser.add_argument(
"--max_epochs", type=int, default=300, help="maximum epoch number to train"
)
parser.add_argument("--batch_size", type=int, default=6, help="batch_size per gpu")
parser.add_argument(
"--base_lr",
type=float,
default=0.0001,
help="segmentation network learning rate",
)
parser.add_argument(
"--img_size",
type=int,
default=224,
help="input patch size of network input",
)
parser.add_argument("--n_gpu", type=int, default=1, help="total gpu")
parser.add_argument(
"--deterministic",
type=int,
default=1,
help="whether use deterministic training",
)
parser.add_argument("--seed", type=int, default=2222, help="random seed")
return parser
def set_deterministic(seed, deterministic):
"""配置随机种子与确定性行为。"""
if not deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def build_snapshot_path(args, dataset_name):
"""根据参数生成输出目录。"""
aggregation = "concat" if args.concatenation else "add"
dw_mode = "series" if args.no_dw_parallel else "parallel"
run = 1
exp = (
args.encoder
+ "_EMCAD_kernel_sizes_"
+ str(args.kernel_sizes)
+ "_dw_"
+ dw_mode
+ "_"
+ aggregation
+ "_lgag_ks_"
+ str(args.lgag_ks)
+ "_ef"
+ str(args.expansion_factor)
+ "_act_mscb_"
+ args.activation_mscb
+ "_loss_"
+ args.supervision
+ "_output_final_layer_Run"
+ str(run)
+ "_"
+ dataset_name
+ str(args.img_size)
)
snapshot_path = "model_pth/{}/{}".format(
exp,
args.encoder
+ "_EMCAD_kernel_sizes_"
+ str(args.kernel_sizes)
+ "_dw_"
+ dw_mode
+ "_"
+ aggregation
+ "_lgag_ks_"
+ str(args.lgag_ks)
+ "_ef"
+ str(args.expansion_factor)
+ "_act_mscb_"
+ args.activation_mscb
+ "_loss_"
+ args.supervision
+ "_output_final_layer_Run"
+ str(run),
)
snapshot_path = snapshot_path.replace("[", "").replace("]", "").replace(", ", "_")
if not args.no_pretrain:
snapshot_path = snapshot_path + "_pretrain"
if args.max_iterations != 50000:
snapshot_path = snapshot_path + "_" + str(args.max_iterations)[0:2] + "k"
if args.max_epochs != 300:
snapshot_path = snapshot_path + "_epo" + str(args.max_epochs)
snapshot_path = snapshot_path + "_bs" + str(args.batch_size)
if args.base_lr != 0.0001:
snapshot_path = snapshot_path + "_lr" + str(args.base_lr)
snapshot_path = snapshot_path + "_" + str(args.img_size)
if args.seed != 1234:
snapshot_path = snapshot_path + "_s" + str(args.seed)
return exp, snapshot_path
def main():
"""主入口函数。"""
parser = build_parser()
args = parser.parse_args()
set_deterministic(args.seed, args.deterministic)
dataset_name = args.dataset
acdc_root = args.root_path
if dataset_name == "ACDC":
tmp = args.root_path.rstrip("/")
if os.path.basename(tmp) == "train":
acdc_root = os.path.dirname(tmp)
else:
acdc_root = tmp
dataset_config = {
"Synapse": {
"root_path": args.root_path,
"volume_path": args.volume_path,
"list_dir": args.list_dir,
"num_classes": args.num_classes,
"z_spacing": 1,
},
"ACDC": {
"root_path": acdc_root,
"volume_path": args.volume_path,
"list_dir": args.list_dir,
"num_classes": args.num_classes,
"z_spacing": 1,
},
}
cfg = dataset_config[dataset_name]
args.num_classes = cfg["num_classes"]
args.root_path = cfg["root_path"]
args.volume_path = cfg["volume_path"]
args.z_spacing = cfg["z_spacing"]
args.list_dir = cfg["list_dir"]
args.exp, snapshot_path = build_snapshot_path(args, dataset_name)
if not os.path.exists(snapshot_path):
os.makedirs(snapshot_path)
model = EMCADNet(
num_classes=args.num_classes,
kernel_sizes=args.kernel_sizes,
expansion_factor=args.expansion_factor,
dw_parallel=not args.no_dw_parallel,
add=not args.concatenation,
lgag_ks=args.lgag_ks,
activation=args.activation_mscb,
encoder=args.encoder,
pretrain=not args.no_pretrain,
pretrained_dir=args.pretrained_dir,
)
model.cuda()
print("Model successfully created.")
trainer_map = {"Synapse": trainer_synapse, "ACDC": trainer_ACDC}
trainer_map[dataset_name](args, model, snapshot_path)
if __name__ == "__main__":
main()

@ -0,0 +1,46 @@
from setuptools import setup, find_packages
setup(
name="EMCADNet",
version="0.1.0",
author="Your Name",
author_email="your.email@example.com",
description="EMCADNet: Efficient Multi-scale Convolutional Attention Decoding for Medical Image Segmentation",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/yourusername/EMCADNet",
packages=find_packages(where="src"),
package_dir={"": "src"},
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires=">=3.8",
install_requires=[
"torch>=1.11.0",
"torchvision>=0.12.0",
"numpy>=1.22.0",
"h5py>=3.0.0",
"scipy>=1.5.0",
"matplotlib>=3.3.0",
"tqdm>=4.50.0",
"tensorboardX>=2.2",
"nibabel>=3.2.0",
"medpy>=0.4.0",
"ptflops>=0.6.4",
"thop>=0.0.31",
"segmentation-mask-overlay>=0.3.0",
"timm>=0.6.0",
],
entry_points={
"console_scripts": [
"emcad-train=scripts.train_synapse:main",
"emcad-test=scripts.test_synapse:main",
],
},
include_package_data=True,
package_data={
"": ["*.md", "*.txt"],
},
)

@ -0,0 +1,3 @@
from .networks import EMCADNet
__all__ = ["EMCADNet"]

@ -0,0 +1,584 @@
"""EMCAD 解码器与注意力模块。"""
import math
from functools import partial
import torch
import torch.nn as nn
from timm.models.helpers import named_apply
from timm.models.layers import trunc_normal_tf_
def gcd(a, b):
"""计算最大公约数。"""
while b:
a, b = b, a % b
return a
def _init_weights(module, name, scheme=""):
"""按指定方案初始化权重。"""
if isinstance(module, (nn.Conv2d, nn.Conv3d)):
if scheme == "normal":
nn.init.normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif scheme == "trunc_normal":
trunc_normal_tf_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif scheme == "xavier_normal":
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif scheme == "kaiming_normal":
nn.init.kaiming_normal_(
module.weight, mode="fan_out", nonlinearity="relu"
)
if module.bias is not None:
nn.init.zeros_(module.bias)
else:
fan_out = (
module.kernel_size[0]
* module.kernel_size[1]
* module.out_channels
)
fan_out //= module.groups
nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d)):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1):
"""根据名称创建激活函数层。"""
act = act.lower()
if act == "relu":
layer = nn.ReLU(inplace)
elif act == "relu6":
layer = nn.ReLU6(inplace)
elif act == "leakyrelu":
layer = nn.LeakyReLU(neg_slope, inplace)
elif act == "prelu":
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
elif act == "gelu":
layer = nn.GELU()
elif act == "hswish":
layer = nn.Hardswish(inplace)
else:
raise NotImplementedError("activation layer [%s] is not found" % act)
return layer
def channel_shuffle(x, groups):
"""对通道进行分组混洗。"""
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
x = x.view(batchsize, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batchsize, -1, height, width)
return x
class MSDC(nn.Module):
"""多尺度深度卷积模块。"""
def __init__(
self,
in_channels,
kernel_sizes,
stride,
activation="relu6",
dw_parallel=True,
):
"""初始化 MSDC。"""
super(MSDC, self).__init__()
self.in_channels = in_channels
self.kernel_sizes = kernel_sizes
self.activation = activation
self.dw_parallel = dw_parallel
self.dwconvs = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(
self.in_channels,
self.in_channels,
kernel_size,
stride,
kernel_size // 2,
groups=self.in_channels,
bias=False,
),
nn.BatchNorm2d(self.in_channels),
act_layer(self.activation, inplace=True),
)
for kernel_size in self.kernel_sizes
]
)
self.init_weights("normal")
def init_weights(self, scheme=""):
"""初始化权重。"""
named_apply(partial(_init_weights, scheme=scheme), self)
def forward(self, x):
"""前向计算。"""
outputs = []
for dwconv in self.dwconvs:
dw_out = dwconv(x)
outputs.append(dw_out)
if not self.dw_parallel:
x = x + dw_out
return outputs
class MSCB(nn.Module):
"""多尺度卷积块MSCB"""
def __init__(
self,
in_channels,
out_channels,
stride,
kernel_sizes=None,
expansion_factor=2,
dw_parallel=True,
add=True,
activation="relu6",
):
"""初始化 MSCB。"""
super(MSCB, self).__init__()
if kernel_sizes is None:
kernel_sizes = [1, 3, 5]
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.kernel_sizes = kernel_sizes
self.expansion_factor = expansion_factor
self.dw_parallel = dw_parallel
self.add = add
self.activation = activation
self.n_scales = len(self.kernel_sizes)
assert self.stride in [1, 2]
self.use_skip_connection = self.stride == 1
self.ex_channels = int(self.in_channels * self.expansion_factor)
self.pconv1 = nn.Sequential(
nn.Conv2d(self.in_channels, self.ex_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(self.ex_channels),
act_layer(self.activation, inplace=True),
)
self.msdc = MSDC(
self.ex_channels,
self.kernel_sizes,
self.stride,
self.activation,
dw_parallel=self.dw_parallel,
)
if self.add:
self.combined_channels = self.ex_channels * 1
else:
self.combined_channels = self.ex_channels * self.n_scales
self.pconv2 = nn.Sequential(
nn.Conv2d(
self.combined_channels, self.out_channels, 1, 1, 0, bias=False
),
nn.BatchNorm2d(self.out_channels),
)
if self.use_skip_connection and (self.in_channels != self.out_channels):
self.conv1x1 = nn.Conv2d(
self.in_channels, self.out_channels, 1, 1, 0, bias=False
)
self.init_weights("normal")
def init_weights(self, scheme=""):
"""初始化权重。"""
named_apply(partial(_init_weights, scheme=scheme), self)
def forward(self, x):
"""前向计算。"""
pout1 = self.pconv1(x)
msdc_outs = self.msdc(pout1)
if self.add:
dout = 0
for dwout in msdc_outs:
dout = dout + dwout
else:
dout = torch.cat(msdc_outs, dim=1)
dout = channel_shuffle(dout, gcd(self.combined_channels, self.out_channels))
out = self.pconv2(dout)
if self.use_skip_connection:
if self.in_channels != self.out_channels:
x = self.conv1x1(x)
return x + out
return out
def MSCBLayer(
in_channels,
out_channels,
n=1,
stride=1,
kernel_sizes=None,
expansion_factor=2,
dw_parallel=True,
add=True,
activation="relu6",
):
"""构建 MSCB 堆叠层。"""
if kernel_sizes is None:
kernel_sizes = [1, 3, 5]
convs = []
mscb = MSCB(
in_channels,
out_channels,
stride,
kernel_sizes=kernel_sizes,
expansion_factor=expansion_factor,
dw_parallel=dw_parallel,
add=add,
activation=activation,
)
convs.append(mscb)
if n > 1:
for _ in range(1, n):
mscb = MSCB(
out_channels,
out_channels,
1,
kernel_sizes=kernel_sizes,
expansion_factor=expansion_factor,
dw_parallel=dw_parallel,
add=add,
activation=activation,
)
convs.append(mscb)
return nn.Sequential(*convs)
class EUCB(nn.Module):
"""高效上采样卷积块EUCB"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, activation="relu"):
"""初始化 EUCB。"""
super(EUCB, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.up_dwc = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(
self.in_channels,
self.in_channels,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
groups=self.in_channels,
bias=False,
),
nn.BatchNorm2d(self.in_channels),
act_layer(activation, inplace=True),
)
self.pwc = nn.Sequential(
nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
)
self.init_weights("normal")
def init_weights(self, scheme=""):
"""初始化权重。"""
named_apply(partial(_init_weights, scheme=scheme), self)
def forward(self, x):
"""前向计算。"""
x = self.up_dwc(x)
x = channel_shuffle(x, self.in_channels)
x = self.pwc(x)
return x
class LGAG(nn.Module):
"""大核分组注意力门控LGAG"""
def __init__(
self, F_g, F_l, F_int, kernel_size=3, groups=1, activation="relu"
):
"""初始化 LGAG。"""
super(LGAG, self).__init__()
if kernel_size == 1:
groups = 1
self.W_g = nn.Sequential(
nn.Conv2d(
F_g,
F_int,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
groups=groups,
bias=True,
),
nn.BatchNorm2d(F_int),
)
self.W_x = nn.Sequential(
nn.Conv2d(
F_l,
F_int,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
groups=groups,
bias=True,
),
nn.BatchNorm2d(F_int),
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid(),
)
self.activation = act_layer(activation, inplace=True)
self.init_weights("normal")
def init_weights(self, scheme=""):
"""初始化权重。"""
named_apply(partial(_init_weights, scheme=scheme), self)
def forward(self, g, x):
"""前向计算。"""
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.activation(g1 + x1)
psi = self.psi(psi)
return x * psi
class CAB(nn.Module):
"""通道注意力模块CAB"""
def __init__(self, in_channels, out_channels=None, ratio=16, activation="relu"):
"""初始化 CAB。"""
super(CAB, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if self.in_channels < ratio:
ratio = self.in_channels
self.reduced_channels = self.in_channels // ratio
if self.out_channels is None:
self.out_channels = in_channels
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.activation = act_layer(activation, inplace=True)
self.fc1 = nn.Conv2d(self.in_channels, self.reduced_channels, 1, bias=False)
self.fc2 = nn.Conv2d(self.reduced_channels, self.out_channels, 1, bias=False)
self.sigmoid = nn.Sigmoid()
self.init_weights("normal")
def init_weights(self, scheme=""):
"""初始化权重。"""
named_apply(partial(_init_weights, scheme=scheme), self)
def forward(self, x):
"""前向计算。"""
avg_pool_out = self.avg_pool(x)
avg_out = self.fc2(self.activation(self.fc1(avg_pool_out)))
max_pool_out = self.max_pool(x)
max_out = self.fc2(self.activation(self.fc1(max_pool_out)))
out = avg_out + max_out
return self.sigmoid(out)
class SAB(nn.Module):
"""空间注意力模块SAB"""
def __init__(self, kernel_size=7):
"""初始化 SAB。"""
super(SAB, self).__init__()
assert kernel_size in (3, 7, 11), "kernel must be 3 or 7 or 11"
padding = kernel_size // 2
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
self.init_weights("normal")
def init_weights(self, scheme=""):
"""初始化权重。"""
named_apply(partial(_init_weights, scheme=scheme), self)
def forward(self, x):
"""前向计算。"""
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv(x)
return self.sigmoid(x)
class EMCAD(nn.Module):
"""高效多尺度卷积注意力解码器EMCAD"""
def __init__(
self,
channels=None,
kernel_sizes=None,
expansion_factor=6,
dw_parallel=True,
add=True,
lgag_ks=3,
activation="relu6",
):
"""初始化 EMCAD。"""
super(EMCAD, self).__init__()
if channels is None:
channels = [512, 320, 128, 64]
if kernel_sizes is None:
kernel_sizes = [1, 3, 5]
eucb_ks = 3
self.mscb4 = MSCBLayer(
channels[0],
channels[0],
n=1,
stride=1,
kernel_sizes=kernel_sizes,
expansion_factor=expansion_factor,
dw_parallel=dw_parallel,
add=add,
activation=activation,
)
self.eucb3 = EUCB(
in_channels=channels[0],
out_channels=channels[1],
kernel_size=eucb_ks,
stride=eucb_ks // 2,
)
self.lgag3 = LGAG(
F_g=channels[1],
F_l=channels[1],
F_int=channels[1] // 2,
kernel_size=lgag_ks,
groups=channels[1] // 2,
)
self.mscb3 = MSCBLayer(
channels[1],
channels[1],
n=1,
stride=1,
kernel_sizes=kernel_sizes,
expansion_factor=expansion_factor,
dw_parallel=dw_parallel,
add=add,
activation=activation,
)
self.eucb2 = EUCB(
in_channels=channels[1],
out_channels=channels[2],
kernel_size=eucb_ks,
stride=eucb_ks // 2,
)
self.lgag2 = LGAG(
F_g=channels[2],
F_l=channels[2],
F_int=channels[2] // 2,
kernel_size=lgag_ks,
groups=channels[2] // 2,
)
self.mscb2 = MSCBLayer(
channels[2],
channels[2],
n=1,
stride=1,
kernel_sizes=kernel_sizes,
expansion_factor=expansion_factor,
dw_parallel=dw_parallel,
add=add,
activation=activation,
)
self.eucb1 = EUCB(
in_channels=channels[2],
out_channels=channels[3],
kernel_size=eucb_ks,
stride=eucb_ks // 2,
)
self.lgag1 = LGAG(
F_g=channels[3],
F_l=channels[3],
F_int=int(channels[3] / 2),
kernel_size=lgag_ks,
groups=int(channels[3] / 2),
)
self.mscb1 = MSCBLayer(
channels[3],
channels[3],
n=1,
stride=1,
kernel_sizes=kernel_sizes,
expansion_factor=expansion_factor,
dw_parallel=dw_parallel,
add=add,
activation=activation,
)
self.cab4 = CAB(channels[0])
self.cab3 = CAB(channels[1])
self.cab2 = CAB(channels[2])
self.cab1 = CAB(channels[3])
self.sab = SAB()
def forward(self, x, skips):
"""前向解码计算。"""
d4 = self.cab4(x) * x
d4 = self.sab(d4) * d4
d4 = self.mscb4(d4)
d3 = self.eucb3(d4)
x3 = self.lgag3(g=d3, x=skips[0])
d3 = d3 + x3
d3 = self.cab3(d3) * d3
d3 = self.sab(d3) * d3
d3 = self.mscb3(d3)
d2 = self.eucb2(d3)
x2 = self.lgag2(g=d2, x=skips[1])
d2 = d2 + x2
d2 = self.cab2(d2) * d2
d2 = self.sab(d2) * d2
d2 = self.mscb2(d2)
d1 = self.eucb1(d2)
x1 = self.lgag1(g=d1, x=skips[2])
d1 = d1 + x1
d1 = self.cab1(d1) * d1
d1 = self.sab(d1) * d1
d1 = self.mscb1(d1)
return [d4, d3, d2, d1]

@ -0,0 +1,145 @@
"""EMCADNet 网络定义。"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.core.decoders import EMCAD
from src.core.pvtv2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b3, pvt_v2_b4, pvt_v2_b5
from src.core.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
class EMCADNet(nn.Module):
"""EMCAD 端到端网络封装。"""
def __init__(
self,
num_classes=1,
kernel_sizes=None,
expansion_factor=2,
dw_parallel=True,
add=True,
lgag_ks=3,
activation="relu",
encoder="pvt_v2_b2",
pretrain=True,
pretrained_dir="./pretrained_pth/pvt/",
):
"""初始化网络。"""
super(EMCADNet, self).__init__()
if kernel_sizes is None:
kernel_sizes = [1, 3, 5]
self.conv = nn.Sequential(
nn.Conv2d(1, 3, kernel_size=1),
nn.BatchNorm2d(3),
nn.ReLU(inplace=True),
)
if encoder == "pvt_v2_b0":
self.backbone = pvt_v2_b0()
path = pretrained_dir + "/pvt_v2_b0.pth"
channels = [256, 160, 64, 32]
elif encoder == "pvt_v2_b1":
self.backbone = pvt_v2_b1()
path = pretrained_dir + "/pvt_v2_b1.pth"
channels = [512, 320, 128, 64]
elif encoder == "pvt_v2_b2":
self.backbone = pvt_v2_b2()
path = pretrained_dir + "/pvt_v2_b2.pth"
channels = [512, 320, 128, 64]
elif encoder == "pvt_v2_b3":
self.backbone = pvt_v2_b3()
path = pretrained_dir + "/pvt_v2_b3.pth"
channels = [512, 320, 128, 64]
elif encoder == "pvt_v2_b4":
self.backbone = pvt_v2_b4()
path = pretrained_dir + "/pvt_v2_b4.pth"
channels = [512, 320, 128, 64]
elif encoder == "pvt_v2_b5":
self.backbone = pvt_v2_b5()
path = pretrained_dir + "/pvt_v2_b5.pth"
channels = [512, 320, 128, 64]
elif encoder == "resnet18":
self.backbone = resnet18(pretrained=pretrain)
channels = [512, 256, 128, 64]
elif encoder == "resnet34":
self.backbone = resnet34(pretrained=pretrain)
channels = [512, 256, 128, 64]
elif encoder == "resnet50":
self.backbone = resnet50(pretrained=pretrain)
channels = [2048, 1024, 512, 256]
elif encoder == "resnet101":
self.backbone = resnet101(pretrained=pretrain)
channels = [2048, 1024, 512, 256]
elif encoder == "resnet152":
self.backbone = resnet152(pretrained=pretrain)
channels = [2048, 1024, 512, 256]
else:
print(
"Encoder not implemented! Continuing with default encoder pvt_v2_b2."
)
self.backbone = pvt_v2_b2()
path = pretrained_dir + "/pvt_v2_b2.pth"
channels = [512, 320, 128, 64]
if pretrain and "pvt_v2" in encoder:
save_model = torch.load(path)
model_dict = self.backbone.state_dict()
state_dict = {k: v for k, v in save_model.items() if k in model_dict}
model_dict.update(state_dict)
self.backbone.load_state_dict(model_dict)
print(
"Model %s created, param count: %d"
% (encoder + " backbone: ", sum(m.numel() for m in self.backbone.parameters()))
)
self.decoder = EMCAD(
channels=channels,
kernel_sizes=kernel_sizes,
expansion_factor=expansion_factor,
dw_parallel=dw_parallel,
add=add,
lgag_ks=lgag_ks,
activation=activation,
)
print(
"Model %s created, param count: %d"
% ("EMCAD decoder: ", sum(m.numel() for m in self.decoder.parameters()))
)
self.out_head4 = nn.Conv2d(channels[0], num_classes, 1)
self.out_head3 = nn.Conv2d(channels[1], num_classes, 1)
self.out_head2 = nn.Conv2d(channels[2], num_classes, 1)
self.out_head1 = nn.Conv2d(channels[3], num_classes, 1)
def forward(self, x, mode="test"):
"""前向计算。"""
if x.size()[1] == 1:
x = self.conv(x)
x1, x2, x3, x4 = self.backbone(x)
dec_outs = self.decoder(x4, [x3, x2, x1])
p4 = self.out_head4(dec_outs[0])
p3 = self.out_head3(dec_outs[1])
p2 = self.out_head2(dec_outs[2])
p1 = self.out_head1(dec_outs[3])
p4 = F.interpolate(p4, scale_factor=32, mode="bilinear")
p3 = F.interpolate(p3, scale_factor=16, mode="bilinear")
p2 = F.interpolate(p2, scale_factor=8, mode="bilinear")
p1 = F.interpolate(p1, scale_factor=4, mode="bilinear")
return [p4, p3, p2, p1]
if __name__ == "__main__":
model = EMCADNet().cuda()
input_tensor = torch.randn(1, 3, 352, 352).cuda()
outputs = model(input_tensor)
print(outputs[0].size(), outputs[1].size(), outputs[2].size(), outputs[3].size())

@ -0,0 +1,442 @@
"""Pyramid Vision Transformer v2 主干网络实现。"""
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
class Mlp(nn.Module):
"""带深度卷积的 MLP 模块。"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
"""带可选空间降采样的多头注意力。"""
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
"""包含注意力与 MLP 的 Transformer 块。"""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
# 使用 DropPath 实现随机深度
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class OverlapPatchEmbed(nn.Module):
"""重叠卷积的图像 Patch Embedding。"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class PyramidVisionTransformerImpr(nn.Module):
"""金字塔视觉 Transformer 主干网络。"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
super().__init__()
self.num_classes = num_classes
self.depths = depths
# 补丁嵌入
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
embed_dim=embed_dims[0])
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1])
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2])
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
embed_dim=embed_dims[3])
# 变换器编码器
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # 随机深度衰减规则
cur = 0
self.block1 = nn.ModuleList([Block(
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[0])
for i in range(depths[0])])
self.norm1 = norm_layer(embed_dims[0])
cur += depths[0]
self.block2 = nn.ModuleList([Block(
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[1])
for i in range(depths[1])])
self.norm2 = norm_layer(embed_dims[1])
cur += depths[1]
self.block3 = nn.ModuleList([Block(
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[2])
for i in range(depths[2])])
self.norm3 = norm_layer(embed_dims[2])
cur += depths[2]
self.block4 = nn.ModuleList([Block(
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[3])
for i in range(depths[3])])
self.norm4 = norm_layer(embed_dims[3])
# 分类头
# 可在此定义分类头
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = 1
# 可在此加载权重
def reset_drop_path(self, drop_path_rate):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
cur = 0
for i in range(self.depths[0]):
self.block1[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[0]
for i in range(self.depths[1]):
self.block2[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[1]
for i in range(self.depths[2]):
self.block3[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[2]
for i in range(self.depths[3]):
self.block4[i].drop_path.drop_prob = dpr[cur + i]
def freeze_patch_emb(self):
self.patch_embed1.requires_grad = False
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # 保留位置编码
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
# 需要位置编码插值时,可在此实现相关逻辑
def forward_features(self, x):
B = x.shape[0]
outs = []
# 阶段 1
x, H, W = self.patch_embed1(x)
for i, blk in enumerate(self.block1):
x = blk(x, H, W)
x = self.norm1(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# 阶段 2
x, H, W = self.patch_embed2(x)
for i, blk in enumerate(self.block2):
x = blk(x, H, W)
x = self.norm2(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# 阶段 3
x, H, W = self.patch_embed3(x)
for i, blk in enumerate(self.block3):
x = blk(x, H, W)
x = self.norm3(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# 阶段 4
x, H, W = self.patch_embed4(x)
for i, blk in enumerate(self.block4):
x = blk(x, H, W)
x = self.norm4(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs
# 可在此返回 token 均值
def forward(self, x):
x = self.forward_features(x)
# 可在此调用分类头
return x
class DWConv(nn.Module):
"""用于 token 混合的深度卷积。"""
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
def _conv_filter(state_dict, patch_size=16):
"""将 patch embedding 权重转换为卷积格式。"""
out_dict = {}
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
out_dict[k] = v
return out_dict
@register_model
class pvt_v2_b0(PyramidVisionTransformerImpr):
"""PVTv2-B0 主干。"""
def __init__(self, **kwargs):
super(pvt_v2_b0, self).__init__(
patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
@register_model
class pvt_v2_b1(PyramidVisionTransformerImpr):
"""PVTv2-B1 主干。"""
def __init__(self, **kwargs):
super(pvt_v2_b1, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
@register_model
class pvt_v2_b2(PyramidVisionTransformerImpr):
"""PVTv2-B2 主干。"""
def __init__(self, **kwargs):
super(pvt_v2_b2, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
@register_model
class pvt_v2_b3(PyramidVisionTransformerImpr):
"""PVTv2-B3 主干。"""
def __init__(self, **kwargs):
super(pvt_v2_b3, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
@register_model
class pvt_v2_b4(PyramidVisionTransformerImpr):
"""PVTv2-B4 主干。"""
def __init__(self, **kwargs):
super(pvt_v2_b4, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
@register_model
class pvt_v2_b5(PyramidVisionTransformerImpr):
"""PVTv2-B5 主干。"""
def __init__(self, **kwargs):
super(pvt_v2_b5, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)

@ -0,0 +1,373 @@
"""用于心脏图像分割的 ResNet 编码器主干网络。
该模块实现了多种深度的 ResNet 网络ResNet-18/34/50/101/152
可作为心脏图像分割模型的编码器输出多尺度特征图用于解码器
"""
import math
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
__all__ = [
"ResNet",
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
]
model_urls = {
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
}
def conv3x3(in_planes, out_planes, stride=1):
"""创建 3x3 卷积层(带边界填充)。
用于残差块中的特征提取保持特征图尺寸不变
参数:
in_planes: 输入特征图的通道数
out_planes: 输出特征图的通道数
stride: 卷积步长默认为 1
返回:
nn.Conv2d: 配置好的 3x3 卷积层
"""
return nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
)
class BasicBlock(nn.Module):
"""ResNet-18/34 使用的基础残差块。
包含两个 3x3 卷积层通过跳跃连接实现残差学习
适用于较浅的 ResNet 配置 ResNet-18/34
"""
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
"""初始化心脏图像分割用的基础残差块。
参数:
inplanes: 输入特征图的通道数
planes: 第一个卷积层的输出通道数
stride: 卷积步长控制特征图下采样比例
downsample: 下采样模块用于匹配特征图尺寸和通道数
"""
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
"""执行残差块的前向传播计算。
参数:
x: 输入特征图形状为 (N, C, H, W)
返回:
torch.Tensor: 输出特征图形状为 (N, C', H', W')。
"""
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
"""ResNet-50/101/152 使用的瓶颈残差块。
采用 1x1-3x3-1x3 的卷积堆叠结构先降维再升维
通过瓶颈设计减少计算量同时保持较强的特征表达能力
适用于较深的 ResNet 配置 ResNet-50/101/152
"""
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
"""初始化心脏图像分割用的瓶颈残差块。
参数:
inplanes: 输入特征图的通道数
planes: 瓶颈层中间卷积的输出通道数
stride: 卷积步长控制特征图下采样比例
downsample: 下采样模块用于匹配特征图尺寸和通道数
"""
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, planes * Bottleneck.expansion, kernel_size=1, bias=False
)
self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
"""执行瓶颈残差块的前向传播计算。
参数:
x: 输入特征图形状为 (N, C, H, W)
返回:
torch.Tensor: 输出特征图形状为 (N, C', H', W')。
"""
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
"""用于心脏图像分割的 ResNet 编码器主干网络。
支持输出多个阶段的特征图用于心脏图像分割任务的编码器部分
可选择标准基础网络或深度基础网络配置
"""
def __init__(self, block, layers, num_classes=1000, deep_base=False, stem_width=32):
"""初始化心脏图像分割用的 ResNet 主干网络。
参数:
block: 残差块类型BasicBlock Bottleneck
layers: 每个阶段的残差块数量列表 [2, 2, 2, 2] 表示 ResNet-18
num_classes: 分类输出的类别数用于预训练权重加载
deep_base: 是否使用深度基础网络配置
stem_width: 基础网络的stem宽度
"""
self.inplanes = stem_width * 2 if deep_base else 64
super(ResNet, self).__init__()
if deep_base:
self.conv1 = nn.Sequential(
nn.Conv2d(
3, stem_width, kernel_size=3, stride=2, padding=1, bias=False
),
nn.BatchNorm2d(stem_width),
nn.ReLU(inplace=True),
nn.Conv2d(
stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False
),
nn.BatchNorm2d(stem_width),
nn.ReLU(inplace=True),
nn.Conv2d(
stem_width, stem_width * 2, kernel_size=3, stride=1, padding=1, bias=False
),
)
else:
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = nn.BatchNorm2d(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for module in self.modules():
if isinstance(module, nn.Conv2d):
n = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
module.weight.data.normal_(0, math.sqrt(2.0 / n))
elif isinstance(module, nn.BatchNorm2d):
module.weight.data.fill_(1)
module.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
"""构建一个 ResNet 阶段(包含多个残差块)。
用于构建编码器中的不同特征提取阶段每个阶段可能包含
多个残差块负责提取不同尺度的心脏图像特征
参数:
block: 残差块类型
planes: 该阶段输出特征的通道数
blocks: 该阶段包含的残差块数量
stride: 第一个残差块的步长控制特征图下采样
返回:
nn.Sequential: 包含多个残差块的序列模块
"""
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
"""执行前向传播,返回心脏图像的多尺度特征图列表。
依次通过 stemlayer1layer2layer3layer4 处理输入心脏图像
返回每个阶段的输出特征用于后续的分割解码器
参数:
x: 输入心脏图像张量形状为 (N, 3, H, W)
返回:
list: 包含四个阶段特征图的列表特征图尺寸递减通道数递增
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
features = []
x = self.layer1(x)
features.append(x)
x = self.layer2(x)
features.append(x)
x = self.layer3(x)
features.append(x)
x = self.layer4(x)
features.append(x)
return features
def resnet18(pretrained=False, **kwargs):
"""构建 ResNet-18 编码器(适用于心脏图像分割)。
参数:
pretrained: 是否加载在 ImageNet 上预训练的权重
**kwargs: 其他传递给 ResNet 的关键字参数
返回:
ResNet: ResNet-18 模型实例
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls["resnet18"]))
return model
def resnet34(pretrained=False, **kwargs):
"""构建 ResNet-34 编码器(适用于心脏图像分割)。
参数:
pretrained: 是否加载在 ImageNet 上预训练的权重
**kwargs: 其他传递给 ResNet 的关键字参数
返回:
ResNet: ResNet-34 模型实例
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
print("Using pretrained weight!")
pretrained_dict = model_zoo.load_url(model_urls["resnet34"])
print("Pretrain model has been loaded!")
model.load_state_dict(pretrained_dict)
return model
def resnet50(pretrained=False, **kwargs):
"""构建 ResNet-50 编码器(适用于心脏图像分割)。
参数:
pretrained: 是否加载在 ImageNet 上预训练的权重
**kwargs: 其他传递给 ResNet 的关键字参数
返回:
ResNet: ResNet-50 模型实例
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
print("Using pretrained weight!")
model.load_state_dict(model_zoo.load_url(model_urls["resnet50"]))
return model
def resnet101(pretrained=False, **kwargs):
"""构建 ResNet-101 编码器(适用于心脏图像分割)。
参数:
pretrained: 是否加载在 ImageNet 上预训练的权重
**kwargs: 其他传递给 ResNet 的关键字参数
返回:
ResNet: ResNet-101 模型实例
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls["resnet101"]))
return model
def resnet152(pretrained=False, **kwargs):
"""构建 ResNet-152 编码器(适用于心脏图像分割)。
参数:
pretrained: 是否加载在 ImageNet 上预训练的权重
**kwargs: 其他传递给 ResNet 的关键字参数
返回:
ResNet: ResNet-152 模型实例
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls["resnet152"]))
return model

@ -0,0 +1,18 @@
from .dataset_synapse import Synapse_dataset, RandomGenerator
from .dataset_ACDC import ACDCdataset
from .dataloader import get_loader
from .trainer import trainer_synapse, trainer_ACDC
from .utils import DiceLoss, powerset, val_single_volume, test_single_volume
__all__ = [
"Synapse_dataset",
"RandomGenerator",
"ACDCdataset",
"get_loader",
"trainer_synapse",
"trainer_ACDC",
"DiceLoss",
"powerset",
"val_single_volume",
"test_single_volume",
]

@ -0,0 +1,338 @@
"""
心脏图像分割数据集加载模块
该模块提供了用于心脏图像分割任务的PyTorch数据集类和数据加载器
支持训练和测试阶段的心脏图像预处理与增强操作
"""
import os
import random
import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
class PolypDataset(data.Dataset):
"""心脏图像分割数据集。"""
def __init__(self, image_root, gt_root, trainsize, augmentations):
"""初始化心脏图像分割数据集与变换。
参数:
image_root: 心脏图像目录
gt_root: 心脏标注掩码目录
trainsize: 训练时统一的图像尺寸
augmentations: 是否启用数据增强字符串标记
"""
self.trainsize = trainsize
self.augmentations = augmentations
print(self.augmentations)
self.images = [
image_root + f
for f in os.listdir(image_root)
if f.endswith(".jpg") or f.endswith(".png")
]
self.gts = [
gt_root + f
for f in os.listdir(gt_root)
if f.endswith(".png") or f.endswith(".jpg")
]
self.images = sorted(self.images)
self.gts = sorted(self.gts)
self.filter_files()
self.size = len(self.images)
if self.augmentations == "True":
print("Using RandomRotation, RandomFlip")
self.img_transform = transforms.Compose(
[
transforms.RandomRotation(
90, resample=False, expand=False, center=None, fill=None
),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Resize((self.trainsize, self.trainsize)),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
),
]
)
self.gt_transform = transforms.Compose(
[
transforms.RandomRotation(
90, resample=False, expand=False, center=None, fill=None
),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Resize((self.trainsize, self.trainsize)),
transforms.ToTensor(),
]
)
else:
print("no augmentation")
self.img_transform = transforms.Compose(
[
transforms.Resize((self.trainsize, self.trainsize)),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
),
]
)
self.gt_transform = transforms.Compose(
[
transforms.Resize((self.trainsize, self.trainsize)),
transforms.ToTensor(),
]
)
def __getitem__(self, index):
"""读取单个训练样本并应用预处理。
使用相同的随机种子确保图像和掩码应用相同的随机变换
以保持空间一致性
参数:
index: 样本在数据集中的索引位置
返回:
tuple: 包含预处理后的图像张量和对应掩码张量的元组
图像张量形状为 (C, H, W)掩码张量形状为 (1, H, W)
"""
image = self.rgb_loader(self.images[index])
gt = self.binary_loader(self.gts[index])
seed = np.random.randint(2147483647)
random.seed(seed)
torch.manual_seed(seed)
if self.img_transform is not None:
image = self.img_transform(image)
random.seed(seed)
torch.manual_seed(seed)
if self.gt_transform is not None:
gt = self.gt_transform(gt)
return image, gt
def filter_files(self):
"""筛选尺寸匹配的图像-掩码对。
仅保留尺寸一致的图像和对应掩码确保训练时不会因尺寸不匹配
导致数据加载错误
"""
assert len(self.images) == len(self.gts)
images = []
gts = []
for img_path, gt_path in zip(self.images, self.gts):
img = Image.open(img_path)
gt = Image.open(gt_path)
if img.size == gt.size:
images.append(img_path)
gts.append(gt_path)
self.images = images
self.gts = gts
def rgb_loader(self, path):
"""从文件路径读取并返回 RGB 格式图像。
参数:
path: 图像文件的完整路径
返回:
PIL.Image.Image: RGB 模式的图像对象
"""
with open(path, "rb") as f:
img = Image.open(f)
return img.convert("RGB")
def binary_loader(self, path):
"""从文件路径读取并返回灰度格式掩码。
参数:
path: 掩码图像文件的完整路径
返回:
PIL.Image.Image: 灰度模式L模式的掩码图像对象
"""
with open(path, "rb") as f:
img = Image.open(f)
return img.convert("L")
def convert2polar(self, img, gt):
"""将笛卡尔坐标系的图像和掩码转换为极坐标系表示。
基于掩码的质心点进行坐标变换可用于增强模型对径向特征的感知能力
参数:
img: PIL Image 图像对象
gt: PIL Image 掩码对象
返回:
tuple: 转换后的 (img, gt) 元组均为极坐标表示形式
"""
center = polar_transformations.centroid(gt)
img = polar_transformations.to_polar(img, center)
gt = polar_transformations.to_polar(gt, center)
return img, gt
def resize(self, img, gt):
"""将图像和掩码尺寸调整至不小于 trainsize。
如果图像尺寸小于 trainsize则按比例放大以满足最小尺寸要求
图像使用双线性插值BILINEAR掩码使用最近邻插值NEAREST
以保持边缘清晰
参数:
img: PIL 图像对象
gt: PIL 掩码对象
返回:
tuple: 调整后的 (img, gt) 元组
"""
assert img.size == gt.size
w, h = img.size
if h < self.trainsize or w < self.trainsize:
h = max(h, self.trainsize)
w = max(w, self.trainsize)
return (
img.resize((w, h), Image.BILINEAR),
gt.resize((w, h), Image.NEAREST),
)
return img, gt
def __len__(self):
"""返回数据集大小。"""
return self.size
def get_loader(
image_root,
gt_root,
batchsize,
trainsize,
shuffle=False,
num_workers=4,
pin_memory=True,
augmentation=False,
):
"""构建心脏图像分割数据集的 DataLoader。
参数:
image_root: 心脏图像目录
gt_root: 心脏标注掩码目录
batchsize: 批大小
trainsize: 训练尺寸
shuffle: 是否打乱数据顺序
num_workers: 数据加载的进程数
pin_memory: 是否使用内存锁页以加速数据传输
augmentation: 是否启用数据增强
返回:
torch.utils.data.DataLoader: 配置好的数据加载器实例
"""
dataset = PolypDataset(image_root, gt_root, trainsize, augmentation)
data_loader = data.DataLoader(
dataset=dataset,
batch_size=batchsize,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
)
return data_loader
class test_dataset(object):
"""用于心脏图像分割测试阶段的数据集包装类。
提供按需加载心脏测试图像的功能支持批量推理场景下的高效数据遍历
每次调用 load_data 方法会按顺序返回一个心脏样本
"""
def __init__(self, image_root, gt_root, testsize):
"""初始化心脏图像测试数据集。
参数:
image_root: 包含心脏测试图像的目录路径
gt_root: 包含心脏标注掩码的目录路径
testsize: 测试时统一的图像尺寸
"""
self.testsize = testsize
self.images = [
image_root + f
for f in os.listdir(image_root)
if f.endswith(".jpg") or f.endswith(".png")
]
self.gts = [
gt_root + f
for f in os.listdir(gt_root)
if f.endswith(".tif")
or f.endswith(".png")
or f.endswith(".jpg")
]
self.images = sorted(self.images)
self.gts = sorted(self.gts)
self.transform = transforms.Compose(
[
transforms.Resize((self.testsize, self.testsize)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
self.gt_transform = transforms.ToTensor()
self.size = len(self.images)
self.index = 0
def load_data(self):
"""加载并返回当前索引处的心脏测试样本。
读取心脏图像和对应掩码进行预处理并返回带有样本名称的结果
调用后内部索引自动递增指向下一个样本
返回:
tuple: 包含以下元素的元组
- image: 预处理后的心脏图像张量形状为 (1, C, H, W)
- gt: 心脏标注掩码张量形状为 (1, H, W)
- name: 心脏图像文件名字符串掩码文件名与其对应
"""
image = self.rgb_loader(self.images[self.index])
image = self.transform(image).unsqueeze(0)
gt = self.binary_loader(self.gts[self.index])
name = self.images[self.index].split("/")[-1]
if name.endswith(".jpg"):
name = name.split(".jpg")[0] + ".png"
self.index += 1
return image, gt, name
def rgb_loader(self, path):
"""从文件路径读取并返回 RGB 格式心脏图像。
参数:
path: 心脏图像文件的完整路径
返回:
PIL.Image.Image: RGB 模式的图像对象
"""
with open(path, "rb") as f:
img = Image.open(f)
return img.convert("RGB")
def binary_loader(self, path):
"""从文件路径读取并返回灰度格式心脏掩码。
参数:
path: 心脏标注掩码文件的完整路径
返回:
PIL.Image.Image: 灰度模式L模式的掩码图像对象
"""
with open(path, "rb") as f:
img = Image.open(f)
return img.convert("L")

@ -0,0 +1,33 @@
import os
import random
import numpy as np
import torch
from torch.utils.data import Dataset
class ACDCdataset(Dataset):
def __init__(self, base_dir, list_dir, split, transform=None):
self.transform = transform # using transform in torch!
self.split = split
self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines()
self.data_dir = base_dir
def __len__(self):
return len(self.sample_list)
def __getitem__(self, idx):
if self.split == "train" or self.split == "valid":
slice_name = self.sample_list[idx].strip('\n')
data_path = os.path.join(self.data_dir, self.split, slice_name)
data = np.load(data_path)
image, label = data['img'], data['label']
else:
vol_name = self.sample_list[idx].strip('\n')
filepath = self.data_dir + "/{}".format(vol_name)
data = np.load(filepath)
image, label = data['img'], data['label']
sample = {'image': image, 'label': label}
if self.transform and self.split == "train":
sample = self.transform(sample)
sample['case_name'] = self.sample_list[idx].strip('\n')
return sample

@ -0,0 +1,55 @@
"""ACDC 数据集的 Dataset 实现。"""
import os
import numpy as np
from torch.utils.data import Dataset
class ACDCdataset(Dataset):
"""用于 ACDC 切片与体数据的 PyTorch Dataset。"""
def __init__(self, base_dir, list_dir, split, transform=None):
"""初始化数据集路径与划分。
参数:
base_dir: 数据集根目录
list_dir: 列表文件目录
split: 划分名称 'train''valid''test'
transform: 可选的数据增强/变换
"""
self.transform = transform
self.split = split
list_path = os.path.join(list_dir, self.split + ".txt")
self.sample_list = open(list_path).readlines()
self.data_dir = base_dir
def __len__(self):
"""返回样本数量。"""
return len(self.sample_list)
def __getitem__(self, idx):
"""按索引读取样本。
参数:
idx: 样本索引
返回:
包含 imagelabel case_name 的字典
"""
if self.split == "train" or self.split == "valid":
slice_name = self.sample_list[idx].strip("\n")
data_path = os.path.join(self.data_dir, self.split, slice_name)
data = np.load(data_path)
image, label = data["img"], data["label"]
else:
vol_name = self.sample_list[idx].strip("\n")
filepath = self.data_dir + "/{}".format(vol_name)
data = np.load(filepath)
image, label = data["img"], data["label"]
sample = {"image": image, "label": label}
if self.transform and self.split == "train":
sample = self.transform(sample)
sample["case_name"] = self.sample_list[idx].strip("\n")
return sample

@ -0,0 +1,100 @@
import os
import random
import h5py
import numpy as np
import torch
import cv2
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset
def random_rot_flip(image, label):
k = np.random.randint(0, 4)
image = np.rot90(image, k)
label = np.rot90(label, k)
axis = np.random.randint(0, 2)
image = np.flip(image, axis=axis).copy()
label = np.flip(label, axis=axis).copy()
return image, label
def random_rotate(image, label):
angle = np.random.randint(-20, 20)
image = ndimage.rotate(image, angle, order=0, reshape=False)
label = ndimage.rotate(label, angle, order=0, reshape=False)
return image, label
class RandomGenerator(object):
def __init__(self, output_size):
self.output_size = output_size
def __call__(self, sample):
image, label = sample['image'], sample['label']
if random.random() > 0.5:
image, label = random_rot_flip(image, label)
elif random.random() > 0.5:
image, label = random_rotate(image, label)
x, y = image.shape
if x != self.output_size[0] or y != self.output_size[1]:
image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3) # why not 3?
label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
label = torch.from_numpy(label.astype(np.float32))
sample = {'image': image, 'label': label.long()}
return sample
class Synapse_dataset(Dataset):
def __init__(self, base_dir, list_dir, split, nclass=9, transform=None):
self.transform = transform # using transform in torch!
self.split = split
self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines()
self.data_dir = base_dir
self.nclass = nclass
def __len__(self):
return len(self.sample_list)
def __getitem__(self, idx):
if self.split == "train":
slice_name = self.sample_list[idx].strip('\n')
data_path = os.path.join(self.data_dir, slice_name+'.npz')
data = np.load(data_path)
image, label = data['image'], data['label']
#print(image.shape)
#image = np.reshape(image, (512, 512))
#image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
#label = np.reshape(label, (512, 512))
else:
vol_name = self.sample_list[idx].strip('\n')
filepath = self.data_dir + "/{}.npy.h5".format(vol_name)
data = h5py.File(filepath)
image, label = data['image'][:], data['label'][:]
#image = np.reshape(image, (image.shape[2], 512, 512))
#label = np.reshape(label, (label.shape[2], 512, 512))
#label[label==5]= 0
#label[label==9]= 0
#label[label==10]= 0
#label[label==12]= 0
#label[label==13]= 0
#label[label==11]= 5
if self.nclass == 9:
label[label==5]= 0
label[label==9]= 0
label[label==10]= 0
label[label==12]= 0
label[label==13]= 0
label[label==11]= 5
sample = {'image': image, 'label': label}
if self.transform:
sample = self.transform(sample)
sample['case_name'] = self.sample_list[idx].strip('\n')
return sample

@ -0,0 +1,149 @@
"""Synapse 数据集相关工具。"""
import os
import random
import h5py
import numpy as np
import torch
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset
def random_rot_flip(image, label):
"""随机旋转并翻转增强。
参数:
image: 输入图像数组
label: 标签数组
返回:
变换后的 (image, label)
"""
k = np.random.randint(0, 4)
image = np.rot90(image, k)
label = np.rot90(label, k)
axis = np.random.randint(0, 2)
image = np.flip(image, axis=axis).copy()
label = np.flip(label, axis=axis).copy()
return image, label
def random_rotate(image, label):
"""随机角度旋转增强。
参数:
image: 输入图像数组
label: 标签数组
返回:
变换后的 (image, label)
"""
angle = np.random.randint(-20, 20)
image = ndimage.rotate(image, angle, order=0, reshape=False)
label = ndimage.rotate(label, angle, order=0, reshape=False)
return image, label
class RandomGenerator(object):
"""随机增强与缩放生成器。"""
def __init__(self, output_size):
"""初始化输出尺寸。
参数:
output_size: (H, W) 目标尺寸
"""
self.output_size = output_size
def __call__(self, sample):
"""对样本进行增强与缩放。
参数:
sample: 包含 'image' 'label' 的字典
返回:
变换后的样本字典
"""
image, label = sample["image"], sample["label"]
if random.random() > 0.5:
image, label = random_rot_flip(image, label)
elif random.random() > 0.5:
image, label = random_rotate(image, label)
x, y = image.shape
if x != self.output_size[0] or y != self.output_size[1]:
image = zoom(
image,
(self.output_size[0] / x, self.output_size[1] / y),
order=3,
)
label = zoom(
label,
(self.output_size[0] / x, self.output_size[1] / y),
order=0,
)
image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
label = torch.from_numpy(label.astype(np.float32))
sample = {"image": image, "label": label.long()}
return sample
class Synapse_dataset(Dataset):
"""Synapse 数据集的 PyTorch Dataset。"""
def __init__(self, base_dir, list_dir, split, nclass=9, transform=None):
"""初始化数据集。
参数:
base_dir: 数据集根目录
list_dir: 列表文件目录
split: 划分名称
nclass: 类别数
transform: 可选的数据增强
"""
self.transform = transform
self.split = split
list_path = os.path.join(list_dir, self.split + ".txt")
self.sample_list = open(list_path).readlines()
self.data_dir = base_dir
self.nclass = nclass
def __len__(self):
"""返回样本数量。"""
return len(self.sample_list)
def __getitem__(self, idx):
"""按索引读取样本。
参数:
idx: 样本索引
返回:
包含 imagelabel case_name 的字典
"""
if self.split == "train":
slice_name = self.sample_list[idx].strip("\n")
data_path = os.path.join(self.data_dir, slice_name + ".npz")
data = np.load(data_path)
image, label = data["image"], data["label"]
else:
vol_name = self.sample_list[idx].strip("\n")
filepath = self.data_dir + "/{}.npy.h5".format(vol_name)
data = h5py.File(filepath)
image, label = data["image"][:], data["label"][:]
if self.nclass == 9:
label[label == 5] = 0
label[label == 9] = 0
label[label == 10] = 0
label[label == 12] = 0
label[label == 13] = 0
label[label == 11] = 5
sample = {"image": image, "label": label}
if self.transform:
sample = self.transform(sample)
sample["case_name"] = self.sample_list[idx].strip("\n")
return sample

@ -0,0 +1,40 @@
import os
import shutil
from libtiff import TIFF # pip install libtiff
from scipy import misc
import random
def tif2png(_src_path, _dst_path):
"""
Usage:
formatting `tif/tiff` files to `jpg/png` files
:param _src_path:
:param _dst_path:
:return:
"""
tif = TIFF.open(_src_path, mode='r')
image = tif.read_image()
misc.imsave(_dst_path, image)
def data_split(src_list):
"""
Usage:
randomly spliting dataset
:param src_list:
:return:
"""
counter_list = random.sample(range(0, len(src_list)), 550)
return counter_list
if __name__ == '__main__':
src_dir = '../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/test_split/masks_tif'
dst_dir = '../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/test_split/masks'
os.makedirs(dst_dir, exist_ok=True)
for img_name in os.listdir(src_dir):
tif2png(os.path.join(src_dir, img_name),
os.path.join(dst_dir, img_name.replace('.tif', '.png')))

@ -0,0 +1,54 @@
"""数据格式转换相关工具。"""
import os
import random
import shutil
from libtiff import TIFF
from scipy import misc
def tif2png(src_path, dst_path):
"""将 TIFF 文件转换为 PNG/JPG 格式。
参数:
src_path: TIFF 路径
dst_path: 输出图像路径
返回:
None
"""
tif = TIFF.open(src_path, mode="r")
image = tif.read_image()
misc.imsave(dst_path, image)
def data_split(src_list):
"""随机生成数据划分的索引列表。
参数:
src_list: 原始列表
返回:
采样后的索引列表
"""
counter_list = random.sample(range(0, len(src_list)), 550)
return counter_list
if __name__ == "__main__":
src_dir = (
"../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/"
"test_split/masks_tif"
)
dst_dir = (
"../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/test_split/masks"
)
os.makedirs(dst_dir, exist_ok=True)
for img_name in os.listdir(src_dir):
tif2png(
os.path.join(src_dir, img_name),
os.path.join(dst_dir, img_name.replace(".tif", ".png")),
)

@ -0,0 +1,444 @@
"""分割任务的联合图像/掩码变换。"""
import math
import numbers
import random
import numpy as np
from PIL import Image, ImageOps
class Compose(object):
"""将多个联合变换串联执行。"""
def __init__(self, transforms):
"""初始化变换列表。
参数:
transforms: 可调用对象列表输入为 (img, mask)
"""
self.transforms = transforms
def __call__(self, img, mask):
"""依次应用所有变换。
参数:
img: PIL 图像
mask: PIL 掩码
返回:
变换后的 (img, mask)
"""
assert img.size == mask.size
for transform in self.transforms:
img, mask = transform(img, mask)
return img, mask
class RandomCrop(object):
"""随机裁剪图像与掩码。"""
def __init__(self, size, padding=0):
"""初始化裁剪参数。
参数:
size: 裁剪尺寸int (h, w)
padding: 先进行边缘填充的大小
"""
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding
def __call__(self, img, mask):
"""执行随机裁剪。
参数:
img: PIL 图像
mask: PIL 掩码
返回:
裁剪后的 (img, mask)
"""
if self.padding > 0:
img = ImageOps.expand(img, border=self.padding, fill=0)
mask = ImageOps.expand(mask, border=self.padding, fill=0)
assert img.size == mask.size
w, h = img.size
th, tw = self.size
if w == tw and h == th:
return img, mask
if w < tw or h < th:
return (
img.resize((tw, th), Image.BILINEAR),
mask.resize((tw, th), Image.NEAREST),
)
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
return (
img.crop((x1, y1, x1 + tw, y1 + th)),
mask.crop((x1, y1, x1 + tw, y1 + th)),
)
class CenterCrop(object):
"""中心裁剪图像与掩码。"""
def __init__(self, size):
"""初始化裁剪参数。
参数:
size: 裁剪尺寸int (h, w)
"""
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img, mask):
"""执行中心裁剪。
参数:
img: PIL 图像
mask: PIL 掩码
返回:
裁剪后的 (img, mask)
"""
assert img.size == mask.size
w, h = img.size
th, tw = self.size
x1 = int(round((w - tw) / 2.0))
y1 = int(round((h - th) / 2.0))
return (
img.crop((x1, y1, x1 + tw, y1 + th)),
mask.crop((x1, y1, x1 + tw, y1 + th)),
)
class RandomHorizontallyFlip(object):
"""随机水平翻转图像与掩码。"""
def __call__(self, img, mask):
"""以 50% 概率水平翻转。
参数:
img: PIL 图像
mask: PIL 掩码
返回:
变换后的 (img, mask)
"""
if random.random() < 0.5:
return (
img.transpose(Image.FLIP_LEFT_RIGHT),
mask.transpose(Image.FLIP_LEFT_RIGHT),
)
return img, mask
class FreeScale(object):
"""将图像与掩码缩放到指定大小。"""
def __init__(self, size):
"""初始化目标尺寸。
参数:
size: 目标尺寸 (h, w)
"""
self.size = tuple(reversed(size))
def __call__(self, img, mask):
"""执行缩放。
参数:
img: PIL 图像
mask: PIL 掩码
返回:
缩放后的 (img, mask)
"""
assert img.size == mask.size
return (
img.resize(self.size, Image.BILINEAR),
mask.resize(self.size, Image.NEAREST),
)
class Scale(object):
"""按长边缩放并保持宽高比。"""
def __init__(self, size):
"""初始化缩放尺寸。
参数:
size: 长边目标尺寸
"""
self.size = size
def __call__(self, img, mask):
"""按比例缩放。
参数:
img: PIL 图像
mask: PIL 掩码
返回:
缩放后的 (img, mask)
"""
assert img.size == mask.size
w, h = img.size
if (w >= h and w == self.size) or (h >= w and h == self.size):
return img, mask
if w > h:
ow = self.size
oh = int(self.size * h / w)
return (
img.resize((ow, oh), Image.BILINEAR),
mask.resize((ow, oh), Image.NEAREST),
)
oh = self.size
ow = int(self.size * w / h)
return (
img.resize((ow, oh), Image.BILINEAR),
mask.resize((ow, oh), Image.NEAREST),
)
class RandomSizedCrop(object):
"""随机尺寸裁剪增强。"""
def __init__(self, size):
"""初始化裁剪大小。
参数:
size: 输出尺寸
"""
self.size = size
def __call__(self, img, mask):
"""执行随机尺寸裁剪。
参数:
img: PIL 图像
mask: PIL 掩码
返回:
裁剪并缩放后的 (img, mask)
"""
assert img.size == mask.size
for _ in range(10):
area = img.size[0] * img.size[1]
target_area = random.uniform(0.45, 1.0) * area
aspect_ratio = random.uniform(0.5, 2)
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if random.random() < 0.5:
w, h = h, w
if w <= img.size[0] and h <= img.size[1]:
x1 = random.randint(0, img.size[0] - w)
y1 = random.randint(0, img.size[1] - h)
img = img.crop((x1, y1, x1 + w, y1 + h))
mask = mask.crop((x1, y1, x1 + w, y1 + h))
assert img.size == (w, h)
return (
img.resize((self.size, self.size), Image.BILINEAR),
mask.resize((self.size, self.size), Image.NEAREST),
)
scale = Scale(self.size)
crop = CenterCrop(self.size)
return crop(*scale(img, mask))
class RandomRotate(object):
"""随机旋转图像与掩码。"""
def __init__(self, degree):
"""初始化旋转角度范围。
参数:
degree: 最大旋转角度
"""
self.degree = degree
def __call__(self, img, mask):
"""执行随机旋转。
参数:
img: PIL 图像
mask: PIL 掩码
返回:
旋转后的 (img, mask)
"""
rotate_degree = random.random() * 2 * self.degree - self.degree
return (
img.rotate(rotate_degree, Image.BILINEAR),
mask.rotate(rotate_degree, Image.NEAREST),
)
class RandomSized(object):
"""随机缩放后再裁剪。"""
def __init__(self, size):
"""初始化尺寸参数。
参数:
size: 目标尺寸
"""
self.size = size
self.scale = Scale(self.size)
self.crop = RandomCrop(self.size)
def __call__(self, img, mask):
"""执行随机缩放与裁剪。
参数:
img: PIL 图像
mask: PIL 掩码
返回:
变换后的 (img, mask)
"""
assert img.size == mask.size
w = int(random.uniform(0.5, 2) * img.size[0])
h = int(random.uniform(0.5, 2) * img.size[1])
img, mask = (
img.resize((w, h), Image.BILINEAR),
mask.resize((w, h), Image.NEAREST),
)
return self.crop(*self.scale(img, mask))
class SlidingCropOld(object):
"""旧版滑窗裁剪实现。"""
def __init__(self, crop_size, stride_rate, ignore_label):
"""初始化滑窗裁剪参数。"""
self.crop_size = crop_size
self.stride_rate = stride_rate
self.ignore_label = ignore_label
def _pad(self, img, mask):
"""将数组填充到裁剪尺寸。"""
h, w = img.shape[:2]
pad_h = max(self.crop_size - h, 0)
pad_w = max(self.crop_size - w, 0)
img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), "constant")
mask = np.pad(
mask,
((0, pad_h), (0, pad_w)),
"constant",
constant_values=self.ignore_label,
)
return img, mask
def __call__(self, img, mask):
"""执行滑窗裁剪并返回切片列表。"""
assert img.size == mask.size
w, h = img.size
long_size = max(h, w)
img = np.array(img)
mask = np.array(mask)
if long_size > self.crop_size:
stride = int(math.ceil(self.crop_size * self.stride_rate))
h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1
w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1
img_sublist, mask_sublist = [], []
for yy in xrange(h_step_num):
for xx in xrange(w_step_num):
sy, sx = yy * stride, xx * stride
ey, ex = sy + self.crop_size, sx + self.crop_size
img_sub = img[sy:ey, sx:ex, :]
mask_sub = mask[sy:ey, sx:ex]
img_sub, mask_sub = self._pad(img_sub, mask_sub)
img_sublist.append(
Image.fromarray(img_sub.astype(np.uint8)).convert("RGB")
)
mask_sublist.append(
Image.fromarray(mask_sub.astype(np.uint8)).convert("P")
)
return img_sublist, mask_sublist
img, mask = self._pad(img, mask)
img = Image.fromarray(img.astype(np.uint8)).convert("RGB")
mask = Image.fromarray(mask.astype(np.uint8)).convert("P")
return img, mask
class SlidingCrop(object):
"""滑窗裁剪并返回切片元信息。"""
def __init__(self, crop_size, stride_rate, ignore_label):
"""初始化滑窗裁剪参数。"""
self.crop_size = crop_size
self.stride_rate = stride_rate
self.ignore_label = ignore_label
def _pad(self, img, mask):
"""填充到裁剪尺寸并返回原始尺寸。"""
h, w = img.shape[:2]
pad_h = max(self.crop_size - h, 0)
pad_w = max(self.crop_size - w, 0)
img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), "constant")
mask = np.pad(
mask,
((0, pad_h), (0, pad_w)),
"constant",
constant_values=self.ignore_label,
)
return img, mask, h, w
def __call__(self, img, mask):
"""执行滑窗裁剪。
返回:
图像切片列表掩码切片列表与切片信息
"""
assert img.size == mask.size
w, h = img.size
long_size = max(h, w)
img = np.array(img)
mask = np.array(mask)
if long_size > self.crop_size:
stride = int(math.ceil(self.crop_size * self.stride_rate))
h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1
w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1
img_slices, mask_slices, slices_info = [], [], []
for yy in range(h_step_num):
for xx in range(w_step_num):
sy, sx = yy * stride, xx * stride
ey, ex = sy + self.crop_size, sx + self.crop_size
img_sub = img[sy:ey, sx:ex, :]
mask_sub = mask[sy:ey, sx:ex]
img_sub, mask_sub, sub_h, sub_w = self._pad(img_sub, mask_sub)
img_slices.append(
Image.fromarray(img_sub.astype(np.uint8)).convert("RGB")
)
mask_slices.append(
Image.fromarray(mask_sub.astype(np.uint8)).convert("P")
)
slices_info.append([sy, ey, sx, ex, sub_h, sub_w])
return img_slices, mask_slices, slices_info
img, mask, sub_h, sub_w = self._pad(img, mask)
img = Image.fromarray(img.astype(np.uint8)).convert("RGB")
mask = Image.fromarray(mask.astype(np.uint8)).convert("P")
return [img], [mask], [[0, sub_h, 0, sub_w, sub_h, sub_w]]

@ -0,0 +1,339 @@
"""训练与评估的杂项工具。"""
import os
from math import ceil
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
def check_mkdir(dir_name):
"""若目录不存在则创建。
参数:
dir_name: 目录路径
"""
if not os.path.exists(dir_name):
os.mkdir(dir_name)
def initialize_weights(*models):
"""初始化模型中的权重参数。"""
for model in models:
for module in model.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal(module.weight)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm2d):
module.weight.data.fill_(1)
module.bias.data.zero_()
def get_upsampling_weight(in_channels, out_channels, kernel_size):
"""生成双线性上采样的权重核。"""
factor = (kernel_size + 1) // 2
if kernel_size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:kernel_size, :kernel_size]
filt = (1 - abs(og[0] - center) / factor) * (
1 - abs(og[1] - center) / factor
)
weight = np.zeros(
(in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64
)
weight[list(range(in_channels)), list(range(out_channels)), :, :] = filt
return torch.from_numpy(weight).float()
class CrossEntropyLoss2d(nn.Module):
"""用于分割的二维交叉熵损失包装。"""
def __init__(self, weight=None, size_average=True, ignore_index=255):
"""初始化损失函数。"""
super(CrossEntropyLoss2d, self).__init__()
self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index)
def forward(self, inputs, targets):
"""计算损失。"""
return self.nll_loss(F.log_softmax(inputs), targets)
class FocalLoss2d(nn.Module):
"""用于分割的二维 Focal Loss 包装。"""
def __init__(self, gamma=2, weight=None, size_average=True, ignore_index=255):
"""初始化损失函数。"""
super(FocalLoss2d, self).__init__()
self.gamma = gamma
self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index)
def forward(self, inputs, targets):
"""计算损失。"""
return self.nll_loss(
(1 - F.softmax(inputs)) ** self.gamma * F.log_softmax(inputs),
targets,
)
def _fast_hist(label_pred, label_true, num_classes):
"""计算混淆矩阵。"""
mask = (label_true >= 0) & (label_true < num_classes)
hist = np.bincount(
num_classes * label_true[mask].astype(int) + label_pred[mask],
minlength=num_classes ** 2,
).reshape(num_classes, num_classes)
return hist
def evaluate(predictions, gts, num_classes):
"""评估分割指标。"""
hist = np.zeros((num_classes, num_classes))
for lp, lt in zip(predictions, gts):
hist += _fast_hist(lp.flatten(), lt.flatten(), num_classes)
acc = np.diag(hist).sum() / hist.sum()
acc_cls = np.diag(hist) / hist.sum(axis=1)
acc_cls = np.nanmean(acc_cls)
iu = np.diag(hist) / (
hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
)
mean_iu = np.nanmean(iu)
freq = hist.sum(axis=1) / hist.sum()
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
return acc, acc_cls, mean_iu, fwavacc
class AverageMeter(object):
"""用于记录均值的计量器。"""
def __init__(self):
"""初始化计量器。"""
self.reset()
def reset(self):
"""重置计量器。"""
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
"""更新数值。
参数:
val: 新值
n: 权重
"""
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class PolyLR(object):
"""多项式学习率调度器。"""
def __init__(self, optimizer, curr_iter, max_iter, lr_decay):
"""初始化调度器。"""
self.max_iter = float(max_iter)
self.init_lr_groups = []
for params in optimizer.param_groups:
self.init_lr_groups.append(params["lr"])
self.param_groups = optimizer.param_groups
self.curr_iter = curr_iter
self.lr_decay = lr_decay
def step(self):
"""执行一次学习率更新。"""
for idx, params in enumerate(self.param_groups):
params["lr"] = self.init_lr_groups[idx] * (
1 - self.curr_iter / self.max_iter
) ** self.lr_decay
class Conv2dDeformable(nn.Module):
"""实验性可变形卷积包装。"""
def __init__(self, regular_filter, cuda=True):
"""初始化可变形卷积。"""
super(Conv2dDeformable, self).__init__()
assert isinstance(regular_filter, nn.Conv2d)
self.regular_filter = regular_filter
self.offset_filter = nn.Conv2d(
regular_filter.in_channels,
2 * regular_filter.in_channels,
kernel_size=3,
padding=1,
bias=False,
)
self.offset_filter.weight.data.normal_(0, 0.0005)
self.input_shape = None
self.grid_w = None
self.grid_h = None
self.cuda = cuda
def forward(self, x):
"""执行可变形卷积前向计算。"""
x_shape = x.size()
offset = self.offset_filter(x)
offset_w, offset_h = torch.split(
offset, self.regular_filter.in_channels, 1
)
offset_w = offset_w.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]))
offset_h = offset_h.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]))
if not self.input_shape or self.input_shape != x_shape:
self.input_shape = x_shape
grid_w, grid_h = np.meshgrid(
np.linspace(-1, 1, x_shape[3]),
np.linspace(-1, 1, x_shape[2]),
)
grid_w = torch.Tensor(grid_w)
grid_h = torch.Tensor(grid_h)
if self.cuda:
grid_w = grid_w.cuda()
grid_h = grid_h.cuda()
self.grid_w = nn.Parameter(grid_w)
self.grid_h = nn.Parameter(grid_h)
offset_w = offset_w + self.grid_w
offset_h = offset_h + self.grid_h
x = (
x.contiguous()
.view(-1, int(x_shape[2]), int(x_shape[3]))
.unsqueeze(1)
)
x = F.grid_sample(x, torch.stack((offset_h, offset_w), 3))
x = x.contiguous().view(
-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3])
)
x = self.regular_filter(x)
return x
def sliced_forward(single_forward):
"""对单次前向计算进行滑窗与多尺度封装。"""
def _pad(x, crop_size):
"""将输入张量补齐到裁剪尺寸。"""
h, w = x.size()[2:]
pad_h = max(crop_size - h, 0)
pad_w = max(crop_size - w, 0)
x = F.pad(x, (0, pad_w, 0, pad_h))
return x, pad_h, pad_w
def wrapper(self, x):
"""封装后的前向函数。"""
batch_size, _, ori_h, ori_w = x.size()
if self.training and self.use_aux:
outputs_all_scales = Variable(
torch.zeros((batch_size, self.num_classes, ori_h, ori_w))
).cuda()
aux_all_scales = Variable(
torch.zeros((batch_size, self.num_classes, ori_h, ori_w))
).cuda()
for scale in self.scales:
new_size = (int(ori_h * scale), int(ori_w * scale))
scaled_x = F.upsample(x, size=new_size, mode="bilinear")
scaled_x = Variable(scaled_x).cuda()
scaled_h, scaled_w = scaled_x.size()[2:]
long_size = max(scaled_h, scaled_w)
if long_size > self.crop_size:
count = torch.zeros((scaled_h, scaled_w))
outputs = Variable(
torch.zeros(
(batch_size, self.num_classes, scaled_h, scaled_w)
)
).cuda()
aux_outputs = Variable(
torch.zeros(
(batch_size, self.num_classes, scaled_h, scaled_w)
)
).cuda()
stride = int(ceil(self.crop_size * self.stride_rate))
h_step_num = (
int(ceil((scaled_h - self.crop_size) / stride)) + 1
)
w_step_num = (
int(ceil((scaled_w - self.crop_size) / stride)) + 1
)
for yy in range(h_step_num):
for xx in range(w_step_num):
sy, sx = yy * stride, xx * stride
ey, ex = sy + self.crop_size, sx + self.crop_size
x_sub = scaled_x[:, :, sy:ey, sx:ex]
x_sub, pad_h, pad_w = _pad(x_sub, self.crop_size)
outputs_sub, aux_sub = single_forward(self, x_sub)
if sy + self.crop_size > scaled_h:
outputs_sub = outputs_sub[:, :, :-pad_h, :]
aux_sub = aux_sub[:, :, :-pad_h, :]
if sx + self.crop_size > scaled_w:
outputs_sub = outputs_sub[:, :, :, :-pad_w]
aux_sub = aux_sub[:, :, :, :-pad_w]
outputs[:, :, sy:ey, sx:ex] = outputs_sub
aux_outputs[:, :, sy:ey, sx:ex] = aux_sub
count[sy:ey, sx:ex] += 1
count = Variable(count).cuda()
outputs = outputs / count
aux_outputs = outputs / count
else:
scaled_x, pad_h, pad_w = _pad(scaled_x, self.crop_size)
outputs, aux_outputs = single_forward(self, scaled_x)
outputs = outputs[:, :, :-pad_h, :-pad_w]
aux_outputs = aux_outputs[:, :, :-pad_h, :-pad_w]
outputs_all_scales += outputs
aux_all_scales += aux_outputs
return outputs_all_scales / len(self.scales), aux_all_scales
outputs_all_scales = Variable(
torch.zeros((batch_size, self.num_classes, ori_h, ori_w))
).cuda()
for scale in self.scales:
new_size = (int(ori_h * scale), int(ori_w * scale))
scaled_x = F.upsample(x, size=new_size, mode="bilinear")
scaled_h, scaled_w = scaled_x.size()[2:]
long_size = max(scaled_h, scaled_w)
if long_size > self.crop_size:
count = torch.zeros((scaled_h, scaled_w))
outputs = Variable(
torch.zeros((batch_size, self.num_classes, scaled_h, scaled_w))
).cuda()
stride = int(ceil(self.crop_size * self.stride_rate))
h_step_num = int(ceil((scaled_h - self.crop_size) / stride)) + 1
w_step_num = int(ceil((scaled_w - self.crop_size) / stride)) + 1
for yy in range(h_step_num):
for xx in range(w_step_num):
sy, sx = yy * stride, xx * stride
ey, ex = sy + self.crop_size, sx + self.crop_size
x_sub = scaled_x[:, :, sy:ey, sx:ex]
x_sub, pad_h, pad_w = _pad(x_sub, self.crop_size)
outputs_sub = single_forward(self, x_sub)
if sy + self.crop_size > scaled_h:
outputs_sub = outputs_sub[:, :, :-pad_h, :]
if sx + self.crop_size > scaled_w:
outputs_sub = outputs_sub[:, :, :, :-pad_w]
outputs[:, :, sy:ey, sx:ex] = outputs_sub
count[sy:ey, sx:ex] += 1
count = Variable(count).cuda()
outputs = outputs / count
else:
scaled_x, pad_h, pad_w = _pad(scaled_x, self.crop_size)
outputs = single_forward(self, scaled_x)
outputs = outputs[:, :, :-pad_h, :-pad_w]
outputs_all_scales += outputs
return outputs_all_scales
return wrapper

@ -0,0 +1,86 @@
"""将 Synapse 数据集预处理为 NPZ/H5 格式。"""
import os
from time import time
import h5py
import nibabel as nib
import numpy as np
def process_split(split, ct_path, seg_path, save_path, upper, lower):
"""处理一个数据划分并保存结果。
参数:
split: 划分名称train/test
ct_path: CT 图像目录
seg_path: 标注目录
save_path: 输出目录
upper: 上截断阈值
lower: 下截断阈值
"""
if not os.path.exists(save_path):
os.mkdir(save_path)
start_time = time()
for ct_file in os.listdir(ct_path):
ct = nib.load(os.path.join(ct_path, ct_file))
seg = nib.load(os.path.join(seg_path, ct_file.replace("img", "label")))
ct_array = ct.get_fdata()
seg_array = seg.get_fdata()
ct_array = np.clip(ct_array, lower, upper)
ct_array = (ct_array - lower) / (upper - lower)
ct_array = np.transpose(ct_array, (2, 0, 1))
seg_array = np.transpose(seg_array, (2, 0, 1))
print("file name:", ct_file)
print("shape:", ct_array.shape)
ct_number = ct_file.split(".")[0]
if split == "test":
new_ct_name = ct_number.replace("img", "case") + ".npy.h5"
with h5py.File(os.path.join(save_path, new_ct_name), "w") as hf:
hf.create_dataset("image", data=ct_array)
hf.create_dataset("label", data=seg_array)
continue
for s_idx in range(ct_array.shape[0]):
ct_array_s = ct_array[s_idx, :, :]
seg_array_s = seg_array[s_idx, :, :]
slice_no = "{:03d}".format(s_idx)
new_ct_name = ct_number.replace("img", "case") + "_slice" + slice_no
np.savez(
os.path.join(save_path, new_ct_name),
image=ct_array_s,
label=seg_array_s,
)
print("already use {:.3f} min".format((time() - start_time) / 60))
print("-----------")
def main():
"""程序入口。"""
splits = ["train", "test"]
for split in splits:
if split == "train":
ct_path = "./data/synapse/Abdomen/RawData/TrainSet/img"
seg_path = "./data/synapse/Abdomen/RawData/TrainSet/label"
save_path = "./data/synapse/train_npz_new/"
else:
ct_path = "./data/synapse/Abdomen/RawData/TestSet/img"
seg_path = "./data/synapse/Abdomen/RawData/TestSet/label"
save_path = "./data/synapse/test_vol_h5_new/"
upper = 275
lower = -125
process_split(split, ct_path, seg_path, save_path, upper, lower)
if __name__ == "__main__":
main()

@ -0,0 +1,95 @@
"""将 Synapse 数据集预处理为多帧 NPZ/H5 格式。"""
import os
from time import time
import h5py
import nibabel as nib
import numpy as np
def process_split(split, ct_path, seg_path, save_path, upper, lower):
"""处理一个数据划分并保存结果。
参数:
split: 划分名称train/test
ct_path: CT 图像目录
seg_path: 标注目录
save_path: 输出目录
upper: 上截断阈值
lower: 下截断阈值
返回:
观察到的最小切片数
"""
if not os.path.exists(save_path):
os.mkdir(save_path)
start_time = time()
min_size = 10000
for ct_file in os.listdir(ct_path):
ct = nib.load(os.path.join(ct_path, ct_file))
seg = nib.load(os.path.join(seg_path, ct_file.replace("img", "label")))
ct_array = ct.get_fdata()
seg_array = seg.get_fdata()
ct_array = np.clip(ct_array, lower, upper)
ct_array = (ct_array - lower) / (upper - lower)
ct_array = np.transpose(ct_array, (2, 0, 1))
seg_array = np.transpose(seg_array, (2, 0, 1))
print("file name:", ct_file)
print("shape:", ct_array.shape)
if ct_array.shape[0] < min_size:
min_size = ct_array.shape[0]
ct_number = ct_file.split(".")[0]
if split == "test":
new_ct_name = ct_number.replace("img", "case") + ".npy.h5"
with h5py.File(os.path.join(save_path, new_ct_name), "w") as hf:
hf.create_dataset("image", data=ct_array)
hf.create_dataset("label", data=seg_array)
continue
for s_idx in range(ct_array.shape[0] - 2):
ct_array_s = np.transpose(ct_array, (1, 2, 0))[:, :, s_idx : s_idx + 3]
seg_array_s = seg_array[s_idx + 1, :, :]
slice_no = "{:03d}".format(s_idx)
new_ct_name = ct_number.replace("img", "case") + "_slice" + slice_no
np.savez(
os.path.join(save_path, new_ct_name),
image=ct_array_s,
label=seg_array_s,
)
print("already use {:.3f} min".format((time() - start_time) / 60))
print("-----------")
print("max_size " + str(min_size))
return min_size
def main():
"""程序入口。"""
splits = ["train", "test"]
for split in splits:
if split == "train":
ct_path = "./data/synapse/Abdomen/RawData/TrainSet/img"
seg_path = "./data/synapse/Abdomen/RawData/TrainSet/label"
save_path = "./data/synapse/train_npz_mframes/"
else:
ct_path = "./data/synapse/Abdomen/RawData/TestSet/img"
seg_path = "./data/synapse/Abdomen/RawData/TestSet/label"
save_path = "./data/synapse/test_vol_h5_mframes/"
upper = 275
lower = -125
process_split(split, ct_path, seg_path, save_path, upper, lower)
if __name__ == "__main__":
main()

@ -0,0 +1,170 @@
import argparse
import logging
import os
import random
import sys
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils.dataset_synapse import Synapse_dataset
from utils.utils import test_single_volume
from lib.networks import EMCADNet
parser = argparse.ArgumentParser()
parser.add_argument('--volume_path', type=str,
default='./data/synapse/test_vol_h5_new', help='root dir for validation volume data')
parser.add_argument('--dataset', type=str,
default='Synapse', help='experiment_name')
parser.add_argument('--num_classes', type=int,
default=9, help='output channel of network')
parser.add_argument('--list_dir', type=str,
default='./lists/lists_Synapse', help='list dir')
# network related parameters
parser.add_argument('--encoder', type=str,
default='pvt_v2_b2', help='Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...')
parser.add_argument('--expansion_factor', type=int,
default=2, help='expansion factor in MSCB block')
parser.add_argument('--kernel_sizes', type=int, nargs='+',
default=[1, 3, 5], help='multi-scale kernel sizes in MSDC block')
parser.add_argument('--lgag_ks', type=int,
default=3, help='Kernel size in LGAG')
parser.add_argument('--activation_mscb', type=str,
default='relu6', help='activation used in MSCB: relu6 or relu')
parser.add_argument('--no_dw_parallel', action='store_true',
default=False, help='use this flag to disable depth-wise parallel convolutions')
parser.add_argument('--concatenation', action='store_true',
default=False, help='use this flag to concatenate feature maps in MSDC block')
parser.add_argument('--no_pretrain', action='store_true',
default=False, help='use this flag to turn off loading pretrained enocder weights')
parser.add_argument('--pretrained_dir', type=str,
default='./pretrained_pth/pvt/', help='path to pretrained encoder dir')
parser.add_argument('--supervision', type=str,
default='mutation', help='loss supervision: mutation, deep_supervision or last_layer')
parser.add_argument('--max_iterations', type=int,default=30000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int, default=300, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int, default=6,
help='batch_size per gpu')
parser.add_argument('--base_lr', type=float, default=0.0001, help='segmentation network learning rate')
parser.add_argument('--img_size', type=int, default=224, help='input patch size of network input')
parser.add_argument('--is_savenii', action="store_true", default=True, help='whether to save results during inference')
parser.add_argument('--test_save_dir', type=str, default='predictions', help='saving prediction as nii!')
parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training')
parser.add_argument('--seed', type=int, default=2222, help='random seed')
args = parser.parse_args()
if(args.num_classes == 14):
classes = ['spleen', 'right kidney', 'left kidney', 'gallbladder', 'esophagus', 'liver', 'stomach', 'aorta', 'inferior vena cava', 'portal vein and splenic vein', 'pancreas', 'right adrenal gland', 'left adrenal gland']
else:
classes = ['spleen', 'right kidney', 'left kidney', 'gallbladder', 'pancreas', 'liver', 'stomach', 'aorta']
def inference(args, model, test_save_path=None):
db_test = args.Dataset(base_dir=args.volume_path, split="test_vol", list_dir=args.list_dir, nclass=args.num_classes)
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
logging.info("{} test iterations per epoch".format(len(testloader)))
model.eval()
metric_list = 0.0
for i_batch, sampled_batch in tqdm(enumerate(testloader)):
h, w = sampled_batch["image"].size()[2:]
image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0]
metric_i = test_single_volume(image, label, model, classes=args.num_classes, patch_size=[args.img_size, args.img_size],
test_save_path=test_save_path, case=case_name, z_spacing=1, class_names=classes)
metric_list += np.array(metric_i)
logging.info('idx %d case %s mean_dice %f mean_hd95 %f, mean_jacard %f mean_asd %f' % (i_batch, case_name, np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1], np.mean(metric_i, axis=0)[2], np.mean(metric_i, axis=0)[3]))
metric_list = metric_list / len(db_test)
for i in range(1, args.num_classes):
logging.info('Mean class (%d) %s mean_dice %f mean_hd95 %f, mean_jacard %f mean_asd %f' % (i, classes[i-1], metric_list[i-1][0], metric_list[i-1][1], metric_list[i-1][2], metric_list[i-1][3]))
performance = np.mean(metric_list, axis=0)[0]
mean_hd95 = np.mean(metric_list, axis=0)[1]
mean_jacard = np.mean(metric_list, axis=0)[2]
mean_asd = np.mean(metric_list, axis=0)[3]
logging.info('Testing performance in best val model: mean_dice : %f mean_hd95 : %f, mean_jacard : %f mean_asd : %f' % (performance, mean_hd95, mean_jacard, mean_asd))
return "Testing Finished!"
if __name__ == "__main__":
if not args.deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
dataset_config = {
'Synapse': {
'Dataset': Synapse_dataset,
'volume_path': args.volume_path,
'list_dir': args.list_dir,
'num_classes': args.num_classes,
'z_spacing': 1,
},
}
dataset_name = args.dataset
args.num_classes = dataset_config[dataset_name]['num_classes']
args.volume_path = dataset_config[dataset_name]['volume_path']
args.Dataset = dataset_config[dataset_name]['Dataset']
args.list_dir = dataset_config[dataset_name]['list_dir']
args.z_spacing = dataset_config[dataset_name]['z_spacing']
print(args.no_pretrain)
if args.concatenation:
aggregation = 'concat'
else:
aggregation = 'add'
if args.no_dw_parallel:
dw_mode = 'series'
else:
dw_mode = 'parallel'
run = 1
args.exp = args.encoder + '_EMCAD_kernel_sizes_' + str(args.kernel_sizes) + '_dw_' + dw_mode + '_' + aggregation + '_lgag_ks_' + str(args.lgag_ks) + '_ef' + str(args.expansion_factor) + '_act_mscb_' + args.activation_mscb + '_loss_' + args.supervision + '_output_final_layer_Run'+str(run)+'_' + dataset_name + str(args.img_size)
snapshot_path = "model_pth/{}/{}".format(args.exp, args.encoder + '_EMCAD_kernel_sizes_' + str(args.kernel_sizes) + '_dw_' + dw_mode + '_' + aggregation + '_lgag_ks_' + str(args.lgag_ks) + '_ef' + str(args.expansion_factor) + '_act_mscb_' + args.activation_mscb + '_loss_' + args.supervision + '_output_final_layer_Run'+str(run))
snapshot_path = snapshot_path.replace('[', '').replace(']', '').replace(', ', '_')
snapshot_path = snapshot_path + '_pretrain' if not args.no_pretrain else snapshot_path
snapshot_path = snapshot_path+'_'+str(args.max_iterations)[0:2]+'k' if args.max_iterations != 50000 else snapshot_path
snapshot_path = snapshot_path + '_epo' +str(args.max_epochs) if args.max_epochs != 300 else snapshot_path
snapshot_path = snapshot_path+'_bs'+str(args.batch_size)
snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.0001 else snapshot_path
snapshot_path = snapshot_path + '_'+str(args.img_size)
snapshot_path = snapshot_path + '_s'+str(args.seed) if args.seed!=1234 else snapshot_path
model = EMCADNet(num_classes=args.num_classes, kernel_sizes=args.kernel_sizes, expansion_factor=args.expansion_factor, dw_parallel=not args.no_dw_parallel, add=not args.concatenation, lgag_ks=args.lgag_ks, activation=args.activation_mscb, encoder=args.encoder, pretrain= not args.no_pretrain, pretrained_dir=args.pretrained_dir)
model.cuda()
#snapshot_path = 'model_pth/'+args.encoder+'_EMCAD_wi_normal_dw_parallel_add_Conv2D_cec_cdc1x1_dwc_cs_ef2_k_sizes_1_3_5_ag3g_relu6_up3_relu_to1_3ch_relu_loss2p4_w1_out1_nlrd_mutation_True_cds_False_cds_decoder_FalseRun'+str(run)+'_Synapse224/'+args.encoder+'_EMCAD_wi_normal_dw_parallel_add_Conv2D_cec_cdc1x1_dwc_cs_ef2_k_sizes_1_3_5_ag3g_relu6_up3_relu_to1_3ch_relu_loss2p4_w1_out1_nlrd_mutation_True_cds_False_cds_decoder_FalseRun'+str(run)+'_50k_epo300_bs6_lr0.0001_224_s2222'
snapshot = os.path.join(snapshot_path, 'best.pth')
if not os.path.exists(snapshot): snapshot = snapshot.replace('best', 'epoch_'+str(args.max_epochs-1))
model.load_state_dict(torch.load(snapshot))
snapshot_name = snapshot_path.split('/')[-1]
log_folder = 'test_log/test_log_' + args.exp
os.makedirs(log_folder, exist_ok=True)
logging.basicConfig(filename=log_folder + '/'+snapshot_name+".txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info(str(args))
logging.info(snapshot_name)
if args.is_savenii:
args.test_save_dir = os.path.join(snapshot_path, "predictions")
test_save_path = os.path.join(args.test_save_dir, args.exp, snapshot_name+'2')
os.makedirs(test_save_path, exist_ok=True)
else:
test_save_path = None
inference(args, model, test_save_path)

@ -0,0 +1,287 @@
import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from lib.networks import EMCADNet
from trainer import trainer_synapse, trainer_ACDC
parser = argparse.ArgumentParser()
# 注意:这里默认值直接按你现在 ACDC 的路径来写
parser.add_argument(
"--root_path",
type=str,
default="/data/ACDC/train",
help="root dir for training data (ACDC: /data/ACDC/train)",
)
parser.add_argument(
"--volume_path",
type=str,
default="/data/ACDC/test",
help="root dir for validation/test volume data",
)
parser.add_argument(
"--dataset",
type=str,
default="ACDC",
choices=["Synapse", "ACDC"],
help="experiment name",
)
parser.add_argument(
"--list_dir",
type=str,
default="/data/ACDC/lists_ACDC",
help="list dir (ACDC: /data/ACDC/lists_ACDC)",
)
parser.add_argument(
"--num_classes",
type=int,
default=4,
help="output channel of network (ACDC = 4)",
)
# ----------------- Network related parameters -----------------
parser.add_argument(
"--encoder",
type=str,
default="pvt_v2_b2",
help="Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...",
)
parser.add_argument(
"--expansion_factor", type=int, default=2, help="expansion factor in MSCB block"
)
parser.add_argument(
"--kernel_sizes",
type=int,
nargs="+",
default=[1, 3, 5],
help="multi-scale kernel sizes in MSDC block",
)
parser.add_argument(
"--lgag_ks", type=int, default=3, help="Kernel size in LGAG block"
)
parser.add_argument(
"--activation_mscb",
type=str,
default="relu6",
help="activation used in MSCB: relu6 or relu",
)
parser.add_argument(
"--no_dw_parallel",
action="store_true",
default=False,
help="use this flag to disable depth-wise parallel convolutions",
)
parser.add_argument(
"--concatenation",
action="store_true",
default=False,
help="use this flag to concatenate feature maps in MSDC block",
)
parser.add_argument(
"--no_pretrain",
action="store_true",
default=False,
help="use this flag to turn off loading pretrained encoder weights",
)
parser.add_argument(
"--pretrained_dir",
type=str,
default="./model_pth/",
help="path to pretrained encoder dir, e.g. ./model_pth/",
)
parser.add_argument(
"--supervision",
type=str,
default="mutation",
help="loss supervision: mutation, deep_supervision or last_layer",
)
# ----------------- Training parameters -----------------
parser.add_argument(
"--max_iterations", type=int, default=50000, help="maximum total iterations"
)
parser.add_argument(
"--max_epochs", type=int, default=300, help="maximum epoch number to train"
)
parser.add_argument(
"--batch_size", type=int, default=6, help="batch_size per gpu"
)
parser.add_argument(
"--base_lr",
type=float,
default=0.0001,
help="segmentation network learning rate",
)
parser.add_argument(
"--img_size",
type=int,
default=224,
help="input patch size of network input",
)
parser.add_argument("--n_gpu", type=int, default=1, help="total gpu")
parser.add_argument(
"--deterministic",
type=int,
default=1,
help="whether use deterministic training",
)
parser.add_argument("--seed", type=int, default=2222, help="random seed")
args = parser.parse_args()
if __name__ == "__main__":
# ----------------- 固定随机种子 -----------------
if not args.deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
dataset_name = args.dataset
# --------- 针对 ACDC 修正 root_path避免出现 /train/train/xxx.npz ---------
acdc_root = args.root_path
if dataset_name == "ACDC":
tmp = args.root_path.rstrip("/")
if os.path.basename(tmp) == "train":
acdc_root = os.path.dirname(tmp)
else:
acdc_root = tmp
# ----------------- 数据集配置 -----------------
dataset_config = {
"Synapse": {
"root_path": args.root_path,
"volume_path": args.volume_path,
"list_dir": args.list_dir,
"num_classes": args.num_classes,
"z_spacing": 1,
},
"ACDC": {
# 注意:这里用修正后的 acdc_root而不是原始 /data/ACDC/train
"root_path": acdc_root,
"volume_path": args.volume_path, # 当前 ACDC 训练里暂时没用到
"list_dir": args.list_dir,
"num_classes": args.num_classes, # 默认 4
"z_spacing": 1,
},
}
cfg = dataset_config[dataset_name]
args.num_classes = cfg["num_classes"]
args.root_path = cfg["root_path"]
args.volume_path = cfg["volume_path"]
args.z_spacing = cfg["z_spacing"]
args.list_dir = cfg["list_dir"]
# ----------------- 生成实验名和保存路径 -----------------
if args.concatenation:
aggregation = "concat"
else:
aggregation = "add"
if args.no_dw_parallel:
dw_mode = "series"
else:
dw_mode = "parallel"
run = 1
args.exp = (
args.encoder
+ "_EMCAD_kernel_sizes_"
+ str(args.kernel_sizes)
+ "_dw_"
+ dw_mode
+ "_"
+ aggregation
+ "_lgag_ks_"
+ str(args.lgag_ks)
+ "_ef"
+ str(args.expansion_factor)
+ "_act_mscb_"
+ args.activation_mscb
+ "_loss_"
+ args.supervision
+ "_output_final_layer_Run"
+ str(run)
+ "_"
+ dataset_name
+ str(args.img_size)
)
snapshot_path = "model_pth/{}/{}".format(
args.exp,
args.encoder
+ "_EMCAD_kernel_sizes_"
+ str(args.kernel_sizes)
+ "_dw_"
+ dw_mode
+ "_"
+ aggregation
+ "_lgag_ks_"
+ str(args.lgag_ks)
+ "_ef"
+ str(args.expansion_factor)
+ "_act_mscb_"
+ args.activation_mscb
+ "_loss_"
+ args.supervision
+ "_output_final_layer_Run"
+ str(run),
)
snapshot_path = (
snapshot_path.replace("[", "").replace("]", "").replace(", ", "_")
)
if not args.no_pretrain:
snapshot_path = snapshot_path + "_pretrain"
if args.max_iterations != 50000:
snapshot_path = snapshot_path + "_" + str(args.max_iterations)[0:2] + "k"
if args.max_epochs != 300:
snapshot_path = snapshot_path + "_epo" + str(args.max_epochs)
snapshot_path = snapshot_path + "_bs" + str(args.batch_size)
if args.base_lr != 0.0001:
snapshot_path = snapshot_path + "_lr" + str(args.base_lr)
snapshot_path = snapshot_path + "_" + str(args.img_size)
if args.seed != 1234:
snapshot_path = snapshot_path + "_s" + str(args.seed)
if not os.path.exists(snapshot_path):
os.makedirs(snapshot_path)
# ----------------- 构建模型 -----------------
model = EMCADNet(
num_classes=args.num_classes,
kernel_sizes=args.kernel_sizes,
expansion_factor=args.expansion_factor,
dw_parallel=not args.no_dw_parallel,
add=not args.concatenation,
lgag_ks=args.lgag_ks,
activation=args.activation_mscb,
encoder=args.encoder,
pretrain=not args.no_pretrain,
pretrained_dir=args.pretrained_dir,
)
model.cuda()
print("Model successfully created.")
# ----------------- 根据数据集选择 trainer -----------------
trainer_map = {
"Synapse": trainer_synapse,
"ACDC": trainer_ACDC,
}
trainer_map[dataset_name](args, model, snapshot_path)

@ -0,0 +1,380 @@
import logging
import os
import random
import sys
import time
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from torch.nn.modules.loss import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision import transforms
from utils.dataset_synapse import Synapse_dataset, RandomGenerator
from utils.dataset_ACDC import ACDCdataset
from utils.utils import powerset, one_hot_encoder, DiceLoss, val_single_volume
def inference(args, model, best_performance):
"""
只用于 Synapse 数据集的验证原始代码保持不动
"""
db_test = Synapse_dataset(
base_dir=args.volume_path,
split="test_vol",
list_dir=args.list_dir,
nclass=args.num_classes,
)
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
logging.info("{} test iterations per epoch".format(len(testloader)))
model.eval()
metric_list = 0.0
for i_batch, sampled_batch in tqdm(enumerate(testloader)):
h, w = sampled_batch["image"].size()[2:]
image, label, case_name = (
sampled_batch["image"],
sampled_batch["label"],
sampled_batch["case_name"][0],
)
metric_i = val_single_volume(
image,
label,
model,
classes=args.num_classes,
patch_size=[args.img_size, args.img_size],
case=case_name,
z_spacing=args.z_spacing,
)
metric_list += np.array(metric_i)
metric_list = metric_list / len(db_test)
performance = np.mean(metric_list, axis=0)
logging.info(
"Testing performance in val model: mean_dice : %f, best_dice : %f"
% (performance, best_performance)
)
return performance
def trainer_synapse(args, model, snapshot_path):
"""
原始 Synapse 训练代码基本保持不动
"""
logging.basicConfig(
filename=snapshot_path + "/log.txt",
level=logging.INFO,
format="[%(asctime)s.%(msecs)03d] %(message)s",
datefmt="%H:%M:%S",
)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info(str(args))
base_lr = args.base_lr
num_classes = args.num_classes
batch_size = args.batch_size * args.n_gpu
db_train = Synapse_dataset(
base_dir=args.root_path,
list_dir=args.list_dir,
split="train",
nclass=args.num_classes,
transform=transforms.Compose(
[RandomGenerator(output_size=[args.img_size, args.img_size])]
),
)
print("The length of train set is: {}".format(len(db_train)))
def worker_init_fn(worker_id):
random.seed(args.seed + worker_id)
trainloader = DataLoader(
db_train,
batch_size=batch_size,
shuffle=True,
num_workers=8,
pin_memory=True,
worker_init_fn=worker_init_fn,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1 and args.n_gpu > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
model.to(device)
model.train()
ce_loss = CrossEntropyLoss()
dice_loss = DiceLoss(num_classes)
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.0001)
writer = SummaryWriter(snapshot_path + "/log")
iter_num = 0
max_epoch = args.max_epochs
max_iterations = args.max_epochs * len(trainloader)
logging.info(
"{} iterations per epoch. {} max iterations ".format(
len(trainloader), max_iterations
)
)
best_performance = 0.0
iterator = tqdm(range(max_epoch), ncols=70)
for epoch_num in iterator:
for i_batch, sampled_batch in enumerate(trainloader):
image_batch, label_batch = (
sampled_batch["image"],
sampled_batch["label"],
)
image_batch, label_batch = (
image_batch.cuda(),
label_batch.squeeze(1).cuda(),
)
P = model(image_batch, mode="train")
if not isinstance(P, list):
P = [P]
# 第一次 forward 时确定 supervision 组合 ss
if epoch_num == 0 and i_batch == 0:
n_outs = len(P)
out_idxs = list(np.arange(n_outs))
if args.supervision == "mutation":
ss = [x for x in powerset(out_idxs)]
elif args.supervision == "deep_supervision":
ss = [[x] for x in out_idxs]
else:
ss = [[-1]]
print(ss)
loss = 0.0
w_ce, w_dice = 0.3, 0.7
for s in ss:
iout = 0.0
if s == []:
continue
for idx in range(len(s)):
iout += P[s[idx]]
loss_ce = ce_loss(iout, label_batch[:].long())
loss_dice = dice_loss(iout, label_batch, softmax=True)
loss += w_ce * loss_ce + w_dice * loss_dice
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 这里原代码注释掉了多项式衰减,直接用固定 lr
lr_ = base_lr
for param_group in optimizer.param_groups:
param_group["lr"] = lr_
iter_num += 1
writer.add_scalar("info/lr", lr_, iter_num)
writer.add_scalar("info/total_loss", loss, iter_num)
if iter_num % 50 == 0:
logging.info(
"iteration %d, epoch %d : loss : %f, lr: %f"
% (iter_num, epoch_num, loss.item(), lr_)
)
# 每个 epoch 都存 last.pth并在 Synapse 上做一次验证
save_mode_path = os.path.join(snapshot_path, "last.pth")
torch.save(model.state_dict(), save_mode_path)
performance = inference(args, model, best_performance)
save_interval = 50
if best_performance <= performance:
best_performance = performance
save_mode_path = os.path.join(snapshot_path, "best.pth")
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to {}".format(save_mode_path))
if (epoch_num + 1) % save_interval == 0:
save_mode_path = os.path.join(
snapshot_path, "epoch_" + str(epoch_num) + ".pth"
)
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to {}".format(save_mode_path))
if epoch_num >= max_epoch - 1:
save_mode_path = os.path.join(
snapshot_path, "epoch_" + str(epoch_num) + ".pth"
)
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to {}".format(save_mode_path))
iterator.close()
writer.close()
return "Training Finished!"
def trainer_ACDC(args, model, snapshot_path):
"""
新增ACDC 数据集的训练函数
重点改动
1. 使用 utils.dataset_ACDC.ACDCdataset 读取数据
2. 自动把 /data/ACDC/train 这种 root 修正成 /data/ACDC避免 train/train/xxx.npz
3. 不调用原来的 inference那个是给 Synapse 用的只做训练和保存模型
"""
logging.basicConfig(
filename=snapshot_path + "/log.txt",
level=logging.INFO,
format="[%(asctime)s.%(msecs)03d] %(message)s",
datefmt="%H:%M:%S",
)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info(str(args))
base_lr = args.base_lr
num_classes = args.num_classes
batch_size = args.batch_size * args.n_gpu
# ------- 关键:修正 ACDC 根目录,避免出现 /train/train/xxx.npz -------
# 用户一般传的是 /data/ACDC/train
acdc_root = args.root_path.rstrip("/")
if os.path.basename(acdc_root) == "train":
# 例如 /data/ACDC/train -> /data/ACDC
acdc_root = os.path.dirname(acdc_root)
logging.info(f"Using ACDC root dir: {acdc_root}")
db_train = ACDCdataset(
base_dir=acdc_root,
list_dir=args.list_dir,
split="train",
transform=transforms.Compose(
[RandomGenerator(output_size=[args.img_size, args.img_size])]
),
)
print("The length of train set is: {}".format(len(db_train)))
def worker_init_fn(worker_id):
random.seed(args.seed + worker_id)
trainloader = DataLoader(
db_train,
batch_size=batch_size,
shuffle=True,
num_workers=8,
pin_memory=True,
worker_init_fn=worker_init_fn,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1 and args.n_gpu > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
model.to(device)
model.train()
ce_loss = CrossEntropyLoss()
dice_loss = DiceLoss(num_classes)
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.0001)
writer = SummaryWriter(snapshot_path + "/log")
iter_num = 0
max_epoch = args.max_epochs
max_iterations = args.max_epochs * len(trainloader)
logging.info(
"{} iterations per epoch. {} max iterations ".format(
len(trainloader), max_iterations
)
)
best_loss = 1e9
iterator = tqdm(range(max_epoch), ncols=70)
for epoch_num in iterator:
epoch_loss = 0.0
for i_batch, sampled_batch in enumerate(trainloader):
image_batch, label_batch = (
sampled_batch["image"],
sampled_batch["label"],
)
image_batch, label_batch = (
image_batch.to(device),
label_batch.squeeze(1).long().to(device),
)
P = model(image_batch, mode="train")
if not isinstance(P, list):
P = [P]
if epoch_num == 0 and i_batch == 0:
n_outs = len(P)
out_idxs = list(np.arange(n_outs))
if args.supervision == "mutation":
ss = [x for x in powerset(out_idxs)]
elif args.supervision == "deep_supervision":
ss = [[x] for x in out_idxs]
else:
ss = [[-1]]
print(ss)
loss = 0.0
w_ce, w_dice = 0.3, 0.7
for s in ss:
iout = 0.0
if s == []:
continue
for idx in range(len(s)):
iout += P[s[idx]]
loss_ce = ce_loss(iout, label_batch)
loss_dice = dice_loss(iout, label_batch, softmax=True)
loss += w_ce * loss_ce + w_dice * loss_dice
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_ = base_lr
for param_group in optimizer.param_groups:
param_group["lr"] = lr_
iter_num += 1
epoch_loss += loss.item()
writer.add_scalar("info/lr", lr_, iter_num)
writer.add_scalar("info/total_loss", loss, iter_num)
if iter_num % 50 == 0:
logging.info(
"iteration %d, epoch %d : loss : %f, lr: %f"
% (iter_num, epoch_num, loss.item(), lr_)
)
epoch_loss /= max(len(trainloader), 1)
logging.info(
f"[ACDC] Epoch {epoch_num} finished, mean loss = {epoch_loss:.4f}"
)
# 每个 epoch 都存一下 last.pth
save_mode_path = os.path.join(snapshot_path, "last.pth")
torch.save(model.state_dict(), save_mode_path)
# 如果这个 epoch 的 loss 更好,就额外存 best.pth
if epoch_loss < best_loss:
best_loss = epoch_loss
best_path = os.path.join(snapshot_path, "best.pth")
torch.save(model.state_dict(), best_path)
logging.info(f"New best model saved to {best_path}, loss={best_loss:.4f}")
# 也可以按间隔存 epoch_xxx.pth
save_interval = 50
if (epoch_num + 1) % save_interval == 0 or epoch_num == max_epoch - 1:
save_mode_path = os.path.join(
snapshot_path, "epoch_" + str(epoch_num) + ".pth"
)
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to {}".format(save_mode_path))
iterator.close()
writer.close()
return "ACDC Training Finished!"

@ -0,0 +1,381 @@
"""Synapse 与 ACDC 的训练流程。"""
import logging
import os
import random
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from torch.nn.modules.loss import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from src.utils.dataset_ACDC import ACDCdataset
from src.utils.dataset_synapse import RandomGenerator, Synapse_dataset
from src.utils.utils import DiceLoss, powerset, val_single_volume
def inference(args, model, best_performance):
"""在 Synapse 测试集上进行验证。
参数:
参数: 训练参数
model: 模型
best_performance: 当前最佳指标
返回:
平均性能指标
"""
db_test = Synapse_dataset(
base_dir=args.volume_path,
split="test_vol",
list_dir=args.list_dir,
nclass=args.num_classes,
)
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
logging.info("%s test iterations per epoch", len(testloader))
model.eval()
metric_list = 0.0
for _, sampled_batch in tqdm(enumerate(testloader)):
image, label, case_name = (
sampled_batch["image"],
sampled_batch["label"],
sampled_batch["case_name"][0],
)
metric_i = val_single_volume(
image,
label,
model,
classes=args.num_classes,
patch_size=[args.img_size, args.img_size],
case=case_name,
z_spacing=args.z_spacing,
)
metric_list += np.array(metric_i)
metric_list = metric_list / len(db_test)
performance = np.mean(metric_list, axis=0)
logging.info(
"Testing performance in val model: mean_dice : %f, best_dice : %f",
performance,
best_performance,
)
return performance
def trainer_synapse(args, model, snapshot_path):
"""Synapse 训练流程。"""
logging.basicConfig(
filename=snapshot_path + "/log.txt",
level=logging.INFO,
format="[%(asctime)s.%(msecs)03d] %(message)s",
datefmt="%H:%M:%S",
)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info(str(args))
base_lr = args.base_lr
num_classes = args.num_classes
batch_size = args.batch_size * args.n_gpu
db_train = Synapse_dataset(
base_dir=args.root_path,
list_dir=args.list_dir,
split="train",
nclass=args.num_classes,
transform=transforms.Compose(
[RandomGenerator(output_size=[args.img_size, args.img_size])]
),
)
print("The length of train set is: {}".format(len(db_train)))
def worker_init_fn(worker_id):
"""为数据加载器设置随机种子。"""
random.seed(args.seed + worker_id)
trainloader = DataLoader(
db_train,
batch_size=batch_size,
shuffle=True,
num_workers=8,
pin_memory=True,
worker_init_fn=worker_init_fn,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1 and args.n_gpu > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
model.to(device)
model.train()
ce_loss = CrossEntropyLoss()
dice_loss = DiceLoss(num_classes)
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.0001)
writer = SummaryWriter(snapshot_path + "/log")
iter_num = 0
max_epoch = args.max_epochs
max_iterations = args.max_epochs * len(trainloader)
logging.info(
"%s iterations per epoch. %s max iterations ",
len(trainloader),
max_iterations,
)
best_performance = 0.0
iterator = tqdm(range(max_epoch), ncols=70)
for epoch_num in iterator:
for i_batch, sampled_batch in enumerate(trainloader):
image_batch, label_batch = (
sampled_batch["image"],
sampled_batch["label"],
)
image_batch, label_batch = (
image_batch.cuda(),
label_batch.squeeze(1).cuda(),
)
outputs = model(image_batch, mode="train")
if not isinstance(outputs, list):
outputs = [outputs]
if epoch_num == 0 and i_batch == 0:
n_outs = len(outputs)
out_idxs = list(np.arange(n_outs))
if args.supervision == "mutation":
supervision_sets = [x for x in powerset(out_idxs)]
elif args.supervision == "deep_supervision":
supervision_sets = [[x] for x in out_idxs]
else:
supervision_sets = [[-1]]
print(supervision_sets)
loss = 0.0
w_ce, w_dice = 0.3, 0.7
for s in supervision_sets:
iout = 0.0
if s == []:
continue
for idx in range(len(s)):
iout += outputs[s[idx]]
loss_ce = ce_loss(iout, label_batch[:].long())
loss_dice = dice_loss(iout, label_batch, softmax=True)
loss += w_ce * loss_ce + w_dice * loss_dice
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_ = base_lr
for param_group in optimizer.param_groups:
param_group["lr"] = lr_
iter_num += 1
writer.add_scalar("info/lr", lr_, iter_num)
writer.add_scalar("info/total_loss", loss, iter_num)
if iter_num % 50 == 0:
logging.info(
"iteration %d, epoch %d : loss : %f, lr: %f",
iter_num,
epoch_num,
loss.item(),
lr_,
)
save_mode_path = os.path.join(snapshot_path, "last.pth")
torch.save(model.state_dict(), save_mode_path)
performance = inference(args, model, best_performance)
save_interval = 50
if best_performance <= performance:
best_performance = performance
save_mode_path = os.path.join(snapshot_path, "best.pth")
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to %s", save_mode_path)
if (epoch_num + 1) % save_interval == 0:
save_mode_path = os.path.join(
snapshot_path, "epoch_" + str(epoch_num) + ".pth"
)
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to %s", save_mode_path)
if epoch_num >= max_epoch - 1:
save_mode_path = os.path.join(
snapshot_path, "epoch_" + str(epoch_num) + ".pth"
)
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to %s", save_mode_path)
iterator.close()
writer.close()
return "Training Finished!"
def trainer_ACDC(args, model, snapshot_path):
"""ACDC 训练流程。"""
logging.basicConfig(
filename=snapshot_path + "/log.txt",
level=logging.INFO,
format="[%(asctime)s.%(msecs)03d] %(message)s",
datefmt="%H:%M:%S",
)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info(str(args))
base_lr = args.base_lr
num_classes = args.num_classes
batch_size = args.batch_size * args.n_gpu
acdc_root = args.root_path.rstrip("/")
if os.path.basename(acdc_root) == "train":
acdc_root = os.path.dirname(acdc_root)
logging.info("Using ACDC root dir: %s", acdc_root)
db_train = ACDCdataset(
base_dir=acdc_root,
list_dir=args.list_dir,
split="train",
transform=transforms.Compose(
[RandomGenerator(output_size=[args.img_size, args.img_size])]
),
)
print("The length of train set is: {}".format(len(db_train)))
def worker_init_fn(worker_id):
"""为数据加载器设置随机种子。"""
random.seed(args.seed + worker_id)
trainloader = DataLoader(
db_train,
batch_size=batch_size,
shuffle=True,
num_workers=8,
pin_memory=True,
worker_init_fn=worker_init_fn,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1 and args.n_gpu > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
model.to(device)
model.train()
ce_loss = CrossEntropyLoss()
dice_loss = DiceLoss(num_classes)
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.0001)
writer = SummaryWriter(snapshot_path + "/log")
iter_num = 0
max_epoch = args.max_epochs
max_iterations = args.max_epochs * len(trainloader)
logging.info(
"%s iterations per epoch. %s max iterations ",
len(trainloader),
max_iterations,
)
best_loss = 1e9
iterator = tqdm(range(max_epoch), ncols=70)
for epoch_num in iterator:
epoch_loss = 0.0
for i_batch, sampled_batch in enumerate(trainloader):
image_batch, label_batch = (
sampled_batch["image"],
sampled_batch["label"],
)
image_batch, label_batch = (
image_batch.to(device),
label_batch.squeeze(1).long().to(device),
)
outputs = model(image_batch, mode="train")
if not isinstance(outputs, list):
outputs = [outputs]
if epoch_num == 0 and i_batch == 0:
n_outs = len(outputs)
out_idxs = list(np.arange(n_outs))
if args.supervision == "mutation":
supervision_sets = [x for x in powerset(out_idxs)]
elif args.supervision == "deep_supervision":
supervision_sets = [[x] for x in out_idxs]
else:
supervision_sets = [[-1]]
print(supervision_sets)
loss = 0.0
w_ce, w_dice = 0.3, 0.7
for s in supervision_sets:
iout = 0.0
if s == []:
continue
for idx in range(len(s)):
iout += outputs[s[idx]]
loss_ce = ce_loss(iout, label_batch)
loss_dice = dice_loss(iout, label_batch, softmax=True)
loss += w_ce * loss_ce + w_dice * loss_dice
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_ = base_lr
for param_group in optimizer.param_groups:
param_group["lr"] = lr_
iter_num += 1
epoch_loss += loss.item()
writer.add_scalar("info/lr", lr_, iter_num)
writer.add_scalar("info/total_loss", loss, iter_num)
if iter_num % 50 == 0:
logging.info(
"iteration %d, epoch %d : loss : %f, lr: %f",
iter_num,
epoch_num,
loss.item(),
lr_,
)
epoch_loss /= max(len(trainloader), 1)
logging.info(
"[ACDC] Epoch %s finished, mean loss = %.4f",
epoch_num,
epoch_loss,
)
save_mode_path = os.path.join(snapshot_path, "last.pth")
torch.save(model.state_dict(), save_mode_path)
if epoch_loss < best_loss:
best_loss = epoch_loss
best_path = os.path.join(snapshot_path, "best.pth")
torch.save(model.state_dict(), best_path)
logging.info(
"New best model saved to %s, loss=%.4f", best_path, best_loss
)
save_interval = 50
if (epoch_num + 1) % save_interval == 0 or epoch_num == max_epoch - 1:
save_mode_path = os.path.join(
snapshot_path, "epoch_" + str(epoch_num) + ".pth"
)
torch.save(model.state_dict(), save_mode_path)
logging.info("save model to %s", save_mode_path)
iterator.close()
writer.close()
return "ACDC Training Finished!"

@ -0,0 +1,126 @@
"""图像与掩码的预处理变换。"""
import random
import numpy as np
import torch
from PIL import Image, ImageFilter
from skimage.filters import gaussian
class RandomVerticalFlip(object):
"""以 50% 概率对 PIL 图像进行垂直翻转。"""
def __call__(self, img):
"""对 PIL 图像应用变换。
参数:
img: 待处理的 PIL 图像
返回:
触发时返回翻转图像否则返回原图
"""
if random.random() < 0.5:
return img.transpose(Image.FLIP_TOP_BOTTOM)
return img
class DeNormalize(object):
"""使用均值与方差对张量做反归一化。"""
def __init__(self, mean, std):
"""初始化反归一化参数。
参数:
mean: 各通道均值序列
std: 各通道标准差序列
"""
self.mean = mean
self.std = std
def __call__(self, tensor):
"""原地反归一化。
参数:
tensor: 形状为 (C, H, W) 的张量
返回:
反归一化后的张量
"""
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
return tensor
class MaskToTensor(object):
"""将 PIL 掩码转换为 long 类型张量。"""
def __call__(self, img):
"""将 PIL 掩码图像转换为张量。
参数:
img: PIL 掩码图像
返回:
long 类型的张量
"""
return torch.from_numpy(np.array(img, dtype=np.int32)).long()
class FreeScale(object):
"""将图像缩放到指定尺寸。"""
def __init__(self, size, interpolation=Image.BILINEAR):
"""初始化缩放参数。
参数:
size: 目标尺寸 (h, w)
interpolation: PIL 插值模式
"""
self.size = tuple(reversed(size)) # 尺寸为 (h, w)
self.interpolation = interpolation
def __call__(self, img):
"""缩放图像到目标尺寸。
参数:
img: PIL 图像
返回:
缩放后的图像
"""
return img.resize(self.size, self.interpolation)
class FlipChannels(object):
"""反转通道顺序RGB 与 BGR 互换)。"""
def __call__(self, img):
"""反转图像通道顺序。
参数:
img: PIL 图像
返回:
通道顺序反转后的图像
"""
img = np.array(img)[:, :, ::-1]
return Image.fromarray(img.astype(np.uint8))
class RandomGaussianBlur(object):
"""使用随机 sigma 进行高斯模糊。"""
def __call__(self, img):
"""对图像应用随机高斯模糊。
参数:
img: PIL 图像
返回:
模糊后的图像
"""
sigma = 0.15 + random.random() * 1.15
blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True)
blurred_img *= 255
return Image.fromarray(blurred_img.astype(np.uint8))

@ -0,0 +1,510 @@
"""训练、评估与可视化相关的工具函数。"""
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
import torch
import torch.nn as nn
from medpy import metric
from ptflops import get_model_complexity_info
from scipy.ndimage import zoom
from segmentation_mask_overlay import overlay_masks
from thop import clever_format
from thop import profile
def powerset(seq):
"""生成序列的所有子集。
参数:
seq: 输入序列
生成:
输入序列的所有子集
"""
if len(seq) <= 1:
yield seq
yield []
else:
for item in powerset(seq[1:]):
yield [seq[0]] + item
yield item
def clip_gradient(optimizer, grad_clip):
"""对优化器内参数梯度进行截断。
参数:
optimizer: 待处理的优化器
grad_clip: 梯度绝对值上限
返回:
None
"""
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is not None:
param.grad.data.clamp_(-grad_clip, grad_clip)
def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
"""按 epoch 进行阶梯式学习率衰减。
参数:
optimizer: 待更新的优化器
init_lr: 初始学习率保留以兼容接口
epoch: 当前 epoch
decay_rate: 衰减倍率
decay_epoch: 衰减间隔
返回:
None
"""
decay = decay_rate ** (epoch // decay_epoch)
for param_group in optimizer.param_groups:
param_group["lr"] *= decay
class AvgMeter(object):
"""用于记录滑动平均的计量器。"""
def __init__(self, num=40):
"""初始化计量器。
参数:
num: 统计窗口长度
"""
self.num = num
self.reset()
def reset(self):
"""重置统计信息。"""
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.losses = []
def update(self, val, n=1):
"""更新计量器。
参数:
val: 新的数值
n: 权重系数
"""
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
self.losses.append(val)
def show(self):
"""返回最近窗口的均值。
返回:
张量形式的均值
"""
start = np.maximum(len(self.losses) - self.num, 0)
return torch.mean(torch.stack(self.losses[start:]))
def CalParams(model, input_tensor):
"""使用 THOP 统计 FLOPs 与参数量。
参数:
model: 待统计的模型
input_tensor: 示例输入
返回:
None
"""
flops, params = profile(model, inputs=(input_tensor,))
flops, params = clever_format([flops, params], "%.3f")
print("[Statistics Information]\nFLOPs: {}\nParams: {}".format(
flops, params
))
def one_hot_encoder(input_tensor, dataset, n_classes=None):
"""将标签张量进行 one-hot 编码。
参数:
input_tensor: 标签张量
dataset: 数据集名称
n_classes: 类别数
返回:
one-hot 编码后的张量
"""
tensor_list = []
if dataset == "MMWHS":
label_values = [0, 205, 420, 500, 550, 600, 820, 850]
for value in label_values:
temp_prob = input_tensor == value
tensor_list.append(temp_prob.unsqueeze(1))
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()
for i in range(n_classes):
temp_prob = input_tensor == i
tensor_list.append(temp_prob.unsqueeze(1))
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()
class DiceLoss(nn.Module):
"""多分类分割的 Dice 损失。"""
def __init__(self, n_classes):
"""初始化 Dice 损失。
参数:
n_classes: 类别数
"""
super(DiceLoss, self).__init__()
self.n_classes = n_classes
def _one_hot_encoder(self, input_tensor):
"""对标签进行 one-hot 编码。"""
tensor_list = []
for i in range(self.n_classes):
temp_prob = input_tensor == i
tensor_list.append(temp_prob.unsqueeze(1))
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()
def _dice_loss(self, score, target):
"""计算单类 Dice 损失。"""
target = target.float()
smooth = 1e-5
intersect = torch.sum(score * target)
y_sum = torch.sum(target * target)
z_sum = torch.sum(score * score)
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
return 1 - loss
def forward(self, inputs, target, weight=None, softmax=False):
"""计算多类 Dice 损失。
参数:
inputs: 模型输出
target: 标签
weight: 类别权重
softmax: 是否对输入做 softmax
返回:
标量损失值
异常:
AssertionError: 输入与标签形状不一致时抛出
"""
if softmax:
inputs = torch.softmax(inputs, dim=1)
target = self._one_hot_encoder(target)
if weight is None:
weight = [1] * self.n_classes
assert inputs.size() == target.size(), (
"predict {} & target {} shape do not match".format(
inputs.size(), target.size()
)
)
loss = 0.0
for i in range(self.n_classes):
dice = self._dice_loss(inputs[:, i], target[:, i])
loss += dice * weight[i]
return loss / self.n_classes
def calculate_metric_percase(pred, gt):
"""计算单个样本的多种评价指标。"""
pred[pred > 0] = 1
gt[gt > 0] = 1
if pred.sum() > 0 and gt.sum() > 0:
dice = metric.binary.dc(pred, gt)
hd95 = metric.binary.hd95(pred, gt)
jaccard = metric.binary.jc(pred, gt)
asd = metric.binary.assd(pred, gt)
return dice, hd95, jaccard, asd
if pred.sum() > 0 and gt.sum() == 0:
return 1, 0, 1, 0
return 0, 0, 0, 0
def calculate_dice_percase(pred, gt):
"""计算单个样本的 Dice 分数。"""
pred[pred > 0] = 1
gt[gt > 0] = 1
if pred.sum() > 0 and gt.sum() > 0:
return metric.binary.dc(pred, gt)
if pred.sum() > 0 and gt.sum() == 0:
return 1
return 0
def test_single_volume(
image,
label,
net,
classes,
patch_size=None,
test_save_path=None,
case=None,
z_spacing=1,
class_names=None,
):
"""对单个体数据进行推理并可选保存结果。
参数:
image: 输入影像张量
label: 标签张量
net: 分割模型
classes: 类别数
patch_size: 目标切片尺寸
test_save_path: 输出目录
case: 样本名称
z_spacing: Z 轴间距
class_names: 类别名称列表
返回:
每类的指标列表
"""
if patch_size is None:
patch_size = [256, 256]
image = image.squeeze(0).cpu().detach().numpy()
label = label.squeeze(0).cpu().detach().numpy()
if class_names is None:
mask_labels = np.arange(1, classes)
else:
mask_labels = class_names
cmaps = mcolors.CSS4_COLORS
my_colors = [
"red",
"darkorange",
"yellow",
"forestgreen",
"blue",
"purple",
"magenta",
"cyan",
"deeppink",
"chocolate",
"olive",
"deepskyblue",
"darkviolet",
]
cmap = {
key: cmaps[key]
for key in sorted(cmaps.keys())
if key in my_colors[: classes - 1]
}
if len(image.shape) == 3:
prediction = np.zeros_like(label)
for ind in range(image.shape[0]):
slice_arr = image[ind, :, :]
x, y = slice_arr.shape[0], slice_arr.shape[1]
if x != patch_size[0] or y != patch_size[1]:
slice_arr = zoom(
slice_arr,
(patch_size[0] / x, patch_size[1] / y),
order=3,
)
input_tensor = (
torch.from_numpy(slice_arr)
.unsqueeze(0)
.unsqueeze(0)
.float()
.cuda()
)
net.eval()
with torch.no_grad():
outputs = net(input_tensor)[-1]
out = torch.argmax(
torch.softmax(outputs, dim=1), dim=1
).squeeze(0)
out = out.cpu().detach().numpy()
if x != patch_size[0] or y != patch_size[1]:
pred = zoom(
out,
(x / patch_size[0], y / patch_size[1]),
order=0,
)
else:
pred = out
prediction[ind] = pred
lbl = label[ind, :, :]
masks = [lbl == i for i in range(1, classes)]
preds_o = [pred == i for i in range(1, classes)]
fig_gt = overlay_masks(
image[ind, :, :],
masks,
labels=mask_labels,
colors=cmap,
mask_alpha=0.5,
)
fig_pred = overlay_masks(
image[ind, :, :],
preds_o,
labels=mask_labels,
colors=cmap,
mask_alpha=0.5,
)
fig_gt.savefig(
test_save_path + "/" + case + "_" + str(ind) + "_gt.png",
bbox_inches="tight",
dpi=300,
)
fig_pred.savefig(
test_save_path
+ "/"
+ case
+ "_"
+ str(ind)
+ "_pred.png",
bbox_inches="tight",
dpi=300,
)
else:
input_tensor = (
torch.from_numpy(image).unsqueeze(0).unsqueeze(0).float().cuda()
)
net.eval()
with torch.no_grad():
outputs = net(input_tensor)[-1]
out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)
prediction = out.cpu().detach().numpy()
metric_list = []
for i in range(1, classes):
metric_list.append(calculate_metric_percase(prediction == i, label == i))
if test_save_path is not None:
img_itk = sitk.GetImageFromArray(image.astype(np.float32))
prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32))
lab_itk = sitk.GetImageFromArray(label.astype(np.float32))
img_itk.SetSpacing((1, 1, z_spacing))
prd_itk.SetSpacing((1, 1, z_spacing))
lab_itk.SetSpacing((1, 1, z_spacing))
sitk.WriteImage(prd_itk, test_save_path + "/" + case + "_pred.nii.gz")
sitk.WriteImage(img_itk, test_save_path + "/" + case + "_img.nii.gz")
sitk.WriteImage(lab_itk, test_save_path + "/" + case + "_gt.nii.gz")
return metric_list
def val_single_volume(
image,
label,
net,
classes,
patch_size=None,
test_save_path=None,
case=None,
z_spacing=1,
):
"""对单个体数据进行验证推理。"""
if patch_size is None:
patch_size = [256, 256]
image = image.squeeze(0).cpu().detach().numpy()
label = label.squeeze(0).cpu().detach().numpy()
if len(image.shape) == 3:
prediction = np.zeros_like(label)
for ind in range(image.shape[0]):
slice_arr = image[ind, :, :]
x, y = slice_arr.shape[0], slice_arr.shape[1]
if x != patch_size[0] or y != patch_size[1]:
slice_arr = zoom(
slice_arr,
(patch_size[0] / x, patch_size[1] / y),
order=3,
)
input_tensor = (
torch.from_numpy(slice_arr)
.unsqueeze(0)
.unsqueeze(0)
.float()
.cuda()
)
net.eval()
with torch.no_grad():
outputs = net(input_tensor)[-1]
out = torch.argmax(
torch.softmax(outputs, dim=1), dim=1
).squeeze(0)
out = out.cpu().detach().numpy()
if x != patch_size[0] or y != patch_size[1]:
pred = zoom(
out,
(x / patch_size[0], y / patch_size[1]),
order=0,
)
else:
pred = out
prediction[ind] = pred
else:
input_tensor = (
torch.from_numpy(image).unsqueeze(0).unsqueeze(0).float().cuda()
)
net.eval()
with torch.no_grad():
outputs = net(input_tensor)[-1]
out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)
prediction = out.cpu().detach().numpy()
metric_list = []
for i in range(1, classes):
metric_list.append(calculate_dice_percase(prediction == i, label == i))
return metric_list
def horizontal_flip(image):
"""水平翻转 numpy 图像。"""
return image[:, ::-1, :]
def vertical_flip(image):
"""垂直翻转 numpy 图像。"""
return image[::-1, :, :]
def tta_model(model, image):
"""使用翻转进行测试时增强。"""
n_image = image
h_image = horizontal_flip(image)
v_image = vertical_flip(image)
n_mask = model.predict(np.expand_dims(n_image, axis=0))[0]
h_mask = model.predict(np.expand_dims(h_image, axis=0))[0]
v_mask = model.predict(np.expand_dims(v_image, axis=0))[0]
h_mask = horizontal_flip(h_mask)
v_mask = vertical_flip(v_mask)
return (n_mask + h_mask + v_mask) / 3.0
def cal_params_flops(model, size, logger):
"""记录模型 FLOPs 与参数量信息。"""
input_tensor = torch.randn(1, 3, size, size).cuda()
flops, params = profile(model, inputs=(input_tensor,))
print("flops", flops / 1e9) # 打印计算量
print("params", params / 1e6) # 打印参数量
total = sum(p.numel() for p in model.parameters())
print("Total params: %.2fM" % (total / 1e6))
logger.info(
"flops: %s, params: %s, Total params: : %.4f",
flops / 1e9,
params / 1e6,
total / 1e6,
)
def print_model_stats(model, input_size=(3, 224, 224)):
"""打印模型的 GMACs 与参数量。"""
total_params = sum(p.numel() for p in model.parameters())
print(f"Model created, param count: {total_params}")
macs, params = get_model_complexity_info(
model, input_size, as_strings=True, print_per_layer_stat=True
)
print(f"Model: {macs} GMACs, {params} parameters")

@ -0,0 +1,362 @@
"""Synapse 数据集推理入口。"""
import argparse
import logging
import os
import random
import sys
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from tqdm import tqdm
from src.core.networks import EMCADNet
from src.utils.dataset_synapse import Synapse_dataset
from src.utils.utils import test_single_volume
def build_parser():
"""构建推理参数解析器。"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--volume_path",
type=str,
default="./data/synapse/test_vol_h5_new",
help="root dir for validation volume data",
)
parser.add_argument("--dataset", type=str, default="Synapse", help="experiment_name")
parser.add_argument("--num_classes", type=int, default=9, help="output channel of network")
parser.add_argument("--list_dir", type=str, default="./lists/lists_Synapse", help="list dir")
parser.add_argument(
"--encoder",
type=str,
default="pvt_v2_b2",
help="Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...",
)
parser.add_argument(
"--expansion_factor",
type=int,
default=2,
help="expansion factor in MSCB block",
)
parser.add_argument(
"--kernel_sizes",
type=int,
nargs="+",
default=[1, 3, 5],
help="multi-scale kernel sizes in MSDC block",
)
parser.add_argument("--lgag_ks", type=int, default=3, help="Kernel size in LGAG")
parser.add_argument(
"--activation_mscb",
type=str,
default="relu6",
help="activation used in MSCB: relu6 or relu",
)
parser.add_argument(
"--no_dw_parallel",
action="store_true",
default=False,
help="use this flag to disable depth-wise parallel convolutions",
)
parser.add_argument(
"--concatenation",
action="store_true",
default=False,
help="use this flag to concatenate feature maps in MSDC block",
)
parser.add_argument(
"--no_pretrain",
action="store_true",
default=False,
help="use this flag to turn off loading pretrained enocder weights",
)
parser.add_argument(
"--pretrained_dir",
type=str,
default="./pretrained_pth/pvt/",
help="path to pretrained encoder dir",
)
parser.add_argument(
"--supervision",
type=str,
default="mutation",
help="loss supervision: mutation, deep_supervision or last_layer",
)
parser.add_argument(
"--max_iterations",
type=int,
default=30000,
help="maximum epoch number to train",
)
parser.add_argument(
"--max_epochs", type=int, default=300, help="maximum epoch number to train"
)
parser.add_argument("--batch_size", type=int, default=6, help="batch_size per gpu")
parser.add_argument(
"--base_lr", type=float, default=0.0001, help="segmentation network learning rate"
)
parser.add_argument("--img_size", type=int, default=224, help="input patch size of network input")
parser.add_argument(
"--is_savenii",
action="store_true",
default=True,
help="whether to save results during inference",
)
parser.add_argument(
"--test_save_dir", type=str, default="predictions", help="saving prediction as nii"
)
parser.add_argument(
"--deterministic", type=int, default=1, help="whether use deterministic training"
)
parser.add_argument("--seed", type=int, default=2222, help="random seed")
return parser
def get_class_names(num_classes):
"""根据类别数量返回类别名称列表。"""
if num_classes == 14:
return [
"spleen",
"right kidney",
"left kidney",
"gallbladder",
"esophagus",
"liver",
"stomach",
"aorta",
"inferior vena cava",
"portal vein and splenic vein",
"pancreas",
"right adrenal gland",
"left adrenal gland",
]
return [
"spleen",
"right kidney",
"left kidney",
"gallbladder",
"pancreas",
"liver",
"stomach",
"aorta",
]
def inference(args, model, class_names, test_save_path=None):
"""在 Synapse 上执行推理。"""
db_test = args.Dataset(
base_dir=args.volume_path,
split="test_vol",
list_dir=args.list_dir,
nclass=args.num_classes,
)
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
logging.info("%s test iterations per epoch", len(testloader))
model.eval()
metric_list = 0.0
for i_batch, sampled_batch in tqdm(enumerate(testloader)):
image, label, case_name = (
sampled_batch["image"],
sampled_batch["label"],
sampled_batch["case_name"][0],
)
metric_i = test_single_volume(
image,
label,
model,
classes=args.num_classes,
patch_size=[args.img_size, args.img_size],
test_save_path=test_save_path,
case=case_name,
z_spacing=1,
class_names=class_names,
)
metric_list += np.array(metric_i)
logging.info(
"idx %d case %s mean_dice %f mean_hd95 %f, mean_jacard %f mean_asd %f",
i_batch,
case_name,
np.mean(metric_i, axis=0)[0],
np.mean(metric_i, axis=0)[1],
np.mean(metric_i, axis=0)[2],
np.mean(metric_i, axis=0)[3],
)
metric_list = metric_list / len(db_test)
for i in range(1, args.num_classes):
logging.info(
"Mean class (%d) %s mean_dice %f mean_hd95 %f, mean_jacard %f mean_asd %f",
i,
class_names[i - 1],
metric_list[i - 1][0],
metric_list[i - 1][1],
metric_list[i - 1][2],
metric_list[i - 1][3],
)
performance = np.mean(metric_list, axis=0)[0]
mean_hd95 = np.mean(metric_list, axis=0)[1]
mean_jacard = np.mean(metric_list, axis=0)[2]
mean_asd = np.mean(metric_list, axis=0)[3]
logging.info(
"Testing performance in best val model: mean_dice : %f mean_hd95 : %f, "
"mean_jacard : %f mean_asd : %f",
performance,
mean_hd95,
mean_jacard,
mean_asd,
)
return "Testing Finished!"
def build_snapshot_path(args, dataset_name):
"""生成推理输出路径。"""
aggregation = "concat" if args.concatenation else "add"
dw_mode = "series" if args.no_dw_parallel else "parallel"
run = 1
args.exp = (
args.encoder
+ "_EMCAD_kernel_sizes_"
+ str(args.kernel_sizes)
+ "_dw_"
+ dw_mode
+ "_"
+ aggregation
+ "_lgag_ks_"
+ str(args.lgag_ks)
+ "_ef"
+ str(args.expansion_factor)
+ "_act_mscb_"
+ args.activation_mscb
+ "_loss_"
+ args.supervision
+ "_output_final_layer_Run"
+ str(run)
+ "_"
+ dataset_name
+ str(args.img_size)
)
snapshot_path = "model_pth/{}/{}".format(
args.exp,
args.encoder
+ "_EMCAD_kernel_sizes_"
+ str(args.kernel_sizes)
+ "_dw_"
+ dw_mode
+ "_"
+ aggregation
+ "_lgag_ks_"
+ str(args.lgag_ks)
+ "_ef"
+ str(args.expansion_factor)
+ "_act_mscb_"
+ args.activation_mscb
+ "_loss_"
+ args.supervision
+ "_output_final_layer_Run"
+ str(run),
)
snapshot_path = snapshot_path.replace("[", "").replace("]", "").replace(", ", "_")
if not args.no_pretrain:
snapshot_path = snapshot_path + "_pretrain"
if args.max_iterations != 50000:
snapshot_path = snapshot_path + "_" + str(args.max_iterations)[0:2] + "k"
if args.max_epochs != 300:
snapshot_path = snapshot_path + "_epo" + str(args.max_epochs)
snapshot_path = snapshot_path + "_bs" + str(args.batch_size)
if args.base_lr != 0.0001:
snapshot_path = snapshot_path + "_lr" + str(args.base_lr)
snapshot_path = snapshot_path + "_" + str(args.img_size)
if args.seed != 1234:
snapshot_path = snapshot_path + "_s" + str(args.seed)
return snapshot_path
def main():
"""主入口函数。"""
parser = build_parser()
args = parser.parse_args()
if not args.deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
dataset_config = {
"Synapse": {
"Dataset": Synapse_dataset,
"volume_path": args.volume_path,
"list_dir": args.list_dir,
"num_classes": args.num_classes,
"z_spacing": 1,
}
}
dataset_name = args.dataset
args.num_classes = dataset_config[dataset_name]["num_classes"]
args.volume_path = dataset_config[dataset_name]["volume_path"]
args.Dataset = dataset_config[dataset_name]["Dataset"]
args.list_dir = dataset_config[dataset_name]["list_dir"]
args.z_spacing = dataset_config[dataset_name]["z_spacing"]
print(args.no_pretrain)
snapshot_path = build_snapshot_path(args, dataset_name)
model = EMCADNet(
num_classes=args.num_classes,
kernel_sizes=args.kernel_sizes,
expansion_factor=args.expansion_factor,
dw_parallel=not args.no_dw_parallel,
add=not args.concatenation,
lgag_ks=args.lgag_ks,
activation=args.activation_mscb,
encoder=args.encoder,
pretrain=not args.no_pretrain,
pretrained_dir=args.pretrained_dir,
)
model.cuda()
snapshot = os.path.join(snapshot_path, "best.pth")
if not os.path.exists(snapshot):
snapshot = snapshot.replace("best", "epoch_" + str(args.max_epochs - 1))
model.load_state_dict(torch.load(snapshot))
snapshot_name = snapshot_path.split("/")[-1]
log_folder = "test_log/test_log_" + args.exp
os.makedirs(log_folder, exist_ok=True)
logging.basicConfig(
filename=log_folder + "/" + snapshot_name + ".txt",
level=logging.INFO,
format="[%(asctime)s.%(msecs)03d] %(message)s",
datefmt="%H:%M:%S",
)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info(str(args))
logging.info(snapshot_name)
if args.is_savenii:
args.test_save_dir = os.path.join(snapshot_path, "predictions")
test_save_path = os.path.join(args.test_save_dir, args.exp, snapshot_name + "2")
os.makedirs(test_save_path, exist_ok=True)
else:
test_save_path = None
class_names = get_class_names(args.num_classes)
inference(args, model, class_names, test_save_path)
if __name__ == "__main__":
main()
Loading…
Cancel
Save