|
|
"""
|
|
|
LLM生成器单元测试
|
|
|
|
|
|
本模块针对LLMGenerator类进行单元测试,验证API集成、重试机制、
|
|
|
流式功能和错误处理等核心功能。
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import pytest
|
|
|
import asyncio
|
|
|
import aiohttp
|
|
|
import json
|
|
|
import time
|
|
|
from unittest.mock import Mock, patch, AsyncMock
|
|
|
from typing import Dict, Any, List
|
|
|
|
|
|
from src.spec.llm_generator import (
|
|
|
LLMGenerator, GenerationRequest, GenerationResult,
|
|
|
APIError, RateLimitError, TimeoutError, RetryExhaustedError,
|
|
|
generate_specification_sync, generate_batch_specifications_sync
|
|
|
)
|
|
|
from src.utils.logger import get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
class TestLLMGenerator:
|
|
|
"""LLM生成器单元测试类"""
|
|
|
|
|
|
@pytest.fixture
|
|
|
def mock_api_response(self):
|
|
|
"""模拟API响应"""
|
|
|
return {
|
|
|
"choices": [
|
|
|
{
|
|
|
"message": {
|
|
|
"content": "\\requires a >= 0 && b >= 0;\\n\\ensures return == a + b;"
|
|
|
}
|
|
|
}
|
|
|
],
|
|
|
"usage": {
|
|
|
"total_tokens": 150
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
|
def mock_stream_response_chunks(self):
|
|
|
"""模拟流式响应块"""
|
|
|
return [
|
|
|
b'data: {"choices": [{"delta": {"content": "\\\\requires"}}]}\\n',
|
|
|
b'data: {"choices": [{"delta": {"content": " a >= 0"}}]}\\n',
|
|
|
b'data: {"choices": [{"delta": {"content": " && b >= 0;"}}]}\\n',
|
|
|
b'data: {"choices": [{"delta": {"content": "\\\\ensures"}}]}\\n',
|
|
|
b'data: {"choices": [{"delta": {"content": " return == a + b;"}}]}\\n',
|
|
|
b'data: [DONE]\\n'
|
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
|
def sample_function_info(self):
|
|
|
"""示例函数信息"""
|
|
|
return {
|
|
|
'name': 'add',
|
|
|
'return_type': 'int',
|
|
|
'parameters': [
|
|
|
{'name': 'a', 'type': 'int'},
|
|
|
{'name': 'b', 'type': 'int'}
|
|
|
],
|
|
|
'variables': [],
|
|
|
'function_calls': []
|
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
|
def sample_request(self, sample_function_info):
|
|
|
"""示例生成请求"""
|
|
|
return GenerationRequest(
|
|
|
function_name='add',
|
|
|
function_info=sample_function_info,
|
|
|
verification_goals=['functional_correctness', 'memory_safety'],
|
|
|
max_retries=2,
|
|
|
validate=True,
|
|
|
store=False
|
|
|
)
|
|
|
|
|
|
def test_llm_generator_initialization(self):
|
|
|
"""测试LLM生成器初始化"""
|
|
|
# 使用环境变量进行测试
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
generator = LLMGenerator()
|
|
|
assert generator.api_key == 'test_key'
|
|
|
assert generator.base_url == 'https://api.siliconflow.cn/v1'
|
|
|
assert generator.model == 'deepseek-ai/DeepSeek-V3.1'
|
|
|
assert generator.timeout > 0
|
|
|
assert generator.max_retries > 0
|
|
|
assert generator.temperature >= 0.0
|
|
|
assert generator.max_tokens > 0
|
|
|
|
|
|
def test_llm_generator_missing_api_key(self):
|
|
|
"""测试缺少API密钥时的错误处理"""
|
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
|
with pytest.raises(Exception): # 应该抛出LLMError
|
|
|
LLMGenerator()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_context_manager(self):
|
|
|
"""测试异步上下文管理器"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
async with LLMGenerator() as generator:
|
|
|
assert generator.session is not None
|
|
|
assert generator._session_managed is False
|
|
|
|
|
|
# 退出上下文后session应该被关闭
|
|
|
assert generator.session is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_generate_specification_success(self, sample_request, mock_api_response):
|
|
|
"""测试成功的规范生成"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
with patch('aiohttp.ClientSession.post') as mock_post:
|
|
|
mock_response = AsyncMock()
|
|
|
mock_response.status = 200
|
|
|
mock_response.json = AsyncMock(return_value=mock_api_response)
|
|
|
mock_post.return_value.__aenter__.return_value = mock_response
|
|
|
|
|
|
async with LLMGenerator() as generator:
|
|
|
result = await generator.generate_specification(sample_request)
|
|
|
|
|
|
assert result.specification is not None
|
|
|
assert result.specification != ""
|
|
|
assert result.tokens_used == 150
|
|
|
assert result.generation_time > 0
|
|
|
assert result.quality_score >= 0.0
|
|
|
assert result.metadata['function_name'] == 'add'
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_generate_specification_api_error(self, sample_request):
|
|
|
"""测试API错误处理"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
with patch('aiohttp.ClientSession.post') as mock_post:
|
|
|
mock_response = AsyncMock()
|
|
|
mock_response.status = 500
|
|
|
mock_response.text = AsyncMock(return_value="Internal Server Error")
|
|
|
mock_post.return_value.__aenter__.return_value = mock_response
|
|
|
|
|
|
async with LLMGenerator() as generator:
|
|
|
with pytest.raises(APIError):
|
|
|
await generator.generate_specification(sample_request)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_generate_specification_rate_limit(self, sample_request):
|
|
|
"""测试速率限制处理"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
with patch('aiohttp.ClientSession.post') as mock_post:
|
|
|
mock_response = AsyncMock()
|
|
|
mock_response.status = 429
|
|
|
mock_post.return_value.__aenter__.return_value = mock_response
|
|
|
|
|
|
async with LLMGenerator() as generator:
|
|
|
with pytest.raises(RateLimitError):
|
|
|
await generator.generate_specification(sample_request)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_retry_mechanism(self, sample_request, mock_api_response):
|
|
|
"""测试重试机制"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
call_count = 0
|
|
|
|
|
|
async def mock_post(*args, **kwargs):
|
|
|
nonlocal call_count
|
|
|
call_count += 1
|
|
|
if call_count < 3: # 前两次失败
|
|
|
mock_response = AsyncMock()
|
|
|
mock_response.status = 500
|
|
|
mock_post.return_value.__aenter__.return_value = mock_response
|
|
|
return mock_response
|
|
|
else: # 第三次成功
|
|
|
mock_response = AsyncMock()
|
|
|
mock_response.status = 200
|
|
|
mock_response.json = AsyncMock(return_value=mock_api_response)
|
|
|
mock_post.return_value.__aenter__.return_value = mock_response
|
|
|
return mock_response
|
|
|
|
|
|
with patch('aiohttp.ClientSession.post', side_effect=mock_post):
|
|
|
async with LLMGenerator() as generator:
|
|
|
result = await generator.generate_specification(sample_request)
|
|
|
|
|
|
assert result.specification is not None
|
|
|
assert call_count == 3 # 应该重试了3次
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_retry_exhausted(self, sample_request):
|
|
|
"""测试重试次数耗尽"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
with patch('aiohttp.ClientSession.post') as mock_post:
|
|
|
mock_response = AsyncMock()
|
|
|
mock_response.status = 500
|
|
|
mock_post.return_value.__aenter__.return_value = mock_response
|
|
|
|
|
|
async with LLMGenerator() as generator:
|
|
|
request = GenerationRequest(
|
|
|
function_name='test',
|
|
|
function_info={'name': 'test'},
|
|
|
verification_goals=['test'],
|
|
|
max_retries=1 # 只重试1次
|
|
|
)
|
|
|
|
|
|
with pytest.raises(RetryExhaustedError):
|
|
|
await generator.generate_specification(request)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_stream_specification(self, sample_request, mock_stream_response_chunks):
|
|
|
"""测试流式规范生成"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
with patch('aiohttp.ClientSession.post') as mock_post:
|
|
|
mock_response = AsyncMock()
|
|
|
mock_response.status = 200
|
|
|
mock_response.content.iter_any = AsyncMock(return_value=mock_stream_response_chunks)
|
|
|
mock_post.return_value.__aenter__.return_value = mock_response
|
|
|
|
|
|
async with LLMGenerator() as generator:
|
|
|
chunks = []
|
|
|
async for chunk in generator.stream_specification(sample_request):
|
|
|
chunks.append(chunk)
|
|
|
|
|
|
assert len(chunks) > 0
|
|
|
full_spec = ''.join(chunks)
|
|
|
assert '\\requires' in full_spec
|
|
|
assert '\\ensures' in full_spec
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_batch_specifications(self, sample_request, mock_api_response):
|
|
|
"""测试批量规范生成"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
with patch('aiohttp.ClientSession.post') as mock_post:
|
|
|
mock_response = AsyncMock()
|
|
|
mock_response.status = 200
|
|
|
mock_response.json = AsyncMock(return_value=mock_api_response)
|
|
|
mock_post.return_value.__aenter__.return_value = mock_response
|
|
|
|
|
|
async with LLMGenerator() as generator:
|
|
|
requests = [sample_request, sample_request] # 两个相同的请求
|
|
|
results = await generator.generate_batch_specifications(requests)
|
|
|
|
|
|
assert len(results) == 2
|
|
|
for result in results:
|
|
|
assert result.specification is not None
|
|
|
assert result.tokens_used > 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_refine_specification(self, sample_request, mock_api_response):
|
|
|
"""测试规范细化功能"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
with patch('aiohttp.ClientSession.post') as mock_post:
|
|
|
mock_response = AsyncMock()
|
|
|
mock_response.status = 200
|
|
|
mock_response.json = AsyncMock(return_value=mock_api_response)
|
|
|
mock_post.return_value.__aenter__.return_value = mock_response
|
|
|
|
|
|
async with LLMGenerator() as generator:
|
|
|
original_spec = "\\requires a >= 0;\\n\\ensures return == a + b;"
|
|
|
issues = ["Missing parameter validation", "Insufficient coverage"]
|
|
|
suggestions = ["Add null checks", "Include overflow checks"]
|
|
|
|
|
|
result = await generator.refine_specification(
|
|
|
original_spec, sample_request.function_info, issues, suggestions
|
|
|
)
|
|
|
|
|
|
assert result.specification is not None
|
|
|
assert result.tokens_used > 0
|
|
|
assert result.quality_score >= 0.0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_health_check(self):
|
|
|
"""测试健康检查功能"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
with patch('aiohttp.ClientSession.post') as mock_post:
|
|
|
mock_response = AsyncMock()
|
|
|
mock_response.status = 200
|
|
|
mock_post.return_value.__aenter__.return_value = mock_response
|
|
|
|
|
|
async with LLMGenerator() as generator:
|
|
|
health = await generator.health_check()
|
|
|
|
|
|
assert health['status'] == 'healthy'
|
|
|
assert 'api_response_time' in health
|
|
|
assert 'timestamp' in health
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_health_check_failure(self):
|
|
|
"""测试健康检查失败情况"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
with patch('aiohttp.ClientSession.post') as mock_post:
|
|
|
mock_response = AsyncMock()
|
|
|
mock_response.status = 500
|
|
|
mock_post.return_value.__aenter__.return_value = mock_response
|
|
|
|
|
|
async with LLMGenerator() as generator:
|
|
|
health = await generator.health_check()
|
|
|
|
|
|
assert health['status'] == 'unhealthy'
|
|
|
|
|
|
def test_post_process_spec(self):
|
|
|
"""测试规范后处理功能"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
generator = LLMGenerator()
|
|
|
|
|
|
# 测试清理Markdown格式
|
|
|
markdown_spec = """
|
|
|
```c
|
|
|
\\requires a >= 0;
|
|
|
\\ensures return == a + b;
|
|
|
```
|
|
|
"""
|
|
|
cleaned = generator._post_process_spec(markdown_spec)
|
|
|
assert '\\requires' in cleaned
|
|
|
assert '\\ensures' in cleaned
|
|
|
assert '```' not in cleaned
|
|
|
|
|
|
# 测试空输入
|
|
|
empty_result = generator._post_process_spec("")
|
|
|
assert empty_result == ""
|
|
|
|
|
|
# 测试None输入
|
|
|
none_result = generator._post_process_spec(None)
|
|
|
assert none_result == ""
|
|
|
|
|
|
def test_generation_stats(self):
|
|
|
"""测试生成统计功能"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
generator = LLMGenerator()
|
|
|
|
|
|
# 模拟存储一些规范
|
|
|
with patch.object(generator.storage, 'list_specifications', return_value=[
|
|
|
{
|
|
|
'metadata': {
|
|
|
'quality_score': 0.8,
|
|
|
'model': 'test-model'
|
|
|
}
|
|
|
},
|
|
|
{
|
|
|
'metadata': {
|
|
|
'quality_score': 0.9,
|
|
|
'model': 'test-model'
|
|
|
}
|
|
|
}
|
|
|
]):
|
|
|
stats = generator.get_generation_stats()
|
|
|
|
|
|
assert stats['total_specifications'] == 2
|
|
|
assert stats['average_quality'] == 0.85
|
|
|
assert 'quality_distribution' in stats
|
|
|
assert 'model_usage' in stats
|
|
|
|
|
|
def test_synchronous_wrapper(self, sample_request):
|
|
|
"""测试同步包装器函数"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
with patch('aiohttp.ClientSession.post') as mock_post:
|
|
|
mock_response = AsyncMock()
|
|
|
mock_response.status = 200
|
|
|
mock_response.json = AsyncMock(return_value={
|
|
|
"choices": [{"message": {"content": "test spec"}}],
|
|
|
"usage": {"total_tokens": 100}
|
|
|
})
|
|
|
mock_post.return_value.__aenter__.return_value = mock_response
|
|
|
|
|
|
# 测试同步函数
|
|
|
result = generate_specification_sync(sample_request)
|
|
|
assert result.specification == "test spec"
|
|
|
assert result.tokens_used == 100
|
|
|
|
|
|
def test_batch_synchronous_wrapper(self, sample_request):
|
|
|
"""测试批量同步包装器函数"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
with patch('aiohttp.ClientSession.post') as mock_post:
|
|
|
mock_response = AsyncMock()
|
|
|
mock_response.status = 200
|
|
|
mock_response.json = AsyncMock(return_value={
|
|
|
"choices": [{"message": {"content": "test spec"}}],
|
|
|
"usage": {"total_tokens": 100}
|
|
|
})
|
|
|
mock_post.return_value.__aenter__.return_value = mock_response
|
|
|
|
|
|
# 测试批量同步函数
|
|
|
requests = [sample_request, sample_request]
|
|
|
results = generate_batch_specifications_sync(requests)
|
|
|
|
|
|
assert len(results) == 2
|
|
|
for result in results:
|
|
|
assert result.specification == "test spec"
|
|
|
assert result.tokens_used == 100
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_timeout_handling(self, sample_request):
|
|
|
"""测试超时处理"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
with patch('aiohttp.ClientSession.post') as mock_post:
|
|
|
# 模拟超时
|
|
|
mock_post.side_effect = asyncio.TimeoutError()
|
|
|
|
|
|
async with LLMGenerator() as generator:
|
|
|
with pytest.raises(TimeoutError):
|
|
|
await generator.generate_specification(sample_request)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_network_error_handling(self, sample_request):
|
|
|
"""测试网络错误处理"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
with patch('aiohttp.ClientSession.post') as mock_post:
|
|
|
# 模拟网络错误
|
|
|
mock_post.side_effect = aiohttp.ClientError("Network error")
|
|
|
|
|
|
async with LLMGenerator() as generator:
|
|
|
with pytest.raises(APIError):
|
|
|
await generator.generate_specification(sample_request)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_invalid_json_response(self, sample_request):
|
|
|
"""测试无效JSON响应处理"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
with patch('aiohttp.ClientSession.post') as mock_post:
|
|
|
mock_response = AsyncMock()
|
|
|
mock_response.status = 200
|
|
|
mock_response.json = AsyncMock(side_effect=json.JSONDecodeError("Invalid JSON", "", 0))
|
|
|
mock_post.return_value.__aenter__.return_value = mock_response
|
|
|
|
|
|
async with LLMGenerator() as generator:
|
|
|
with pytest.raises(APIError):
|
|
|
await generator.generate_specification(sample_request)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_missing_choices_in_response(self, sample_request):
|
|
|
"""测试响应中缺少choices字段"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
with patch('aiohttp.ClientSession.post') as mock_post:
|
|
|
mock_response = AsyncMock()
|
|
|
mock_response.status = 200
|
|
|
mock_response.json = AsyncMock(return_value={}) # 空响应
|
|
|
mock_post.return_value.__aenter__.return_value = mock_response
|
|
|
|
|
|
async with LLMGenerator() as generator:
|
|
|
with pytest.raises(APIError):
|
|
|
await generator.generate_specification(sample_request)
|
|
|
|
|
|
def test_session_management(self):
|
|
|
"""测试会话管理"""
|
|
|
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
|
|
|
generator = LLMGenerator()
|
|
|
|
|
|
# 测试自动创建会话
|
|
|
assert generator.session is None
|
|
|
generator._ensure_session()
|
|
|
assert generator.session is not None
|
|
|
assert generator._session_managed is True
|
|
|
|
|
|
# 测试清理会话
|
|
|
loop = asyncio.new_event_loop()
|
|
|
asyncio.set_event_loop(loop)
|
|
|
loop.run_until_complete(generator._cleanup_session())
|
|
|
assert generator.session is None
|
|
|
loop.close()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
pytest.main([__file__, "-v"]) |