You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
265 lines
7.7 KiB
265 lines
7.7 KiB
"""Synapse 数据集测试模块。"""
|
|
|
|
import os
|
|
import sys
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
sys.path.insert(0, os.path.abspath("../.."))
|
|
|
|
from src.core.networks import EMCADNet
|
|
from src.utils.config import Config
|
|
|
|
|
|
class TestEMCADNet:
|
|
"""EMCADNet 模型测试类。"""
|
|
|
|
@pytest.fixture
|
|
def config(self):
|
|
"""加载默认配置。"""
|
|
config_path = os.path.join(
|
|
os.path.dirname(__file__), "..", "configs", "default.yaml"
|
|
)
|
|
if os.path.exists(config_path):
|
|
return Config.from_yaml(config_path)
|
|
return Config()
|
|
|
|
@pytest.fixture
|
|
def model(self, config):
|
|
"""创建测试用模型。"""
|
|
return EMCADNet(
|
|
num_classes=config.model.num_classes,
|
|
encoder=config.model.encoder,
|
|
kernel_sizes=config.model.kernel_sizes,
|
|
expansion_factor=config.model.expansion_factor,
|
|
dw_parallel=config.model.dw_parallel,
|
|
add=config.model.add,
|
|
lgag_ks=config.model.lgag_ks,
|
|
activation=config.model.activation,
|
|
pretrain=False,
|
|
)
|
|
|
|
def test_model_initialization(self, model, config):
|
|
"""测试模型初始化。"""
|
|
assert model is not None
|
|
assert isinstance(model, torch.nn.Module)
|
|
assert model.num_classes == config.model.num_classes
|
|
|
|
def test_forward_pass(self, model):
|
|
"""测试前向传播。"""
|
|
batch_size = 2
|
|
img_size = 224
|
|
x = torch.randn(batch_size, 3, img_size, img_size)
|
|
|
|
with torch.no_grad():
|
|
outputs = model(x, mode="test")
|
|
|
|
assert outputs is not None
|
|
if isinstance(outputs, list):
|
|
assert len(outputs) > 0
|
|
assert outputs[0].shape == (batch_size, model.num_classes, img_size, img_size)
|
|
else:
|
|
assert outputs.shape == (batch_size, model.num_classes, img_size, img_size)
|
|
|
|
def test_training_mode(self, model):
|
|
"""测试训练模式前向传播。"""
|
|
batch_size = 2
|
|
img_size = 224
|
|
x = torch.randn(batch_size, 3, img_size, img_size)
|
|
|
|
model.train()
|
|
with torch.no_grad():
|
|
outputs = model(x, mode="train")
|
|
|
|
assert outputs is not None
|
|
|
|
def test_get_loss(self, model):
|
|
"""测试损失计算。"""
|
|
batch_size = 2
|
|
img_size = 224
|
|
num_classes = model.num_classes
|
|
|
|
x = torch.randn(batch_size, 3, img_size, img_size)
|
|
outputs = model(x, mode="test")
|
|
targets = torch.randint(0, num_classes, (batch_size, 1, img_size, img_size))
|
|
|
|
loss = model.get_loss(outputs, targets)
|
|
|
|
assert loss is not None
|
|
assert isinstance(loss, torch.Tensor)
|
|
assert loss.item() >= 0
|
|
|
|
def test_multi_scale_outputs(self, model):
|
|
"""测试多尺度输出。"""
|
|
batch_size = 2
|
|
img_size = 224
|
|
x = torch.randn(batch_size, 3, img_size, img_size)
|
|
|
|
with torch.no_grad():
|
|
outputs = model(x, mode="train")
|
|
|
|
assert isinstance(outputs, list)
|
|
assert len(outputs) == 4
|
|
|
|
def test_single_channel_input(self, model):
|
|
"""测试单通道输入。"""
|
|
batch_size = 2
|
|
img_size = 224
|
|
x = torch.randn(batch_size, 1, img_size, img_size)
|
|
|
|
with torch.no_grad():
|
|
outputs = model(x, mode="test")
|
|
|
|
assert outputs is not None
|
|
if isinstance(outputs, list):
|
|
assert outputs[0].shape == (batch_size, model.num_classes, img_size, img_size)
|
|
|
|
def test_parameter_count(self, model):
|
|
"""测试参数计数。"""
|
|
num_params = model.count_parameters()
|
|
|
|
assert isinstance(num_params, int)
|
|
assert num_params > 0
|
|
assert num_params > 1e6
|
|
|
|
def test_save_load_state_dict(self, model, tmp_path):
|
|
"""测试模型状态保存和加载。"""
|
|
state_dict = model.state_dict()
|
|
|
|
save_path = os.path.join(tmp_path, "model.pth")
|
|
model.save(save_path)
|
|
|
|
assert os.path.exists(save_path)
|
|
|
|
new_model = EMCADNet(
|
|
num_classes=model.num_classes,
|
|
encoder="pvt_v2_b2",
|
|
pretrain=False,
|
|
)
|
|
new_model.load_state_dict(save_path)
|
|
|
|
def test_to_device(self, model):
|
|
"""测试设备迁移。"""
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model_device = model.to(device)
|
|
|
|
assert next(model_device.parameters()).device == device
|
|
|
|
batch_size = 2
|
|
img_size = 224
|
|
x = torch.randn(batch_size, 3, img_size, img_size).to(device)
|
|
|
|
with torch.no_grad():
|
|
outputs = model_device(x, mode="test")
|
|
|
|
assert outputs is not None
|
|
|
|
|
|
class TestConfig:
|
|
"""配置管理测试类。"""
|
|
|
|
def test_config_creation(self):
|
|
"""测试配置创建。"""
|
|
config = Config()
|
|
|
|
assert config is not None
|
|
assert hasattr(config, "dataset")
|
|
assert hasattr(config, "model")
|
|
assert hasattr(config, "training")
|
|
assert hasattr(config, "output")
|
|
assert hasattr(config, "logging")
|
|
|
|
def test_config_from_yaml(self):
|
|
"""测试从 YAML 加载配置。"""
|
|
config_path = os.path.join(
|
|
os.path.dirname(__file__), "..", "configs", "default.yaml"
|
|
)
|
|
|
|
if os.path.exists(config_path):
|
|
config = Config.from_yaml(config_path)
|
|
|
|
assert config is not None
|
|
assert config.dataset is not None
|
|
assert config.model is not None
|
|
assert config.training is not None
|
|
|
|
def test_config_to_dict(self):
|
|
"""测试配置序列化。"""
|
|
config = Config()
|
|
config_dict = config.to_dict()
|
|
|
|
assert isinstance(config_dict, dict)
|
|
assert "dataset" in config_dict
|
|
assert "model" in config_dict
|
|
assert "training" in config_dict
|
|
|
|
def test_config_from_dict(self):
|
|
"""测试配置反序列化。"""
|
|
data = {
|
|
"dataset": {"name": "Test", "num_classes": 4},
|
|
"model": {"encoder": "pvt_v2_b2", "num_classes": 4},
|
|
}
|
|
|
|
config = Config.from_dict(data)
|
|
|
|
assert config is not None
|
|
assert config.dataset.name == "Test"
|
|
assert config.model.num_classes == 4
|
|
|
|
|
|
class TestInference:
|
|
"""推理流程测试类。"""
|
|
|
|
@pytest.fixture
|
|
def trained_model(self):
|
|
"""创建预训练模型(用于推理测试)。"""
|
|
model = EMCADNet(
|
|
num_classes=9,
|
|
encoder="pvt_v2_b2",
|
|
kernel_sizes=[1, 3, 5],
|
|
expansion_factor=2,
|
|
dw_parallel=True,
|
|
add=True,
|
|
lgag_ks=3,
|
|
activation="relu",
|
|
pretrain=False,
|
|
)
|
|
return model
|
|
|
|
def test_inference_pipeline(self, trained_model):
|
|
"""测试推理流程。"""
|
|
trained_model.eval()
|
|
|
|
batch_size = 1
|
|
img_size = 224
|
|
x = torch.randn(batch_size, 3, img_size, img_size)
|
|
|
|
with torch.no_grad():
|
|
outputs = trained_model(x, mode="test")
|
|
|
|
predictions = torch.argmax(outputs, dim=1) if isinstance(outputs, torch.Tensor) else torch.argmax(outputs[0], dim=1)
|
|
|
|
assert predictions.shape == (batch_size, img_size, img_size)
|
|
|
|
def test_dice_calculation(self, trained_model):
|
|
"""测试 Dice 指标计算。"""
|
|
trained_model.eval()
|
|
|
|
batch_size = 1
|
|
img_size = 224
|
|
x = torch.randn(batch_size, 3, img_size, img_size)
|
|
target = torch.randint(0, 9, (batch_size, 1, img_size, img_size))
|
|
|
|
with torch.no_grad():
|
|
outputs = trained_model(x, mode="test")
|
|
|
|
loss = trained_model.get_loss(outputs, target)
|
|
|
|
assert loss is not None
|
|
assert isinstance(loss, torch.Tensor)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|