feat: 更新项目代码

main
learnljs 4 months ago
parent f6e9285936
commit 0a9b26b848

@ -0,0 +1,129 @@
# EMCAD 项目预提交钩子配置
# 用于在提交前自动执行代码质量检查
repos:
# 1. 检查 YAML 语法
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-yaml
args: [--unsafe] # 允许自定义标签
- id: check-json
- id: check-toml
- id: end-of-file-fixer
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
- id: mixed-line-ending
args: [--fix=lf]
# 2. ruff - 代码检查和格式化 (主工具)
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.6
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format
# 3. isort - 导入排序
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args: [--profile=black, --filter-files]
# 4. Python 类型检查
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
additional_dependencies:
- numpy
- torch
- torchmetrics
args: [--ignore-missing-imports, --warn-unused-configs, --pretty]
# 5. 检测秘密密钥
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: detect-private-key
# 6. 检测巨大的文件
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-added-large-files
args: ['--maxkb=5000']
# 7. Black - 代码格式化 (备选)
- repo: https://github.com/psf/black
rev: 24.1.1
hooks:
- id: black
language_version: python3
args: [--line-length=88, --target-version=py38]
# 8. 禁止调试语句
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: debug-statements
# 9. Python 特定检查
- repo: https://github.com/PyCQA/pygrep-hooks
rev: v1.9.0
hooks:
- id: python-check-blanket-noqa
args: [--ignore-names=test_*.py]
- id: python-check-mock-methods
- id: python-use-type-annotations
- id: python-import-lint
# 10. 单元测试快速检查 (可选)
- repo: https://github.com/sirosen/check-jsonschema
rev: 0.28.3
hooks:
- id: check-schema
name: Check YAML configs against schema
files: ^configs/
args: [--builtin-schema vendor.github-workflows]
# 预提交配置
default_install_hook_types: [pre-commit, commit-msg]
default_stages: [pre-commit]
# 跳过特定文件
exclude: |
(?x)^(
tests/test_synapse-checkpoint\.py|
src/utils/dataset_synapse-checkpoint\.py|
src/utils/trainer-checkpoint\.py|
src/utils/format_conversion-checkpoint\.py|
src/utils/preprocess_synapse_data\.py|
src/utils/test_synapse-checkpoint\.py|
src/utils/train_synapse-checkpoint\.py|
docs/|
data/|
model_pth/|
pretrained_pth/|
experiments/|
.git/
)
# 并行执行
require_serial: false
# 失败策略
fail_fast: false
# 语言版本
language_version: python3

@ -0,0 +1,105 @@
.PHONY: install train test lint format clean docs preprocess run help
# EMCAD 项目 Makefile
# 提供常用命令的快捷入口
# 安装依赖
install:
pip install -r requirements.txt
pip install -e .
# 运行训练 (Synapse 数据集)
train: export CUDA_VISIBLE_DEVICES=0
python scripts/train_synapse.py --cfg configs/default.yaml
# 运行训练 (ACDC 数据集)
train-acdc: export CUDA_VISIBLE_DATA_DEVICES=0
python scripts/train_acdc.py --cfg configs/default.yaml
# 运行测试
test:
python -m pytest tests/ -v --tb=short
# 运行特定测试文件
test-synapse:
python -m pytest tests/test_synapse.py -v --tb=short
# 代码检查 (ruff)
lint:
ruff check src/ tests/
ruff check scripts/
# 代码格式化 (ruff)
format:
ruff format src/ tests/ scripts/
isort src/ tests/ scripts/
# 类型检查 (mypy, 如果安装了的话)
typecheck:
mypy src/ --ignore-missing-imports || echo "mypy not installed, skipping type check"
# 完整代码质量检查
check: lint format typecheck
# 数据预处理 (Synapse)
preprocess-synapse:
python src/utils/preprocess_synapse_data.py
# 数据预处理 (ACDC)
preprocess-acdc:
python src/utils/preprocess_acdc_data.py
# 清理临时文件
clean:
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
find . -type f -name "*.pyc" -delete 2>/dev/null || true
find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true
find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
rm -rf .coverage htmlcov/ .mypy_cache/ .ruff_cache/
# 清理实验输出
clean-experiments:
rm -rf experiments/ logs/ *.log
# 清理所有生成的文件
clean-all: clean clean-experiments
# 构建文档
docs:
cd docs && make html || sphinx-build -b html docs/source docs/build
# 查看文档
docs-serve:
cd docs && make livehtml || sphinx-autobuild docs/source docs/build --host 0.0.0.0 --port 8000
# 运行示例推理
run-inference:
python scripts/inference.py --cfg configs/default.yaml --checkpoint experiments/latest.pth --input input.png
# 创建实验目录结构
setup-dirs:
mkdir -p experiments/ logs/ model_pth/ pretrained_pth/
# 显示帮助信息
help:
@echo "EMCAD 项目 Makefile 命令:"
@echo ""
@echo " install - 安装项目依赖"
@echo " train - 运行 Synapse 数据集训练"
@echo " train-acdc - 运行 ACDC 数据集训练"
@echo " test - 运行所有测试"
@echo " test-synapse - 运行 Synapse 测试"
@echo " lint - 代码检查 (ruff)"
@echo " format - 代码格式化 (ruff, isort)"
@echo " typecheck - 类型检查 (mypy)"
@echo " check - 完整代码质量检查"
@echo " preprocess-synapse - 预处理 Synapse 数据"
@echo " preprocess-acdc - 预处理 ACDC 数据"
@echo " clean - 清理 Python 缓存文件"
@echo " clean-experiments - 清理实验输出"
@echo " clean-all - 清理所有生成的文件"
@echo " docs - 构建文档"
@echo " docs-serve - 启动文档预览服务器"
@echo " run-inference - 运行示例推理"
@echo " setup-dirs - 创建目录结构"
@echo " help - 显示此帮助信息"

@ -1,3 +1,4 @@
# 本项目用于论文复现https://arxiv.org/abs/2405.06880
# EMCAD高效多尺度卷积注意力解码器用于医学图像分割
本项目是论文《EMCAD用于医学图像分割的高效多尺度卷积注意力解码器》的官方PyTorch实现该论文发表于CVPR 2024。项目作者包括Md Mostafijur Rahman、Mustafa Munir和Radu Marculescu均来自德克萨斯大学奥斯汀分校。

@ -0,0 +1,68 @@
# EMCAD 默认配置文件
# 数据集配置
dataset:
name: "ACDC" # Synapse 或 ACDC
root_path: "/data/ACDC/train"
volume_path: "/data/ACDC/test"
list_dir: "/data/ACDC/lists_ACDC"
num_classes: 4
z_spacing: 1
img_size: 224
# 模型配置
model:
encoder: "pvt_v2_b2" # pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, resnet18, etc.
num_classes: 4
kernel_sizes: [1, 3, 5] # 多尺度卷积核大小
expansion_factor: 2 # MSCB 模块扩展因子
dw_parallel: true # 深度卷积并行模式
add: true # 特征图相加模式False 为拼接模式)
lgag_ks: 3 # LGAG 模块卷积核大小
activation: "relu" # 激活函数: relu 或 relu6
pretrain: true # 是否加载预训练权重
pretrained_dir: "./model_pth/"
# 训练配置
training:
# 六、避免硬编码超参
max_epochs: 300
max_iterations: 50000
batch_size: 6
base_lr: 0.0001 # 学习率
weight_decay: 0.0001
seed: 2222
deterministic: true
n_gpu: 1
# 损失函数权重
loss:
w_ce: 0.3
w_dice: 0.7
# 监督模式
supervision: "mutation" # mutation, deep_supervision, last_layer
# 早停配置
early_stopping:
enabled: false
patience: 50
min_delta: 0.001
# 学习率调度
lr_scheduler:
type: "CosineAnnealingLR"
T_max: 300
eta_min: 1e-6
# 输出配置
output:
snapshot_path: "./experiments/"
save_interval: 50 # 保存模型的间隔 epoch 数
log_interval: 50 # 日志输出间隔
# 日志配置
logging:
level: "INFO"
format: "[%(asctime)s.%(msecs)03d] %(message)s"
log_file: "train.log"

@ -0,0 +1,40 @@
API 参考
=========
.. toctree::
:maxdepth: 2
:caption: 模块:
modules/core
modules/utils
核心模块
--------
.. automodule:: src.core.base
:members:
:undoc-members:
:show-inheritance:
.. automodule:: src.core.networks
:members:
:undoc-members:
:show-inheritance:
工具模块
--------
.. automodule:: src.utils.config
:members:
:undoc-members:
:show-inheritance:
.. automodule:: src.utils.trainer
:members:
:undoc-members:
:show-inheritance:
.. automodule:: src.utils.dataloader
:members:
:undoc-members:
:show-inheritance:

