You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
cbmc/codedetect/tests/unit/test_llm_generator.py

463 lines
19 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
LLM生成器单元测试
本模块针对LLMGenerator类进行单元测试验证API集成、重试机制、
流式功能和错误处理等核心功能。
"""
import os
import pytest
import asyncio
import aiohttp
import json
import time
from unittest.mock import Mock, patch, AsyncMock
from typing import Dict, Any, List
from src.spec.llm_generator import (
LLMGenerator, GenerationRequest, GenerationResult,
APIError, RateLimitError, TimeoutError, RetryExhaustedError,
generate_specification_sync, generate_batch_specifications_sync
)
from src.utils.logger import get_logger
logger = get_logger(__name__)
class TestLLMGenerator:
"""LLM生成器单元测试类"""
@pytest.fixture
def mock_api_response(self):
"""模拟API响应"""
return {
"choices": [
{
"message": {
"content": "\\requires a >= 0 && b >= 0;\\n\\ensures return == a + b;"
}
}
],
"usage": {
"total_tokens": 150
}
}
@pytest.fixture
def mock_stream_response_chunks(self):
"""模拟流式响应块"""
return [
b'data: {"choices": [{"delta": {"content": "\\\\requires"}}]}\\n',
b'data: {"choices": [{"delta": {"content": " a >= 0"}}]}\\n',
b'data: {"choices": [{"delta": {"content": " && b >= 0;"}}]}\\n',
b'data: {"choices": [{"delta": {"content": "\\\\ensures"}}]}\\n',
b'data: {"choices": [{"delta": {"content": " return == a + b;"}}]}\\n',
b'data: [DONE]\\n'
]
@pytest.fixture
def sample_function_info(self):
"""示例函数信息"""
return {
'name': 'add',
'return_type': 'int',
'parameters': [
{'name': 'a', 'type': 'int'},
{'name': 'b', 'type': 'int'}
],
'variables': [],
'function_calls': []
}
@pytest.fixture
def sample_request(self, sample_function_info):
"""示例生成请求"""
return GenerationRequest(
function_name='add',
function_info=sample_function_info,
verification_goals=['functional_correctness', 'memory_safety'],
max_retries=2,
validate=True,
store=False
)
def test_llm_generator_initialization(self):
"""测试LLM生成器初始化"""
# 使用环境变量进行测试
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
generator = LLMGenerator()
assert generator.api_key == 'test_key'
assert generator.base_url == 'https://api.siliconflow.cn/v1'
assert generator.model == 'deepseek-ai/DeepSeek-V3.1'
assert generator.timeout > 0
assert generator.max_retries > 0
assert generator.temperature >= 0.0
assert generator.max_tokens > 0
def test_llm_generator_missing_api_key(self):
"""测试缺少API密钥时的错误处理"""
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(Exception): # 应该抛出LLMError
LLMGenerator()
@pytest.mark.asyncio
async def test_context_manager(self):
"""测试异步上下文管理器"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
async with LLMGenerator() as generator:
assert generator.session is not None
assert generator._session_managed is False
# 退出上下文后session应该被关闭
assert generator.session is None
@pytest.mark.asyncio
async def test_generate_specification_success(self, sample_request, mock_api_response):
"""测试成功的规范生成"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
with patch('aiohttp.ClientSession.post') as mock_post:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=mock_api_response)
mock_post.return_value.__aenter__.return_value = mock_response
async with LLMGenerator() as generator:
result = await generator.generate_specification(sample_request)
assert result.specification is not None
assert result.specification != ""
assert result.tokens_used == 150
assert result.generation_time > 0
assert result.quality_score >= 0.0
assert result.metadata['function_name'] == 'add'
@pytest.mark.asyncio
async def test_generate_specification_api_error(self, sample_request):
"""测试API错误处理"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
with patch('aiohttp.ClientSession.post') as mock_post:
mock_response = AsyncMock()
mock_response.status = 500
mock_response.text = AsyncMock(return_value="Internal Server Error")
mock_post.return_value.__aenter__.return_value = mock_response
async with LLMGenerator() as generator:
with pytest.raises(APIError):
await generator.generate_specification(sample_request)
@pytest.mark.asyncio
async def test_generate_specification_rate_limit(self, sample_request):
"""测试速率限制处理"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
with patch('aiohttp.ClientSession.post') as mock_post:
mock_response = AsyncMock()
mock_response.status = 429
mock_post.return_value.__aenter__.return_value = mock_response
async with LLMGenerator() as generator:
with pytest.raises(RateLimitError):
await generator.generate_specification(sample_request)
@pytest.mark.asyncio
async def test_retry_mechanism(self, sample_request, mock_api_response):
"""测试重试机制"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
call_count = 0
async def mock_post(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count < 3: # 前两次失败
mock_response = AsyncMock()
mock_response.status = 500
mock_post.return_value.__aenter__.return_value = mock_response
return mock_response
else: # 第三次成功
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=mock_api_response)
mock_post.return_value.__aenter__.return_value = mock_response
return mock_response
with patch('aiohttp.ClientSession.post', side_effect=mock_post):
async with LLMGenerator() as generator:
result = await generator.generate_specification(sample_request)
assert result.specification is not None
assert call_count == 3 # 应该重试了3次
@pytest.mark.asyncio
async def test_retry_exhausted(self, sample_request):
"""测试重试次数耗尽"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
with patch('aiohttp.ClientSession.post') as mock_post:
mock_response = AsyncMock()
mock_response.status = 500
mock_post.return_value.__aenter__.return_value = mock_response
async with LLMGenerator() as generator:
request = GenerationRequest(
function_name='test',
function_info={'name': 'test'},
verification_goals=['test'],
max_retries=1 # 只重试1次
)
with pytest.raises(RetryExhaustedError):
await generator.generate_specification(request)
@pytest.mark.asyncio
async def test_stream_specification(self, sample_request, mock_stream_response_chunks):
"""测试流式规范生成"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
with patch('aiohttp.ClientSession.post') as mock_post:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.content.iter_any = AsyncMock(return_value=mock_stream_response_chunks)
mock_post.return_value.__aenter__.return_value = mock_response
async with LLMGenerator() as generator:
chunks = []
async for chunk in generator.stream_specification(sample_request):
chunks.append(chunk)
assert len(chunks) > 0
full_spec = ''.join(chunks)
assert '\\requires' in full_spec
assert '\\ensures' in full_spec
@pytest.mark.asyncio
async def test_batch_specifications(self, sample_request, mock_api_response):
"""测试批量规范生成"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
with patch('aiohttp.ClientSession.post') as mock_post:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=mock_api_response)
mock_post.return_value.__aenter__.return_value = mock_response
async with LLMGenerator() as generator:
requests = [sample_request, sample_request] # 两个相同的请求
results = await generator.generate_batch_specifications(requests)
assert len(results) == 2
for result in results:
assert result.specification is not None
assert result.tokens_used > 0
@pytest.mark.asyncio
async def test_refine_specification(self, sample_request, mock_api_response):
"""测试规范细化功能"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
with patch('aiohttp.ClientSession.post') as mock_post:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=mock_api_response)
mock_post.return_value.__aenter__.return_value = mock_response
async with LLMGenerator() as generator:
original_spec = "\\requires a >= 0;\\n\\ensures return == a + b;"
issues = ["Missing parameter validation", "Insufficient coverage"]
suggestions = ["Add null checks", "Include overflow checks"]
result = await generator.refine_specification(
original_spec, sample_request.function_info, issues, suggestions
)
assert result.specification is not None
assert result.tokens_used > 0
assert result.quality_score >= 0.0
@pytest.mark.asyncio
async def test_health_check(self):
"""测试健康检查功能"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
with patch('aiohttp.ClientSession.post') as mock_post:
mock_response = AsyncMock()
mock_response.status = 200
mock_post.return_value.__aenter__.return_value = mock_response
async with LLMGenerator() as generator:
health = await generator.health_check()
assert health['status'] == 'healthy'
assert 'api_response_time' in health
assert 'timestamp' in health
@pytest.mark.asyncio
async def test_health_check_failure(self):
"""测试健康检查失败情况"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
with patch('aiohttp.ClientSession.post') as mock_post:
mock_response = AsyncMock()
mock_response.status = 500
mock_post.return_value.__aenter__.return_value = mock_response
async with LLMGenerator() as generator:
health = await generator.health_check()
assert health['status'] == 'unhealthy'
def test_post_process_spec(self):
"""测试规范后处理功能"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
generator = LLMGenerator()
# 测试清理Markdown格式
markdown_spec = """
```c
\\requires a >= 0;
\\ensures return == a + b;
```
"""
cleaned = generator._post_process_spec(markdown_spec)
assert '\\requires' in cleaned
assert '\\ensures' in cleaned
assert '```' not in cleaned
# 测试空输入
empty_result = generator._post_process_spec("")
assert empty_result == ""
# 测试None输入
none_result = generator._post_process_spec(None)
assert none_result == ""
def test_generation_stats(self):
"""测试生成统计功能"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
generator = LLMGenerator()
# 模拟存储一些规范
with patch.object(generator.storage, 'list_specifications', return_value=[
{
'metadata': {
'quality_score': 0.8,
'model': 'test-model'
}
},
{
'metadata': {
'quality_score': 0.9,
'model': 'test-model'
}
}
]):
stats = generator.get_generation_stats()
assert stats['total_specifications'] == 2
assert stats['average_quality'] == 0.85
assert 'quality_distribution' in stats
assert 'model_usage' in stats
def test_synchronous_wrapper(self, sample_request):
"""测试同步包装器函数"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
with patch('aiohttp.ClientSession.post') as mock_post:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
"choices": [{"message": {"content": "test spec"}}],
"usage": {"total_tokens": 100}
})
mock_post.return_value.__aenter__.return_value = mock_response
# 测试同步函数
result = generate_specification_sync(sample_request)
assert result.specification == "test spec"
assert result.tokens_used == 100
def test_batch_synchronous_wrapper(self, sample_request):
"""测试批量同步包装器函数"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
with patch('aiohttp.ClientSession.post') as mock_post:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
"choices": [{"message": {"content": "test spec"}}],
"usage": {"total_tokens": 100}
})
mock_post.return_value.__aenter__.return_value = mock_response
# 测试批量同步函数
requests = [sample_request, sample_request]
results = generate_batch_specifications_sync(requests)
assert len(results) == 2
for result in results:
assert result.specification == "test spec"
assert result.tokens_used == 100
@pytest.mark.asyncio
async def test_timeout_handling(self, sample_request):
"""测试超时处理"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
with patch('aiohttp.ClientSession.post') as mock_post:
# 模拟超时
mock_post.side_effect = asyncio.TimeoutError()
async with LLMGenerator() as generator:
with pytest.raises(TimeoutError):
await generator.generate_specification(sample_request)
@pytest.mark.asyncio
async def test_network_error_handling(self, sample_request):
"""测试网络错误处理"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
with patch('aiohttp.ClientSession.post') as mock_post:
# 模拟网络错误
mock_post.side_effect = aiohttp.ClientError("Network error")
async with LLMGenerator() as generator:
with pytest.raises(APIError):
await generator.generate_specification(sample_request)
@pytest.mark.asyncio
async def test_invalid_json_response(self, sample_request):
"""测试无效JSON响应处理"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
with patch('aiohttp.ClientSession.post') as mock_post:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(side_effect=json.JSONDecodeError("Invalid JSON", "", 0))
mock_post.return_value.__aenter__.return_value = mock_response
async with LLMGenerator() as generator:
with pytest.raises(APIError):
await generator.generate_specification(sample_request)
@pytest.mark.asyncio
async def test_missing_choices_in_response(self, sample_request):
"""测试响应中缺少choices字段"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
with patch('aiohttp.ClientSession.post') as mock_post:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={}) # 空响应
mock_post.return_value.__aenter__.return_value = mock_response
async with LLMGenerator() as generator:
with pytest.raises(APIError):
await generator.generate_specification(sample_request)
def test_session_management(self):
"""测试会话管理"""
with patch.dict(os.environ, {'SILICONFLOW_API_KEY': 'test_key'}):
generator = LLMGenerator()
# 测试自动创建会话
assert generator.session is None
generator._ensure_session()
assert generator.session is not None
assert generator._session_managed is True
# 测试清理会话
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(generator._cleanup_session())
assert generator.session is None
loop.close()
if __name__ == "__main__":
pytest.main([__file__, "-v"])