You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

265 lines
7.7 KiB

"""Synapse 数据集测试模块。"""
import os
import sys
import pytest
import torch
sys.path.insert(0, os.path.abspath("../.."))
from src.core.networks import EMCADNet
from src.utils.config import Config
class TestEMCADNet:
"""EMCADNet 模型测试类。"""
@pytest.fixture
def config(self):
"""加载默认配置。"""
config_path = os.path.join(
os.path.dirname(__file__), "..", "configs", "default.yaml"
)
if os.path.exists(config_path):
return Config.from_yaml(config_path)
return Config()
@pytest.fixture
def model(self, config):
"""创建测试用模型。"""
return EMCADNet(
num_classes=config.model.num_classes,
encoder=config.model.encoder,
kernel_sizes=config.model.kernel_sizes,
expansion_factor=config.model.expansion_factor,
dw_parallel=config.model.dw_parallel,
add=config.model.add,
lgag_ks=config.model.lgag_ks,
activation=config.model.activation,
pretrain=False,
)
def test_model_initialization(self, model, config):
"""测试模型初始化。"""
assert model is not None
assert isinstance(model, torch.nn.Module)
assert model.num_classes == config.model.num_classes
def test_forward_pass(self, model):
"""测试前向传播。"""
batch_size = 2
img_size = 224
x = torch.randn(batch_size, 3, img_size, img_size)
with torch.no_grad():
outputs = model(x, mode="test")
assert outputs is not None
if isinstance(outputs, list):
assert len(outputs) > 0
assert outputs[0].shape == (batch_size, model.num_classes, img_size, img_size)
else:
assert outputs.shape == (batch_size, model.num_classes, img_size, img_size)
def test_training_mode(self, model):
"""测试训练模式前向传播。"""
batch_size = 2
img_size = 224
x = torch.randn(batch_size, 3, img_size, img_size)
model.train()
with torch.no_grad():
outputs = model(x, mode="train")
assert outputs is not None
def test_get_loss(self, model):
"""测试损失计算。"""
batch_size = 2
img_size = 224
num_classes = model.num_classes
x = torch.randn(batch_size, 3, img_size, img_size)
outputs = model(x, mode="test")
targets = torch.randint(0, num_classes, (batch_size, 1, img_size, img_size))
loss = model.get_loss(outputs, targets)
assert loss is not None
assert isinstance(loss, torch.Tensor)
assert loss.item() >= 0
def test_multi_scale_outputs(self, model):
"""测试多尺度输出。"""
batch_size = 2
img_size = 224
x = torch.randn(batch_size, 3, img_size, img_size)
with torch.no_grad():
outputs = model(x, mode="train")
assert isinstance(outputs, list)
assert len(outputs) == 4
def test_single_channel_input(self, model):
"""测试单通道输入。"""
batch_size = 2
img_size = 224
x = torch.randn(batch_size, 1, img_size, img_size)
with torch.no_grad():
outputs = model(x, mode="test")
assert outputs is not None
if isinstance(outputs, list):
assert outputs[0].shape == (batch_size, model.num_classes, img_size, img_size)
def test_parameter_count(self, model):
"""测试参数计数。"""
num_params = model.count_parameters()
assert isinstance(num_params, int)
assert num_params > 0
assert num_params > 1e6
def test_save_load_state_dict(self, model, tmp_path):
"""测试模型状态保存和加载。"""
state_dict = model.state_dict()
save_path = os.path.join(tmp_path, "model.pth")
model.save(save_path)
assert os.path.exists(save_path)
new_model = EMCADNet(
num_classes=model.num_classes,
encoder="pvt_v2_b2",
pretrain=False,
)
new_model.load_state_dict(save_path)
def test_to_device(self, model):
"""测试设备迁移。"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_device = model.to(device)
assert next(model_device.parameters()).device == device
batch_size = 2
img_size = 224
x = torch.randn(batch_size, 3, img_size, img_size).to(device)
with torch.no_grad():
outputs = model_device(x, mode="test")
assert outputs is not None
class TestConfig:
"""配置管理测试类。"""
def test_config_creation(self):
"""测试配置创建。"""
config = Config()
assert config is not None
assert hasattr(config, "dataset")
assert hasattr(config, "model")
assert hasattr(config, "training")
assert hasattr(config, "output")
assert hasattr(config, "logging")
def test_config_from_yaml(self):
"""测试从 YAML 加载配置。"""
config_path = os.path.join(
os.path.dirname(__file__), "..", "configs", "default.yaml"
)
if os.path.exists(config_path):
config = Config.from_yaml(config_path)
assert config is not None
assert config.dataset is not None
assert config.model is not None
assert config.training is not None
def test_config_to_dict(self):
"""测试配置序列化。"""
config = Config()
config_dict = config.to_dict()
assert isinstance(config_dict, dict)
assert "dataset" in config_dict
assert "model" in config_dict
assert "training" in config_dict
def test_config_from_dict(self):
"""测试配置反序列化。"""
data = {
"dataset": {"name": "Test", "num_classes": 4},
"model": {"encoder": "pvt_v2_b2", "num_classes": 4},
}
config = Config.from_dict(data)
assert config is not None
assert config.dataset.name == "Test"
assert config.model.num_classes == 4
class TestInference:
"""推理流程测试类。"""
@pytest.fixture
def trained_model(self):
"""创建预训练模型(用于推理测试)。"""
model = EMCADNet(
num_classes=9,
encoder="pvt_v2_b2",
kernel_sizes=[1, 3, 5],
expansion_factor=2,
dw_parallel=True,
add=True,
lgag_ks=3,
activation="relu",
pretrain=False,
)
return model
def test_inference_pipeline(self, trained_model):
"""测试推理流程。"""
trained_model.eval()
batch_size = 1
img_size = 224
x = torch.randn(batch_size, 3, img_size, img_size)
with torch.no_grad():
outputs = trained_model(x, mode="test")
predictions = torch.argmax(outputs, dim=1) if isinstance(outputs, torch.Tensor) else torch.argmax(outputs[0], dim=1)
assert predictions.shape == (batch_size, img_size, img_size)
def test_dice_calculation(self, trained_model):
"""测试 Dice 指标计算。"""
trained_model.eval()
batch_size = 1
img_size = 224
x = torch.randn(batch_size, 3, img_size, img_size)
target = torch.randint(0, 9, (batch_size, 1, img_size, img_size))
with torch.no_grad():
outputs = trained_model(x, mode="test")
loss = trained_model.get_loss(outputs, target)
assert loss is not None
assert isinstance(loss, torch.Tensor)
if __name__ == "__main__":
pytest.main([__file__, "-v"])