@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
project = "EMCAD"
copyright = "2024, EMCAD Team"
author = "EMCAD Team"
release = "1.0.0"
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.viewcode",
"sphinx.ext.napoleon",
"sphinx.ext.autosummary",
"sphinx_rtd_theme",
]
templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
html_theme = "sphinx_rtd_theme"
html_static_path = ["_static"]
html_logo = "../../docs/EMCAD_architecture.jpg"
html_theme_options = {
"navigation_depth": 4,
"collapse_navigation": False,
"sticky_navigation": True,
"includehidden": True,
}
autodoc_default_options = {
"members": True,
"undoc-members": True,
"show-inheritance": True,
}
autosummary_generate = True

@ -0,0 +1,53 @@
EMCAD 项目文档
=================
.. toctree::
:maxdepth: 2
:caption: 目录:
usage/getting_started
usage/training
usage/configuration
api/modules
概述
----
EMCADEfficient Multi-Scale Context Adaptive Distillation是一个用于医学图像分割的高效深度学习模型。
主要特性
--------
- **多尺度特征提取**: 使用 MSCBMulti-Scale Context Block模块捕获不同尺度的上下文信息
- **轻量级设计**: 通过知识蒸馏和高效卷积操作减少模型参数量
- **多种监督策略**: 支持 mutation、deep_supervision 和 last_layer 三种监督模式
- **灵活的配置文件**: 使用 YAML 格式进行统一的参数管理
安装
----
.. code-block:: bash
pip install -r requirements.txt
pip install -e .
快速开始
--------
.. code-block:: python
from src.core.networks import EMCADNet
from src.utils.config import Config
config = Config.from_yaml("configs/default.yaml")
model = EMCADNet(
num_classes=config.model.num_classes,
encoder=config.model.encoder,
)
引用
----
如果您在研究中使用了 EMCAD请引用我们的工作。
.. [1] Author, "EMCAD: Efficient Multi-Scale Context Adaptive Distillation for Medical Image Segmentation", 2024.

@ -0,0 +1,99 @@
配置指南
=========
配置文件结构
------------
EMCAD 使用 YAML 格式的配置文件,所有参数集中管理:
.. code-block:: yaml
# 数据集配置
dataset:
name: "Synapse"
root_path: "/data/Synapse/train"
num_classes: 4
img_size: 224
# 模型配置
model:
encoder: "pvt_v2_b2"
num_classes: 4
kernel_sizes: [1, 3, 5]
expansion_factor: 2
# 训练配置
training:
max_epochs: 300
batch_size: 6
base_lr: 0.0001
# 输出配置
output:
snapshot_path: "./experiments/"
# 日志配置
logging:
level: "INFO"
数据集配置
----------
.. list-table::
:header-rows: 1
:widths: 20 80
* - 参数
- 描述
* - ``name``
- 数据集名称 (Synapse 或 ACDC)
* - ``root_path``
- 训练数据根目录
* - ``volume_path``
- 测试数据目录
* - ``num_classes``
- 分割类别数
* - ``img_size``
- 输入图像尺寸
模型配置
--------
.. list-table::
:header-rows: 1
:widths: 20 80
* - 参数
- 描述
* - ``encoder``
- 编码器类型 (pvt_v2_b0/b1/b2, resnet18)
* - ``kernel_sizes``
- MSCB 模块卷积核大小列表
* - ``expansion_factor``
- MSCB 模块扩展因子
* - ``dw_parallel``
- 深度卷积并行模式
* - ``add``
- 特征相加模式 (False 为拼接)
训练配置
--------
.. list-table::
:header-rows: 1
:widths: 20 80
* - 参数
- 描述
* - ``max_epochs``
- 最大训练轮数
* - ``batch_size``
- 批次大小
* - ``base_lr``
- 基础学习率
* - ``weight_decay``
- 权重衰减
* - ``w_ce``
- 交叉熵损失权重
* - ``w_dice``
- Dice 损失权重

@ -0,0 +1,68 @@
快速开始
=========
环境安装
--------
1. 克隆项目仓库
.. code-block:: bash
git clone https://github.com/your-repo/EMCAD.git
cd EMCAD
2. 安装依赖
.. code-block:: bash
pip install -r requirements.txt
pip install -e .
3. 数据准备
按照数据集说明准备 Synapse 或 ACDC 数据集,并将路径配置到 ``configs/default.yaml`` 中。
模型训练
--------
使用 Makefile 运行训练:
.. code-block:: bash
make train
或者直接运行训练脚本:
.. code-block:: bash
python scripts/train_synapse.py --cfg configs/default.yaml
模型推理
--------
.. code-block:: python
import torch
from src.core.networks import EMCADNet
from src.utils.config import Config
config = Config.from_yaml("configs/default.yaml")
model = EMCADNet(num_classes=config.model.num_classes)
checkpoint = torch.load("experiments/latest.pth")
model.load_state_dict(checkpoint["state_dict"])
model.eval()
with torch.no_grad():
output = model(input_tensor)
配置修改
--------
编辑 ``configs/default.yaml`` 文件来调整训练参数:
.. code-block:: yaml
training:
max_epochs: 300
batch_size: 6
base_lr: 0.0001

@ -0,0 +1,6 @@
.. toctree::
:maxdepth: 1
getting_started
training
configuration

@ -0,0 +1,83 @@
训练指南
=========
基本训练
--------
EMCAD 项目使用 Trainer 类封装完整的训练流程。
.. code-block:: python
from src.utils.config import Config
from src.core.networks import EMCADNet
from src.utils.dataloader import get_loader
from src.utils.trainer import Trainer
config = Config.from_yaml("configs/default.yaml")
train_loader = get_loader(
config.dataset.root_path,
config.dataset.list_dir,
config.dataset.volume_path,
split="train",
img_size=config.dataset.img_size,
batch_size=config.training.batch_size,
)
model = EMCADNet(
num_classes=config.model.num_classes,
encoder=config.model.encoder,
kernel_sizes=config.model.kernel_sizes,
expansion_factor=config.model.expansion_factor,
)
trainer = Trainer(
model=model,
config=config,
train_loader=train_loader,
)
trainer.train(epochs=config.training.max_epochs)
自定义训练
----------
早停配置
~~~~~~~~
在配置文件中启用早停:
.. code-block:: yaml
training:
early_stopping:
enabled: true
patience: 50
min_delta: 0.001
学习率调度
~~~~~~~~~~
支持多种学习率调度器:
.. code-block:: yaml
training:
lr_scheduler:
type: "CosineAnnealingLR"
T_max: 300
eta_min: 1e-6
监督策略
--------
EMCAD 支持三种监督策略:
1. **mutation**: 多尺度监督
2. **deep_supervision**: 深度监督
3. **last_layer**: 仅最后一层监督
.. code-block:: yaml
training:
supervision: "mutation"

