parent
d64d987ebe
commit
f6e9285936
@ -0,0 +1,211 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
Pipfile.lock
|
||||
|
||||
# poetry
|
||||
poetry.lock
|
||||
|
||||
# pdm
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
.idea/
|
||||
*.iml
|
||||
*.ipr
|
||||
*.iws
|
||||
.vscode/
|
||||
!.vscode/extensions.json
|
||||
!.vscode/settings.json
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
*.DS_Store
|
||||
|
||||
# Windows
|
||||
Thumbs.db
|
||||
ehthumbs.db
|
||||
Desktop.ini
|
||||
|
||||
# Project specific
|
||||
model_pth/*
|
||||
pretrained_pth/*
|
||||
!model_pth/.gitkeep
|
||||
!pretrained_pth/.gitkeep
|
||||
|
||||
# Data directories
|
||||
data/synapse/*
|
||||
data/ACDC/ACDC/train/*
|
||||
data/ACDC/ACDC/test/*
|
||||
data/ACDC/ACDC/valid/*
|
||||
!data/synapse/.gitkeep
|
||||
!data/ACDC/ACDC/train/.gitkeep
|
||||
!data/ACDC/ACDC/test/.gitkeep
|
||||
!data/ACDC/ACDC/valid/.gitkeep
|
||||
|
||||
# Training outputs
|
||||
snapshots/
|
||||
runs/
|
||||
experiments/
|
||||
predictions/
|
||||
*.log
|
||||
|
||||
# Checkpoint files
|
||||
*.pth
|
||||
*.pt
|
||||
|
||||
# numpy arrays data files
|
||||
*.npy
|
||||
*.npz
|
||||
|
||||
# Medical image formats
|
||||
*.nii
|
||||
*.nii.gz
|
||||
*.mha
|
||||
*.mhd
|
||||
|
||||
# Checkpoints
|
||||
checkpoint-*
|
||||
*.ckpt
|
||||
|
||||
# Virtual environment
|
||||
EMCADenv/
|
||||
emcadenv/
|
||||
|
||||
# SimpleITK wheel
|
||||
SimpleITK.whl
|
||||
|
||||
# TensorBoard
|
||||
events.out.tfevents.*
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.bak
|
||||
*~
|
||||
\#*\#
|
||||
@ -0,0 +1,40 @@
|
||||
case_002_volume_ED.npz
|
||||
case_002_volume_ES.npz
|
||||
case_003_volume_ED.npz
|
||||
case_003_volume_ES.npz
|
||||
case_008_volume_ED.npz
|
||||
case_008_volume_ES.npz
|
||||
case_009_volume_ED.npz
|
||||
case_009_volume_ES.npz
|
||||
case_012_volume_ED.npz
|
||||
case_012_volume_ES.npz
|
||||
case_014_volume_ED.npz
|
||||
case_014_volume_ES.npz
|
||||
case_017_volume_ED.npz
|
||||
case_017_volume_ES.npz
|
||||
case_024_volume_ED.npz
|
||||
case_024_volume_ES.npz
|
||||
case_042_volume_ED.npz
|
||||
case_042_volume_ES.npz
|
||||
case_048_volume_ED.npz
|
||||
case_048_volume_ES.npz
|
||||
case_049_volume_ED.npz
|
||||
case_049_volume_ES.npz
|
||||
case_053_volume_ED.npz
|
||||
case_053_volume_ES.npz
|
||||
case_055_volume_ED.npz
|
||||
case_055_volume_ES.npz
|
||||
case_064_volume_ED.npz
|
||||
case_064_volume_ES.npz
|
||||
case_067_volume_ED.npz
|
||||
case_067_volume_ES.npz
|
||||
case_079_volume_ED.npz
|
||||
case_079_volume_ES.npz
|
||||
case_081_volume_ED.npz
|
||||
case_081_volume_ES.npz
|
||||
case_088_volume_ED.npz
|
||||
case_088_volume_ES.npz
|
||||
case_092_volume_ED.npz
|
||||
case_092_volume_ES.npz
|
||||
case_095_volume_ED.npz
|
||||
case_095_volume_ES.npz
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,182 @@
|
||||
case_019_sliceED_0.npz
|
||||
case_019_sliceED_1.npz
|
||||
case_019_sliceED_10.npz
|
||||
case_019_sliceED_2.npz
|
||||
case_019_sliceED_3.npz
|
||||
case_019_sliceED_4.npz
|
||||
case_019_sliceED_5.npz
|
||||
case_019_sliceED_6.npz
|
||||
case_019_sliceED_7.npz
|
||||
case_019_sliceED_8.npz
|
||||
case_019_sliceED_9.npz
|
||||
case_019_sliceES_0.npz
|
||||
case_019_sliceES_1.npz
|
||||
case_019_sliceES_10.npz
|
||||
case_019_sliceES_2.npz
|
||||
case_019_sliceES_3.npz
|
||||
case_019_sliceES_4.npz
|
||||
case_019_sliceES_5.npz
|
||||
case_019_sliceES_6.npz
|
||||
case_019_sliceES_7.npz
|
||||
case_019_sliceES_8.npz
|
||||
case_019_sliceES_9.npz
|
||||
case_021_sliceED_0.npz
|
||||
case_021_sliceED_1.npz
|
||||
case_021_sliceED_2.npz
|
||||
case_021_sliceED_3.npz
|
||||
case_021_sliceED_4.npz
|
||||
case_021_sliceED_5.npz
|
||||
case_021_sliceED_6.npz
|
||||
case_021_sliceED_7.npz
|
||||
case_021_sliceED_8.npz
|
||||
case_021_sliceED_9.npz
|
||||
case_021_sliceES_0.npz
|
||||
case_021_sliceES_1.npz
|
||||
case_021_sliceES_2.npz
|
||||
case_021_sliceES_3.npz
|
||||
case_021_sliceES_4.npz
|
||||
case_021_sliceES_5.npz
|
||||
case_021_sliceES_6.npz
|
||||
case_021_sliceES_7.npz
|
||||
case_021_sliceES_8.npz
|
||||
case_021_sliceES_9.npz
|
||||
case_029_sliceED_0.npz
|
||||
case_029_sliceED_1.npz
|
||||
case_029_sliceED_10.npz
|
||||
case_029_sliceED_2.npz
|
||||
case_029_sliceED_3.npz
|
||||
case_029_sliceED_4.npz
|
||||
case_029_sliceED_5.npz
|
||||
case_029_sliceED_6.npz
|
||||
case_029_sliceED_7.npz
|
||||
case_029_sliceED_8.npz
|
||||
case_029_sliceED_9.npz
|
||||
case_029_sliceES_0.npz
|
||||
case_029_sliceES_1.npz
|
||||
case_029_sliceES_10.npz
|
||||
case_029_sliceES_2.npz
|
||||
case_029_sliceES_3.npz
|
||||
case_029_sliceES_4.npz
|
||||
case_029_sliceES_5.npz
|
||||
case_029_sliceES_6.npz
|
||||
case_029_sliceES_7.npz
|
||||
case_029_sliceES_8.npz
|
||||
case_029_sliceES_9.npz
|
||||
case_033_sliceED_0.npz
|
||||
case_033_sliceED_1.npz
|
||||
case_033_sliceED_2.npz
|
||||
case_033_sliceED_3.npz
|
||||
case_033_sliceED_4.npz
|
||||
case_033_sliceED_5.npz
|
||||
case_033_sliceED_6.npz
|
||||
case_033_sliceED_7.npz
|
||||
case_033_sliceED_8.npz
|
||||
case_033_sliceED_9.npz
|
||||
case_033_sliceES_0.npz
|
||||
case_033_sliceES_1.npz
|
||||
case_033_sliceES_2.npz
|
||||
case_033_sliceES_3.npz
|
||||
case_033_sliceES_4.npz
|
||||
case_033_sliceES_5.npz
|
||||
case_033_sliceES_6.npz
|
||||
case_033_sliceES_7.npz
|
||||
case_033_sliceES_8.npz
|
||||
case_033_sliceES_9.npz
|
||||
case_041_sliceED_0.npz
|
||||
case_041_sliceED_1.npz
|
||||
case_041_sliceED_2.npz
|
||||
case_041_sliceED_3.npz
|
||||
case_041_sliceED_4.npz
|
||||
case_041_sliceED_5.npz
|
||||
case_041_sliceES_0.npz
|
||||
case_041_sliceES_1.npz
|
||||
case_041_sliceES_2.npz
|
||||
case_041_sliceES_3.npz
|
||||
case_041_sliceES_4.npz
|
||||
case_041_sliceES_5.npz
|
||||
case_050_sliceED_0.npz
|
||||
case_050_sliceED_1.npz
|
||||
case_050_sliceED_2.npz
|
||||
case_050_sliceED_3.npz
|
||||
case_050_sliceED_4.npz
|
||||
case_050_sliceED_5.npz
|
||||
case_050_sliceED_6.npz
|
||||
case_050_sliceED_7.npz
|
||||
case_050_sliceED_8.npz
|
||||
case_050_sliceED_9.npz
|
||||
case_050_sliceES_0.npz
|
||||
case_050_sliceES_1.npz
|
||||
case_050_sliceES_2.npz
|
||||
case_050_sliceES_3.npz
|
||||
case_050_sliceES_4.npz
|
||||
case_050_sliceES_5.npz
|
||||
case_050_sliceES_6.npz
|
||||
case_050_sliceES_7.npz
|
||||
case_050_sliceES_8.npz
|
||||
case_050_sliceES_9.npz
|
||||
case_061_sliceED_0.npz
|
||||
case_061_sliceED_1.npz
|
||||
case_061_sliceED_2.npz
|
||||
case_061_sliceED_3.npz
|
||||
case_061_sliceED_4.npz
|
||||
case_061_sliceED_5.npz
|
||||
case_061_sliceED_6.npz
|
||||
case_061_sliceED_7.npz
|
||||
case_061_sliceED_8.npz
|
||||
case_061_sliceES_0.npz
|
||||
case_061_sliceES_1.npz
|
||||
case_061_sliceES_2.npz
|
||||
case_061_sliceES_3.npz
|
||||
case_061_sliceES_4.npz
|
||||
case_061_sliceES_5.npz
|
||||
case_061_sliceES_6.npz
|
||||
case_061_sliceES_7.npz
|
||||
case_061_sliceES_8.npz
|
||||
case_071_sliceED_0.npz
|
||||
case_071_sliceED_1.npz
|
||||
case_071_sliceED_2.npz
|
||||
case_071_sliceED_3.npz
|
||||
case_071_sliceED_4.npz
|
||||
case_071_sliceED_5.npz
|
||||
case_071_sliceED_6.npz
|
||||
case_071_sliceED_7.npz
|
||||
case_071_sliceED_8.npz
|
||||
case_071_sliceED_9.npz
|
||||
case_071_sliceES_0.npz
|
||||
case_071_sliceES_1.npz
|
||||
case_071_sliceES_2.npz
|
||||
case_071_sliceES_3.npz
|
||||
case_071_sliceES_4.npz
|
||||
case_071_sliceES_5.npz
|
||||
case_071_sliceES_6.npz
|
||||
case_071_sliceES_7.npz
|
||||
case_071_sliceES_8.npz
|
||||
case_071_sliceES_9.npz
|
||||
case_076_sliceED_0.npz
|
||||
case_076_sliceED_1.npz
|
||||
case_076_sliceED_2.npz
|
||||
case_076_sliceED_3.npz
|
||||
case_076_sliceED_4.npz
|
||||
case_076_sliceED_5.npz
|
||||
case_076_sliceED_6.npz
|
||||
case_076_sliceED_7.npz
|
||||
case_076_sliceES_0.npz
|
||||
case_076_sliceES_1.npz
|
||||
case_076_sliceES_2.npz
|
||||
case_076_sliceES_3.npz
|
||||
case_076_sliceES_4.npz
|
||||
case_076_sliceES_5.npz
|
||||
case_076_sliceES_6.npz
|
||||
case_076_sliceES_7.npz
|
||||
case_080_sliceED_0.npz
|
||||
case_080_sliceED_1.npz
|
||||
case_080_sliceED_2.npz
|
||||
case_080_sliceED_3.npz
|
||||
case_080_sliceED_4.npz
|
||||
case_080_sliceED_5.npz
|
||||
case_080_sliceES_0.npz
|
||||
case_080_sliceES_1.npz
|
||||
case_080_sliceES_2.npz
|
||||
case_080_sliceES_3.npz
|
||||
case_080_sliceES_4.npz
|
||||
case_080_sliceES_5.npz
|
||||
@ -0,0 +1,30 @@
|
||||
case0031.npy.h5
|
||||
case0007.npy.h5
|
||||
case0009.npy.h5
|
||||
case0005.npy.h5
|
||||
case0026.npy.h5
|
||||
case0039.npy.h5
|
||||
case0024.npy.h5
|
||||
case0034.npy.h5
|
||||
case0033.npy.h5
|
||||
case0030.npy.h5
|
||||
case0023.npy.h5
|
||||
case0040.npy.h5
|
||||
case0010.npy.h5
|
||||
case0021.npy.h5
|
||||
case0006.npy.h5
|
||||
case0027.npy.h5
|
||||
case0028.npy.h5
|
||||
case0037.npy.h5
|
||||
case0008.npy.h5
|
||||
case0022.npy.h5
|
||||
case0038.npy.h5
|
||||
case0036.npy.h5
|
||||
case0032.npy.h5
|
||||
case0002.npy.h5
|
||||
case0029.npy.h5
|
||||
case0003.npy.h5
|
||||
case0001.npy.h5
|
||||
case0004.npy.h5
|
||||
case0025.npy.h5
|
||||
case0035.npy.h5
|
||||
@ -0,0 +1,12 @@
|
||||
case0008
|
||||
case0022
|
||||
case0038
|
||||
case0036
|
||||
case0032
|
||||
case0002
|
||||
case0029
|
||||
case0003
|
||||
case0001
|
||||
case0004
|
||||
case0025
|
||||
case0035
|
||||
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 1.3 MiB |
|
After Width: | Height: | Size: 120 KiB |
|
After Width: | Height: | Size: 130 KiB |
|
After Width: | Height: | Size: 130 KiB |
|
After Width: | Height: | Size: 92 KiB |
|
After Width: | Height: | Size: 160 KiB |
@ -0,0 +1,33 @@
|
||||
numpy==1.22.4
|
||||
loguru
|
||||
tqdm
|
||||
pyyaml
|
||||
pandas
|
||||
matplotlib
|
||||
scikit-learn
|
||||
scikit-image
|
||||
scipy
|
||||
opencv-python
|
||||
seaborn
|
||||
albumentations==1.1.0
|
||||
tabulate
|
||||
warmup-scheduler
|
||||
transformers==4.21.3
|
||||
torchprofile
|
||||
torchmetrics
|
||||
einops
|
||||
ptflops
|
||||
torchsummary
|
||||
torchsummaryx
|
||||
segmentation-mask-overlay==0.3.4
|
||||
timm==0.6.12
|
||||
tifffile
|
||||
pillow
|
||||
thop
|
||||
simpleitk
|
||||
nibabel
|
||||
h5py
|
||||
huggingface-hub==0.11.0
|
||||
ml_collections
|
||||
tensorboardx
|
||||
medpy
|
||||
@ -0,0 +1,289 @@
|
||||
"""Synapse 与 ACDC 训练入口。"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
from src.core.networks import EMCADNet
|
||||
from src.utils.trainer import trainer_ACDC, trainer_synapse
|
||||
|
||||
|
||||
def build_parser():
|
||||
"""构建训练参数解析器。"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--root_path",
|
||||
type=str,
|
||||
default="/data/ACDC/train",
|
||||
help="root dir for training data (ACDC: /data/ACDC/train)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--volume_path",
|
||||
type=str,
|
||||
default="/data/ACDC/test",
|
||||
help="root dir for validation/test volume data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="ACDC",
|
||||
choices=["Synapse", "ACDC"],
|
||||
help="experiment name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list_dir",
|
||||
type=str,
|
||||
default="/data/ACDC/lists_ACDC",
|
||||
help="list dir (ACDC: /data/ACDC/lists_ACDC)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_classes",
|
||||
type=int,
|
||||
default=4,
|
||||
help="output channel of network (ACDC = 4)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder",
|
||||
type=str,
|
||||
default="pvt_v2_b2",
|
||||
help="Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expansion_factor",
|
||||
type=int,
|
||||
default=2,
|
||||
help="expansion factor in MSCB block",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kernel_sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[1, 3, 5],
|
||||
help="multi-scale kernel sizes in MSDC block",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lgag_ks", type=int, default=3, help="Kernel size in LGAG block"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--activation_mscb",
|
||||
type=str,
|
||||
default="relu6",
|
||||
help="activation used in MSCB: relu6 or relu",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_dw_parallel",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="use this flag to disable depth-wise parallel convolutions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concatenation",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="use this flag to concatenate feature maps in MSDC block",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_pretrain",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="use this flag to turn off loading pretrained encoder weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_dir",
|
||||
type=str,
|
||||
default="./model_pth/",
|
||||
help="path to pretrained encoder dir, e.g. ./model_pth/",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--supervision",
|
||||
type=str,
|
||||
default="mutation",
|
||||
help="loss supervision: mutation, deep_supervision or last_layer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_iterations", type=int, default=50000, help="maximum total iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_epochs", type=int, default=300, help="maximum epoch number to train"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=6, help="batch_size per gpu")
|
||||
parser.add_argument(
|
||||
"--base_lr",
|
||||
type=float,
|
||||
default=0.0001,
|
||||
help="segmentation network learning rate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--img_size",
|
||||
type=int,
|
||||
default=224,
|
||||
help="input patch size of network input",
|
||||
)
|
||||
parser.add_argument("--n_gpu", type=int, default=1, help="total gpu")
|
||||
parser.add_argument(
|
||||
"--deterministic",
|
||||
type=int,
|
||||
default=1,
|
||||
help="whether use deterministic training",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=2222, help="random seed")
|
||||
return parser
|
||||
|
||||
|
||||
def set_deterministic(seed, deterministic):
|
||||
"""配置随机种子与确定性行为。"""
|
||||
if not deterministic:
|
||||
cudnn.benchmark = True
|
||||
cudnn.deterministic = False
|
||||
else:
|
||||
cudnn.benchmark = False
|
||||
cudnn.deterministic = True
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
|
||||
def build_snapshot_path(args, dataset_name):
|
||||
"""根据参数生成输出目录。"""
|
||||
aggregation = "concat" if args.concatenation else "add"
|
||||
dw_mode = "series" if args.no_dw_parallel else "parallel"
|
||||
|
||||
run = 1
|
||||
exp = (
|
||||
args.encoder
|
||||
+ "_EMCAD_kernel_sizes_"
|
||||
+ str(args.kernel_sizes)
|
||||
+ "_dw_"
|
||||
+ dw_mode
|
||||
+ "_"
|
||||
+ aggregation
|
||||
+ "_lgag_ks_"
|
||||
+ str(args.lgag_ks)
|
||||
+ "_ef"
|
||||
+ str(args.expansion_factor)
|
||||
+ "_act_mscb_"
|
||||
+ args.activation_mscb
|
||||
+ "_loss_"
|
||||
+ args.supervision
|
||||
+ "_output_final_layer_Run"
|
||||
+ str(run)
|
||||
+ "_"
|
||||
+ dataset_name
|
||||
+ str(args.img_size)
|
||||
)
|
||||
|
||||
snapshot_path = "model_pth/{}/{}".format(
|
||||
exp,
|
||||
args.encoder
|
||||
+ "_EMCAD_kernel_sizes_"
|
||||
+ str(args.kernel_sizes)
|
||||
+ "_dw_"
|
||||
+ dw_mode
|
||||
+ "_"
|
||||
+ aggregation
|
||||
+ "_lgag_ks_"
|
||||
+ str(args.lgag_ks)
|
||||
+ "_ef"
|
||||
+ str(args.expansion_factor)
|
||||
+ "_act_mscb_"
|
||||
+ args.activation_mscb
|
||||
+ "_loss_"
|
||||
+ args.supervision
|
||||
+ "_output_final_layer_Run"
|
||||
+ str(run),
|
||||
)
|
||||
snapshot_path = snapshot_path.replace("[", "").replace("]", "").replace(", ", "_")
|
||||
|
||||
if not args.no_pretrain:
|
||||
snapshot_path = snapshot_path + "_pretrain"
|
||||
if args.max_iterations != 50000:
|
||||
snapshot_path = snapshot_path + "_" + str(args.max_iterations)[0:2] + "k"
|
||||
if args.max_epochs != 300:
|
||||
snapshot_path = snapshot_path + "_epo" + str(args.max_epochs)
|
||||
snapshot_path = snapshot_path + "_bs" + str(args.batch_size)
|
||||
if args.base_lr != 0.0001:
|
||||
snapshot_path = snapshot_path + "_lr" + str(args.base_lr)
|
||||
snapshot_path = snapshot_path + "_" + str(args.img_size)
|
||||
if args.seed != 1234:
|
||||
snapshot_path = snapshot_path + "_s" + str(args.seed)
|
||||
|
||||
return exp, snapshot_path
|
||||
|
||||
|
||||
def main():
|
||||
"""主入口函数。"""
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
set_deterministic(args.seed, args.deterministic)
|
||||
|
||||
dataset_name = args.dataset
|
||||
|
||||
acdc_root = args.root_path
|
||||
if dataset_name == "ACDC":
|
||||
tmp = args.root_path.rstrip("/")
|
||||
if os.path.basename(tmp) == "train":
|
||||
acdc_root = os.path.dirname(tmp)
|
||||
else:
|
||||
acdc_root = tmp
|
||||
|
||||
dataset_config = {
|
||||
"Synapse": {
|
||||
"root_path": args.root_path,
|
||||
"volume_path": args.volume_path,
|
||||
"list_dir": args.list_dir,
|
||||
"num_classes": args.num_classes,
|
||||
"z_spacing": 1,
|
||||
},
|
||||
"ACDC": {
|
||||
"root_path": acdc_root,
|
||||
"volume_path": args.volume_path,
|
||||
"list_dir": args.list_dir,
|
||||
"num_classes": args.num_classes,
|
||||
"z_spacing": 1,
|
||||
},
|
||||
}
|
||||
|
||||
cfg = dataset_config[dataset_name]
|
||||
args.num_classes = cfg["num_classes"]
|
||||
args.root_path = cfg["root_path"]
|
||||
args.volume_path = cfg["volume_path"]
|
||||
args.z_spacing = cfg["z_spacing"]
|
||||
args.list_dir = cfg["list_dir"]
|
||||
|
||||
args.exp, snapshot_path = build_snapshot_path(args, dataset_name)
|
||||
|
||||
if not os.path.exists(snapshot_path):
|
||||
os.makedirs(snapshot_path)
|
||||
|
||||
model = EMCADNet(
|
||||
num_classes=args.num_classes,
|
||||
kernel_sizes=args.kernel_sizes,
|
||||
expansion_factor=args.expansion_factor,
|
||||
dw_parallel=not args.no_dw_parallel,
|
||||
add=not args.concatenation,
|
||||
lgag_ks=args.lgag_ks,
|
||||
activation=args.activation_mscb,
|
||||
encoder=args.encoder,
|
||||
pretrain=not args.no_pretrain,
|
||||
pretrained_dir=args.pretrained_dir,
|
||||
)
|
||||
|
||||
model.cuda()
|
||||
print("Model successfully created.")
|
||||
|
||||
trainer_map = {"Synapse": trainer_synapse, "ACDC": trainer_ACDC}
|
||||
trainer_map[dataset_name](args, model, snapshot_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,46 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="EMCADNet",
|
||||
version="0.1.0",
|
||||
author="Your Name",
|
||||
author_email="your.email@example.com",
|
||||
description="EMCADNet: Efficient Multi-scale Convolutional Attention Decoding for Medical Image Segmentation",
|
||||
long_description=open("README.md").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/yourusername/EMCADNet",
|
||||
packages=find_packages(where="src"),
|
||||
package_dir={"": "src"},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
python_requires=">=3.8",
|
||||
install_requires=[
|
||||
"torch>=1.11.0",
|
||||
"torchvision>=0.12.0",
|
||||
"numpy>=1.22.0",
|
||||
"h5py>=3.0.0",
|
||||
"scipy>=1.5.0",
|
||||
"matplotlib>=3.3.0",
|
||||
"tqdm>=4.50.0",
|
||||
"tensorboardX>=2.2",
|
||||
"nibabel>=3.2.0",
|
||||
"medpy>=0.4.0",
|
||||
"ptflops>=0.6.4",
|
||||
"thop>=0.0.31",
|
||||
"segmentation-mask-overlay>=0.3.0",
|
||||
"timm>=0.6.0",
|
||||
],
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"emcad-train=scripts.train_synapse:main",
|
||||
"emcad-test=scripts.test_synapse:main",
|
||||
],
|
||||
},
|
||||
include_package_data=True,
|
||||
package_data={
|
||||
"": ["*.md", "*.txt"],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,3 @@
|
||||
from .networks import EMCADNet
|
||||
|
||||
__all__ = ["EMCADNet"]
|
||||
@ -0,0 +1,145 @@
|
||||
"""EMCADNet 网络定义。"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from src.core.decoders import EMCAD
|
||||
from src.core.pvtv2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b3, pvt_v2_b4, pvt_v2_b5
|
||||
from src.core.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
|
||||
|
||||
|
||||
class EMCADNet(nn.Module):
|
||||
"""EMCAD 端到端网络封装。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes=1,
|
||||
kernel_sizes=None,
|
||||
expansion_factor=2,
|
||||
dw_parallel=True,
|
||||
add=True,
|
||||
lgag_ks=3,
|
||||
activation="relu",
|
||||
encoder="pvt_v2_b2",
|
||||
pretrain=True,
|
||||
pretrained_dir="./pretrained_pth/pvt/",
|
||||
):
|
||||
"""初始化网络。"""
|
||||
super(EMCADNet, self).__init__()
|
||||
if kernel_sizes is None:
|
||||
kernel_sizes = [1, 3, 5]
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(1, 3, kernel_size=1),
|
||||
nn.BatchNorm2d(3),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
if encoder == "pvt_v2_b0":
|
||||
self.backbone = pvt_v2_b0()
|
||||
path = pretrained_dir + "/pvt_v2_b0.pth"
|
||||
channels = [256, 160, 64, 32]
|
||||
elif encoder == "pvt_v2_b1":
|
||||
self.backbone = pvt_v2_b1()
|
||||
path = pretrained_dir + "/pvt_v2_b1.pth"
|
||||
channels = [512, 320, 128, 64]
|
||||
elif encoder == "pvt_v2_b2":
|
||||
self.backbone = pvt_v2_b2()
|
||||
path = pretrained_dir + "/pvt_v2_b2.pth"
|
||||
channels = [512, 320, 128, 64]
|
||||
elif encoder == "pvt_v2_b3":
|
||||
self.backbone = pvt_v2_b3()
|
||||
path = pretrained_dir + "/pvt_v2_b3.pth"
|
||||
channels = [512, 320, 128, 64]
|
||||
elif encoder == "pvt_v2_b4":
|
||||
self.backbone = pvt_v2_b4()
|
||||
path = pretrained_dir + "/pvt_v2_b4.pth"
|
||||
channels = [512, 320, 128, 64]
|
||||
elif encoder == "pvt_v2_b5":
|
||||
self.backbone = pvt_v2_b5()
|
||||
path = pretrained_dir + "/pvt_v2_b5.pth"
|
||||
channels = [512, 320, 128, 64]
|
||||
elif encoder == "resnet18":
|
||||
self.backbone = resnet18(pretrained=pretrain)
|
||||
channels = [512, 256, 128, 64]
|
||||
elif encoder == "resnet34":
|
||||
self.backbone = resnet34(pretrained=pretrain)
|
||||
channels = [512, 256, 128, 64]
|
||||
elif encoder == "resnet50":
|
||||
self.backbone = resnet50(pretrained=pretrain)
|
||||
channels = [2048, 1024, 512, 256]
|
||||
elif encoder == "resnet101":
|
||||
self.backbone = resnet101(pretrained=pretrain)
|
||||
channels = [2048, 1024, 512, 256]
|
||||
elif encoder == "resnet152":
|
||||
self.backbone = resnet152(pretrained=pretrain)
|
||||
channels = [2048, 1024, 512, 256]
|
||||
else:
|
||||
print(
|
||||
"Encoder not implemented! Continuing with default encoder pvt_v2_b2."
|
||||
)
|
||||
self.backbone = pvt_v2_b2()
|
||||
path = pretrained_dir + "/pvt_v2_b2.pth"
|
||||
channels = [512, 320, 128, 64]
|
||||
|
||||
if pretrain and "pvt_v2" in encoder:
|
||||
save_model = torch.load(path)
|
||||
model_dict = self.backbone.state_dict()
|
||||
state_dict = {k: v for k, v in save_model.items() if k in model_dict}
|
||||
model_dict.update(state_dict)
|
||||
self.backbone.load_state_dict(model_dict)
|
||||
|
||||
print(
|
||||
"Model %s created, param count: %d"
|
||||
% (encoder + " backbone: ", sum(m.numel() for m in self.backbone.parameters()))
|
||||
)
|
||||
|
||||
self.decoder = EMCAD(
|
||||
channels=channels,
|
||||
kernel_sizes=kernel_sizes,
|
||||
expansion_factor=expansion_factor,
|
||||
dw_parallel=dw_parallel,
|
||||
add=add,
|
||||
lgag_ks=lgag_ks,
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
print(
|
||||
"Model %s created, param count: %d"
|
||||
% ("EMCAD decoder: ", sum(m.numel() for m in self.decoder.parameters()))
|
||||
)
|
||||
|
||||
self.out_head4 = nn.Conv2d(channels[0], num_classes, 1)
|
||||
self.out_head3 = nn.Conv2d(channels[1], num_classes, 1)
|
||||
self.out_head2 = nn.Conv2d(channels[2], num_classes, 1)
|
||||
self.out_head1 = nn.Conv2d(channels[3], num_classes, 1)
|
||||
|
||||
def forward(self, x, mode="test"):
|
||||
"""前向计算。"""
|
||||
if x.size()[1] == 1:
|
||||
x = self.conv(x)
|
||||
|
||||
x1, x2, x3, x4 = self.backbone(x)
|
||||
|
||||
dec_outs = self.decoder(x4, [x3, x2, x1])
|
||||
|
||||
p4 = self.out_head4(dec_outs[0])
|
||||
p3 = self.out_head3(dec_outs[1])
|
||||
p2 = self.out_head2(dec_outs[2])
|
||||
p1 = self.out_head1(dec_outs[3])
|
||||
|
||||
p4 = F.interpolate(p4, scale_factor=32, mode="bilinear")
|
||||
p3 = F.interpolate(p3, scale_factor=16, mode="bilinear")
|
||||
p2 = F.interpolate(p2, scale_factor=8, mode="bilinear")
|
||||
p1 = F.interpolate(p1, scale_factor=4, mode="bilinear")
|
||||
|
||||
return [p4, p3, p2, p1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = EMCADNet().cuda()
|
||||
input_tensor = torch.randn(1, 3, 352, 352).cuda()
|
||||
|
||||
outputs = model(input_tensor)
|
||||
print(outputs[0].size(), outputs[1].size(), outputs[2].size(), outputs[3].size())
|
||||
@ -0,0 +1,442 @@
|
||||
"""Pyramid Vision Transformer v2 主干网络实现。"""
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
from timm.models.registry import register_model
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""带深度卷积的 MLP 模块。"""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.dwconv = DWConv(hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
x = self.fc1(x)
|
||||
x = self.dwconv(x, H, W)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""带可选空间降采样的多头注意力。"""
|
||||
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.sr_ratio = sr_ratio
|
||||
if sr_ratio > 1:
|
||||
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, N, C = x.shape
|
||||
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
|
||||
if self.sr_ratio > 1:
|
||||
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
||||
x_ = self.norm(x_)
|
||||
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
else:
|
||||
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""包含注意力与 MLP 的 Transformer 块。"""
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
|
||||
# 使用 DropPath 实现随机深度
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class OverlapPatchEmbed(nn.Module):
|
||||
"""重叠卷积的图像 Patch Embedding。"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
||||
self.num_patches = self.H * self.W
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
||||
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
|
||||
return x, H, W
|
||||
|
||||
|
||||
class PyramidVisionTransformerImpr(nn.Module):
|
||||
"""金字塔视觉 Transformer 主干网络。"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
|
||||
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
|
||||
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
|
||||
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.depths = depths
|
||||
|
||||
# 补丁嵌入
|
||||
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
|
||||
embed_dim=embed_dims[0])
|
||||
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
|
||||
embed_dim=embed_dims[1])
|
||||
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
|
||||
embed_dim=embed_dims[2])
|
||||
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
|
||||
embed_dim=embed_dims[3])
|
||||
|
||||
# 变换器编码器
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # 随机深度衰减规则
|
||||
cur = 0
|
||||
self.block1 = nn.ModuleList([Block(
|
||||
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
||||
sr_ratio=sr_ratios[0])
|
||||
for i in range(depths[0])])
|
||||
self.norm1 = norm_layer(embed_dims[0])
|
||||
|
||||
cur += depths[0]
|
||||
self.block2 = nn.ModuleList([Block(
|
||||
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
||||
sr_ratio=sr_ratios[1])
|
||||
for i in range(depths[1])])
|
||||
self.norm2 = norm_layer(embed_dims[1])
|
||||
|
||||
cur += depths[1]
|
||||
self.block3 = nn.ModuleList([Block(
|
||||
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
||||
sr_ratio=sr_ratios[2])
|
||||
for i in range(depths[2])])
|
||||
self.norm3 = norm_layer(embed_dims[2])
|
||||
|
||||
cur += depths[2]
|
||||
self.block4 = nn.ModuleList([Block(
|
||||
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
||||
sr_ratio=sr_ratios[3])
|
||||
for i in range(depths[3])])
|
||||
self.norm4 = norm_layer(embed_dims[3])
|
||||
|
||||
# 分类头
|
||||
# 可在此定义分类头
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str):
|
||||
logger = 1
|
||||
# 可在此加载权重
|
||||
|
||||
def reset_drop_path(self, drop_path_rate):
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
||||
cur = 0
|
||||
for i in range(self.depths[0]):
|
||||
self.block1[i].drop_path.drop_prob = dpr[cur + i]
|
||||
|
||||
cur += self.depths[0]
|
||||
for i in range(self.depths[1]):
|
||||
self.block2[i].drop_path.drop_prob = dpr[cur + i]
|
||||
|
||||
cur += self.depths[1]
|
||||
for i in range(self.depths[2]):
|
||||
self.block3[i].drop_path.drop_prob = dpr[cur + i]
|
||||
|
||||
cur += self.depths[2]
|
||||
for i in range(self.depths[3]):
|
||||
self.block4[i].drop_path.drop_prob = dpr[cur + i]
|
||||
|
||||
def freeze_patch_emb(self):
|
||||
self.patch_embed1.requires_grad = False
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # 保留位置编码
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
# 需要位置编码插值时,可在此实现相关逻辑
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
outs = []
|
||||
|
||||
# 阶段 1
|
||||
x, H, W = self.patch_embed1(x)
|
||||
for i, blk in enumerate(self.block1):
|
||||
x = blk(x, H, W)
|
||||
x = self.norm1(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
# 阶段 2
|
||||
x, H, W = self.patch_embed2(x)
|
||||
for i, blk in enumerate(self.block2):
|
||||
x = blk(x, H, W)
|
||||
x = self.norm2(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
# 阶段 3
|
||||
x, H, W = self.patch_embed3(x)
|
||||
for i, blk in enumerate(self.block3):
|
||||
x = blk(x, H, W)
|
||||
x = self.norm3(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
# 阶段 4
|
||||
x, H, W = self.patch_embed4(x)
|
||||
for i, blk in enumerate(self.block4):
|
||||
x = blk(x, H, W)
|
||||
x = self.norm4(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
return outs
|
||||
|
||||
# 可在此返回 token 均值
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
# 可在此调用分类头
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DWConv(nn.Module):
|
||||
"""用于 token 混合的深度卷积。"""
|
||||
def __init__(self, dim=768):
|
||||
super(DWConv, self).__init__()
|
||||
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, N, C = x.shape
|
||||
x = x.transpose(1, 2).view(B, C, H, W)
|
||||
x = self.dwconv(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _conv_filter(state_dict, patch_size=16):
|
||||
"""将 patch embedding 权重转换为卷积格式。"""
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if 'patch_embed.proj.weight' in k:
|
||||
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
||||
out_dict[k] = v
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
@register_model
|
||||
class pvt_v2_b0(PyramidVisionTransformerImpr):
|
||||
"""PVTv2-B0 主干。"""
|
||||
def __init__(self, **kwargs):
|
||||
super(pvt_v2_b0, self).__init__(
|
||||
patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
||||
drop_rate=0.0, drop_path_rate=0.1)
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
class pvt_v2_b1(PyramidVisionTransformerImpr):
|
||||
"""PVTv2-B1 主干。"""
|
||||
def __init__(self, **kwargs):
|
||||
super(pvt_v2_b1, self).__init__(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
||||
drop_rate=0.0, drop_path_rate=0.1)
|
||||
|
||||
@register_model
|
||||
class pvt_v2_b2(PyramidVisionTransformerImpr):
|
||||
"""PVTv2-B2 主干。"""
|
||||
def __init__(self, **kwargs):
|
||||
super(pvt_v2_b2, self).__init__(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
|
||||
drop_rate=0.0, drop_path_rate=0.1)
|
||||
|
||||
@register_model
|
||||
class pvt_v2_b3(PyramidVisionTransformerImpr):
|
||||
"""PVTv2-B3 主干。"""
|
||||
def __init__(self, **kwargs):
|
||||
super(pvt_v2_b3, self).__init__(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
|
||||
drop_rate=0.0, drop_path_rate=0.1)
|
||||
|
||||
@register_model
|
||||
class pvt_v2_b4(PyramidVisionTransformerImpr):
|
||||
"""PVTv2-B4 主干。"""
|
||||
def __init__(self, **kwargs):
|
||||
super(pvt_v2_b4, self).__init__(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
|
||||
drop_rate=0.0, drop_path_rate=0.1)
|
||||
|
||||
|
||||
@register_model
|
||||
class pvt_v2_b5(PyramidVisionTransformerImpr):
|
||||
"""PVTv2-B5 主干。"""
|
||||
def __init__(self, **kwargs):
|
||||
super(pvt_v2_b5, self).__init__(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
|
||||
drop_rate=0.0, drop_path_rate=0.1)
|
||||
@ -0,0 +1,18 @@
|
||||
from .dataset_synapse import Synapse_dataset, RandomGenerator
|
||||
from .dataset_ACDC import ACDCdataset
|
||||
from .dataloader import get_loader
|
||||
from .trainer import trainer_synapse, trainer_ACDC
|
||||
from .utils import DiceLoss, powerset, val_single_volume, test_single_volume
|
||||
|
||||
__all__ = [
|
||||
"Synapse_dataset",
|
||||
"RandomGenerator",
|
||||
"ACDCdataset",
|
||||
"get_loader",
|
||||
"trainer_synapse",
|
||||
"trainer_ACDC",
|
||||
"DiceLoss",
|
||||
"powerset",
|
||||
"val_single_volume",
|
||||
"test_single_volume",
|
||||
]
|
||||
@ -0,0 +1,33 @@
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class ACDCdataset(Dataset):
|
||||
def __init__(self, base_dir, list_dir, split, transform=None):
|
||||
self.transform = transform # using transform in torch!
|
||||
self.split = split
|
||||
self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines()
|
||||
self.data_dir = base_dir
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_list)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.split == "train" or self.split == "valid":
|
||||
slice_name = self.sample_list[idx].strip('\n')
|
||||
data_path = os.path.join(self.data_dir, self.split, slice_name)
|
||||
data = np.load(data_path)
|
||||
image, label = data['img'], data['label']
|
||||
else:
|
||||
vol_name = self.sample_list[idx].strip('\n')
|
||||
filepath = self.data_dir + "/{}".format(vol_name)
|
||||
data = np.load(filepath)
|
||||
image, label = data['img'], data['label']
|
||||
|
||||
sample = {'image': image, 'label': label}
|
||||
if self.transform and self.split == "train":
|
||||
sample = self.transform(sample)
|
||||
sample['case_name'] = self.sample_list[idx].strip('\n')
|
||||
return sample
|
||||
@ -0,0 +1,100 @@
|
||||
import os
|
||||
import random
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
from scipy import ndimage
|
||||
from scipy.ndimage.interpolation import zoom
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def random_rot_flip(image, label):
|
||||
k = np.random.randint(0, 4)
|
||||
image = np.rot90(image, k)
|
||||
label = np.rot90(label, k)
|
||||
axis = np.random.randint(0, 2)
|
||||
image = np.flip(image, axis=axis).copy()
|
||||
label = np.flip(label, axis=axis).copy()
|
||||
return image, label
|
||||
|
||||
|
||||
def random_rotate(image, label):
|
||||
angle = np.random.randint(-20, 20)
|
||||
image = ndimage.rotate(image, angle, order=0, reshape=False)
|
||||
label = ndimage.rotate(label, angle, order=0, reshape=False)
|
||||
return image, label
|
||||
|
||||
|
||||
class RandomGenerator(object):
|
||||
def __init__(self, output_size):
|
||||
self.output_size = output_size
|
||||
|
||||
def __call__(self, sample):
|
||||
image, label = sample['image'], sample['label']
|
||||
|
||||
if random.random() > 0.5:
|
||||
image, label = random_rot_flip(image, label)
|
||||
elif random.random() > 0.5:
|
||||
image, label = random_rotate(image, label)
|
||||
x, y = image.shape
|
||||
if x != self.output_size[0] or y != self.output_size[1]:
|
||||
image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3) # why not 3?
|
||||
label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
|
||||
image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
|
||||
label = torch.from_numpy(label.astype(np.float32))
|
||||
sample = {'image': image, 'label': label.long()}
|
||||
return sample
|
||||
|
||||
|
||||
class Synapse_dataset(Dataset):
|
||||
def __init__(self, base_dir, list_dir, split, nclass=9, transform=None):
|
||||
self.transform = transform # using transform in torch!
|
||||
self.split = split
|
||||
self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines()
|
||||
self.data_dir = base_dir
|
||||
self.nclass = nclass
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_list)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.split == "train":
|
||||
slice_name = self.sample_list[idx].strip('\n')
|
||||
data_path = os.path.join(self.data_dir, slice_name+'.npz')
|
||||
data = np.load(data_path)
|
||||
image, label = data['image'], data['label']
|
||||
#print(image.shape)
|
||||
#image = np.reshape(image, (512, 512))
|
||||
#image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
|
||||
#label = np.reshape(label, (512, 512))
|
||||
|
||||
|
||||
else:
|
||||
vol_name = self.sample_list[idx].strip('\n')
|
||||
filepath = self.data_dir + "/{}.npy.h5".format(vol_name)
|
||||
data = h5py.File(filepath)
|
||||
image, label = data['image'][:], data['label'][:]
|
||||
#image = np.reshape(image, (image.shape[2], 512, 512))
|
||||
#label = np.reshape(label, (label.shape[2], 512, 512))
|
||||
#label[label==5]= 0
|
||||
#label[label==9]= 0
|
||||
#label[label==10]= 0
|
||||
#label[label==12]= 0
|
||||
#label[label==13]= 0
|
||||
#label[label==11]= 5
|
||||
|
||||
if self.nclass == 9:
|
||||
label[label==5]= 0
|
||||
label[label==9]= 0
|
||||
label[label==10]= 0
|
||||
label[label==12]= 0
|
||||
label[label==13]= 0
|
||||
label[label==11]= 5
|
||||
|
||||
sample = {'image': image, 'label': label}
|
||||
if self.transform:
|
||||
sample = self.transform(sample)
|
||||
sample['case_name'] = self.sample_list[idx].strip('\n')
|
||||
return sample
|
||||
@ -0,0 +1,149 @@
|
||||
"""Synapse 数据集相关工具。"""
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy import ndimage
|
||||
from scipy.ndimage.interpolation import zoom
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def random_rot_flip(image, label):
|
||||
"""随机旋转并翻转增强。
|
||||
|
||||
参数:
|
||||
image: 输入图像数组。
|
||||
label: 标签数组。
|
||||
|
||||
返回:
|
||||
变换后的 (image, label)。
|
||||
"""
|
||||
k = np.random.randint(0, 4)
|
||||
image = np.rot90(image, k)
|
||||
label = np.rot90(label, k)
|
||||
axis = np.random.randint(0, 2)
|
||||
image = np.flip(image, axis=axis).copy()
|
||||
label = np.flip(label, axis=axis).copy()
|
||||
return image, label
|
||||
|
||||
|
||||
def random_rotate(image, label):
|
||||
"""随机角度旋转增强。
|
||||
|
||||
参数:
|
||||
image: 输入图像数组。
|
||||
label: 标签数组。
|
||||
|
||||
返回:
|
||||
变换后的 (image, label)。
|
||||
"""
|
||||
angle = np.random.randint(-20, 20)
|
||||
image = ndimage.rotate(image, angle, order=0, reshape=False)
|
||||
label = ndimage.rotate(label, angle, order=0, reshape=False)
|
||||
return image, label
|
||||
|
||||
|
||||
class RandomGenerator(object):
|
||||
"""随机增强与缩放生成器。"""
|
||||
|
||||
def __init__(self, output_size):
|
||||
"""初始化输出尺寸。
|
||||
|
||||
参数:
|
||||
output_size: (H, W) 目标尺寸。
|
||||
"""
|
||||
self.output_size = output_size
|
||||
|
||||
def __call__(self, sample):
|
||||
"""对样本进行增强与缩放。
|
||||
|
||||
参数:
|
||||
sample: 包含 'image' 与 'label' 的字典。
|
||||
|
||||
返回:
|
||||
变换后的样本字典。
|
||||
"""
|
||||
image, label = sample["image"], sample["label"]
|
||||
|
||||
if random.random() > 0.5:
|
||||
image, label = random_rot_flip(image, label)
|
||||
elif random.random() > 0.5:
|
||||
image, label = random_rotate(image, label)
|
||||
x, y = image.shape
|
||||
if x != self.output_size[0] or y != self.output_size[1]:
|
||||
image = zoom(
|
||||
image,
|
||||
(self.output_size[0] / x, self.output_size[1] / y),
|
||||
order=3,
|
||||
)
|
||||
label = zoom(
|
||||
label,
|
||||
(self.output_size[0] / x, self.output_size[1] / y),
|
||||
order=0,
|
||||
)
|
||||
image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
|
||||
label = torch.from_numpy(label.astype(np.float32))
|
||||
sample = {"image": image, "label": label.long()}
|
||||
return sample
|
||||
|
||||
|
||||
class Synapse_dataset(Dataset):
|
||||
"""Synapse 数据集的 PyTorch Dataset。"""
|
||||
|
||||
def __init__(self, base_dir, list_dir, split, nclass=9, transform=None):
|
||||
"""初始化数据集。
|
||||
|
||||
参数:
|
||||
base_dir: 数据集根目录。
|
||||
list_dir: 列表文件目录。
|
||||
split: 划分名称。
|
||||
nclass: 类别数。
|
||||
transform: 可选的数据增强。
|
||||
"""
|
||||
self.transform = transform
|
||||
self.split = split
|
||||
list_path = os.path.join(list_dir, self.split + ".txt")
|
||||
self.sample_list = open(list_path).readlines()
|
||||
self.data_dir = base_dir
|
||||
self.nclass = nclass
|
||||
|
||||
def __len__(self):
|
||||
"""返回样本数量。"""
|
||||
return len(self.sample_list)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""按索引读取样本。
|
||||
|
||||
参数:
|
||||
idx: 样本索引。
|
||||
|
||||
返回:
|
||||
包含 image、label 与 case_name 的字典。
|
||||
"""
|
||||
if self.split == "train":
|
||||
slice_name = self.sample_list[idx].strip("\n")
|
||||
data_path = os.path.join(self.data_dir, slice_name + ".npz")
|
||||
data = np.load(data_path)
|
||||
image, label = data["image"], data["label"]
|
||||
else:
|
||||
vol_name = self.sample_list[idx].strip("\n")
|
||||
filepath = self.data_dir + "/{}.npy.h5".format(vol_name)
|
||||
data = h5py.File(filepath)
|
||||
image, label = data["image"][:], data["label"][:]
|
||||
|
||||
if self.nclass == 9:
|
||||
label[label == 5] = 0
|
||||
label[label == 9] = 0
|
||||
label[label == 10] = 0
|
||||
label[label == 12] = 0
|
||||
label[label == 13] = 0
|
||||
label[label == 11] = 5
|
||||
|
||||
sample = {"image": image, "label": label}
|
||||
if self.transform:
|
||||
sample = self.transform(sample)
|
||||
sample["case_name"] = self.sample_list[idx].strip("\n")
|
||||
return sample
|
||||
@ -0,0 +1,40 @@
|
||||
import os
|
||||
import shutil
|
||||
from libtiff import TIFF # pip install libtiff
|
||||
from scipy import misc
|
||||
import random
|
||||
|
||||
|
||||
def tif2png(_src_path, _dst_path):
|
||||
"""
|
||||
Usage:
|
||||
formatting `tif/tiff` files to `jpg/png` files
|
||||
:param _src_path:
|
||||
:param _dst_path:
|
||||
:return:
|
||||
"""
|
||||
tif = TIFF.open(_src_path, mode='r')
|
||||
image = tif.read_image()
|
||||
misc.imsave(_dst_path, image)
|
||||
|
||||
|
||||
def data_split(src_list):
|
||||
"""
|
||||
Usage:
|
||||
randomly spliting dataset
|
||||
:param src_list:
|
||||
:return:
|
||||
"""
|
||||
counter_list = random.sample(range(0, len(src_list)), 550)
|
||||
|
||||
return counter_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
src_dir = '../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/test_split/masks_tif'
|
||||
dst_dir = '../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/test_split/masks'
|
||||
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
for img_name in os.listdir(src_dir):
|
||||
tif2png(os.path.join(src_dir, img_name),
|
||||
os.path.join(dst_dir, img_name.replace('.tif', '.png')))
|
||||
@ -0,0 +1,54 @@
|
||||
"""数据格式转换相关工具。"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
|
||||
from libtiff import TIFF
|
||||
from scipy import misc
|
||||
|
||||
|
||||
def tif2png(src_path, dst_path):
|
||||
"""将 TIFF 文件转换为 PNG/JPG 格式。
|
||||
|
||||
参数:
|
||||
src_path: 源 TIFF 路径。
|
||||
dst_path: 输出图像路径。
|
||||
|
||||
返回:
|
||||
None。
|
||||
"""
|
||||
tif = TIFF.open(src_path, mode="r")
|
||||
image = tif.read_image()
|
||||
misc.imsave(dst_path, image)
|
||||
|
||||
|
||||
def data_split(src_list):
|
||||
"""随机生成数据划分的索引列表。
|
||||
|
||||
参数:
|
||||
src_list: 原始列表。
|
||||
|
||||
返回:
|
||||
采样后的索引列表。
|
||||
"""
|
||||
counter_list = random.sample(range(0, len(src_list)), 550)
|
||||
|
||||
return counter_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
src_dir = (
|
||||
"../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/"
|
||||
"test_split/masks_tif"
|
||||
)
|
||||
dst_dir = (
|
||||
"../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/test_split/masks"
|
||||
)
|
||||
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
for img_name in os.listdir(src_dir):
|
||||
tif2png(
|
||||
os.path.join(src_dir, img_name),
|
||||
os.path.join(dst_dir, img_name.replace(".tif", ".png")),
|
||||
)
|
||||
@ -0,0 +1,339 @@
|
||||
"""训练与评估的杂项工具。"""
|
||||
|
||||
import os
|
||||
from math import ceil
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
def check_mkdir(dir_name):
|
||||
"""若目录不存在则创建。
|
||||
|
||||
参数:
|
||||
dir_name: 目录路径。
|
||||
"""
|
||||
if not os.path.exists(dir_name):
|
||||
os.mkdir(dir_name)
|
||||
|
||||
|
||||
def initialize_weights(*models):
|
||||
"""初始化模型中的权重参数。"""
|
||||
for model in models:
|
||||
for module in model.modules():
|
||||
if isinstance(module, (nn.Conv2d, nn.Linear)):
|
||||
nn.init.kaiming_normal(module.weight)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
module.weight.data.fill_(1)
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
def get_upsampling_weight(in_channels, out_channels, kernel_size):
|
||||
"""生成双线性上采样的权重核。"""
|
||||
factor = (kernel_size + 1) // 2
|
||||
if kernel_size % 2 == 1:
|
||||
center = factor - 1
|
||||
else:
|
||||
center = factor - 0.5
|
||||
og = np.ogrid[:kernel_size, :kernel_size]
|
||||
filt = (1 - abs(og[0] - center) / factor) * (
|
||||
1 - abs(og[1] - center) / factor
|
||||
)
|
||||
weight = np.zeros(
|
||||
(in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64
|
||||
)
|
||||
weight[list(range(in_channels)), list(range(out_channels)), :, :] = filt
|
||||
return torch.from_numpy(weight).float()
|
||||
|
||||
|
||||
class CrossEntropyLoss2d(nn.Module):
|
||||
"""用于分割的二维交叉熵损失包装。"""
|
||||
|
||||
def __init__(self, weight=None, size_average=True, ignore_index=255):
|
||||
"""初始化损失函数。"""
|
||||
super(CrossEntropyLoss2d, self).__init__()
|
||||
self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
"""计算损失。"""
|
||||
return self.nll_loss(F.log_softmax(inputs), targets)
|
||||
|
||||
|
||||
class FocalLoss2d(nn.Module):
|
||||
"""用于分割的二维 Focal Loss 包装。"""
|
||||
|
||||
def __init__(self, gamma=2, weight=None, size_average=True, ignore_index=255):
|
||||
"""初始化损失函数。"""
|
||||
super(FocalLoss2d, self).__init__()
|
||||
self.gamma = gamma
|
||||
self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
"""计算损失。"""
|
||||
return self.nll_loss(
|
||||
(1 - F.softmax(inputs)) ** self.gamma * F.log_softmax(inputs),
|
||||
targets,
|
||||
)
|
||||
|
||||
|
||||
def _fast_hist(label_pred, label_true, num_classes):
|
||||
"""计算混淆矩阵。"""
|
||||
mask = (label_true >= 0) & (label_true < num_classes)
|
||||
hist = np.bincount(
|
||||
num_classes * label_true[mask].astype(int) + label_pred[mask],
|
||||
minlength=num_classes ** 2,
|
||||
).reshape(num_classes, num_classes)
|
||||
return hist
|
||||
|
||||
|
||||
def evaluate(predictions, gts, num_classes):
|
||||
"""评估分割指标。"""
|
||||
hist = np.zeros((num_classes, num_classes))
|
||||
for lp, lt in zip(predictions, gts):
|
||||
hist += _fast_hist(lp.flatten(), lt.flatten(), num_classes)
|
||||
acc = np.diag(hist).sum() / hist.sum()
|
||||
acc_cls = np.diag(hist) / hist.sum(axis=1)
|
||||
acc_cls = np.nanmean(acc_cls)
|
||||
iu = np.diag(hist) / (
|
||||
hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
|
||||
)
|
||||
mean_iu = np.nanmean(iu)
|
||||
freq = hist.sum(axis=1) / hist.sum()
|
||||
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
|
||||
return acc, acc_cls, mean_iu, fwavacc
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""用于记录均值的计量器。"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化计量器。"""
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""重置计量器。"""
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
"""更新数值。
|
||||
|
||||
参数:
|
||||
val: 新值。
|
||||
n: 权重。
|
||||
"""
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
class PolyLR(object):
|
||||
"""多项式学习率调度器。"""
|
||||
|
||||
def __init__(self, optimizer, curr_iter, max_iter, lr_decay):
|
||||
"""初始化调度器。"""
|
||||
self.max_iter = float(max_iter)
|
||||
self.init_lr_groups = []
|
||||
for params in optimizer.param_groups:
|
||||
self.init_lr_groups.append(params["lr"])
|
||||
self.param_groups = optimizer.param_groups
|
||||
self.curr_iter = curr_iter
|
||||
self.lr_decay = lr_decay
|
||||
|
||||
def step(self):
|
||||
"""执行一次学习率更新。"""
|
||||
for idx, params in enumerate(self.param_groups):
|
||||
params["lr"] = self.init_lr_groups[idx] * (
|
||||
1 - self.curr_iter / self.max_iter
|
||||
) ** self.lr_decay
|
||||
|
||||
|
||||
class Conv2dDeformable(nn.Module):
|
||||
"""实验性可变形卷积包装。"""
|
||||
|
||||
def __init__(self, regular_filter, cuda=True):
|
||||
"""初始化可变形卷积。"""
|
||||
super(Conv2dDeformable, self).__init__()
|
||||
assert isinstance(regular_filter, nn.Conv2d)
|
||||
self.regular_filter = regular_filter
|
||||
self.offset_filter = nn.Conv2d(
|
||||
regular_filter.in_channels,
|
||||
2 * regular_filter.in_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
)
|
||||
self.offset_filter.weight.data.normal_(0, 0.0005)
|
||||
self.input_shape = None
|
||||
self.grid_w = None
|
||||
self.grid_h = None
|
||||
self.cuda = cuda
|
||||
|
||||
def forward(self, x):
|
||||
"""执行可变形卷积前向计算。"""
|
||||
x_shape = x.size()
|
||||
offset = self.offset_filter(x)
|
||||
offset_w, offset_h = torch.split(
|
||||
offset, self.regular_filter.in_channels, 1
|
||||
)
|
||||
offset_w = offset_w.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]))
|
||||
offset_h = offset_h.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]))
|
||||
if not self.input_shape or self.input_shape != x_shape:
|
||||
self.input_shape = x_shape
|
||||
grid_w, grid_h = np.meshgrid(
|
||||
np.linspace(-1, 1, x_shape[3]),
|
||||
np.linspace(-1, 1, x_shape[2]),
|
||||
)
|
||||
grid_w = torch.Tensor(grid_w)
|
||||
grid_h = torch.Tensor(grid_h)
|
||||
if self.cuda:
|
||||
grid_w = grid_w.cuda()
|
||||
grid_h = grid_h.cuda()
|
||||
self.grid_w = nn.Parameter(grid_w)
|
||||
self.grid_h = nn.Parameter(grid_h)
|
||||
offset_w = offset_w + self.grid_w
|
||||
offset_h = offset_h + self.grid_h
|
||||
x = (
|
||||
x.contiguous()
|
||||
.view(-1, int(x_shape[2]), int(x_shape[3]))
|
||||
.unsqueeze(1)
|
||||
)
|
||||
x = F.grid_sample(x, torch.stack((offset_h, offset_w), 3))
|
||||
x = x.contiguous().view(
|
||||
-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3])
|
||||
)
|
||||
x = self.regular_filter(x)
|
||||
return x
|
||||
|
||||
|
||||
def sliced_forward(single_forward):
|
||||
"""对单次前向计算进行滑窗与多尺度封装。"""
|
||||
|
||||
def _pad(x, crop_size):
|
||||
"""将输入张量补齐到裁剪尺寸。"""
|
||||
h, w = x.size()[2:]
|
||||
pad_h = max(crop_size - h, 0)
|
||||
pad_w = max(crop_size - w, 0)
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||
return x, pad_h, pad_w
|
||||
|
||||
def wrapper(self, x):
|
||||
"""封装后的前向函数。"""
|
||||
batch_size, _, ori_h, ori_w = x.size()
|
||||
if self.training and self.use_aux:
|
||||
outputs_all_scales = Variable(
|
||||
torch.zeros((batch_size, self.num_classes, ori_h, ori_w))
|
||||
).cuda()
|
||||
aux_all_scales = Variable(
|
||||
torch.zeros((batch_size, self.num_classes, ori_h, ori_w))
|
||||
).cuda()
|
||||
for scale in self.scales:
|
||||
new_size = (int(ori_h * scale), int(ori_w * scale))
|
||||
scaled_x = F.upsample(x, size=new_size, mode="bilinear")
|
||||
scaled_x = Variable(scaled_x).cuda()
|
||||
scaled_h, scaled_w = scaled_x.size()[2:]
|
||||
long_size = max(scaled_h, scaled_w)
|
||||
|
||||
if long_size > self.crop_size:
|
||||
count = torch.zeros((scaled_h, scaled_w))
|
||||
outputs = Variable(
|
||||
torch.zeros(
|
||||
(batch_size, self.num_classes, scaled_h, scaled_w)
|
||||
)
|
||||
).cuda()
|
||||
aux_outputs = Variable(
|
||||
torch.zeros(
|
||||
(batch_size, self.num_classes, scaled_h, scaled_w)
|
||||
)
|
||||
).cuda()
|
||||
stride = int(ceil(self.crop_size * self.stride_rate))
|
||||
h_step_num = (
|
||||
int(ceil((scaled_h - self.crop_size) / stride)) + 1
|
||||
)
|
||||
w_step_num = (
|
||||
int(ceil((scaled_w - self.crop_size) / stride)) + 1
|
||||
)
|
||||
for yy in range(h_step_num):
|
||||
for xx in range(w_step_num):
|
||||
sy, sx = yy * stride, xx * stride
|
||||
ey, ex = sy + self.crop_size, sx + self.crop_size
|
||||
x_sub = scaled_x[:, :, sy:ey, sx:ex]
|
||||
x_sub, pad_h, pad_w = _pad(x_sub, self.crop_size)
|
||||
outputs_sub, aux_sub = single_forward(self, x_sub)
|
||||
|
||||
if sy + self.crop_size > scaled_h:
|
||||
outputs_sub = outputs_sub[:, :, :-pad_h, :]
|
||||
aux_sub = aux_sub[:, :, :-pad_h, :]
|
||||
|
||||
if sx + self.crop_size > scaled_w:
|
||||
outputs_sub = outputs_sub[:, :, :, :-pad_w]
|
||||
aux_sub = aux_sub[:, :, :, :-pad_w]
|
||||
|
||||
outputs[:, :, sy:ey, sx:ex] = outputs_sub
|
||||
aux_outputs[:, :, sy:ey, sx:ex] = aux_sub
|
||||
|
||||
count[sy:ey, sx:ex] += 1
|
||||
count = Variable(count).cuda()
|
||||
outputs = outputs / count
|
||||
aux_outputs = outputs / count
|
||||
else:
|
||||
scaled_x, pad_h, pad_w = _pad(scaled_x, self.crop_size)
|
||||
outputs, aux_outputs = single_forward(self, scaled_x)
|
||||
outputs = outputs[:, :, :-pad_h, :-pad_w]
|
||||
aux_outputs = aux_outputs[:, :, :-pad_h, :-pad_w]
|
||||
outputs_all_scales += outputs
|
||||
aux_all_scales += aux_outputs
|
||||
return outputs_all_scales / len(self.scales), aux_all_scales
|
||||
outputs_all_scales = Variable(
|
||||
torch.zeros((batch_size, self.num_classes, ori_h, ori_w))
|
||||
).cuda()
|
||||
for scale in self.scales:
|
||||
new_size = (int(ori_h * scale), int(ori_w * scale))
|
||||
scaled_x = F.upsample(x, size=new_size, mode="bilinear")
|
||||
scaled_h, scaled_w = scaled_x.size()[2:]
|
||||
long_size = max(scaled_h, scaled_w)
|
||||
|
||||
if long_size > self.crop_size:
|
||||
count = torch.zeros((scaled_h, scaled_w))
|
||||
outputs = Variable(
|
||||
torch.zeros((batch_size, self.num_classes, scaled_h, scaled_w))
|
||||
).cuda()
|
||||
stride = int(ceil(self.crop_size * self.stride_rate))
|
||||
h_step_num = int(ceil((scaled_h - self.crop_size) / stride)) + 1
|
||||
w_step_num = int(ceil((scaled_w - self.crop_size) / stride)) + 1
|
||||
for yy in range(h_step_num):
|
||||
for xx in range(w_step_num):
|
||||
sy, sx = yy * stride, xx * stride
|
||||
ey, ex = sy + self.crop_size, sx + self.crop_size
|
||||
x_sub = scaled_x[:, :, sy:ey, sx:ex]
|
||||
x_sub, pad_h, pad_w = _pad(x_sub, self.crop_size)
|
||||
|
||||
outputs_sub = single_forward(self, x_sub)
|
||||
|
||||
if sy + self.crop_size > scaled_h:
|
||||
outputs_sub = outputs_sub[:, :, :-pad_h, :]
|
||||
|
||||
if sx + self.crop_size > scaled_w:
|
||||
outputs_sub = outputs_sub[:, :, :, :-pad_w]
|
||||
|
||||
outputs[:, :, sy:ey, sx:ex] = outputs_sub
|
||||
|
||||
count[sy:ey, sx:ex] += 1
|
||||
count = Variable(count).cuda()
|
||||
outputs = outputs / count
|
||||
else:
|
||||
scaled_x, pad_h, pad_w = _pad(scaled_x, self.crop_size)
|
||||
outputs = single_forward(self, scaled_x)
|
||||
outputs = outputs[:, :, :-pad_h, :-pad_w]
|
||||
outputs_all_scales += outputs
|
||||
return outputs_all_scales
|
||||
|
||||
return wrapper
|
||||
@ -0,0 +1,170 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils.dataset_synapse import Synapse_dataset
|
||||
from utils.utils import test_single_volume
|
||||
|
||||
from lib.networks import EMCADNet
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--volume_path', type=str,
|
||||
default='./data/synapse/test_vol_h5_new', help='root dir for validation volume data')
|
||||
parser.add_argument('--dataset', type=str,
|
||||
default='Synapse', help='experiment_name')
|
||||
parser.add_argument('--num_classes', type=int,
|
||||
default=9, help='output channel of network')
|
||||
parser.add_argument('--list_dir', type=str,
|
||||
default='./lists/lists_Synapse', help='list dir')
|
||||
|
||||
# network related parameters
|
||||
parser.add_argument('--encoder', type=str,
|
||||
default='pvt_v2_b2', help='Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...')
|
||||
parser.add_argument('--expansion_factor', type=int,
|
||||
default=2, help='expansion factor in MSCB block')
|
||||
parser.add_argument('--kernel_sizes', type=int, nargs='+',
|
||||
default=[1, 3, 5], help='multi-scale kernel sizes in MSDC block')
|
||||
parser.add_argument('--lgag_ks', type=int,
|
||||
default=3, help='Kernel size in LGAG')
|
||||
parser.add_argument('--activation_mscb', type=str,
|
||||
default='relu6', help='activation used in MSCB: relu6 or relu')
|
||||
parser.add_argument('--no_dw_parallel', action='store_true',
|
||||
default=False, help='use this flag to disable depth-wise parallel convolutions')
|
||||
parser.add_argument('--concatenation', action='store_true',
|
||||
default=False, help='use this flag to concatenate feature maps in MSDC block')
|
||||
parser.add_argument('--no_pretrain', action='store_true',
|
||||
default=False, help='use this flag to turn off loading pretrained enocder weights')
|
||||
parser.add_argument('--pretrained_dir', type=str,
|
||||
default='./pretrained_pth/pvt/', help='path to pretrained encoder dir')
|
||||
parser.add_argument('--supervision', type=str,
|
||||
default='mutation', help='loss supervision: mutation, deep_supervision or last_layer')
|
||||
|
||||
parser.add_argument('--max_iterations', type=int,default=30000, help='maximum epoch number to train')
|
||||
parser.add_argument('--max_epochs', type=int, default=300, help='maximum epoch number to train')
|
||||
parser.add_argument('--batch_size', type=int, default=6,
|
||||
help='batch_size per gpu')
|
||||
parser.add_argument('--base_lr', type=float, default=0.0001, help='segmentation network learning rate')
|
||||
parser.add_argument('--img_size', type=int, default=224, help='input patch size of network input')
|
||||
parser.add_argument('--is_savenii', action="store_true", default=True, help='whether to save results during inference')
|
||||
|
||||
parser.add_argument('--test_save_dir', type=str, default='predictions', help='saving prediction as nii!')
|
||||
parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training')
|
||||
parser.add_argument('--seed', type=int, default=2222, help='random seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
if(args.num_classes == 14):
|
||||
classes = ['spleen', 'right kidney', 'left kidney', 'gallbladder', 'esophagus', 'liver', 'stomach', 'aorta', 'inferior vena cava', 'portal vein and splenic vein', 'pancreas', 'right adrenal gland', 'left adrenal gland']
|
||||
else:
|
||||
classes = ['spleen', 'right kidney', 'left kidney', 'gallbladder', 'pancreas', 'liver', 'stomach', 'aorta']
|
||||
|
||||
def inference(args, model, test_save_path=None):
|
||||
db_test = args.Dataset(base_dir=args.volume_path, split="test_vol", list_dir=args.list_dir, nclass=args.num_classes)
|
||||
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
|
||||
logging.info("{} test iterations per epoch".format(len(testloader)))
|
||||
model.eval()
|
||||
metric_list = 0.0
|
||||
for i_batch, sampled_batch in tqdm(enumerate(testloader)):
|
||||
h, w = sampled_batch["image"].size()[2:]
|
||||
image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0]
|
||||
metric_i = test_single_volume(image, label, model, classes=args.num_classes, patch_size=[args.img_size, args.img_size],
|
||||
test_save_path=test_save_path, case=case_name, z_spacing=1, class_names=classes)
|
||||
metric_list += np.array(metric_i)
|
||||
logging.info('idx %d case %s mean_dice %f mean_hd95 %f, mean_jacard %f mean_asd %f' % (i_batch, case_name, np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1], np.mean(metric_i, axis=0)[2], np.mean(metric_i, axis=0)[3]))
|
||||
metric_list = metric_list / len(db_test)
|
||||
for i in range(1, args.num_classes):
|
||||
logging.info('Mean class (%d) %s mean_dice %f mean_hd95 %f, mean_jacard %f mean_asd %f' % (i, classes[i-1], metric_list[i-1][0], metric_list[i-1][1], metric_list[i-1][2], metric_list[i-1][3]))
|
||||
performance = np.mean(metric_list, axis=0)[0]
|
||||
mean_hd95 = np.mean(metric_list, axis=0)[1]
|
||||
mean_jacard = np.mean(metric_list, axis=0)[2]
|
||||
mean_asd = np.mean(metric_list, axis=0)[3]
|
||||
logging.info('Testing performance in best val model: mean_dice : %f mean_hd95 : %f, mean_jacard : %f mean_asd : %f' % (performance, mean_hd95, mean_jacard, mean_asd))
|
||||
return "Testing Finished!"
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
if not args.deterministic:
|
||||
cudnn.benchmark = True
|
||||
cudnn.deterministic = False
|
||||
else:
|
||||
cudnn.benchmark = False
|
||||
cudnn.deterministic = True
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
|
||||
dataset_config = {
|
||||
'Synapse': {
|
||||
'Dataset': Synapse_dataset,
|
||||
'volume_path': args.volume_path,
|
||||
'list_dir': args.list_dir,
|
||||
'num_classes': args.num_classes,
|
||||
'z_spacing': 1,
|
||||
},
|
||||
}
|
||||
dataset_name = args.dataset
|
||||
args.num_classes = dataset_config[dataset_name]['num_classes']
|
||||
args.volume_path = dataset_config[dataset_name]['volume_path']
|
||||
args.Dataset = dataset_config[dataset_name]['Dataset']
|
||||
args.list_dir = dataset_config[dataset_name]['list_dir']
|
||||
args.z_spacing = dataset_config[dataset_name]['z_spacing']
|
||||
print(args.no_pretrain)
|
||||
|
||||
if args.concatenation:
|
||||
aggregation = 'concat'
|
||||
else:
|
||||
aggregation = 'add'
|
||||
|
||||
if args.no_dw_parallel:
|
||||
dw_mode = 'series'
|
||||
else:
|
||||
dw_mode = 'parallel'
|
||||
|
||||
run = 1
|
||||
|
||||
args.exp = args.encoder + '_EMCAD_kernel_sizes_' + str(args.kernel_sizes) + '_dw_' + dw_mode + '_' + aggregation + '_lgag_ks_' + str(args.lgag_ks) + '_ef' + str(args.expansion_factor) + '_act_mscb_' + args.activation_mscb + '_loss_' + args.supervision + '_output_final_layer_Run'+str(run)+'_' + dataset_name + str(args.img_size)
|
||||
snapshot_path = "model_pth/{}/{}".format(args.exp, args.encoder + '_EMCAD_kernel_sizes_' + str(args.kernel_sizes) + '_dw_' + dw_mode + '_' + aggregation + '_lgag_ks_' + str(args.lgag_ks) + '_ef' + str(args.expansion_factor) + '_act_mscb_' + args.activation_mscb + '_loss_' + args.supervision + '_output_final_layer_Run'+str(run))
|
||||
snapshot_path = snapshot_path.replace('[', '').replace(']', '').replace(', ', '_')
|
||||
|
||||
snapshot_path = snapshot_path + '_pretrain' if not args.no_pretrain else snapshot_path
|
||||
snapshot_path = snapshot_path+'_'+str(args.max_iterations)[0:2]+'k' if args.max_iterations != 50000 else snapshot_path
|
||||
snapshot_path = snapshot_path + '_epo' +str(args.max_epochs) if args.max_epochs != 300 else snapshot_path
|
||||
snapshot_path = snapshot_path+'_bs'+str(args.batch_size)
|
||||
snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.0001 else snapshot_path
|
||||
snapshot_path = snapshot_path + '_'+str(args.img_size)
|
||||
snapshot_path = snapshot_path + '_s'+str(args.seed) if args.seed!=1234 else snapshot_path
|
||||
|
||||
model = EMCADNet(num_classes=args.num_classes, kernel_sizes=args.kernel_sizes, expansion_factor=args.expansion_factor, dw_parallel=not args.no_dw_parallel, add=not args.concatenation, lgag_ks=args.lgag_ks, activation=args.activation_mscb, encoder=args.encoder, pretrain= not args.no_pretrain, pretrained_dir=args.pretrained_dir)
|
||||
model.cuda()
|
||||
|
||||
#snapshot_path = 'model_pth/'+args.encoder+'_EMCAD_wi_normal_dw_parallel_add_Conv2D_cec_cdc1x1_dwc_cs_ef2_k_sizes_1_3_5_ag3g_relu6_up3_relu_to1_3ch_relu_loss2p4_w1_out1_nlrd_mutation_True_cds_False_cds_decoder_FalseRun'+str(run)+'_Synapse224/'+args.encoder+'_EMCAD_wi_normal_dw_parallel_add_Conv2D_cec_cdc1x1_dwc_cs_ef2_k_sizes_1_3_5_ag3g_relu6_up3_relu_to1_3ch_relu_loss2p4_w1_out1_nlrd_mutation_True_cds_False_cds_decoder_FalseRun'+str(run)+'_50k_epo300_bs6_lr0.0001_224_s2222'
|
||||
snapshot = os.path.join(snapshot_path, 'best.pth')
|
||||
if not os.path.exists(snapshot): snapshot = snapshot.replace('best', 'epoch_'+str(args.max_epochs-1))
|
||||
model.load_state_dict(torch.load(snapshot))
|
||||
snapshot_name = snapshot_path.split('/')[-1]
|
||||
|
||||
log_folder = 'test_log/test_log_' + args.exp
|
||||
os.makedirs(log_folder, exist_ok=True)
|
||||
logging.basicConfig(filename=log_folder + '/'+snapshot_name+".txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
|
||||
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
|
||||
logging.info(str(args))
|
||||
logging.info(snapshot_name)
|
||||
|
||||
if args.is_savenii:
|
||||
args.test_save_dir = os.path.join(snapshot_path, "predictions")
|
||||
test_save_path = os.path.join(args.test_save_dir, args.exp, snapshot_name+'2')
|
||||
os.makedirs(test_save_path, exist_ok=True)
|
||||
else:
|
||||
test_save_path = None
|
||||
inference(args, model, test_save_path)
|
||||
|
||||
|
||||
@ -0,0 +1,381 @@
|
||||
"""Synapse 与 ACDC 的训练流程。"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from tensorboardX import SummaryWriter
|
||||
from torch.nn.modules.loss import CrossEntropyLoss
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.utils.dataset_ACDC import ACDCdataset
|
||||
from src.utils.dataset_synapse import RandomGenerator, Synapse_dataset
|
||||
from src.utils.utils import DiceLoss, powerset, val_single_volume
|
||||
|
||||
|
||||
def inference(args, model, best_performance):
|
||||
"""在 Synapse 测试集上进行验证。
|
||||
|
||||
参数:
|
||||
参数: 训练参数。
|
||||
model: 模型。
|
||||
best_performance: 当前最佳指标。
|
||||
|
||||
返回:
|
||||
平均性能指标。
|
||||
"""
|
||||
db_test = Synapse_dataset(
|
||||
base_dir=args.volume_path,
|
||||
split="test_vol",
|
||||
list_dir=args.list_dir,
|
||||
nclass=args.num_classes,
|
||||
)
|
||||
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
|
||||
logging.info("%s test iterations per epoch", len(testloader))
|
||||
|
||||
model.eval()
|
||||
metric_list = 0.0
|
||||
for _, sampled_batch in tqdm(enumerate(testloader)):
|
||||
image, label, case_name = (
|
||||
sampled_batch["image"],
|
||||
sampled_batch["label"],
|
||||
sampled_batch["case_name"][0],
|
||||
)
|
||||
metric_i = val_single_volume(
|
||||
image,
|
||||
label,
|
||||
model,
|
||||
classes=args.num_classes,
|
||||
patch_size=[args.img_size, args.img_size],
|
||||
case=case_name,
|
||||
z_spacing=args.z_spacing,
|
||||
)
|
||||
metric_list += np.array(metric_i)
|
||||
|
||||
metric_list = metric_list / len(db_test)
|
||||
performance = np.mean(metric_list, axis=0)
|
||||
logging.info(
|
||||
"Testing performance in val model: mean_dice : %f, best_dice : %f",
|
||||
performance,
|
||||
best_performance,
|
||||
)
|
||||
return performance
|
||||
|
||||
|
||||
def trainer_synapse(args, model, snapshot_path):
|
||||
"""Synapse 训练流程。"""
|
||||
logging.basicConfig(
|
||||
filename=snapshot_path + "/log.txt",
|
||||
level=logging.INFO,
|
||||
format="[%(asctime)s.%(msecs)03d] %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
|
||||
logging.info(str(args))
|
||||
|
||||
base_lr = args.base_lr
|
||||
num_classes = args.num_classes
|
||||
batch_size = args.batch_size * args.n_gpu
|
||||
|
||||
db_train = Synapse_dataset(
|
||||
base_dir=args.root_path,
|
||||
list_dir=args.list_dir,
|
||||
split="train",
|
||||
nclass=args.num_classes,
|
||||
transform=transforms.Compose(
|
||||
[RandomGenerator(output_size=[args.img_size, args.img_size])]
|
||||
),
|
||||
)
|
||||
print("The length of train set is: {}".format(len(db_train)))
|
||||
|
||||
def worker_init_fn(worker_id):
|
||||
"""为数据加载器设置随机种子。"""
|
||||
random.seed(args.seed + worker_id)
|
||||
|
||||
trainloader = DataLoader(
|
||||
db_train,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if torch.cuda.device_count() > 1 and args.n_gpu > 1:
|
||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
ce_loss = CrossEntropyLoss()
|
||||
dice_loss = DiceLoss(num_classes)
|
||||
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.0001)
|
||||
|
||||
writer = SummaryWriter(snapshot_path + "/log")
|
||||
iter_num = 0
|
||||
max_epoch = args.max_epochs
|
||||
max_iterations = args.max_epochs * len(trainloader)
|
||||
logging.info(
|
||||
"%s iterations per epoch. %s max iterations ",
|
||||
len(trainloader),
|
||||
max_iterations,
|
||||
)
|
||||
best_performance = 0.0
|
||||
|
||||
iterator = tqdm(range(max_epoch), ncols=70)
|
||||
for epoch_num in iterator:
|
||||
for i_batch, sampled_batch in enumerate(trainloader):
|
||||
image_batch, label_batch = (
|
||||
sampled_batch["image"],
|
||||
sampled_batch["label"],
|
||||
)
|
||||
image_batch, label_batch = (
|
||||
image_batch.cuda(),
|
||||
label_batch.squeeze(1).cuda(),
|
||||
)
|
||||
|
||||
outputs = model(image_batch, mode="train")
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
if epoch_num == 0 and i_batch == 0:
|
||||
n_outs = len(outputs)
|
||||
out_idxs = list(np.arange(n_outs))
|
||||
if args.supervision == "mutation":
|
||||
supervision_sets = [x for x in powerset(out_idxs)]
|
||||
elif args.supervision == "deep_supervision":
|
||||
supervision_sets = [[x] for x in out_idxs]
|
||||
else:
|
||||
supervision_sets = [[-1]]
|
||||
print(supervision_sets)
|
||||
|
||||
loss = 0.0
|
||||
w_ce, w_dice = 0.3, 0.7
|
||||
for s in supervision_sets:
|
||||
iout = 0.0
|
||||
if s == []:
|
||||
continue
|
||||
for idx in range(len(s)):
|
||||
iout += outputs[s[idx]]
|
||||
loss_ce = ce_loss(iout, label_batch[:].long())
|
||||
loss_dice = dice_loss(iout, label_batch, softmax=True)
|
||||
loss += w_ce * loss_ce + w_dice * loss_dice
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
lr_ = base_lr
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr_
|
||||
|
||||
iter_num += 1
|
||||
writer.add_scalar("info/lr", lr_, iter_num)
|
||||
writer.add_scalar("info/total_loss", loss, iter_num)
|
||||
|
||||
if iter_num % 50 == 0:
|
||||
logging.info(
|
||||
"iteration %d, epoch %d : loss : %f, lr: %f",
|
||||
iter_num,
|
||||
epoch_num,
|
||||
loss.item(),
|
||||
lr_,
|
||||
)
|
||||
|
||||
save_mode_path = os.path.join(snapshot_path, "last.pth")
|
||||
torch.save(model.state_dict(), save_mode_path)
|
||||
|
||||
performance = inference(args, model, best_performance)
|
||||
save_interval = 50
|
||||
|
||||
if best_performance <= performance:
|
||||
best_performance = performance
|
||||
save_mode_path = os.path.join(snapshot_path, "best.pth")
|
||||
torch.save(model.state_dict(), save_mode_path)
|
||||
logging.info("save model to %s", save_mode_path)
|
||||
|
||||
if (epoch_num + 1) % save_interval == 0:
|
||||
save_mode_path = os.path.join(
|
||||
snapshot_path, "epoch_" + str(epoch_num) + ".pth"
|
||||
)
|
||||
torch.save(model.state_dict(), save_mode_path)
|
||||
logging.info("save model to %s", save_mode_path)
|
||||
|
||||
if epoch_num >= max_epoch - 1:
|
||||
save_mode_path = os.path.join(
|
||||
snapshot_path, "epoch_" + str(epoch_num) + ".pth"
|
||||
)
|
||||
torch.save(model.state_dict(), save_mode_path)
|
||||
logging.info("save model to %s", save_mode_path)
|
||||
|
||||
iterator.close()
|
||||
writer.close()
|
||||
return "Training Finished!"
|
||||
|
||||
|
||||
def trainer_ACDC(args, model, snapshot_path):
|
||||
"""ACDC 训练流程。"""
|
||||
logging.basicConfig(
|
||||
filename=snapshot_path + "/log.txt",
|
||||
level=logging.INFO,
|
||||
format="[%(asctime)s.%(msecs)03d] %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
|
||||
logging.info(str(args))
|
||||
|
||||
base_lr = args.base_lr
|
||||
num_classes = args.num_classes
|
||||
batch_size = args.batch_size * args.n_gpu
|
||||
|
||||
acdc_root = args.root_path.rstrip("/")
|
||||
if os.path.basename(acdc_root) == "train":
|
||||
acdc_root = os.path.dirname(acdc_root)
|
||||
logging.info("Using ACDC root dir: %s", acdc_root)
|
||||
|
||||
db_train = ACDCdataset(
|
||||
base_dir=acdc_root,
|
||||
list_dir=args.list_dir,
|
||||
split="train",
|
||||
transform=transforms.Compose(
|
||||
[RandomGenerator(output_size=[args.img_size, args.img_size])]
|
||||
),
|
||||
)
|
||||
print("The length of train set is: {}".format(len(db_train)))
|
||||
|
||||
def worker_init_fn(worker_id):
|
||||
"""为数据加载器设置随机种子。"""
|
||||
random.seed(args.seed + worker_id)
|
||||
|
||||
trainloader = DataLoader(
|
||||
db_train,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if torch.cuda.device_count() > 1 and args.n_gpu > 1:
|
||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
ce_loss = CrossEntropyLoss()
|
||||
dice_loss = DiceLoss(num_classes)
|
||||
optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.0001)
|
||||
|
||||
writer = SummaryWriter(snapshot_path + "/log")
|
||||
iter_num = 0
|
||||
max_epoch = args.max_epochs
|
||||
max_iterations = args.max_epochs * len(trainloader)
|
||||
logging.info(
|
||||
"%s iterations per epoch. %s max iterations ",
|
||||
len(trainloader),
|
||||
max_iterations,
|
||||
)
|
||||
|
||||
best_loss = 1e9
|
||||
iterator = tqdm(range(max_epoch), ncols=70)
|
||||
for epoch_num in iterator:
|
||||
epoch_loss = 0.0
|
||||
for i_batch, sampled_batch in enumerate(trainloader):
|
||||
image_batch, label_batch = (
|
||||
sampled_batch["image"],
|
||||
sampled_batch["label"],
|
||||
)
|
||||
image_batch, label_batch = (
|
||||
image_batch.to(device),
|
||||
label_batch.squeeze(1).long().to(device),
|
||||
)
|
||||
|
||||
outputs = model(image_batch, mode="train")
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
if epoch_num == 0 and i_batch == 0:
|
||||
n_outs = len(outputs)
|
||||
out_idxs = list(np.arange(n_outs))
|
||||
if args.supervision == "mutation":
|
||||
supervision_sets = [x for x in powerset(out_idxs)]
|
||||
elif args.supervision == "deep_supervision":
|
||||
supervision_sets = [[x] for x in out_idxs]
|
||||
else:
|
||||
supervision_sets = [[-1]]
|
||||
print(supervision_sets)
|
||||
|
||||
loss = 0.0
|
||||
w_ce, w_dice = 0.3, 0.7
|
||||
for s in supervision_sets:
|
||||
iout = 0.0
|
||||
if s == []:
|
||||
continue
|
||||
for idx in range(len(s)):
|
||||
iout += outputs[s[idx]]
|
||||
loss_ce = ce_loss(iout, label_batch)
|
||||
loss_dice = dice_loss(iout, label_batch, softmax=True)
|
||||
loss += w_ce * loss_ce + w_dice * loss_dice
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
lr_ = base_lr
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr_
|
||||
|
||||
iter_num += 1
|
||||
epoch_loss += loss.item()
|
||||
writer.add_scalar("info/lr", lr_, iter_num)
|
||||
writer.add_scalar("info/total_loss", loss, iter_num)
|
||||
|
||||
if iter_num % 50 == 0:
|
||||
logging.info(
|
||||
"iteration %d, epoch %d : loss : %f, lr: %f",
|
||||
iter_num,
|
||||
epoch_num,
|
||||
loss.item(),
|
||||
lr_,
|
||||
)
|
||||
|
||||
epoch_loss /= max(len(trainloader), 1)
|
||||
logging.info(
|
||||
"[ACDC] Epoch %s finished, mean loss = %.4f",
|
||||
epoch_num,
|
||||
epoch_loss,
|
||||
)
|
||||
|
||||
save_mode_path = os.path.join(snapshot_path, "last.pth")
|
||||
torch.save(model.state_dict(), save_mode_path)
|
||||
|
||||
if epoch_loss < best_loss:
|
||||
best_loss = epoch_loss
|
||||
best_path = os.path.join(snapshot_path, "best.pth")
|
||||
torch.save(model.state_dict(), best_path)
|
||||
logging.info(
|
||||
"New best model saved to %s, loss=%.4f", best_path, best_loss
|
||||
)
|
||||
|
||||
save_interval = 50
|
||||
if (epoch_num + 1) % save_interval == 0 or epoch_num == max_epoch - 1:
|
||||
save_mode_path = os.path.join(
|
||||
snapshot_path, "epoch_" + str(epoch_num) + ".pth"
|
||||
)
|
||||
torch.save(model.state_dict(), save_mode_path)
|
||||
logging.info("save model to %s", save_mode_path)
|
||||
|
||||
iterator.close()
|
||||
writer.close()
|
||||
return "ACDC Training Finished!"
|
||||
@ -0,0 +1,362 @@
|
||||
"""Synapse 数据集推理入口。"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.core.networks import EMCADNet
|
||||
from src.utils.dataset_synapse import Synapse_dataset
|
||||
from src.utils.utils import test_single_volume
|
||||
|
||||
|
||||
def build_parser():
|
||||
"""构建推理参数解析器。"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--volume_path",
|
||||
type=str,
|
||||
default="./data/synapse/test_vol_h5_new",
|
||||
help="root dir for validation volume data",
|
||||
)
|
||||
parser.add_argument("--dataset", type=str, default="Synapse", help="experiment_name")
|
||||
parser.add_argument("--num_classes", type=int, default=9, help="output channel of network")
|
||||
parser.add_argument("--list_dir", type=str, default="./lists/lists_Synapse", help="list dir")
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder",
|
||||
type=str,
|
||||
default="pvt_v2_b2",
|
||||
help="Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expansion_factor",
|
||||
type=int,
|
||||
default=2,
|
||||
help="expansion factor in MSCB block",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kernel_sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[1, 3, 5],
|
||||
help="multi-scale kernel sizes in MSDC block",
|
||||
)
|
||||
parser.add_argument("--lgag_ks", type=int, default=3, help="Kernel size in LGAG")
|
||||
parser.add_argument(
|
||||
"--activation_mscb",
|
||||
type=str,
|
||||
default="relu6",
|
||||
help="activation used in MSCB: relu6 or relu",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_dw_parallel",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="use this flag to disable depth-wise parallel convolutions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concatenation",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="use this flag to concatenate feature maps in MSDC block",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_pretrain",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="use this flag to turn off loading pretrained enocder weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_dir",
|
||||
type=str,
|
||||
default="./pretrained_pth/pvt/",
|
||||
help="path to pretrained encoder dir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--supervision",
|
||||
type=str,
|
||||
default="mutation",
|
||||
help="loss supervision: mutation, deep_supervision or last_layer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_iterations",
|
||||
type=int,
|
||||
default=30000,
|
||||
help="maximum epoch number to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_epochs", type=int, default=300, help="maximum epoch number to train"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=6, help="batch_size per gpu")
|
||||
parser.add_argument(
|
||||
"--base_lr", type=float, default=0.0001, help="segmentation network learning rate"
|
||||
)
|
||||
parser.add_argument("--img_size", type=int, default=224, help="input patch size of network input")
|
||||
parser.add_argument(
|
||||
"--is_savenii",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="whether to save results during inference",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--test_save_dir", type=str, default="predictions", help="saving prediction as nii"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--deterministic", type=int, default=1, help="whether use deterministic training"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=2222, help="random seed")
|
||||
return parser
|
||||
|
||||
|
||||
def get_class_names(num_classes):
|
||||
"""根据类别数量返回类别名称列表。"""
|
||||
if num_classes == 14:
|
||||
return [
|
||||
"spleen",
|
||||
"right kidney",
|
||||
"left kidney",
|
||||
"gallbladder",
|
||||
"esophagus",
|
||||
"liver",
|
||||
"stomach",
|
||||
"aorta",
|
||||
"inferior vena cava",
|
||||
"portal vein and splenic vein",
|
||||
"pancreas",
|
||||
"right adrenal gland",
|
||||
"left adrenal gland",
|
||||
]
|
||||
return [
|
||||
"spleen",
|
||||
"right kidney",
|
||||
"left kidney",
|
||||
"gallbladder",
|
||||
"pancreas",
|
||||
"liver",
|
||||
"stomach",
|
||||
"aorta",
|
||||
]
|
||||
|
||||
|
||||
def inference(args, model, class_names, test_save_path=None):
|
||||
"""在 Synapse 上执行推理。"""
|
||||
db_test = args.Dataset(
|
||||
base_dir=args.volume_path,
|
||||
split="test_vol",
|
||||
list_dir=args.list_dir,
|
||||
nclass=args.num_classes,
|
||||
)
|
||||
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)
|
||||
logging.info("%s test iterations per epoch", len(testloader))
|
||||
model.eval()
|
||||
metric_list = 0.0
|
||||
for i_batch, sampled_batch in tqdm(enumerate(testloader)):
|
||||
image, label, case_name = (
|
||||
sampled_batch["image"],
|
||||
sampled_batch["label"],
|
||||
sampled_batch["case_name"][0],
|
||||
)
|
||||
metric_i = test_single_volume(
|
||||
image,
|
||||
label,
|
||||
model,
|
||||
classes=args.num_classes,
|
||||
patch_size=[args.img_size, args.img_size],
|
||||
test_save_path=test_save_path,
|
||||
case=case_name,
|
||||
z_spacing=1,
|
||||
class_names=class_names,
|
||||
)
|
||||
metric_list += np.array(metric_i)
|
||||
logging.info(
|
||||
"idx %d case %s mean_dice %f mean_hd95 %f, mean_jacard %f mean_asd %f",
|
||||
i_batch,
|
||||
case_name,
|
||||
np.mean(metric_i, axis=0)[0],
|
||||
np.mean(metric_i, axis=0)[1],
|
||||
np.mean(metric_i, axis=0)[2],
|
||||
np.mean(metric_i, axis=0)[3],
|
||||
)
|
||||
metric_list = metric_list / len(db_test)
|
||||
for i in range(1, args.num_classes):
|
||||
logging.info(
|
||||
"Mean class (%d) %s mean_dice %f mean_hd95 %f, mean_jacard %f mean_asd %f",
|
||||
i,
|
||||
class_names[i - 1],
|
||||
metric_list[i - 1][0],
|
||||
metric_list[i - 1][1],
|
||||
metric_list[i - 1][2],
|
||||
metric_list[i - 1][3],
|
||||
)
|
||||
performance = np.mean(metric_list, axis=0)[0]
|
||||
mean_hd95 = np.mean(metric_list, axis=0)[1]
|
||||
mean_jacard = np.mean(metric_list, axis=0)[2]
|
||||
mean_asd = np.mean(metric_list, axis=0)[3]
|
||||
logging.info(
|
||||
"Testing performance in best val model: mean_dice : %f mean_hd95 : %f, "
|
||||
"mean_jacard : %f mean_asd : %f",
|
||||
performance,
|
||||
mean_hd95,
|
||||
mean_jacard,
|
||||
mean_asd,
|
||||
)
|
||||
return "Testing Finished!"
|
||||
|
||||
|
||||
def build_snapshot_path(args, dataset_name):
|
||||
"""生成推理输出路径。"""
|
||||
aggregation = "concat" if args.concatenation else "add"
|
||||
dw_mode = "series" if args.no_dw_parallel else "parallel"
|
||||
run = 1
|
||||
|
||||
args.exp = (
|
||||
args.encoder
|
||||
+ "_EMCAD_kernel_sizes_"
|
||||
+ str(args.kernel_sizes)
|
||||
+ "_dw_"
|
||||
+ dw_mode
|
||||
+ "_"
|
||||
+ aggregation
|
||||
+ "_lgag_ks_"
|
||||
+ str(args.lgag_ks)
|
||||
+ "_ef"
|
||||
+ str(args.expansion_factor)
|
||||
+ "_act_mscb_"
|
||||
+ args.activation_mscb
|
||||
+ "_loss_"
|
||||
+ args.supervision
|
||||
+ "_output_final_layer_Run"
|
||||
+ str(run)
|
||||
+ "_"
|
||||
+ dataset_name
|
||||
+ str(args.img_size)
|
||||
)
|
||||
|
||||
snapshot_path = "model_pth/{}/{}".format(
|
||||
args.exp,
|
||||
args.encoder
|
||||
+ "_EMCAD_kernel_sizes_"
|
||||
+ str(args.kernel_sizes)
|
||||
+ "_dw_"
|
||||
+ dw_mode
|
||||
+ "_"
|
||||
+ aggregation
|
||||
+ "_lgag_ks_"
|
||||
+ str(args.lgag_ks)
|
||||
+ "_ef"
|
||||
+ str(args.expansion_factor)
|
||||
+ "_act_mscb_"
|
||||
+ args.activation_mscb
|
||||
+ "_loss_"
|
||||
+ args.supervision
|
||||
+ "_output_final_layer_Run"
|
||||
+ str(run),
|
||||
)
|
||||
snapshot_path = snapshot_path.replace("[", "").replace("]", "").replace(", ", "_")
|
||||
|
||||
if not args.no_pretrain:
|
||||
snapshot_path = snapshot_path + "_pretrain"
|
||||
if args.max_iterations != 50000:
|
||||
snapshot_path = snapshot_path + "_" + str(args.max_iterations)[0:2] + "k"
|
||||
if args.max_epochs != 300:
|
||||
snapshot_path = snapshot_path + "_epo" + str(args.max_epochs)
|
||||
snapshot_path = snapshot_path + "_bs" + str(args.batch_size)
|
||||
if args.base_lr != 0.0001:
|
||||
snapshot_path = snapshot_path + "_lr" + str(args.base_lr)
|
||||
snapshot_path = snapshot_path + "_" + str(args.img_size)
|
||||
if args.seed != 1234:
|
||||
snapshot_path = snapshot_path + "_s" + str(args.seed)
|
||||
|
||||
return snapshot_path
|
||||
|
||||
|
||||
def main():
|
||||
"""主入口函数。"""
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.deterministic:
|
||||
cudnn.benchmark = True
|
||||
cudnn.deterministic = False
|
||||
else:
|
||||
cudnn.benchmark = False
|
||||
cudnn.deterministic = True
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
|
||||
dataset_config = {
|
||||
"Synapse": {
|
||||
"Dataset": Synapse_dataset,
|
||||
"volume_path": args.volume_path,
|
||||
"list_dir": args.list_dir,
|
||||
"num_classes": args.num_classes,
|
||||
"z_spacing": 1,
|
||||
}
|
||||
}
|
||||
dataset_name = args.dataset
|
||||
args.num_classes = dataset_config[dataset_name]["num_classes"]
|
||||
args.volume_path = dataset_config[dataset_name]["volume_path"]
|
||||
args.Dataset = dataset_config[dataset_name]["Dataset"]
|
||||
args.list_dir = dataset_config[dataset_name]["list_dir"]
|
||||
args.z_spacing = dataset_config[dataset_name]["z_spacing"]
|
||||
print(args.no_pretrain)
|
||||
|
||||
snapshot_path = build_snapshot_path(args, dataset_name)
|
||||
|
||||
model = EMCADNet(
|
||||
num_classes=args.num_classes,
|
||||
kernel_sizes=args.kernel_sizes,
|
||||
expansion_factor=args.expansion_factor,
|
||||
dw_parallel=not args.no_dw_parallel,
|
||||
add=not args.concatenation,
|
||||
lgag_ks=args.lgag_ks,
|
||||
activation=args.activation_mscb,
|
||||
encoder=args.encoder,
|
||||
pretrain=not args.no_pretrain,
|
||||
pretrained_dir=args.pretrained_dir,
|
||||
)
|
||||
model.cuda()
|
||||
|
||||
snapshot = os.path.join(snapshot_path, "best.pth")
|
||||
if not os.path.exists(snapshot):
|
||||
snapshot = snapshot.replace("best", "epoch_" + str(args.max_epochs - 1))
|
||||
model.load_state_dict(torch.load(snapshot))
|
||||
snapshot_name = snapshot_path.split("/")[-1]
|
||||
|
||||
log_folder = "test_log/test_log_" + args.exp
|
||||
os.makedirs(log_folder, exist_ok=True)
|
||||
logging.basicConfig(
|
||||
filename=log_folder + "/" + snapshot_name + ".txt",
|
||||
level=logging.INFO,
|
||||
format="[%(asctime)s.%(msecs)03d] %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
|
||||
logging.info(str(args))
|
||||
logging.info(snapshot_name)
|
||||
|
||||
if args.is_savenii:
|
||||
args.test_save_dir = os.path.join(snapshot_path, "predictions")
|
||||
test_save_path = os.path.join(args.test_save_dir, args.exp, snapshot_name + "2")
|
||||
os.makedirs(test_save_path, exist_ok=True)
|
||||
else:
|
||||
test_save_path = None
|
||||
|
||||
class_names = get_class_names(args.num_classes)
|
||||
inference(args, model, class_names, test_save_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in new issue