""" LLM生成器单元测试 本模块针对LLMGenerator类进行单元测试,验证API集成、重试机制、 流式功能和错误处理等核心功能。 """ import os import pytest import asyncio import aiohttp import json import time from unittest.mock import Mock, patch, AsyncMock from typing import Dict, Any, List from src.spec.llm_generator import ( LLMGenerator, GenerationRequest, GenerationResult, APIError, RateLimitError, TimeoutError, RetryExhaustedError, generate_specification_sync, generate_batch_specifications_sync ) from src.utils.logger import get_logger logger = get_logger(__name__) class TestLLMGenerator: """LLM生成器单元测试类""" @pytest.fixture def mock_api_response(self): """模拟API响应""" return { "choices": [ { "message": { "content": "\\requires a >= 0 && b >= 0;\\n\\ensures return == a + b;" } } ], "usage": { "total_tokens": 150 } } @pytest.fixture def mock_stream_response_chunks(self): """模拟流式响应块""" return [ b'data: {"choices": [{"delta": {"content": "\\\\requires"}}]}\\n', b'data: {"choices": [{"delta": {"content": " a >= 0"}}]}\\n', b'data: {"choices": [{"delta": {"content": " && b >= 0;"}}]}\\n', b'data: {"choices": [{"delta": {"content": "\\\\ensures"}}]}\\n', b'data: {"choices": [{"delta": {"content": " return == a + b;"}}]}\\n', b'data: [DONE]\\n' ] @pytest.fixture def sample_function_info(self): """示例函数信息""" return { 'name': 'add', 'return_type': 'int', 'parameters': [ {'name': 'a', 'type': 'int'}, {'name': 'b', 'type': 'int'} ], 'variables': [], 'function_calls': [] } @pytest.fixture def sample_request(self, sample_function_info): """示例生成请求""" return GenerationRequest( function_name='add', function_info=sample_function_info, verification_goals=['functional_correctness', 'memory_safety'], max_retries=2, validate=True, store=False ) def test_llm_generator_initialization(self): """测试LLM生成器初始化""" # 使用环境变量进行测试 with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): generator = LLMGenerator() assert generator.api_key == 'test_key' assert generator.base_url == 'https://api.siliconflow.cn/v1' assert generator.model == 'deepseek-ai/DeepSeek-V3.1' assert generator.timeout > 0 assert generator.max_retries > 0 assert generator.temperature >= 0.0 assert generator.max_tokens > 0 def test_llm_generator_missing_api_key(self): """测试缺少API密钥时的错误处理""" with patch.dict(os.environ, {}, clear=True): with pytest.raises(Exception): # 应该抛出LLMError LLMGenerator() @pytest.mark.asyncio async def test_context_manager(self): """测试异步上下文管理器""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): async with LLMGenerator() as generator: assert generator.session is not None assert generator._session_managed is False # 退出上下文后session应该被关闭 assert generator.session is None @pytest.mark.asyncio async def test_generate_specification_success(self, sample_request, mock_api_response): """测试成功的规范生成""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): with patch('aiohttp.ClientSession.post') as mock_post: mock_response = AsyncMock() mock_response.status = 200 mock_response.json = AsyncMock(return_value=mock_api_response) mock_post.return_value.__aenter__.return_value = mock_response async with LLMGenerator() as generator: result = await generator.generate_specification(sample_request) assert result.specification is not None assert result.specification != "" assert result.tokens_used == 150 assert result.generation_time > 0 assert result.quality_score >= 0.0 assert result.metadata['function_name'] == 'add' @pytest.mark.asyncio async def test_generate_specification_api_error(self, sample_request): """测试API错误处理""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): with patch('aiohttp.ClientSession.post') as mock_post: mock_response = AsyncMock() mock_response.status = 500 mock_response.text = AsyncMock(return_value="Internal Server Error") mock_post.return_value.__aenter__.return_value = mock_response async with LLMGenerator() as generator: with pytest.raises(APIError): await generator.generate_specification(sample_request) @pytest.mark.asyncio async def test_generate_specification_rate_limit(self, sample_request): """测试速率限制处理""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): with patch('aiohttp.ClientSession.post') as mock_post: mock_response = AsyncMock() mock_response.status = 429 mock_post.return_value.__aenter__.return_value = mock_response async with LLMGenerator() as generator: with pytest.raises(RateLimitError): await generator.generate_specification(sample_request) @pytest.mark.asyncio async def test_retry_mechanism(self, sample_request, mock_api_response): """测试重试机制""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): call_count = 0 async def mock_post(*args, **kwargs): nonlocal call_count call_count += 1 if call_count < 3: # 前两次失败 mock_response = AsyncMock() mock_response.status = 500 mock_post.return_value.__aenter__.return_value = mock_response return mock_response else: # 第三次成功 mock_response = AsyncMock() mock_response.status = 200 mock_response.json = AsyncMock(return_value=mock_api_response) mock_post.return_value.__aenter__.return_value = mock_response return mock_response with patch('aiohttp.ClientSession.post', side_effect=mock_post): async with LLMGenerator() as generator: result = await generator.generate_specification(sample_request) assert result.specification is not None assert call_count == 3 # 应该重试了3次 @pytest.mark.asyncio async def test_retry_exhausted(self, sample_request): """测试重试次数耗尽""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): with patch('aiohttp.ClientSession.post') as mock_post: mock_response = AsyncMock() mock_response.status = 500 mock_post.return_value.__aenter__.return_value = mock_response async with LLMGenerator() as generator: request = GenerationRequest( function_name='test', function_info={'name': 'test'}, verification_goals=['test'], max_retries=1 # 只重试1次 ) with pytest.raises(RetryExhaustedError): await generator.generate_specification(request) @pytest.mark.asyncio async def test_stream_specification(self, sample_request, mock_stream_response_chunks): """测试流式规范生成""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): with patch('aiohttp.ClientSession.post') as mock_post: mock_response = AsyncMock() mock_response.status = 200 mock_response.content.iter_any = AsyncMock(return_value=mock_stream_response_chunks) mock_post.return_value.__aenter__.return_value = mock_response async with LLMGenerator() as generator: chunks = [] async for chunk in generator.stream_specification(sample_request): chunks.append(chunk) assert len(chunks) > 0 full_spec = ''.join(chunks) assert '\\requires' in full_spec assert '\\ensures' in full_spec @pytest.mark.asyncio async def test_batch_specifications(self, sample_request, mock_api_response): """测试批量规范生成""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): with patch('aiohttp.ClientSession.post') as mock_post: mock_response = AsyncMock() mock_response.status = 200 mock_response.json = AsyncMock(return_value=mock_api_response) mock_post.return_value.__aenter__.return_value = mock_response async with LLMGenerator() as generator: requests = [sample_request, sample_request] # 两个相同的请求 results = await generator.generate_batch_specifications(requests) assert len(results) == 2 for result in results: assert result.specification is not None assert result.tokens_used > 0 @pytest.mark.asyncio async def test_refine_specification(self, sample_request, mock_api_response): """测试规范细化功能""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): with patch('aiohttp.ClientSession.post') as mock_post: mock_response = AsyncMock() mock_response.status = 200 mock_response.json = AsyncMock(return_value=mock_api_response) mock_post.return_value.__aenter__.return_value = mock_response async with LLMGenerator() as generator: original_spec = "\\requires a >= 0;\\n\\ensures return == a + b;" issues = ["Missing parameter validation", "Insufficient coverage"] suggestions = ["Add null checks", "Include overflow checks"] result = await generator.refine_specification( original_spec, sample_request.function_info, issues, suggestions ) assert result.specification is not None assert result.tokens_used > 0 assert result.quality_score >= 0.0 @pytest.mark.asyncio async def test_health_check(self): """测试健康检查功能""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): with patch('aiohttp.ClientSession.post') as mock_post: mock_response = AsyncMock() mock_response.status = 200 mock_post.return_value.__aenter__.return_value = mock_response async with LLMGenerator() as generator: health = await generator.health_check() assert health['status'] == 'healthy' assert 'api_response_time' in health assert 'timestamp' in health @pytest.mark.asyncio async def test_health_check_failure(self): """测试健康检查失败情况""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): with patch('aiohttp.ClientSession.post') as mock_post: mock_response = AsyncMock() mock_response.status = 500 mock_post.return_value.__aenter__.return_value = mock_response async with LLMGenerator() as generator: health = await generator.health_check() assert health['status'] == 'unhealthy' def test_post_process_spec(self): """测试规范后处理功能""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): generator = LLMGenerator() # 测试清理Markdown格式 markdown_spec = """ ```c \\requires a >= 0; \\ensures return == a + b; ``` """ cleaned = generator._post_process_spec(markdown_spec) assert '\\requires' in cleaned assert '\\ensures' in cleaned assert '```' not in cleaned # 测试空输入 empty_result = generator._post_process_spec("") assert empty_result == "" # 测试None输入 none_result = generator._post_process_spec(None) assert none_result == "" def test_generation_stats(self): """测试生成统计功能""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): generator = LLMGenerator() # 模拟存储一些规范 with patch.object(generator.storage, 'list_specifications', return_value=[ { 'metadata': { 'quality_score': 0.8, 'model': 'test-model' } }, { 'metadata': { 'quality_score': 0.9, 'model': 'test-model' } } ]): stats = generator.get_generation_stats() assert stats['total_specifications'] == 2 assert stats['average_quality'] == 0.85 assert 'quality_distribution' in stats assert 'model_usage' in stats def test_synchronous_wrapper(self, sample_request): """测试同步包装器函数""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): with patch('aiohttp.ClientSession.post') as mock_post: mock_response = AsyncMock() mock_response.status = 200 mock_response.json = AsyncMock(return_value={ "choices": [{"message": {"content": "test spec"}}], "usage": {"total_tokens": 100} }) mock_post.return_value.__aenter__.return_value = mock_response # 测试同步函数 result = generate_specification_sync(sample_request) assert result.specification == "test spec" assert result.tokens_used == 100 def test_batch_synchronous_wrapper(self, sample_request): """测试批量同步包装器函数""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): with patch('aiohttp.ClientSession.post') as mock_post: mock_response = AsyncMock() mock_response.status = 200 mock_response.json = AsyncMock(return_value={ "choices": [{"message": {"content": "test spec"}}], "usage": {"total_tokens": 100} }) mock_post.return_value.__aenter__.return_value = mock_response # 测试批量同步函数 requests = [sample_request, sample_request] results = generate_batch_specifications_sync(requests) assert len(results) == 2 for result in results: assert result.specification == "test spec" assert result.tokens_used == 100 @pytest.mark.asyncio async def test_timeout_handling(self, sample_request): """测试超时处理""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): with patch('aiohttp.ClientSession.post') as mock_post: # 模拟超时 mock_post.side_effect = asyncio.TimeoutError() async with LLMGenerator() as generator: with pytest.raises(TimeoutError): await generator.generate_specification(sample_request) @pytest.mark.asyncio async def test_network_error_handling(self, sample_request): """测试网络错误处理""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): with patch('aiohttp.ClientSession.post') as mock_post: # 模拟网络错误 mock_post.side_effect = aiohttp.ClientError("Network error") async with LLMGenerator() as generator: with pytest.raises(APIError): await generator.generate_specification(sample_request) @pytest.mark.asyncio async def test_invalid_json_response(self, sample_request): """测试无效JSON响应处理""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): with patch('aiohttp.ClientSession.post') as mock_post: mock_response = AsyncMock() mock_response.status = 200 mock_response.json = AsyncMock(side_effect=json.JSONDecodeError("Invalid JSON", "", 0)) mock_post.return_value.__aenter__.return_value = mock_response async with LLMGenerator() as generator: with pytest.raises(APIError): await generator.generate_specification(sample_request) @pytest.mark.asyncio async def test_missing_choices_in_response(self, sample_request): """测试响应中缺少choices字段""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): with patch('aiohttp.ClientSession.post') as mock_post: mock_response = AsyncMock() mock_response.status = 200 mock_response.json = AsyncMock(return_value={}) # 空响应 mock_post.return_value.__aenter__.return_value = mock_response async with LLMGenerator() as generator: with pytest.raises(APIError): await generator.generate_specification(sample_request) def test_session_management(self): """测试会话管理""" with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}): generator = LLMGenerator() # 测试自动创建会话 assert generator.session is None generator._ensure_session() assert generator.session is not None assert generator._session_managed is True # 测试清理会话 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(generator._cleanup_session()) assert generator.session is None loop.close() if __name__ == "__main__": pytest.main([__file__, "-v"])