"""Synapse 数据集测试模块。""" import os import sys import pytest import torch sys.path.insert(0, os.path.abspath("../..")) from src.core.networks import EMCADNet from src.utils.config import Config class TestEMCADNet: """EMCADNet 模型测试类。""" @pytest.fixture def config(self): """加载默认配置。""" config_path = os.path.join( os.path.dirname(__file__), "..", "configs", "default.yaml" ) if os.path.exists(config_path): return Config.from_yaml(config_path) return Config() @pytest.fixture def model(self, config): """创建测试用模型。""" return EMCADNet( num_classes=config.model.num_classes, encoder=config.model.encoder, kernel_sizes=config.model.kernel_sizes, expansion_factor=config.model.expansion_factor, dw_parallel=config.model.dw_parallel, add=config.model.add, lgag_ks=config.model.lgag_ks, activation=config.model.activation, pretrain=False, ) def test_model_initialization(self, model, config): """测试模型初始化。""" assert model is not None assert isinstance(model, torch.nn.Module) assert model.num_classes == config.model.num_classes def test_forward_pass(self, model): """测试前向传播。""" batch_size = 2 img_size = 224 x = torch.randn(batch_size, 3, img_size, img_size) with torch.no_grad(): outputs = model(x, mode="test") assert outputs is not None if isinstance(outputs, list): assert len(outputs) > 0 assert outputs[0].shape == (batch_size, model.num_classes, img_size, img_size) else: assert outputs.shape == (batch_size, model.num_classes, img_size, img_size) def test_training_mode(self, model): """测试训练模式前向传播。""" batch_size = 2 img_size = 224 x = torch.randn(batch_size, 3, img_size, img_size) model.train() with torch.no_grad(): outputs = model(x, mode="train") assert outputs is not None def test_get_loss(self, model): """测试损失计算。""" batch_size = 2 img_size = 224 num_classes = model.num_classes x = torch.randn(batch_size, 3, img_size, img_size) outputs = model(x, mode="test") targets = torch.randint(0, num_classes, (batch_size, 1, img_size, img_size)) loss = model.get_loss(outputs, targets) assert loss is not None assert isinstance(loss, torch.Tensor) assert loss.item() >= 0 def test_multi_scale_outputs(self, model): """测试多尺度输出。""" batch_size = 2 img_size = 224 x = torch.randn(batch_size, 3, img_size, img_size) with torch.no_grad(): outputs = model(x, mode="train") assert isinstance(outputs, list) assert len(outputs) == 4 def test_single_channel_input(self, model): """测试单通道输入。""" batch_size = 2 img_size = 224 x = torch.randn(batch_size, 1, img_size, img_size) with torch.no_grad(): outputs = model(x, mode="test") assert outputs is not None if isinstance(outputs, list): assert outputs[0].shape == (batch_size, model.num_classes, img_size, img_size) def test_parameter_count(self, model): """测试参数计数。""" num_params = model.count_parameters() assert isinstance(num_params, int) assert num_params > 0 assert num_params > 1e6 def test_save_load_state_dict(self, model, tmp_path): """测试模型状态保存和加载。""" state_dict = model.state_dict() save_path = os.path.join(tmp_path, "model.pth") model.save(save_path) assert os.path.exists(save_path) new_model = EMCADNet( num_classes=model.num_classes, encoder="pvt_v2_b2", pretrain=False, ) new_model.load_state_dict(save_path) def test_to_device(self, model): """测试设备迁移。""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_device = model.to(device) assert next(model_device.parameters()).device == device batch_size = 2 img_size = 224 x = torch.randn(batch_size, 3, img_size, img_size).to(device) with torch.no_grad(): outputs = model_device(x, mode="test") assert outputs is not None class TestConfig: """配置管理测试类。""" def test_config_creation(self): """测试配置创建。""" config = Config() assert config is not None assert hasattr(config, "dataset") assert hasattr(config, "model") assert hasattr(config, "training") assert hasattr(config, "output") assert hasattr(config, "logging") def test_config_from_yaml(self): """测试从 YAML 加载配置。""" config_path = os.path.join( os.path.dirname(__file__), "..", "configs", "default.yaml" ) if os.path.exists(config_path): config = Config.from_yaml(config_path) assert config is not None assert config.dataset is not None assert config.model is not None assert config.training is not None def test_config_to_dict(self): """测试配置序列化。""" config = Config() config_dict = config.to_dict() assert isinstance(config_dict, dict) assert "dataset" in config_dict assert "model" in config_dict assert "training" in config_dict def test_config_from_dict(self): """测试配置反序列化。""" data = { "dataset": {"name": "Test", "num_classes": 4}, "model": {"encoder": "pvt_v2_b2", "num_classes": 4}, } config = Config.from_dict(data) assert config is not None assert config.dataset.name == "Test" assert config.model.num_classes == 4 class TestInference: """推理流程测试类。""" @pytest.fixture def trained_model(self): """创建预训练模型(用于推理测试)。""" model = EMCADNet( num_classes=9, encoder="pvt_v2_b2", kernel_sizes=[1, 3, 5], expansion_factor=2, dw_parallel=True, add=True, lgag_ks=3, activation="relu", pretrain=False, ) return model def test_inference_pipeline(self, trained_model): """测试推理流程。""" trained_model.eval() batch_size = 1 img_size = 224 x = torch.randn(batch_size, 3, img_size, img_size) with torch.no_grad(): outputs = trained_model(x, mode="test") predictions = torch.argmax(outputs, dim=1) if isinstance(outputs, torch.Tensor) else torch.argmax(outputs[0], dim=1) assert predictions.shape == (batch_size, img_size, img_size) def test_dice_calculation(self, trained_model): """测试 Dice 指标计算。""" trained_model.eval() batch_size = 1 img_size = 224 x = torch.randn(batch_size, 3, img_size, img_size) target = torch.randint(0, 9, (batch_size, 1, img_size, img_size)) with torch.no_grad(): outputs = trained_model(x, mode="test") loss = trained_model.get_loss(outputs, target) assert loss is not None assert isinstance(loss, torch.Tensor) if __name__ == "__main__": pytest.main([__file__, "-v"])