parent
f6e9285936
commit
0a9b26b848
@ -0,0 +1,129 @@
|
||||
# EMCAD 项目预提交钩子配置
|
||||
# 用于在提交前自动执行代码质量检查
|
||||
|
||||
repos:
|
||||
# 1. 检查 YAML 语法
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
args: [--unsafe] # 允许自定义标签
|
||||
|
||||
- id: check-json
|
||||
|
||||
- id: check-toml
|
||||
|
||||
- id: end-of-file-fixer
|
||||
|
||||
- id: trailing-whitespace
|
||||
args: [--markdown-linebreak-ext=md]
|
||||
|
||||
- id: mixed-line-ending
|
||||
args: [--fix=lf]
|
||||
|
||||
# 2. ruff - 代码检查和格式化 (主工具)
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.1.6
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
|
||||
- id: ruff-format
|
||||
|
||||
# 3. isort - 导入排序
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
args: [--profile=black, --filter-files]
|
||||
|
||||
# 4. Python 类型检查
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.8.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
- numpy
|
||||
- torch
|
||||
- torchmetrics
|
||||
args: [--ignore-missing-imports, --warn-unused-configs, --pretty]
|
||||
|
||||
# 5. 检测秘密密钥
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: detect-private-key
|
||||
|
||||
# 6. 检测巨大的文件
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=5000']
|
||||
|
||||
# 7. Black - 代码格式化 (备选)
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.1.1
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3
|
||||
args: [--line-length=88, --target-version=py38]
|
||||
|
||||
# 8. 禁止调试语句
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: debug-statements
|
||||
|
||||
# 9. Python 特定检查
|
||||
- repo: https://github.com/PyCQA/pygrep-hooks
|
||||
rev: v1.9.0
|
||||
hooks:
|
||||
- id: python-check-blanket-noqa
|
||||
args: [--ignore-names=test_*.py]
|
||||
|
||||
- id: python-check-mock-methods
|
||||
|
||||
- id: python-use-type-annotations
|
||||
|
||||
- id: python-import-lint
|
||||
|
||||
# 10. 单元测试快速检查 (可选)
|
||||
- repo: https://github.com/sirosen/check-jsonschema
|
||||
rev: 0.28.3
|
||||
hooks:
|
||||
- id: check-schema
|
||||
name: Check YAML configs against schema
|
||||
files: ^configs/
|
||||
args: [--builtin-schema vendor.github-workflows]
|
||||
|
||||
# 预提交配置
|
||||
default_install_hook_types: [pre-commit, commit-msg]
|
||||
default_stages: [pre-commit]
|
||||
|
||||
# 跳过特定文件
|
||||
exclude: |
|
||||
(?x)^(
|
||||
tests/test_synapse-checkpoint\.py|
|
||||
src/utils/dataset_synapse-checkpoint\.py|
|
||||
src/utils/trainer-checkpoint\.py|
|
||||
src/utils/format_conversion-checkpoint\.py|
|
||||
src/utils/preprocess_synapse_data\.py|
|
||||
src/utils/test_synapse-checkpoint\.py|
|
||||
src/utils/train_synapse-checkpoint\.py|
|
||||
docs/|
|
||||
data/|
|
||||
model_pth/|
|
||||
pretrained_pth/|
|
||||
experiments/|
|
||||
.git/
|
||||
)
|
||||
|
||||
# 并行执行
|
||||
require_serial: false
|
||||
|
||||
# 失败策略
|
||||
fail_fast: false
|
||||
|
||||
# 语言版本
|
||||
language_version: python3
|
||||
@ -0,0 +1,105 @@
|
||||
.PHONY: install train test lint format clean docs preprocess run help
|
||||
|
||||
# EMCAD 项目 Makefile
|
||||
# 提供常用命令的快捷入口
|
||||
|
||||
# 安装依赖
|
||||
install:
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
|
||||
# 运行训练 (Synapse 数据集)
|
||||
train: export CUDA_VISIBLE_DEVICES=0
|
||||
python scripts/train_synapse.py --cfg configs/default.yaml
|
||||
|
||||
# 运行训练 (ACDC 数据集)
|
||||
train-acdc: export CUDA_VISIBLE_DATA_DEVICES=0
|
||||
python scripts/train_acdc.py --cfg configs/default.yaml
|
||||
|
||||
# 运行测试
|
||||
test:
|
||||
python -m pytest tests/ -v --tb=short
|
||||
|
||||
# 运行特定测试文件
|
||||
test-synapse:
|
||||
python -m pytest tests/test_synapse.py -v --tb=short
|
||||
|
||||
# 代码检查 (ruff)
|
||||
lint:
|
||||
ruff check src/ tests/
|
||||
ruff check scripts/
|
||||
|
||||
# 代码格式化 (ruff)
|
||||
format:
|
||||
ruff format src/ tests/ scripts/
|
||||
isort src/ tests/ scripts/
|
||||
|
||||
# 类型检查 (mypy, 如果安装了的话)
|
||||
typecheck:
|
||||
mypy src/ --ignore-missing-imports || echo "mypy not installed, skipping type check"
|
||||
|
||||
# 完整代码质量检查
|
||||
check: lint format typecheck
|
||||
|
||||
# 数据预处理 (Synapse)
|
||||
preprocess-synapse:
|
||||
python src/utils/preprocess_synapse_data.py
|
||||
|
||||
# 数据预处理 (ACDC)
|
||||
preprocess-acdc:
|
||||
python src/utils/preprocess_acdc_data.py
|
||||
|
||||
# 清理临时文件
|
||||
clean:
|
||||
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||
find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
|
||||
rm -rf .coverage htmlcov/ .mypy_cache/ .ruff_cache/
|
||||
|
||||
# 清理实验输出
|
||||
clean-experiments:
|
||||
rm -rf experiments/ logs/ *.log
|
||||
|
||||
# 清理所有生成的文件
|
||||
clean-all: clean clean-experiments
|
||||
|
||||
# 构建文档
|
||||
docs:
|
||||
cd docs && make html || sphinx-build -b html docs/source docs/build
|
||||
|
||||
# 查看文档
|
||||
docs-serve:
|
||||
cd docs && make livehtml || sphinx-autobuild docs/source docs/build --host 0.0.0.0 --port 8000
|
||||
|
||||
# 运行示例推理
|
||||
run-inference:
|
||||
python scripts/inference.py --cfg configs/default.yaml --checkpoint experiments/latest.pth --input input.png
|
||||
|
||||
# 创建实验目录结构
|
||||
setup-dirs:
|
||||
mkdir -p experiments/ logs/ model_pth/ pretrained_pth/
|
||||
|
||||
# 显示帮助信息
|
||||
help:
|
||||
@echo "EMCAD 项目 Makefile 命令:"
|
||||
@echo ""
|
||||
@echo " install - 安装项目依赖"
|
||||
@echo " train - 运行 Synapse 数据集训练"
|
||||
@echo " train-acdc - 运行 ACDC 数据集训练"
|
||||
@echo " test - 运行所有测试"
|
||||
@echo " test-synapse - 运行 Synapse 测试"
|
||||
@echo " lint - 代码检查 (ruff)"
|
||||
@echo " format - 代码格式化 (ruff, isort)"
|
||||
@echo " typecheck - 类型检查 (mypy)"
|
||||
@echo " check - 完整代码质量检查"
|
||||
@echo " preprocess-synapse - 预处理 Synapse 数据"
|
||||
@echo " preprocess-acdc - 预处理 ACDC 数据"
|
||||
@echo " clean - 清理 Python 缓存文件"
|
||||
@echo " clean-experiments - 清理实验输出"
|
||||
@echo " clean-all - 清理所有生成的文件"
|
||||
@echo " docs - 构建文档"
|
||||
@echo " docs-serve - 启动文档预览服务器"
|
||||
@echo " run-inference - 运行示例推理"
|
||||
@echo " setup-dirs - 创建目录结构"
|
||||
@echo " help - 显示此帮助信息"
|
||||
@ -0,0 +1,40 @@
|
||||
API 参考
|
||||
=========
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: 模块:
|
||||
|
||||
modules/core
|
||||
modules/utils
|
||||
|
||||
核心模块
|
||||
--------
|
||||
|
||||
.. automodule:: src.core.base
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. automodule:: src.core.networks
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
工具模块
|
||||
--------
|
||||
|
||||
.. automodule:: src.utils.config
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. automodule:: src.utils.trainer
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. automodule:: src.utils.dataloader
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -0,0 +1,40 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
project = "EMCAD"
|
||||
copyright = "2024, EMCAD Team"
|
||||
author = "EMCAD Team"
|
||||
release = "1.0.0"
|
||||
|
||||
extensions = [
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx.ext.napoleon",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx_rtd_theme",
|
||||
]
|
||||
|
||||
templates_path = ["_templates"]
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
html_static_path = ["_static"]
|
||||
html_logo = "../../docs/EMCAD_architecture.jpg"
|
||||
|
||||
html_theme_options = {
|
||||
"navigation_depth": 4,
|
||||
"collapse_navigation": False,
|
||||
"sticky_navigation": True,
|
||||
"includehidden": True,
|
||||
}
|
||||
|
||||
autodoc_default_options = {
|
||||
"members": True,
|
||||
"undoc-members": True,
|
||||
"show-inheritance": True,
|
||||
}
|
||||
|
||||
autosummary_generate = True
|
||||
@ -0,0 +1,99 @@
|
||||
配置指南
|
||||
=========
|
||||
|
||||
配置文件结构
|
||||
------------
|
||||
|
||||
EMCAD 使用 YAML 格式的配置文件,所有参数集中管理:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
# 数据集配置
|
||||
dataset:
|
||||
name: "Synapse"
|
||||
root_path: "/data/Synapse/train"
|
||||
num_classes: 4
|
||||
img_size: 224
|
||||
|
||||
# 模型配置
|
||||
model:
|
||||
encoder: "pvt_v2_b2"
|
||||
num_classes: 4
|
||||
kernel_sizes: [1, 3, 5]
|
||||
expansion_factor: 2
|
||||
|
||||
# 训练配置
|
||||
training:
|
||||
max_epochs: 300
|
||||
batch_size: 6
|
||||
base_lr: 0.0001
|
||||
|
||||
# 输出配置
|
||||
output:
|
||||
snapshot_path: "./experiments/"
|
||||
|
||||
# 日志配置
|
||||
logging:
|
||||
level: "INFO"
|
||||
|
||||
数据集配置
|
||||
----------
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
:widths: 20 80
|
||||
|
||||
* - 参数
|
||||
- 描述
|
||||
* - ``name``
|
||||
- 数据集名称 (Synapse 或 ACDC)
|
||||
* - ``root_path``
|
||||
- 训练数据根目录
|
||||
* - ``volume_path``
|
||||
- 测试数据目录
|
||||
* - ``num_classes``
|
||||
- 分割类别数
|
||||
* - ``img_size``
|
||||
- 输入图像尺寸
|
||||
|
||||
模型配置
|
||||
--------
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
:widths: 20 80
|
||||
|
||||
* - 参数
|
||||
- 描述
|
||||
* - ``encoder``
|
||||
- 编码器类型 (pvt_v2_b0/b1/b2, resnet18)
|
||||
* - ``kernel_sizes``
|
||||
- MSCB 模块卷积核大小列表
|
||||
* - ``expansion_factor``
|
||||
- MSCB 模块扩展因子
|
||||
* - ``dw_parallel``
|
||||
- 深度卷积并行模式
|
||||
* - ``add``
|
||||
- 特征相加模式 (False 为拼接)
|
||||
|
||||
训练配置
|
||||
--------
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
:widths: 20 80
|
||||
|
||||
* - 参数
|
||||
- 描述
|
||||
* - ``max_epochs``
|
||||
- 最大训练轮数
|
||||
* - ``batch_size``
|
||||
- 批次大小
|
||||
* - ``base_lr``
|
||||
- 基础学习率
|
||||
* - ``weight_decay``
|
||||
- 权重衰减
|
||||
* - ``w_ce``
|
||||
- 交叉熵损失权重
|
||||
* - ``w_dice``
|
||||
- Dice 损失权重
|
||||
@ -0,0 +1,68 @@
|
||||
快速开始
|
||||
=========
|
||||
|
||||
环境安装
|
||||
--------
|
||||
|
||||
1. 克隆项目仓库
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/your-repo/EMCAD.git
|
||||
cd EMCAD
|
||||
|
||||
2. 安装依赖
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
|
||||
3. 数据准备
|
||||
|
||||
按照数据集说明准备 Synapse 或 ACDC 数据集,并将路径配置到 ``configs/default.yaml`` 中。
|
||||
|
||||
模型训练
|
||||
--------
|
||||
|
||||
使用 Makefile 运行训练:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
make train
|
||||
|
||||
或者直接运行训练脚本:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/train_synapse.py --cfg configs/default.yaml
|
||||
|
||||
模型推理
|
||||
--------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
from src.core.networks import EMCADNet
|
||||
from src.utils.config import Config
|
||||
|
||||
config = Config.from_yaml("configs/default.yaml")
|
||||
model = EMCADNet(num_classes=config.model.num_classes)
|
||||
checkpoint = torch.load("experiments/latest.pth")
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
output = model(input_tensor)
|
||||
|
||||
配置修改
|
||||
--------
|
||||
|
||||
编辑 ``configs/default.yaml`` 文件来调整训练参数:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
training:
|
||||
max_epochs: 300
|
||||
batch_size: 6
|
||||
base_lr: 0.0001
|
||||
@ -0,0 +1,6 @@
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
getting_started
|
||||
training
|
||||
configuration
|
||||
@ -0,0 +1,83 @@
|
||||
训练指南
|
||||
=========
|
||||
|
||||
基本训练
|
||||
--------
|
||||
|
||||
EMCAD 项目使用 Trainer 类封装完整的训练流程。
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from src.utils.config import Config
|
||||
from src.core.networks import EMCADNet
|
||||
from src.utils.dataloader import get_loader
|
||||
from src.utils.trainer import Trainer
|
||||
|
||||
config = Config.from_yaml("configs/default.yaml")
|
||||
|
||||
train_loader = get_loader(
|
||||
config.dataset.root_path,
|
||||
config.dataset.list_dir,
|
||||
config.dataset.volume_path,
|
||||
split="train",
|
||||
img_size=config.dataset.img_size,
|
||||
batch_size=config.training.batch_size,
|
||||
)
|
||||
|
||||
model = EMCADNet(
|
||||
num_classes=config.model.num_classes,
|
||||
encoder=config.model.encoder,
|
||||
kernel_sizes=config.model.kernel_sizes,
|
||||
expansion_factor=config.model.expansion_factor,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
config=config,
|
||||
train_loader=train_loader,
|
||||
)
|
||||
|
||||
trainer.train(epochs=config.training.max_epochs)
|
||||
|
||||
自定义训练
|
||||
----------
|
||||
|
||||
早停配置
|
||||
~~~~~~~~
|
||||
|
||||
在配置文件中启用早停:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
training:
|
||||
early_stopping:
|
||||
enabled: true
|
||||
patience: 50
|
||||
min_delta: 0.001
|
||||
|
||||
学习率调度
|
||||
~~~~~~~~~~
|
||||
|
||||
支持多种学习率调度器:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
training:
|
||||
lr_scheduler:
|
||||
type: "CosineAnnealingLR"
|
||||
T_max: 300
|
||||
eta_min: 1e-6
|
||||
|
||||
监督策略
|
||||
--------
|
||||
|
||||
EMCAD 支持三种监督策略:
|
||||
|
||||
1. **mutation**: 多尺度监督
|
||||
2. **deep_supervision**: 深度监督
|
||||
3. **last_layer**: 仅最后一层监督
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
training:
|
||||
supervision: "mutation"
|
||||
@ -1,3 +1,4 @@
|
||||
from .base import BaseModel
|
||||
from .networks import EMCADNet
|
||||
|
||||
__all__ = ["EMCADNet"]
|
||||
__all__ = ["BaseModel", "EMCADNet"]
|
||||
|
||||
@ -1,362 +1,264 @@
|
||||
"""Synapse 数据集推理入口。"""
|
||||
"""Synapse 数据集测试模块。"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
from src.core.networks import EMCADNet
|
||||
from src.utils.dataset_synapse import Synapse_dataset
|
||||
from src.utils.utils import test_single_volume
|
||||
|
||||
|
||||
def build_parser():
|
||||
"""构建推理参数解析器。"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--volume_path",
|
||||
type=str,
|
||||
default="./data/synapse/test_vol_h5_new",
|
||||
help="root dir for validation volume data",
|
||||
)
|
||||
parser.add_argument("--dataset", type=str, default="Synapse", help="experiment_name")
|
||||
parser.add_argument("--num_classes", type=int, default=9, help="output channel of network")
|
||||
parser.add_argument("--list_dir", type=str, default="./lists/lists_Synapse", help="list dir")
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder",
|
||||
type=str,
|
||||
default="pvt_v2_b2",
|
||||
help="Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expansion_factor",
|
||||
type=int,
|
||||
default=2,
|
||||
help="expansion factor in MSCB block",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kernel_sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[1, 3, 5],
|
||||
help="multi-scale kernel sizes in MSDC block",
|
||||
)
|
||||
parser.add_argument("--lgag_ks", type=int, default=3, help="Kernel size in LGAG")
|
||||
parser.add_argument(
|
||||
"--activation_mscb",
|
||||
type=str,
|
||||
default="relu6",
|
||||
help="activation used in MSCB: relu6 or relu",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_dw_parallel",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="use this flag to disable depth-wise parallel convolutions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concatenation",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="use this flag to concatenate feature maps in MSDC block",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_pretrain",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="use this flag to turn off loading pretrained enocder weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_dir",
|
||||
type=str,
|
||||
default="./pretrained_pth/pvt/",
|
||||
help="path to pretrained encoder dir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--supervision",
|
||||
type=str,
|
||||
default="mutation",
|
||||
help="loss supervision: mutation, deep_supervision or last_layer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_iterations",
|
||||
type=int,
|
||||
default=30000,
|
||||
help="maximum epoch number to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_epochs", type=int, default=300, help="maximum epoch number to train"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=6, help="batch_size per gpu")
|
||||
parser.add_argument(
|
||||
"--base_lr", type=float, default=0.0001, help="segmentation network learning rate"
|
||||
)
|
||||
parser.add_argument("--img_size", type=int, default=224, help="input patch size of network input")
|
||||
parser.add_argument(
|
||||
"--is_savenii",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="whether to save results during inference",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--test_save_dir", type=str, default="predictions", help="saving prediction as nii"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--deterministic", type=int, default=1, help="whether use deterministic training"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=2222, help="random seed")
|
||||
return parser
|
||||
|
||||
|
||||
def get_class_names(num_classes):
|
||||
"""根据类别数量返回类别名称列表。"""
|
||||
if num_classes == 14:
|
||||
return [
|
||||
"spleen",
|
||||
"right kidney",
|
||||
"left kidney",
|
||||
"gallbladder",
|
||||
"esophagus",
|
||||
"liver",
|
||||
"stomach",
|
||||
"aorta",
|
||||
"inferior vena cava",
|
||||
"portal vein and splenic vein",
|
||||
"pancreas",
|
||||
"right adrenal gland",
|
||||
"left adrenal gland",
|
||||
]
|
||||
return [
|
||||
"spleen",
|
||||
"right kidney",
|
||||
"left kidney",
|
||||
"gallbladder",
|
||||
"pancreas",
|
||||
"liver",
|
||||
"stomach",
|
||||
"aorta",
|
||||
]
|
||||
|
||||
|
||||
def inference(args, model, class_names, test_save_path=None):
|
||||
"""在 Synapse 上执行推理。"""
|
||||
db_test = args.Dataset(
|
||||
base_dir=args.volume_path,
|
||||
split="test_vol",
|
||||
list_dir=args.list_dir,
|
||||
nclass=args.num_classes,
|
||||
)
|
||||
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
|
||||
logging.info("%s test iterations per epoch", len(testloader))
|
||||
model.eval()
|
||||
metric_list = 0.0
|
||||
for i_batch, sampled_batch in tqdm(enumerate(testloader)):
|
||||
image, label, case_name = (
|
||||
sampled_batch["image"],
|
||||
sampled_batch["label"],
|
||||
sampled_batch["case_name"][0],
|
||||
from src.utils.config import Config
|
||||
|
||||
|
||||
class TestEMCADNet:
|
||||
"""EMCADNet 模型测试类。"""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
"""加载默认配置。"""
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", "configs", "default.yaml"
|
||||
)
|
||||
metric_i = test_single_volume(
|
||||
image,
|
||||
label,
|
||||
model,
|
||||
classes=args.num_classes,
|
||||
patch_size=[args.img_size, args.img_size],
|
||||
test_save_path=test_save_path,
|
||||
case=case_name,
|
||||
z_spacing=1,
|
||||
class_names=class_names,
|
||||
if os.path.exists(config_path):
|
||||
return Config.from_yaml(config_path)
|
||||
return Config()
|
||||
|
||||
@pytest.fixture
|
||||
def model(self, config):
|
||||
"""创建测试用模型。"""
|
||||
return EMCADNet(
|
||||
num_classes=config.model.num_classes,
|
||||
encoder=config.model.encoder,
|
||||
kernel_sizes=config.model.kernel_sizes,
|
||||
expansion_factor=config.model.expansion_factor,
|
||||
dw_parallel=config.model.dw_parallel,
|
||||
add=config.model.add,
|
||||
lgag_ks=config.model.lgag_ks,
|
||||
activation=config.model.activation,
|
||||
pretrain=False,
|
||||
)
|
||||
metric_list += np.array(metric_i)
|
||||
logging.info(
|
||||
"idx %d case %s mean_dice %f mean_hd95 %f, mean_jacard %f mean_asd %f",
|
||||
i_batch,
|
||||
case_name,
|
||||
np.mean(metric_i, axis=0)[0],
|
||||
np.mean(metric_i, axis=0)[1],
|
||||
np.mean(metric_i, axis=0)[2],
|
||||
np.mean(metric_i, axis=0)[3],
|
||||
|
||||
def test_model_initialization(self, model, config):
|
||||
"""测试模型初始化。"""
|
||||
assert model is not None
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
assert model.num_classes == config.model.num_classes
|
||||
|
||||
def test_forward_pass(self, model):
|
||||
"""测试前向传播。"""
|
||||
batch_size = 2
|
||||
img_size = 224
|
||||
x = torch.randn(batch_size, 3, img_size, img_size)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(x, mode="test")
|
||||
|
||||
assert outputs is not None
|
||||
if isinstance(outputs, list):
|
||||
assert len(outputs) > 0
|
||||
assert outputs[0].shape == (batch_size, model.num_classes, img_size, img_size)
|
||||
else:
|
||||
assert outputs.shape == (batch_size, model.num_classes, img_size, img_size)
|
||||
|
||||
def test_training_mode(self, model):
|
||||
"""测试训练模式前向传播。"""
|
||||
batch_size = 2
|
||||
img_size = 224
|
||||
x = torch.randn(batch_size, 3, img_size, img_size)
|
||||
|
||||
model.train()
|
||||
with torch.no_grad():
|
||||
outputs = model(x, mode="train")
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
def test_get_loss(self, model):
|
||||
"""测试损失计算。"""
|
||||
batch_size = 2
|
||||
img_size = 224
|
||||
num_classes = model.num_classes
|
||||
|
||||
x = torch.randn(batch_size, 3, img_size, img_size)
|
||||
outputs = model(x, mode="test")
|
||||
targets = torch.randint(0, num_classes, (batch_size, 1, img_size, img_size))
|
||||
|
||||
loss = model.get_loss(outputs, targets)
|
||||
|
||||
assert loss is not None
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.item() >= 0
|
||||
|
||||
def test_multi_scale_outputs(self, model):
|
||||
"""测试多尺度输出。"""
|
||||
batch_size = 2
|
||||
img_size = 224
|
||||
x = torch.randn(batch_size, 3, img_size, img_size)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(x, mode="train")
|
||||
|
||||
assert isinstance(outputs, list)
|
||||
assert len(outputs) == 4
|
||||
|
||||
def test_single_channel_input(self, model):
|
||||
"""测试单通道输入。"""
|
||||
batch_size = 2
|
||||
img_size = 224
|
||||
x = torch.randn(batch_size, 1, img_size, img_size)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(x, mode="test")
|
||||
|
||||
assert outputs is not None
|
||||
if isinstance(outputs, list):
|
||||
assert outputs[0].shape == (batch_size, model.num_classes, img_size, img_size)
|
||||
|
||||
def test_parameter_count(self, model):
|
||||
"""测试参数计数。"""
|
||||
num_params = model.count_parameters()
|
||||
|
||||
assert isinstance(num_params, int)
|
||||
assert num_params > 0
|
||||
assert num_params > 1e6
|
||||
|
||||
def test_save_load_state_dict(self, model, tmp_path):
|
||||
"""测试模型状态保存和加载。"""
|
||||
state_dict = model.state_dict()
|
||||
|
||||
save_path = os.path.join(tmp_path, "model.pth")
|
||||
model.save(save_path)
|
||||
|
||||
assert os.path.exists(save_path)
|
||||
|
||||
new_model = EMCADNet(
|
||||
num_classes=model.num_classes,
|
||||
encoder="pvt_v2_b2",
|
||||
pretrain=False,
|
||||
)
|
||||
metric_list = metric_list / len(db_test)
|
||||
for i in range(1, args.num_classes):
|
||||
logging.info(
|
||||
"Mean class (%d) %s mean_dice %f mean_hd95 %f, mean_jacard %f mean_asd %f",
|
||||
i,
|
||||
class_names[i - 1],
|
||||
metric_list[i - 1][0],
|
||||
metric_list[i - 1][1],
|
||||
metric_list[i - 1][2],
|
||||
metric_list[i - 1][3],
|
||||
new_model.load_state_dict(save_path)
|
||||
|
||||
def test_to_device(self, model):
|
||||
"""测试设备迁移。"""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model_device = model.to(device)
|
||||
|
||||
assert next(model_device.parameters()).device == device
|
||||
|
||||
batch_size = 2
|
||||
img_size = 224
|
||||
x = torch.randn(batch_size, 3, img_size, img_size).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model_device(x, mode="test")
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
|
||||
class TestConfig:
|
||||
"""配置管理测试类。"""
|
||||
|
||||
def test_config_creation(self):
|
||||
"""测试配置创建。"""
|
||||
config = Config()
|
||||
|
||||
assert config is not None
|
||||
assert hasattr(config, "dataset")
|
||||
assert hasattr(config, "model")
|
||||
assert hasattr(config, "training")
|
||||
assert hasattr(config, "output")
|
||||
assert hasattr(config, "logging")
|
||||
|
||||
def test_config_from_yaml(self):
|
||||
"""测试从 YAML 加载配置。"""
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", "configs", "default.yaml"
|
||||
)
|
||||
performance = np.mean(metric_list, axis=0)[0]
|
||||
mean_hd95 = np.mean(metric_list, axis=0)[1]
|
||||
mean_jacard = np.mean(metric_list, axis=0)[2]
|
||||
mean_asd = np.mean(metric_list, axis=0)[3]
|
||||
logging.info(
|
||||
"Testing performance in best val model: mean_dice : %f mean_hd95 : %f, "
|
||||
"mean_jacard : %f mean_asd : %f",
|
||||
performance,
|
||||
mean_hd95,
|
||||
mean_jacard,
|
||||
mean_asd,
|
||||
)
|
||||
return "Testing Finished!"
|
||||
|
||||
|
||||
def build_snapshot_path(args, dataset_name):
|
||||
"""生成推理输出路径。"""
|
||||
aggregation = "concat" if args.concatenation else "add"
|
||||
dw_mode = "series" if args.no_dw_parallel else "parallel"
|
||||
run = 1
|
||||
|
||||
args.exp = (
|
||||
args.encoder
|
||||
+ "_EMCAD_kernel_sizes_"
|
||||
+ str(args.kernel_sizes)
|
||||
+ "_dw_"
|
||||
+ dw_mode
|
||||
+ "_"
|
||||
+ aggregation
|
||||
+ "_lgag_ks_"
|
||||
+ str(args.lgag_ks)
|
||||
+ "_ef"
|
||||
+ str(args.expansion_factor)
|
||||
+ "_act_mscb_"
|
||||
+ args.activation_mscb
|
||||
+ "_loss_"
|
||||
+ args.supervision
|
||||
+ "_output_final_layer_Run"
|
||||
+ str(run)
|
||||
+ "_"
|
||||
+ dataset_name
|
||||
+ str(args.img_size)
|
||||
)
|
||||
|
||||
snapshot_path = "model_pth/{}/{}".format(
|
||||
args.exp,
|
||||
args.encoder
|
||||
+ "_EMCAD_kernel_sizes_"
|
||||
+ str(args.kernel_sizes)
|
||||
+ "_dw_"
|
||||
+ dw_mode
|
||||
+ "_"
|
||||
+ aggregation
|
||||
+ "_lgag_ks_"
|
||||
+ str(args.lgag_ks)
|
||||
+ "_ef"
|
||||
+ str(args.expansion_factor)
|
||||
+ "_act_mscb_"
|
||||
+ args.activation_mscb
|
||||
+ "_loss_"
|
||||
+ args.supervision
|
||||
+ "_output_final_layer_Run"
|
||||
+ str(run),
|
||||
)
|
||||
snapshot_path = snapshot_path.replace("[", "").replace("]", "").replace(", ", "_")
|
||||
|
||||
if not args.no_pretrain:
|
||||
snapshot_path = snapshot_path + "_pretrain"
|
||||
if args.max_iterations != 50000:
|
||||
snapshot_path = snapshot_path + "_" + str(args.max_iterations)[0:2] + "k"
|
||||
if args.max_epochs != 300:
|
||||
snapshot_path = snapshot_path + "_epo" + str(args.max_epochs)
|
||||
snapshot_path = snapshot_path + "_bs" + str(args.batch_size)
|
||||
if args.base_lr != 0.0001:
|
||||
snapshot_path = snapshot_path + "_lr" + str(args.base_lr)
|
||||
snapshot_path = snapshot_path + "_" + str(args.img_size)
|
||||
if args.seed != 1234:
|
||||
snapshot_path = snapshot_path + "_s" + str(args.seed)
|
||||
|
||||
return snapshot_path
|
||||
|
||||
|
||||
def main():
|
||||
"""主入口函数。"""
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.deterministic:
|
||||
cudnn.benchmark = True
|
||||
cudnn.deterministic = False
|
||||
else:
|
||||
cudnn.benchmark = False
|
||||
cudnn.deterministic = True
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
|
||||
dataset_config = {
|
||||
"Synapse": {
|
||||
"Dataset": Synapse_dataset,
|
||||
"volume_path": args.volume_path,
|
||||
"list_dir": args.list_dir,
|
||||
"num_classes": args.num_classes,
|
||||
"z_spacing": 1,
|
||||
|
||||
if os.path.exists(config_path):
|
||||
config = Config.from_yaml(config_path)
|
||||
|
||||
assert config is not None
|
||||
assert config.dataset is not None
|
||||
assert config.model is not None
|
||||
assert config.training is not None
|
||||
|
||||
def test_config_to_dict(self):
|
||||
"""测试配置序列化。"""
|
||||
config = Config()
|
||||
config_dict = config.to_dict()
|
||||
|
||||
assert isinstance(config_dict, dict)
|
||||
assert "dataset" in config_dict
|
||||
assert "model" in config_dict
|
||||
assert "training" in config_dict
|
||||
|
||||
def test_config_from_dict(self):
|
||||
"""测试配置反序列化。"""
|
||||
data = {
|
||||
"dataset": {"name": "Test", "num_classes": 4},
|
||||
"model": {"encoder": "pvt_v2_b2", "num_classes": 4},
|
||||
}
|
||||
}
|
||||
dataset_name = args.dataset
|
||||
args.num_classes = dataset_config[dataset_name]["num_classes"]
|
||||
args.volume_path = dataset_config[dataset_name]["volume_path"]
|
||||
args.Dataset = dataset_config[dataset_name]["Dataset"]
|
||||
args.list_dir = dataset_config[dataset_name]["list_dir"]
|
||||
args.z_spacing = dataset_config[dataset_name]["z_spacing"]
|
||||
print(args.no_pretrain)
|
||||
|
||||
snapshot_path = build_snapshot_path(args, dataset_name)
|
||||
|
||||
model = EMCADNet(
|
||||
num_classes=args.num_classes,
|
||||
kernel_sizes=args.kernel_sizes,
|
||||
expansion_factor=args.expansion_factor,
|
||||
dw_parallel=not args.no_dw_parallel,
|
||||
add=not args.concatenation,
|
||||
lgag_ks=args.lgag_ks,
|
||||
activation=args.activation_mscb,
|
||||
encoder=args.encoder,
|
||||
pretrain=not args.no_pretrain,
|
||||
pretrained_dir=args.pretrained_dir,
|
||||
)
|
||||
model.cuda()
|
||||
|
||||
snapshot = os.path.join(snapshot_path, "best.pth")
|
||||
if not os.path.exists(snapshot):
|
||||
snapshot = snapshot.replace("best", "epoch_" + str(args.max_epochs - 1))
|
||||
model.load_state_dict(torch.load(snapshot))
|
||||
snapshot_name = snapshot_path.split("/")[-1]
|
||||
|
||||
log_folder = "test_log/test_log_" + args.exp
|
||||
os.makedirs(log_folder, exist_ok=True)
|
||||
logging.basicConfig(
|
||||
filename=log_folder + "/" + snapshot_name + ".txt",
|
||||
level=logging.INFO,
|
||||
format="[%(asctime)s.%(msecs)03d] %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
|
||||
logging.info(str(args))
|
||||
logging.info(snapshot_name)
|
||||
|
||||
if args.is_savenii:
|
||||
args.test_save_dir = os.path.join(snapshot_path, "predictions")
|
||||
test_save_path = os.path.join(args.test_save_dir, args.exp, snapshot_name + "2")
|
||||
os.makedirs(test_save_path, exist_ok=True)
|
||||
else:
|
||||
test_save_path = None
|
||||
|
||||
class_names = get_class_names(args.num_classes)
|
||||
inference(args, model, class_names, test_save_path)
|
||||
|
||||
config = Config.from_dict(data)
|
||||
|
||||
assert config is not None
|
||||
assert config.dataset.name == "Test"
|
||||
assert config.model.num_classes == 4
|
||||
|
||||
|
||||
class TestInference:
|
||||
"""推理流程测试类。"""
|
||||
|
||||
@pytest.fixture
|
||||
def trained_model(self):
|
||||
"""创建预训练模型(用于推理测试)。"""
|
||||
model = EMCADNet(
|
||||
num_classes=9,
|
||||
encoder="pvt_v2_b2",
|
||||
kernel_sizes=[1, 3, 5],
|
||||
expansion_factor=2,
|
||||
dw_parallel=True,
|
||||
add=True,
|
||||
lgag_ks=3,
|
||||
activation="relu",
|
||||
pretrain=False,
|
||||
)
|
||||
return model
|
||||
|
||||
def test_inference_pipeline(self, trained_model):
|
||||
"""测试推理流程。"""
|
||||
trained_model.eval()
|
||||
|
||||
batch_size = 1
|
||||
img_size = 224
|
||||
x = torch.randn(batch_size, 3, img_size, img_size)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = trained_model(x, mode="test")
|
||||
|
||||
predictions = torch.argmax(outputs, dim=1) if isinstance(outputs, torch.Tensor) else torch.argmax(outputs[0], dim=1)
|
||||
|
||||
assert predictions.shape == (batch_size, img_size, img_size)
|
||||
|
||||
def test_dice_calculation(self, trained_model):
|
||||
"""测试 Dice 指标计算。"""
|
||||
trained_model.eval()
|
||||
|
||||
batch_size = 1
|
||||
img_size = 224
|
||||
x = torch.randn(batch_size, 3, img_size, img_size)
|
||||
target = torch.randint(0, 9, (batch_size, 1, img_size, img_size))
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = trained_model(x, mode="test")
|
||||
|
||||
loss = trained_model.get_loss(outputs, target)
|
||||
|
||||
assert loss is not None
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
Loading…
Reference in new issue