@ -8,7 +8,7 @@ setup(
description="EMCADNet: Efficient Multi-scale Convolutional Attention Decoding for Medical Image Segmentation",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/yourusername/EMCADNet",
url="https://github.com/learnljs/EMCAD",
packages=find_packages(where="src"),
package_dir={"": "src"},
classifiers=[

@ -1,3 +1,4 @@
from .base import BaseModel
from .networks import EMCADNet
__all__ = ["EMCADNet"]
__all__ = ["BaseModel", "EMCADNet"]

@ -0,0 +1,184 @@
"""六、模型开发规范 - 抽象基类 BaseModel。"""
from abc import ABC, abstractmethod
import os
import torch
import torch.nn as nn
from typing import Any, Dict, List, Optional, Union
class BaseModel(nn.Module, ABC):
"""所有模型必须继承的抽象基类。
强制实现方法:
forward(x): 前向传播
get_loss(outputs, targets): 计算损失
提供通用方法:
save(filepath): 保存模型
load(filepath): 加载模型
count_parameters(): 统计参数数量
freeze_layers(layer_names): 冻结指定层
"""
def __init__(self):
"""初始化基类。"""
super(BaseModel, self).__init__()
self._is_training = True
@abstractmethod
def forward(self, x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]:
"""前向传播。
参数:
x: 输入张量
返回:
模型输出
"""
pass
@abstractmethod
def get_loss(
self,
outputs: Union[torch.Tensor, List[torch.Tensor]],
targets: torch.Tensor,
) -> torch.Tensor:
"""计算损失。
参数:
outputs: 模型输出
targets: 目标标签
返回:
损失值
"""
pass
def save(self, filepath: str, optimizer: Optional[torch.optim.Optimizer] = None) -> None:
"""保存模型到文件。
参数:
filepath: 保存路径
optimizer: 优化器可选保存训练状态
"""
checkpoint = {
"model_state_dict": self.state_dict(),
}
if optimizer is not None:
checkpoint["optimizer_state_dict"] = optimizer.state_dict()
os.makedirs(os.path.dirname(filepath), exist_ok=True)
torch.save(checkpoint, filepath)
def load(self, filepath: str, device: torch.device = None) -> None:
"""从文件加载模型。
参数:
filepath: 模型文件路径
device: 加载设备
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not os.path.exists(filepath):
raise FileNotFoundError(f"Model file not found: {filepath}")
checkpoint = torch.load(filepath, map_location=device)
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
self.load_state_dict(checkpoint["model_state_dict"])
else:
self.load_state_dict(checkpoint)
def count_parameters(self, trainable: bool = True) -> int:
"""统计模型参数量。
参数:
trainable: 只统计可训练参数
返回:
参数数量
"""
if trainable:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
return sum(p.numel() for p in self.parameters())
def get_parameter_groups(self) -> Dict[str, int]:
"""获取参数分组统计。
返回:
参数字典层名: 参数量
"""
groups = {}
for name, param in self.named_parameters():
if param.requires_grad:
if param.dim() == 4:
layer_type = "conv"
elif param.dim() == 1:
layer_type = "norm"
elif param.dim() == 2:
layer_type = "linear"
else:
layer_type = "other"
groups[f"{layer_type}_{name.split('.')[0]}"] = (
groups.get(f"{layer_type}_{name.split('.')[0]}", 0) + param.numel()
)
return groups
def freeze_layers(self, layer_names: List[str]) -> None:
"""冻结指定名称的层。
参数:
layer_names: 要冻结的层名称列表
"""
for name, param in self.named_parameters():
for layer_name in layer_names:
if layer_name in name:
param.requires_grad = False
def unfreeze_layers(self, layer_names: List[str]) -> None:
"""解冻指定名称的层。
参数:
layer_names: 要解冻的层名称列表
"""
for name, param in self.named_parameters():
for layer_name in layer_names:
if layer_name in name:
param.requires_grad = True
def get_gradients(self) -> Dict[str, torch.Tensor]:
"""获取所有可训练参数的梯度。
返回:
梯度字典
"""
gradients = {}
for name, param in self.named_parameters():
if param.requires_grad and param.grad is not None:
gradients[name] = param.grad
return gradients
def set_train_mode(self, mode: bool = True) -> None:
"""设置训练/评估模式。
参数:
mode: True 为训练模式False 为评估模式
"""
self._is_training = mode
self.train(mode)
def get_info(self) -> Dict[str, Any]:
"""获取模型信息。
返回:
模型信息字典
"""
return {
"model_name": self.__class__.__name__,
"total_parameters": self.count_parameters(trainable=False),
"trainable_parameters": self.count_parameters(trainable=True),
"layers": len(list(self.modules())),
}

@ -1,16 +1,29 @@
"""EMCADNet 网络定义。"""
import logging
from typing import List, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.core.base import BaseModel
from src.core.decoders import EMCAD
from src.core.pvtv2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b3, pvt_v2_b4, pvt_v2_b5
from src.core.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from src.utils.utils import DiceLoss
class EMCADNet(BaseModel):
"""EMCAD 端到端网络封装。
class EMCADNet(nn.Module):
"""EMCAD 端到端网络封装。"""
继承自 BaseModel实现了必须的方法
forward(x): 前向传播
get_loss(outputs, targets): 计算损失
同时继承了 BaseModel 提供的通用方法
save(filepath), load(filepath), count_parameters(), freeze_layers(layer_names)
"""
def __init__(
self,
@ -30,6 +43,10 @@ class EMCADNet(nn.Module):
if kernel_sizes is None:
kernel_sizes = [1, 3, 5]
self.num_classes = num_classes
self.encoder = encoder
self.pretrain = pretrain
self.conv = nn.Sequential(
nn.Conv2d(1, 3, kernel_size=1),
nn.BatchNorm2d(3),
@ -76,7 +93,7 @@ class EMCADNet(nn.Module):
self.backbone = resnet152(pretrained=pretrain)
channels = [2048, 1024, 512, 256]
else:
print(
logging.warning(
"Encoder not implemented! Continuing with default encoder pvt_v2_b2."
)
self.backbone = pvt_v2_b2()
@ -90,9 +107,9 @@ class EMCADNet(nn.Module):
model_dict.update(state_dict)
self.backbone.load_state_dict(model_dict)
print(
"Model %s created, param count: %d"
% (encoder + " backbone: ", sum(m.numel() for m in self.backbone.parameters()))
logging.info(
"Model %s created, param count: %d",
encoder + " backbone: ", sum(m.numel() for m in self.backbone.parameters())
)
self.decoder = EMCAD(
@ -105,9 +122,9 @@ class EMCADNet(nn.Module):
activation=activation,
)
print(
"Model %s created, param count: %d"
% ("EMCAD decoder: ", sum(m.numel() for m in self.decoder.parameters()))
logging.info(
"Model %s created, param count: %d",
"EMCAD decoder: ", sum(m.numel() for m in self.decoder.parameters())
)
self.out_head4 = nn.Conv2d(channels[0], num_classes, 1)
@ -115,8 +132,16 @@ class EMCADNet(nn.Module):
self.out_head2 = nn.Conv2d(channels[2], num_classes, 1)
self.out_head1 = nn.Conv2d(channels[3], num_classes, 1)
def forward(self, x, mode="test"):
"""前向计算。"""
def forward(self, x: torch.Tensor, mode: str = "test") -> List[torch.Tensor]:
"""前向计算。
参数:
x: 输入张量
mode: 模式 ('train' 'test')
返回:
输出张量列表 [p4, p3, p2, p1]
"""
if x.size()[1] == 1:
x = self.conv(x)
@ -136,10 +161,43 @@ class EMCADNet(nn.Module):
return [p4, p3, p2, p1]
def get_loss(
self,
outputs: List[torch.Tensor],
targets: torch.Tensor,
w_ce: float = 0.3,
w_dice: float = 0.7,
) -> torch.Tensor:
"""计算损失。
参数:
outputs: 模型输出列表
targets: 目标标签
w_ce: 交叉熵损失权重
w_dice: Dice 损失权重
返回:
总损失值
"""
from torch.nn.modules.loss import CrossEntropyLoss
ce_loss_fn = CrossEntropyLoss()
dice_loss_fn = DiceLoss(self.num_classes)
loss = 0.0
for output in outputs:
loss_ce = ce_loss_fn(output, targets.squeeze(1).long())
loss_dice = dice_loss_fn(output, targets, softmax=True)
loss += w_ce * loss_ce + w_dice * loss_dice
return loss
if __name__ == "__main__":
logger.info("Testing EMCADNet forward pass")
model = EMCADNet().cuda()
input_tensor = torch.randn(1, 3, 352, 352).cuda()
outputs = model(input_tensor)
print(outputs[0].size(), outputs[1].size(), outputs[2].size(), outputs[3].size())
logger.info("Output sizes: %s, %s, %s, %s",
outputs[0].size(), outputs[1].size(), outputs[2].size(), outputs[3].size())

@ -4,11 +4,14 @@
可作为心脏图像分割模型的编码器输出多尺度特征图用于解码器
"""
import logging
import math
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
logger = logging.getLogger(__name__)
__all__ = [
"ResNet",
"resnet18",
@ -317,9 +320,9 @@ def resnet34(pretrained=False, **kwargs):
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
print("Using pretrained weight!")
logger.info("Using pretrained weight!")
pretrained_dict = model_zoo.load_url(model_urls["resnet34"])
print("Pretrain model has been loaded!")
logger.info("Pretrain model has been loaded!")
model.load_state_dict(pretrained_dict)
return model

@ -1,10 +1,14 @@
from .config import Config, ConfigFactory
from .dataset_synapse import Synapse_dataset, RandomGenerator
from .dataset_ACDC import ACDCdataset
from .dataloader import get_loader
from .trainer import trainer_synapse, trainer_ACDC
from .trainer import Trainer, trainer_synapse, trainer_ACDC
from .utils import DiceLoss, powerset, val_single_volume, test_single_volume
__all__ = [
"Config",
"ConfigFactory",
"Trainer",
"Synapse_dataset",
"RandomGenerator",
"ACDCdataset",

@ -0,0 +1,302 @@
"""五、配置管理 - 使用 dataclass 封装配置。
提供 from_yaml() to_dict() 方法
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import yaml
@dataclass
class DatasetConfig:
"""数据集配置。"""
name: str = "ACDC"
root_path: str = "/data/ACDC/train"
volume_path: str = "/data/ACDC/test"
list_dir: str = "/data/ACDC/lists_ACDC"
num_classes: int = 4
z_spacing: int = 1
img_size: int = 224
def to_dict(self) -> Dict[str, Any]:
"""转换为字典。"""
return {
"name": self.name,
"root_path": self.root_path,
"volume_path": self.volume_path,
"list_dir": self.list_dir,
"num_classes": self.num_classes,
"z_spacing": self.z_spacing,
"img_size": self.img_size,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DatasetConfig":
"""从字典创建。"""
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
@dataclass
class ModelConfig:
"""模型配置。"""
encoder: str = "pvt_v2_b2"
num_classes: int = 4
kernel_sizes: List[int] = field(default_factory=lambda: [1, 3, 5])
expansion_factor: int = 2
dw_parallel: bool = True
add: bool = True
lgag_ks: int = 3
activation: str = "relu"
pretrain: bool = True
pretrained_dir: str = "./model_pth/"
def to_dict(self) -> Dict[str, Any]:
"""转换为字典。"""
return {
"encoder": self.encoder,
"num_classes": self.num_classes,
"kernel_sizes": self.kernel_sizes,
"expansion_factor": self.expansion_factor,
"dw_parallel": self.dw_parallel,
"add": self.add,
"lgag_ks": self.lgag_ks,
"activation": self.activation,
"pretrain": self.pretrain,
"pretrained_dir": self.pretrained_dir,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ModelConfig":
"""从字典创建。"""
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
@dataclass
class LossConfig:
"""损失函数配置。"""
w_ce: float = 0.3
w_dice: float = 0.7
def to_dict(self) -> Dict[str, Any]:
"""转换为字典。"""
return {
"w_ce": self.w_ce,
"w_dice": self.w_dice,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "LossConfig":
"""从字典创建。"""
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
@dataclass
class EarlyStoppingConfig:
"""早停配置。"""
enabled: bool = False
patience: int = 50
min_delta: float = 0.001
def to_dict(self) -> Dict[str, Any]:
"""转换为字典。"""
return {
"enabled": self.enabled,
"patience": self.patience,
"min_delta": self.min_delta,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "EarlyStoppingConfig":
"""从字典创建。"""
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
@dataclass
class LRSchedulerConfig:
"""学习率调度器配置。"""
type: str = "CosineAnnealingLR"
T_max: int = 300
eta_min: float = 1e-6
def to_dict(self) -> Dict[str, Any]:
"""转换为字典。"""
return {
"type": self.type,
"T_max": self.T_max,
"eta_min": self.eta_min,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "LRSchedulerConfig":
"""从字典创建。"""
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
@dataclass
class TrainingConfig:
"""训练配置。"""
max_epochs: int = 300
max_iterations: int = 50000
batch_size: int = 6
base_lr: float = 0.0001
weight_decay: float = 0.0001
seed: int = 2222
deterministic: bool = True
n_gpu: int = 1
loss: LossConfig = field(default_factory=LossConfig)
supervision: str = "mutation"
early_stopping: EarlyStoppingConfig = field(default_factory=EarlyStoppingConfig)
lr_scheduler: LRSchedulerConfig = field(default_factory=LRSchedulerConfig)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典。"""
return {
"max_epochs": self.max_epochs,
"max_iterations": self.max_iterations,
"batch_size": self.batch_size,
"base_lr": self.base_lr,
"weight_decay": self.weight_decay,
"seed": self.seed,
"deterministic": self.deterministic,
"n_gpu": self.n_gpu,
"loss": self.loss.to_dict(),
"supervision": self.supervision,
"early_stopping": self.early_stopping.to_dict(),
"lr_scheduler": self.lr_scheduler.to_dict(),
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TrainingConfig":
"""从字典创建。"""
if "loss" in data and isinstance(data["loss"], dict):
data["loss"] = LossConfig.from_dict(data["loss"])
if "early_stopping" in data and isinstance(data["early_stopping"], dict):
data["early_stopping"] = EarlyStoppingConfig.from_dict(data["early_stopping"])
if "lr_scheduler" in data and isinstance(data["lr_scheduler"], dict):
data["lr_scheduler"] = LRSchedulerConfig.from_dict(data["lr_scheduler"])
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
@dataclass
class OutputConfig:
"""输出配置。"""
snapshot_path: str = "./experiments/"
save_interval: int = 50
log_interval: int = 50
def to_dict(self) -> Dict[str, Any]:
"""转换为字典。"""
return {
"snapshot_path": self.snapshot_path,
"save_interval": self.save_interval,
"log_interval": self.log_interval,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OutputConfig":
"""从字典创建。"""
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
@dataclass
class LoggingConfig:
"""日志配置。"""
level: str = "INFO"
format: str = "[%(asctime)s.%(msecs)03d] %(message)s"
log_file: str = "train.log"
def to_dict(self) -> Dict[str, Any]:
"""转换为字典。"""
return {
"level": self.level,
"format": self.format,
"log_file": self.log_file,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "LoggingConfig":
"""从字典创建。"""
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
@dataclass
class Config:
"""项目主配置类(支持嵌套)。
使用 dataclass 封装所有配置提供 from_yaml() to_dict() 方法
"""
dataset: DatasetConfig = field(default_factory=DatasetConfig)
model: ModelConfig = field(default_factory=ModelConfig)
training: TrainingConfig = field(default_factory=TrainingConfig)
output: OutputConfig = field(default_factory=OutputConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典。"""
return {
"dataset": self.dataset.to_dict(),
"model": self.model.to_dict(),
"training": self.training.to_dict(),
"output": self.output.to_dict(),
"logging": self.logging.to_dict(),
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Config":
"""从字典创建。"""
config = cls()
if "dataset" in data and isinstance(data["dataset"], dict):
config.dataset = DatasetConfig.from_dict(data["dataset"])
if "model" in data and isinstance(data["model"], dict):
config.model = ModelConfig.from_dict(data["model"])
if "training" in data and isinstance(data["training"], dict):
config.training = TrainingConfig.from_dict(data["training"])
if "output" in data and isinstance(data["output"], dict):
config.output = OutputConfig.from_dict(data["output"])
if "logging" in data and isinstance(data["logging"], dict):
config.logging = LoggingConfig.from_dict(data["logging"])
return config
@classmethod
def from_yaml(cls, filepath: str) -> "Config":
"""从 YAML 文件加载配置。
参数:
filepath: YAML 文件路径
返回:
Config 实例
"""
with open(filepath, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
return cls.from_dict(data) if data else cls()
def to_yaml(self, filepath: str) -> None:
"""保存配置到 YAML 文件。
参数:
filepath: YAML 文件路径
"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, "w", encoding="utf-8") as f:
yaml.dump(self.to_dict(), f, default_flow_style=False, allow_unicode=True)
def update(self, **kwargs) -> None:
"""更新配置项。
参数:
kwargs: 配置更新键值对
"""
for key, value in kwargs.items():
if hasattr(self, key):
current = getattr(self, key)
if isinstance(current, (DatasetConfig, ModelConfig, TrainingConfig, OutputConfig, LoggingConfig)):
if isinstance(value, dict):
current = current.from_dict(value)
setattr(self, key, current)
else:
setattr(self, key, value)
import os

@ -5,6 +5,7 @@
支持训练和测试阶段的心脏图像预处理与增强操作
"""
import logging
import os
import random
@ -14,6 +15,8 @@ import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
logger = logging.getLogger(__name__)
class PolypDataset(data.Dataset):
"""心脏图像分割数据集。"""
@ -50,7 +53,7 @@ class PolypDataset(data.Dataset):
self.size = len(self.images)
if self.augmentations == "True":
print("Using RandomRotation, RandomFlip")
logger.info("Using RandomRotation, RandomFlip")
self.img_transform = transforms.Compose(
[
transforms.RandomRotation(

@ -1,5 +1,6 @@
"""将 Synapse 数据集预处理为 NPZ/H5 格式。"""
import logging
import os
from time import time
@ -7,6 +8,8 @@ import h5py
import nibabel as nib
import numpy as np
logger = logging.getLogger(__name__)
def process_split(split, ct_path, seg_path, save_path, upper, lower):
"""处理一个数据划分并保存结果。
@ -59,8 +62,8 @@ def process_split(split, ct_path, seg_path, save_path, upper, lower):
label=seg_array_s,
)
print("already use {:.3f} min".format((time() - start_time) / 60))
print("-----------")
logger.info("Processing completed in %.3f min", (time() - start_time) / 60)
logger.info("---")
def main():

@ -1,5 +1,6 @@
"""将 Synapse 数据集预处理为多帧 NPZ/H5 格式。"""
import logging
import os
from time import time
@ -7,6 +8,8 @@ import h5py
import nibabel as nib
import numpy as np
logger = logging.getLogger(__name__)
def process_split(split, ct_path, seg_path, save_path, upper, lower):
"""处理一个数据划分并保存结果。
@ -66,9 +69,9 @@ def process_split(split, ct_path, seg_path, save_path, upper, lower):
label=seg_array_s,
)
print("already use {:.3f} min".format((time() - start_time) / 60))
print("-----------")
print("max_size " + str(min_size))
logger.info("Processing completed in %.3f min", (time() - start_time) / 60)
logger.info("---")
logger.info("max_size: %s", min_size)
return min_size

@ -1,78 +1,432 @@
"""Synapse 与 ACDC 的训练流程。"""
"""七、训练流程 - 使用 Trainer 类封装训练循环。
支持:
- 训练/验证分离
- 学习率调度
- 早停early stopping
- 实验跟踪自动创建 experiments/exp_时间/ 目录保存 config + best model
"""
import logging
import os
import random
import sys
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from torch.nn.modules.loss import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from src.utils.dataset_ACDC import ACDCdataset
from src.utils.dataset_synapse import RandomGenerator, Synapse_dataset
from src.core.base import BaseModel
from src.utils.config import Config, TrainingConfig
from src.utils.utils import DiceLoss, powerset, val_single_volume
def inference(args, model, best_performance):
"""在 Synapse 测试集上进行验证。
参数:
参数: 训练参数
model: 模型
best_performance: 当前最佳指标
class Trainer:
"""训练器类,封装完整的训练流程。
返回:
平均性能指标
支持:
- 训练/验证分离
- 学习率调度
- 早停
- 实验跟踪自动创建 experiments/exp_时间/ 目录
"""
db_test = Synapse_dataset(
base_dir=args.volume_path,
split="test_vol",
list_dir=args.list_dir,
nclass=args.num_classes,
)
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
logging.info("%s test iterations per epoch", len(testloader))
model.eval()
metric_list = 0.0
for _, sampled_batch in tqdm(enumerate(testloader)):
image, label, case_name = (
sampled_batch["image"],
sampled_batch["label"],
sampled_batch["case_name"][0],
def __init__(
self,
model: nn.Module,
config: Union[Config, TrainingConfig],
train_loader: DataLoader,
val_loader: Optional[DataLoader] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
criterion: Optional[Callable] = None,
device: Optional[torch.device] = None,
experiment_name: Optional[str] = None,
):
"""初始化训练器。
参数:
model: 模型
config: 配置
train_loader: 训练数据加载器
val_loader: 验证数据加载器可选
optimizer: 优化器可选
criterion: 损失函数可选
device: 训练设备
experiment_name: 实验名称可选
"""
self.config = config if isinstance(config, TrainingConfig) else config.training
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.optimizer = optimizer or optim.AdamW(
self.model.parameters(),
lr=self.config.base_lr,
weight_decay=self.config.weight_decay,
)
metric_i = val_single_volume(
image,
label,
model,
classes=args.num_classes,
patch_size=[args.img_size, args.img_size],
case=case_name,
z_spacing=args.z_spacing,
if criterion is None:
self.criterion = self._default_criterion
else:
self.criterion = criterion
self.lr_scheduler = self._setup_lr_scheduler()
self.scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
self.writer = None
self.experiment_path = None
self.experiment_name = experiment_name or f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.current_epoch = 0
self.current_iteration = 0
self.best_metric = float("-inf")
self.early_stopping_patience = self.config.early_stopping.patience
self.early_stopping_counter = 0
self.early_stopping_min_delta = self.config.early_stopping.min_delta
self.logger = logging.getLogger(__name__)
def _default_criterion(
self,
outputs: Union[torch.Tensor, List[torch.Tensor]],
targets: torch.Tensor,
) -> torch.Tensor:
"""默认损失函数。"""
if hasattr(self.model, "get_loss"):
return self.model.get_loss(outputs, targets)
else:
w_ce = self.config.loss.w_ce
w_dice = self.config.loss.w_dice
num_classes = getattr(self.model, "num_classes", 2)
ce_loss_fn = CrossEntropyLoss()
dice_loss_fn = DiceLoss(num_classes)
loss = 0.0
for output in outputs if isinstance(outputs, list) else [outputs]:
loss_ce = ce_loss_fn(output, targets.squeeze(1).long())
loss_dice = dice_loss_fn(output, targets, softmax=True)
loss += w_ce * loss_ce + w_dice * loss_dice
return loss
def _setup_lr_scheduler(self) -> Optional[torch.optim.lr_scheduler._LRScheduler]:
"""设置学习率调度器。"""
scheduler_type = self.config.lr_scheduler.type
if scheduler_type == "CosineAnnealingLR":
return optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=self.config.lr_scheduler.T_max,
eta_min=self.config.lr_scheduler.eta_min,
)
elif scheduler_type == "StepLR":
return optim.lr_scheduler.StepLR(
self.optimizer,
step_size=30,
gamma=0.1,
)
elif scheduler_type == "ReduceLROnPlateau":
return optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode="max",
factor=0.5,
patience=10,
)
else:
return None
def _setup_experiment(self) -> str:
"""设置实验目录。"""
base_path = self.config.snapshot_path if hasattr(self.config, 'snapshot_path') else "./experiments/"
self.experiment_path = os.path.join(base_path, self.experiment_name)
os.makedirs(self.experiment_path, exist_ok=True)
os.makedirs(os.path.join(self.experiment_path, "checkpoints"), exist_ok=True)
os.makedirs(os.path.join(self.experiment_path, "logs"), exist_ok=True)
self.writer = SummaryWriter(os.path.join(self.experiment_path, "logs"))
self._setup_logging()
if isinstance(self.config, Config):
config_path = os.path.join(self.experiment_path, "config.yaml")
self.config.to_yaml(config_path)
self.logger.info(f"Experiment path: {self.experiment_path}")
return self.experiment_path
def _setup_logging(self) -> None:
"""配置日志记录。"""
log_file = os.path.join(self.experiment_path, "logs", "train.log")
file_handler = logging.FileHandler(log_file, encoding="utf-8")
console_handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
"[%(asctime)s.%(msecs)03d] %(message)s",
datefmt="%H:%M:%S",
)
metric_list += np.array(metric_i)
metric_list = metric_list / len(db_test)
performance = np.mean(metric_list, axis=0)
logging.info(
"Testing performance in val model: mean_dice : %f, best_dice : %f",
performance,
best_performance,
)
return performance
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
self.logger.addHandler(file_handler)
self.logger.addHandler(console_handler)
self.logger.setLevel(logging.INFO)
def train(self, epochs: Optional[int] = None) -> Dict[str, Any]:
"""执行训练流程。
参数:
epochs: 训练轮数覆盖配置中的值
返回:
训练历史记录
"""
if self.experiment_path is None:
self._setup_experiment()
max_epochs = epochs or self.config.max_epochs
save_interval = getattr(self.config, 'save_interval', 50)
log_interval = getattr(self.config, 'log_interval', 50)
history = {
"train_loss": [],
"val_metric": [],
"lr": [],
}
self.logger.info(f"Starting training for {max_epochs} epochs")
self.logger.info(f"Training samples: {len(self.train_loader)}")
for epoch in range(max_epochs):
self.current_epoch = epoch
train_loss = self._train_epoch()
history["train_loss"].append(train_loss)
if self.lr_scheduler is not None:
current_lr = self.optimizer.param_groups[0]["lr"]
history["lr"].append(current_lr)
if isinstance(self.lr_scheduler, optim.lr_scheduler.ReduceLROnPlateau):
if self.val_loader is not None:
val_metric = self._validate()
history["val_metric"].append(val_metric)
self.lr_scheduler.step(val_metric)
else:
self.lr_scheduler.step()
if self.val_loader is not None:
val_metric = self._validate()
history["val_metric"].append(val_metric)
if val_metric > self.best_metric + self.early_stopping_min_delta:
self.best_metric = val_metric
self.early_stopping_counter = 0
self._save_checkpoint("best.pth")
else:
self.early_stopping_counter += 1
if self.early_stopping_counter >= self.early_stopping_patience:
self.logger.info(f"Early stopping at epoch {epoch}")
break
if (epoch + 1) % save_interval == 0 or epoch == max_epochs - 1:
self._save_checkpoint(f"epoch_{epoch}.pth")
if (epoch + 1) % log_interval == 0:
self.logger.info(
f"Epoch {epoch}/{max_epochs-1} - Loss: {train_loss:.4f}"
)
if self.writer is not None:
self.writer.add_scalar("loss/train", train_loss, epoch)
if self.val_loader is not None:
self.writer.add_scalar("metric/val", val_metric, epoch)
self._save_checkpoint("last.pth")
if self.writer is not None:
self.writer.close()
self.logger.info("Training finished!")
return history
def _train_epoch(self) -> float:
"""训练一个 epoch。"""
self.model.train()
total_loss = 0.0
for batch_idx, sampled_batch in enumerate(self.train_loader):
image_batch = sampled_batch["image"].to(self.device)
label_batch = sampled_batch["label"].to(self.device)
self.optimizer.zero_grad()
outputs = self.model(image_batch, mode="train")
if not isinstance(outputs, list):
outputs = [outputs]
if self.current_epoch == 0 and batch_idx == 0:
n_outs = len(outputs)
out_idxs = list(np.arange(n_outs))
if self.config.supervision == "mutation":
supervision_sets = [x for x in powerset(out_idxs)]
elif self.config.supervision == "deep_supervision":
supervision_sets = [[x] for x in out_idxs]
else:
supervision_sets = [[-1]]
loss = 0.0
w_ce = self.config.loss.w_ce
w_dice = self.config.loss.w_dice
for s in supervision_sets:
iout = 0.0
if s == []:
continue
for idx in range(len(s)):
iout += outputs[s[idx]]
loss_ce = self.criterion(iout, label_batch)
loss += w_ce * loss_ce
for s in supervision_sets:
iout = 0.0
if s == []:
continue
for idx in range(len(s)):
iout += outputs[s[idx]]
loss_dice = self.criterion(iout, label_batch)
loss += w_dice * loss_dice
loss.backward()
self.optimizer.step()
total_loss += loss.item()
self.current_iteration += 1
return total_loss / max(len(self.train_loader), 1)
@torch.no_grad()
def _validate(self) -> float:
"""验证模型。"""
self.model.eval()
total_metric = 0.0
for sampled_batch in self.val_loader:
image_batch = sampled_batch["image"].to(self.device)
label_batch = sampled_batch["label"].to(self.device)
outputs = self.model(image_batch, mode="test")
if isinstance(outputs, list):
outputs = outputs[-1]
metric_i = self._compute_metric(outputs, label_batch)
total_metric += metric_i
return total_metric / max(len(self.val_loader), 1)
def _compute_metric(self, output: torch.Tensor, target: torch.Tensor) -> float:
"""计算评估指标。"""
num_classes = getattr(self.model, "num_classes", 4)
output_mask = output.argmax(dim=1)
target_mask = target.squeeze(1)
dice_scores = []
for c in range(num_classes):
pred_c = (output_mask == c).float()
target_c = (target_mask == c).float()
intersection = (pred_c * target_c).sum()
union = pred_c.sum() + target_c.sum()
if union > 0:
dice_scores.append(2.0 * intersection / union)
else:
dice_scores.append(1.0)
return np.mean(dice_scores)
def _save_checkpoint(self, filename: str) -> None:
"""保存检查点。"""
checkpoint_dir = os.path.join(self.experiment_path, "checkpoints")
filepath = os.path.join(checkpoint_dir, filename)
checkpoint = {
"epoch": self.current_epoch,
"iteration": self.current_iteration,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"best_metric": self.best_metric,
}
if self.lr_scheduler is not None:
checkpoint["lr_scheduler_state_dict"] = self.lr_scheduler.state_dict()
torch.save(checkpoint, filepath)
self.logger.info(f"Checkpoint saved: {filepath}")
def load_checkpoint(self, filepath: str) -> None:
"""加载检查点。
参数:
filepath: 检查点文件路径
"""
checkpoint = torch.load(filepath, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if "epoch" in checkpoint:
self.current_epoch = checkpoint["epoch"] + 1
if "iteration" in checkpoint:
self.current_iteration = checkpoint["iteration"]
if "best_metric" in checkpoint:
self.best_metric = checkpoint["best_metric"]
if self.lr_scheduler is not None and "lr_scheduler_state_dict" in checkpoint:
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
self.logger.info(f"Checkpoint loaded: {filepath}")
def predict(self, x: torch.Tensor) -> torch.Tensor:
"""推理预测。
参数:
x: 输入张量
返回:
预测结果
"""
self.model.eval()
with torch.no_grad():
x = x.to(self.device)
outputs = self.model(x, mode="test")
if isinstance(outputs, list):
outputs = outputs[-1]
return outputs
def trainer_synapse(args, model, snapshot_path):
"""Synapse 训练流程。"""
"""Synapse 训练流程(兼容旧接口)。"""
from src.utils.dataset_synapse import RandomGenerator, Synapse_dataset
from torchvision import transforms
logging.basicConfig(
filename=snapshot_path + "/log.txt",
filename=os.path.join(snapshot_path, "log.txt"),
level=logging.INFO,
format="[%(asctime)s.%(msecs)03d] %(message)s",
datefmt="%H:%M:%S",
@ -93,10 +447,9 @@ def trainer_synapse(args, model, snapshot_path):
[RandomGenerator(output_size=[args.img_size, args.img_size])]
),
)
print("The length of train set is: {}".format(len(db_train)))
logging.info("The length of train set is: %d", len(db_train))
def worker_init_fn(worker_id):
"""为数据加载器设置随机种子。"""
random.seed(args.seed + worker_id)
trainloader = DataLoader(
@ -110,25 +463,20 @@ def trainer_synapse(args, model, snapshot_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1 and args.n_gpu > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
logging.info("Let's use %d GPUs!", torch.cuda.device_count())
model = nn.DataParallel(model)
model.to(device)
model.train()
ce_loss = CrossEntropyLoss()
dice_loss = DiceLoss(num_classes)
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.0001)
writer = SummaryWriter(snapshot_path + "/log")
writer = SummaryWriter(os.path.join(snapshot_path, "log"))
iter_num = 0
max_epoch = args.max_epochs
max_iterations = args.max_epochs * len(trainloader)
logging.info(
"%s iterations per epoch. %s max iterations ",
len(trainloader),
max_iterations,
)
logging.info("%s iterations per epoch. %s max iterations", len(trainloader), max_iterations)
best_performance = 0.0
iterator = tqdm(range(max_epoch), ncols=70)
@ -156,7 +504,7 @@ def trainer_synapse(args, model, snapshot_path):
supervision_sets = [[x] for x in out_idxs]
else:
supervision_sets = [[-1]]
print(supervision_sets)
logging.info("Supervision sets: %s", supervision_sets)
loss = 0.0
w_ce, w_dice = 0.3, 0.7
@ -222,10 +570,56 @@ def trainer_synapse(args, model, snapshot_path):
return "Training Finished!"
def inference(args, model, best_performance):
"""在 Synapse 测试集上进行验证。"""
from src.utils.dataset_synapse import Synapse_dataset
from torch.utils.data import DataLoader
db_test = Synapse_dataset(
base_dir=args.volume_path,
split="test_vol",
list_dir=args.list_dir,
nclass=args.num_classes,
)
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
logging.info("%s test iterations per epoch", len(testloader))
model.eval()
metric_list = 0.0
for _, sampled_batch in tqdm(enumerate(testloader), total=len(testloader)):
image, label, case_name = (
sampled_batch["image"],
sampled_batch["label"],
sampled_batch["case_name"][0],
)
metric_i = val_single_volume(
image,
label,
model,
classes=args.num_classes,
patch_size=[args.img_size, args.img_size],
case=case_name,
z_spacing=args.z_spacing,
)
metric_list += np.array(metric_i)
metric_list = metric_list / len(db_test)
performance = np.mean(metric_list, axis=0)
logging.info(
"Testing performance in val model: mean_dice : %f, best_dice : %f",
performance,
best_performance,
)
return performance
def trainer_ACDC(args, model, snapshot_path):
"""ACDC 训练流程。"""
"""ACDC 训练流程(兼容旧接口)。"""
from src.utils.dataset_ACDC import ACDCdataset
from torchvision import transforms
logging.basicConfig(
filename=snapshot_path + "/log.txt",
filename=os.path.join(snapshot_path, "log.txt"),
level=logging.INFO,
format="[%(asctime)s.%(msecs)03d] %(message)s",
datefmt="%H:%M:%S",
@ -250,10 +644,9 @@ def trainer_ACDC(args, model, snapshot_path):
[RandomGenerator(output_size=[args.img_size, args.img_size])]
),
)
print("The length of train set is: {}".format(len(db_train)))
logging.info("The length of train set is: %d", len(db_train))
def worker_init_fn(worker_id):
"""为数据加载器设置随机种子。"""
random.seed(args.seed + worker_id)
trainloader = DataLoader(
@ -267,25 +660,20 @@ def trainer_ACDC(args, model, snapshot_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1 and args.n_gpu > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
logging.info("Let's use %d GPUs!", torch.cuda.device_count())
model = nn.DataParallel(model)
model.to(device)
model.train()
ce_loss = CrossEntropyLoss()
dice_loss = DiceLoss(num_classes)
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.0001)
writer = SummaryWriter(snapshot_path + "/log")
writer = SummaryWriter(os.path.join(snapshot_path, "log"))
iter_num = 0
max_epoch = args.max_epochs
max_iterations = args.max_epochs * len(trainloader)
logging.info(
"%s iterations per epoch. %s max iterations ",
len(trainloader),
max_iterations,
)
logging.info("%s iterations per epoch. %s max iterations", len(trainloader), max_iterations)
best_loss = 1e9
iterator = tqdm(range(max_epoch), ncols=70)
@ -314,7 +702,7 @@ def trainer_ACDC(args, model, snapshot_path):
supervision_sets = [[x] for x in out_idxs]
else:
supervision_sets = [[-1]]
print(supervision_sets)
logging.info("Supervision sets: %s", supervision_sets)
loss = 0.0
w_ce, w_dice = 0.3, 0.7
@ -364,9 +752,7 @@ def trainer_ACDC(args, model, snapshot_path):
best_loss = epoch_loss
best_path = os.path.join(snapshot_path, "best.pth")
torch.save(model.state_dict(), best_path)
logging.info(
"New best model saved to %s, loss=%.4f", best_path, best_loss
)
logging.info("New best model saved to %s, loss=%.4f", best_path, best_loss)
save_interval = 50
if (epoch_num + 1) % save_interval == 0 or epoch_num == max_epoch - 1:

@ -1,5 +1,8 @@
"""训练、评估与可视化相关的工具函数。"""
import logging
import os
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
@ -13,6 +16,8 @@ from segmentation_mask_overlay import overlay_masks
from thop import clever_format
from thop import profile
logger = logging.getLogger(__name__)
def powerset(seq):
"""生成序列的所有子集。
@ -121,9 +126,7 @@ def CalParams(model, input_tensor):
"""
flops, params = profile(model, inputs=(input_tensor,))
flops, params = clever_format([flops, params], "%.3f")
print("[Statistics Information]\nFLOPs: {}\nParams: {}".format(
flops, params
))
logger.info("FLOPs: %s, Params: %s", flops, params)
def one_hot_encoder(input_tensor, dataset, n_classes=None):
@ -486,25 +489,18 @@ def cal_params_flops(model, size, logger):
"""记录模型 FLOPs 与参数量信息。"""
input_tensor = torch.randn(1, 3, size, size).cuda()
flops, params = profile(model, inputs=(input_tensor,))
print("flops", flops / 1e9) # 打印计算量
print("params", params / 1e6) # 打印参数量
logger.info("FLOPs: %.3f G, Params: %.3f M", flops / 1e9, params / 1e6)
total = sum(p.numel() for p in model.parameters())
print("Total params: %.2fM" % (total / 1e6))
logger.info(
"flops: %s, params: %s, Total params: : %.4f",
flops / 1e9,
params / 1e6,
total / 1e6,
)
logger.info("Total params: %.2fM", total / 1e6)
def print_model_stats(model, input_size=(3, 224, 224)):
"""打印模型的 GMACs 与参数量。"""
total_params = sum(p.numel() for p in model.parameters())
print(f"Model created, param count: {total_params}")
logger.info("Model created, param count: %d", total_params)
macs, params = get_model_complexity_info(
model, input_size, as_strings=True, print_per_layer_stat=True
)
print(f"Model: {macs} GMACs, {params} parameters")
logger.info("Model: %s GMACs, %s parameters", macs, params)

@ -1,362 +1,264 @@
"""Synapse 数据集推理入口"""
"""Synapse 数据集测试模块"""
import argparse
import logging
import os
import random
import sys
import numpy as np
import pytest
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from tqdm import tqdm
sys.path.insert(0, os.path.abspath("../.."))
from src.core.networks import EMCADNet
from src.utils.dataset_synapse import Synapse_dataset
from src.utils.utils import test_single_volume
def build_parser():
"""构建推理参数解析器。"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--volume_path",
type=str,
default="./data/synapse/test_vol_h5_new",
help="root dir for validation volume data",
)
parser.add_argument("--dataset", type=str, default="Synapse", help="experiment_name")
parser.add_argument("--num_classes", type=int, default=9, help="output channel of network")
parser.add_argument("--list_dir", type=str, default="./lists/lists_Synapse", help="list dir")
parser.add_argument(
"--encoder",
type=str,
default="pvt_v2_b2",
help="Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...",
)
parser.add_argument(
"--expansion_factor",
type=int,
default=2,
help="expansion factor in MSCB block",
)
parser.add_argument(
"--kernel_sizes",
type=int,
nargs="+",
default=[1, 3, 5],
help="multi-scale kernel sizes in MSDC block",
)
parser.add_argument("--lgag_ks", type=int, default=3, help="Kernel size in LGAG")
parser.add_argument(
"--activation_mscb",
type=str,
default="relu6",
help="activation used in MSCB: relu6 or relu",
)
parser.add_argument(
"--no_dw_parallel",
action="store_true",
default=False,
help="use this flag to disable depth-wise parallel convolutions",
)
parser.add_argument(
"--concatenation",
action="store_true",
default=False,
help="use this flag to concatenate feature maps in MSDC block",
)
parser.add_argument(
"--no_pretrain",
action="store_true",
default=False,
help="use this flag to turn off loading pretrained enocder weights",
)
parser.add_argument(
"--pretrained_dir",
type=str,
default="./pretrained_pth/pvt/",
help="path to pretrained encoder dir",
)
parser.add_argument(
"--supervision",
type=str,
default="mutation",
help="loss supervision: mutation, deep_supervision or last_layer",
)
parser.add_argument(
"--max_iterations",
type=int,
default=30000,
help="maximum epoch number to train",
)
parser.add_argument(
"--max_epochs", type=int, default=300, help="maximum epoch number to train"
)
parser.add_argument("--batch_size", type=int, default=6, help="batch_size per gpu")
parser.add_argument(
"--base_lr", type=float, default=0.0001, help="segmentation network learning rate"
)
parser.add_argument("--img_size", type=int, default=224, help="input patch size of network input")
parser.add_argument(
"--is_savenii",
action="store_true",
default=True,
help="whether to save results during inference",
)
parser.add_argument(
"--test_save_dir", type=str, default="predictions", help="saving prediction as nii"
)
parser.add_argument(
"--deterministic", type=int, default=1, help="whether use deterministic training"
)
parser.add_argument("--seed", type=int, default=2222, help="random seed")
return parser
def get_class_names(num_classes):
"""根据类别数量返回类别名称列表。"""
if num_classes == 14:
return [
"spleen",
"right kidney",
"left kidney",
"gallbladder",
"esophagus",
"liver",
"stomach",
"aorta",
"inferior vena cava",
"portal vein and splenic vein",
"pancreas",
"right adrenal gland",
"left adrenal gland",
]
return [
"spleen",
"right kidney",
"left kidney",
"gallbladder",
"pancreas",
"liver",
"stomach",
"aorta",
]
def inference(args, model, class_names, test_save_path=None):
"""在 Synapse 上执行推理。"""
db_test = args.Dataset(
base_dir=args.volume_path,
split="test_vol",
list_dir=args.list_dir,
nclass=args.num_classes,
)
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
logging.info("%s test iterations per epoch", len(testloader))
model.eval()
metric_list = 0.0
for i_batch, sampled_batch in tqdm(enumerate(testloader)):
image, label, case_name = (
sampled_batch["image"],
sampled_batch["label"],
sampled_batch["case_name"][0],
from src.utils.config import Config
class TestEMCADNet:
"""EMCADNet 模型测试类。"""
@pytest.fixture
def config(self):
"""加载默认配置。"""
config_path = os.path.join(
os.path.dirname(__file__), "..", "configs", "default.yaml"
)
metric_i = test_single_volume(
image,
label,
model,
classes=args.num_classes,
patch_size=[args.img_size, args.img_size],
test_save_path=test_save_path,
case=case_name,
z_spacing=1,
class_names=class_names,
if os.path.exists(config_path):
return Config.from_yaml(config_path)
return Config()
@pytest.fixture
def model(self, config):
"""创建测试用模型。"""
return EMCADNet(
num_classes=config.model.num_classes,
encoder=config.model.encoder,
kernel_sizes=config.model.kernel_sizes,
expansion_factor=config.model.expansion_factor,
dw_parallel=config.model.dw_parallel,
add=config.model.add,
lgag_ks=config.model.lgag_ks,
activation=config.model.activation,
pretrain=False,
)
metric_list += np.array(metric_i)
logging.info(
"idx %d case %s mean_dice %f mean_hd95 %f, mean_jacard %f mean_asd %f",
i_batch,
case_name,
np.mean(metric_i, axis=0)[0],
np.mean(metric_i, axis=0)[1],
np.mean(metric_i, axis=0)[2],
np.mean(metric_i, axis=0)[3],
def test_model_initialization(self, model, config):
"""测试模型初始化。"""
assert model is not None
assert isinstance(model, torch.nn.Module)
assert model.num_classes == config.model.num_classes
def test_forward_pass(self, model):
"""测试前向传播。"""
batch_size = 2
img_size = 224
x = torch.randn(batch_size, 3, img_size, img_size)
with torch.no_grad():
outputs = model(x, mode="test")
assert outputs is not None
if isinstance(outputs, list):
assert len(outputs) > 0
assert outputs[0].shape == (batch_size, model.num_classes, img_size, img_size)
else:
assert outputs.shape == (batch_size, model.num_classes, img_size, img_size)
def test_training_mode(self, model):
"""测试训练模式前向传播。"""
batch_size = 2
img_size = 224
x = torch.randn(batch_size, 3, img_size, img_size)
model.train()
with torch.no_grad():
outputs = model(x, mode="train")
assert outputs is not None
def test_get_loss(self, model):
"""测试损失计算。"""
batch_size = 2
img_size = 224
num_classes = model.num_classes
x = torch.randn(batch_size, 3, img_size, img_size)
outputs = model(x, mode="test")
targets = torch.randint(0, num_classes, (batch_size, 1, img_size, img_size))
loss = model.get_loss(outputs, targets)
assert loss is not None
assert isinstance(loss, torch.Tensor)
assert loss.item() >= 0
def test_multi_scale_outputs(self, model):
"""测试多尺度输出。"""
batch_size = 2
img_size = 224
x = torch.randn(batch_size, 3, img_size, img_size)
with torch.no_grad():
outputs = model(x, mode="train")
assert isinstance(outputs, list)
assert len(outputs) == 4
def test_single_channel_input(self, model):
"""测试单通道输入。"""
batch_size = 2
img_size = 224
x = torch.randn(batch_size, 1, img_size, img_size)
with torch.no_grad():
outputs = model(x, mode="test")
assert outputs is not None
if isinstance(outputs, list):
assert outputs[0].shape == (batch_size, model.num_classes, img_size, img_size)
def test_parameter_count(self, model):
"""测试参数计数。"""
num_params = model.count_parameters()
assert isinstance(num_params, int)
assert num_params > 0
assert num_params > 1e6
def test_save_load_state_dict(self, model, tmp_path):
"""测试模型状态保存和加载。"""
state_dict = model.state_dict()
save_path = os.path.join(tmp_path, "model.pth")
model.save(save_path)
assert os.path.exists(save_path)
new_model = EMCADNet(
num_classes=model.num_classes,
encoder="pvt_v2_b2",
pretrain=False,
)
metric_list = metric_list / len(db_test)
for i in range(1, args.num_classes):
logging.info(
"Mean class (%d) %s mean_dice %f mean_hd95 %f, mean_jacard %f mean_asd %f",
i,
class_names[i - 1],
metric_list[i - 1][0],
metric_list[i - 1][1],
metric_list[i - 1][2],
metric_list[i - 1][3],
new_model.load_state_dict(save_path)
def test_to_device(self, model):
"""测试设备迁移。"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_device = model.to(device)
assert next(model_device.parameters()).device == device
batch_size = 2
img_size = 224
x = torch.randn(batch_size, 3, img_size, img_size).to(device)
with torch.no_grad():
outputs = model_device(x, mode="test")
assert outputs is not None
class TestConfig:
"""配置管理测试类。"""
def test_config_creation(self):
"""测试配置创建。"""
config = Config()
assert config is not None
assert hasattr(config, "dataset")
assert hasattr(config, "model")
assert hasattr(config, "training")
assert hasattr(config, "output")
assert hasattr(config, "logging")
def test_config_from_yaml(self):
"""测试从 YAML 加载配置。"""
config_path = os.path.join(
os.path.dirname(__file__), "..", "configs", "default.yaml"
)
performance = np.mean(metric_list, axis=0)[0]
mean_hd95 = np.mean(metric_list, axis=0)[1]
mean_jacard = np.mean(metric_list, axis=0)[2]
mean_asd = np.mean(metric_list, axis=0)[3]
logging.info(
"Testing performance in best val model: mean_dice : %f mean_hd95 : %f, "
"mean_jacard : %f mean_asd : %f",
performance,
mean_hd95,
mean_jacard,
mean_asd,
)
return "Testing Finished!"
def build_snapshot_path(args, dataset_name):
"""生成推理输出路径。"""
aggregation = "concat" if args.concatenation else "add"
dw_mode = "series" if args.no_dw_parallel else "parallel"
run = 1
args.exp = (
args.encoder
+ "_EMCAD_kernel_sizes_"
+ str(args.kernel_sizes)
+ "_dw_"
+ dw_mode
+ "_"
+ aggregation
+ "_lgag_ks_"
+ str(args.lgag_ks)
+ "_ef"
+ str(args.expansion_factor)
+ "_act_mscb_"
+ args.activation_mscb
+ "_loss_"
+ args.supervision
+ "_output_final_layer_Run"
+ str(run)
+ "_"
+ dataset_name
+ str(args.img_size)
)
snapshot_path = "model_pth/{}/{}".format(
args.exp,
args.encoder
+ "_EMCAD_kernel_sizes_"
+ str(args.kernel_sizes)
+ "_dw_"
+ dw_mode
+ "_"
+ aggregation
+ "_lgag_ks_"
+ str(args.lgag_ks)
+ "_ef"
+ str(args.expansion_factor)
+ "_act_mscb_"
+ args.activation_mscb
+ "_loss_"
+ args.supervision
+ "_output_final_layer_Run"
+ str(run),
)
snapshot_path = snapshot_path.replace("[", "").replace("]", "").replace(", ", "_")
if not args.no_pretrain:
snapshot_path = snapshot_path + "_pretrain"
if args.max_iterations != 50000:
snapshot_path = snapshot_path + "_" + str(args.max_iterations)[0:2] + "k"
if args.max_epochs != 300:
snapshot_path = snapshot_path + "_epo" + str(args.max_epochs)
snapshot_path = snapshot_path + "_bs" + str(args.batch_size)
if args.base_lr != 0.0001:
snapshot_path = snapshot_path + "_lr" + str(args.base_lr)
snapshot_path = snapshot_path + "_" + str(args.img_size)
if args.seed != 1234:
snapshot_path = snapshot_path + "_s" + str(args.seed)
return snapshot_path
def main():
"""主入口函数。"""
parser = build_parser()
args = parser.parse_args()
if not args.deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
dataset_config = {
"Synapse": {
"Dataset": Synapse_dataset,
"volume_path": args.volume_path,
"list_dir": args.list_dir,
"num_classes": args.num_classes,
"z_spacing": 1,
if os.path.exists(config_path):
config = Config.from_yaml(config_path)
assert config is not None
assert config.dataset is not None
assert config.model is not None
assert config.training is not None
def test_config_to_dict(self):
"""测试配置序列化。"""
config = Config()
config_dict = config.to_dict()
assert isinstance(config_dict, dict)
assert "dataset" in config_dict
assert "model" in config_dict
assert "training" in config_dict
def test_config_from_dict(self):
"""测试配置反序列化。"""
data = {
"dataset": {"name": "Test", "num_classes": 4},
"model": {"encoder": "pvt_v2_b2", "num_classes": 4},
}
}
dataset_name = args.dataset
args.num_classes = dataset_config[dataset_name]["num_classes"]
args.volume_path = dataset_config[dataset_name]["volume_path"]
args.Dataset = dataset_config[dataset_name]["Dataset"]
args.list_dir = dataset_config[dataset_name]["list_dir"]
args.z_spacing = dataset_config[dataset_name]["z_spacing"]
print(args.no_pretrain)
snapshot_path = build_snapshot_path(args, dataset_name)
model = EMCADNet(
num_classes=args.num_classes,
kernel_sizes=args.kernel_sizes,
expansion_factor=args.expansion_factor,
dw_parallel=not args.no_dw_parallel,
add=not args.concatenation,
lgag_ks=args.lgag_ks,
activation=args.activation_mscb,
encoder=args.encoder,
pretrain=not args.no_pretrain,
pretrained_dir=args.pretrained_dir,
)
model.cuda()
snapshot = os.path.join(snapshot_path, "best.pth")
if not os.path.exists(snapshot):
snapshot = snapshot.replace("best", "epoch_" + str(args.max_epochs - 1))
model.load_state_dict(torch.load(snapshot))
snapshot_name = snapshot_path.split("/")[-1]
log_folder = "test_log/test_log_" + args.exp
os.makedirs(log_folder, exist_ok=True)
logging.basicConfig(
filename=log_folder + "/" + snapshot_name + ".txt",
level=logging.INFO,
format="[%(asctime)s.%(msecs)03d] %(message)s",
datefmt="%H:%M:%S",
)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info(str(args))
logging.info(snapshot_name)
if args.is_savenii:
args.test_save_dir = os.path.join(snapshot_path, "predictions")
test_save_path = os.path.join(args.test_save_dir, args.exp, snapshot_name + "2")
os.makedirs(test_save_path, exist_ok=True)
else:
test_save_path = None
class_names = get_class_names(args.num_classes)
inference(args, model, class_names, test_save_path)
config = Config.from_dict(data)
assert config is not None
assert config.dataset.name == "Test"
assert config.model.num_classes == 4
class TestInference:
"""推理流程测试类。"""
@pytest.fixture
def trained_model(self):
"""创建预训练模型(用于推理测试)。"""
model = EMCADNet(
num_classes=9,
encoder="pvt_v2_b2",
kernel_sizes=[1, 3, 5],
expansion_factor=2,
dw_parallel=True,
add=True,
lgag_ks=3,
activation="relu",
pretrain=False,
)
return model
def test_inference_pipeline(self, trained_model):
"""测试推理流程。"""
trained_model.eval()
batch_size = 1
img_size = 224
x = torch.randn(batch_size, 3, img_size, img_size)
with torch.no_grad():
outputs = trained_model(x, mode="test")
predictions = torch.argmax(outputs, dim=1) if isinstance(outputs, torch.Tensor) else torch.argmax(outputs[0], dim=1)
assert predictions.shape == (batch_size, img_size, img_size)
def test_dice_calculation(self, trained_model):
"""测试 Dice 指标计算。"""
trained_model.eval()
batch_size = 1
img_size = 224
x = torch.randn(batch_size, 3, img_size, img_size)
target = torch.randint(0, 9, (batch_size, 1, img_size, img_size))
with torch.no_grad():
outputs = trained_model(x, mode="test")
loss = trained_model.get_loss(outputs, target)
assert loss is not None
assert isinstance(loss, torch.Tensor)
if __name__ == "__main__":
main()
pytest.main([__file__, "-v"])

Loading…
Cancel
Save