You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
cbmc/codedetect/tests/integration/test_parser_llm_pipeline.py

541 lines
17 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
解析器到LLM管道集成测试
本模块实现了完整的管道集成测试验证从C/C++源代码解析到CBMC规范生成的
整个流程。包括不同验证目标、错误处理场景和边缘情况的测试。
"""
import pytest
import asyncio
import tempfile
import os
from pathlib import Path
from typing import List, Dict, Any, Optional
from src.parser.c_parser import CParser, CParserFactory
from src.parser.ast_extractor import ASTExtractor
from src.spec.llm_generator import LLMGenerator, GenerationRequest, GenerationResult
from utils.cbmc_spec_validator import CBMCSpecificationValidator
from src.utils.logger import get_logger
from test_data.simple_c_examples import (
get_basic_test_suite,
get_test_cases_by_complexity,
get_test_cases_by_category
)
from test_llm_generation import TestRunner, TestCase
logger = get_logger(__name__)
class TestParserLLMPipeline:
"""解析器到LLM管道集成测试类"""
@pytest.fixture
def parser(self):
"""创建C解析器实例"""
return CParserFactory.create_parser()
@pytest.fixture
def ast_extractor(self):
"""创建AST提取器实例"""
return ASTExtractor()
@pytest.fixture
def llm_generator(self):
"""创建LLM生成器实例"""
return LLMGenerator()
@pytest.fixture
def validator(self):
"""创建规范验证器实例"""
return CBMCSpecificationValidator()
@pytest.fixture
def test_runner(self):
"""创建测试运行器实例"""
return TestRunner()
@pytest.fixture
def basic_test_cases(self):
"""获取基础测试用例"""
return get_basic_test_suite()
@pytest.fixture
def temp_c_file(self):
"""创建临时C文件"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.c', delete=False) as f:
test_code = """
int test_function(int x) {
return x * 2;
}
"""
f.write(test_code)
temp_path = f.name
yield temp_path
# 清理
os.unlink(temp_path)
@pytest.mark.asyncio
async def test_basic_pipeline_integration(self, parser, ast_extractor, llm_generator, validator, temp_c_file):
"""测试基本的管道集成流程"""
# 1. 解析C文件
ast = parser.parse_file(temp_c_file)
assert ast is not None
assert parser.validate_ast(ast)
# 2. 提取元数据
metadata = ast_extractor.extract_metadata(ast)
assert len(metadata.functions) > 0
# 3. 获取第一个函数
function_name = list(metadata.functions.keys())[0]
function_info = metadata.functions[function_name].to_dict()
# 4. 创建生成请求
request = GenerationRequest(
function_name=function_name,
function_info=function_info,
verification_goals=['functional_correctness', 'memory_safety'],
max_retries=2,
validate=True,
store=False
)
# 5. 生成规范
result = await llm_generator.generate_specification(request)
assert result.specification is not None
assert result.specification != ""
assert result.generation_time > 0
assert result.tokens_used > 0
# 6. 验证规范
validation_result = validator.validate_specification(result.specification, function_info)
assert validation_result.is_valid or len(validation_result.warnings) > 0
logger.info(f"Pipeline integration test passed for {function_name}")
@pytest.mark.asyncio
async def test_multiple_verification_goals(self, parser, ast_extractor, llm_generator):
"""测试多个验证目标"""
test_code = """
#include <stddef.h>
int safe_array_access(int *array, size_t size, size_t index) {
if (array == NULL || size == 0) {
return -1;
}
if (index >= size) {
return -2;
}
return array[index];
}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.c', delete=False) as f:
f.write(test_code)
temp_path = f.name
try:
# 解析和提取元数据
ast = parser.parse_file(temp_path)
metadata = ast_extractor.extract_metadata(ast)
function_name = list(metadata.functions.keys())[0]
function_info = metadata.functions[function_name].to_dict()
# 测试不同的验证目标组合
test_cases = [
['memory_safety'],
['functional_correctness'],
['memory_safety', 'functional_correctness'],
['memory_safety', 'buffer_overflow', 'null_pointer']
]
for goals in test_cases:
request = GenerationRequest(
function_name=function_name,
function_info=function_info,
verification_goals=goals,
max_retries=2,
validate=True,
store=False
)
result = await llm_generator.generate_specification(request)
assert result.specification is not None
assert result.quality_score >= 0.0
# 验证规范中包含相关目标的关键词
spec_text = result.specification.lower()
for goal in goals:
goal_keywords = {
'memory_safety': ['valid', 'memory'],
'functional_correctness': ['return', 'ensures'],
'buffer_overflow': ['bounds', 'size'],
'null_pointer': ['null', '\\0']
}
if goal in goal_keywords:
found = any(keyword in spec_text for keyword in goal_keywords[goal])
# 不严格要求因为LLM可能使用不同的表达方式
if not found:
logger.warning(f"Goal {goal} keywords not found in specification")
finally:
os.unlink(temp_path)
@pytest.mark.asyncio
async def test_error_handling_scenarios(self, parser, ast_extractor, llm_generator):
"""测试错误处理场景"""
# 测试无效文件
with pytest.raises(Exception):
parser.parse_file("nonexistent_file.c")
# 测试空文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.c', delete=False) as f:
f.write("")
temp_path = f.name
try:
ast = parser.parse_file(temp_path)
metadata = ast_extractor.extract_metadata(ast)
# 空文件应该没有函数
assert len(metadata.functions) == 0
finally:
os.unlink(temp_path)
@pytest.mark.asyncio
async def test_complexity_scaling(self, parser, ast_extractor, llm_generator):
"""测试不同复杂度函数的处理"""
test_cases = [
# 简单函数
"""
int simple_add(int a, int b) {
return a + b;
}
""",
# 中等复杂度函数
"""
int array_sum(int *arr, int size) {
if (arr == NULL || size <= 0) return 0;
int sum = 0;
for (int i = 0; i < size; i++) {
sum += arr[i];
}
return sum;
}
""",
# 复杂函数
"""
#include <stdlib.h>
typedef struct Node {
int data;
struct Node *next;
} Node;
Node* create_list(int size) {
if (size <= 0) return NULL;
Node *head = (Node*)malloc(sizeof(Node));
if (head == NULL) return NULL;
head->data = 0;
head->next = NULL;
Node *current = head;
for (int i = 1; i < size; i++) {
current->next = (Node*)malloc(sizeof(Node));
if (current->next == NULL) {
// 内存分配失败,清理已分配的内存
while (head != NULL) {
Node *temp = head;
head = head->next;
free(temp);
}
return NULL;
}
current->next->data = i;
current->next->next = NULL;
current = current->next;
}
return head;
}
"""
]
for i, code in enumerate(test_cases):
with tempfile.NamedTemporaryFile(mode='w', suffix='.c', delete=False) as f:
f.write(code)
temp_path = f.name
try:
start_time = asyncio.get_event_loop().time()
# 解析
ast = parser.parse_file(temp_path)
metadata = ast_extractor.extract_metadata(ast)
if len(metadata.functions) > 0:
function_name = list(metadata.functions.keys())[0]
function_info = metadata.functions[function_name].to_dict()
# 生成规范
request = GenerationRequest(
function_name=function_name,
function_info=function_info,
verification_goals=['memory_safety', 'functional_correctness'],
max_retries=2,
validate=True,
store=False
)
result = await llm_generator.generate_specification(request)
end_time = asyncio.get_event_loop().time()
duration = end_time - start_time
assert result.specification is not None
assert result.generation_time > 0
assert result.quality_score >= 0.0
# 记录性能
logger.info(f"Complexity test {i}: duration={duration:.2f}s, quality={result.quality_score:.2f}")
finally:
os.unlink(temp_path)
@pytest.mark.asyncio
async def test_batch_processing(self, parser, ast_extractor, llm_generator):
"""测试批量处理功能"""
# 创建多个测试文件
test_files = []
for i in range(3):
with tempfile.NamedTemporaryFile(mode='w', suffix='.c', delete=False) as f:
test_code = f"""
int test_function_{i}(int x) {{
return x + {i};
}}
"""
f.write(test_code)
test_files.append(f.name)
try:
# 解析所有文件
requests = []
for temp_path in test_files:
ast = parser.parse_file(temp_path)
metadata = ast_extractor.extract_metadata(ast)
if len(metadata.functions) > 0:
function_name = list(metadata.functions.keys())[0]
function_info = metadata.functions[function_name].to_dict()
request = GenerationRequest(
function_name=function_name,
function_info=function_info,
verification_goals=['functional_correctness'],
max_retries=2,
validate=True,
store=False
)
requests.append(request)
# 批量生成规范
results = await llm_generator.generate_batch_specifications(requests)
# 验证结果
assert len(results) == len(requests)
for result in results:
assert result.specification is not None
assert result.quality_score >= 0.0
logger.info(f"Batch processing test passed: {len(results)} functions processed")
finally:
for temp_path in test_files:
os.unlink(temp_path)
@pytest.mark.asyncio
async def test_edge_cases(self, parser, ast_extractor, llm_generator):
"""测试边缘情况"""
# 测试包含特殊字符的代码
edge_case_code = """
int special_chars(int *ptr, char ch) {
if (ptr == NULL) return -1;
switch (ch) {
case '\\\\': return 1;
case '\\n': return 2;
case '\\t': return 3;
default: return 0;
}
}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.c', delete=False) as f:
f.write(edge_case_code)
temp_path = f.name
try:
ast = parser.parse_file(temp_path)
metadata = ast_extractor.extract_metadata(ast)
if len(metadata.functions) > 0:
function_name = list(metadata.functions.keys())[0]
function_info = metadata.functions[function_name].to_dict()
request = GenerationRequest(
function_name=function_name,
function_info=function_info,
verification_goals=['memory_safety', 'functional_correctness'],
max_retries=2,
validate=True,
store=False
)
result = await llm_generator.generate_specification(request)
assert result.specification is not None
finally:
os.unlink(temp_path)
@pytest.mark.asyncio
async def test_with_test_runner(self, test_runner, basic_test_cases):
"""使用测试运行器进行集成测试"""
results = await test_runner.run_batch_tests(basic_test_cases, parallel=False)
# 验证结果
assert len(results) > 0
successful_results = [r for r in results if r.success]
assert len(successful_results) > 0
# 生成报告
report = test_runner.generate_report()
assert report['summary']['total_tests'] > 0
assert report['summary']['success_rate'] >= 0.0
logger.info(f"Integration test with TestRunner: {len(successful_results)}/{len(results)} passed")
@pytest.mark.asyncio
async def test_performance_metrics(self, parser, ast_extractor, llm_generator):
"""测试性能指标收集"""
test_code = """
int performance_test(int *array, int size) {
if (array == NULL || size <= 0) return 0;
int sum = 0;
for (int i = 0; i < size; i++) {
sum += array[i];
}
return sum;
}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.c', delete=False) as f:
f.write(test_code)
temp_path = f.name
try:
ast = parser.parse_file(temp_path)
metadata = ast_extractor.extract_metadata(ast)
function_name = list(metadata.functions.keys())[0]
function_info = metadata.functions[function_name].to_dict()
request = GenerationRequest(
function_name=function_name,
function_info=function_info,
verification_goals=['memory_safety', 'functional_correctness'],
max_retries=2,
validate=True,
store=False
)
result = await llm_generator.generate_specification(request)
# 验证性能指标
assert result.generation_time > 0
assert result.tokens_used > 0
assert result.quality_score >= 0.0
# 计算性能指标
tokens_per_second = result.tokens_used / result.generation_time
assert tokens_per_second > 0
logger.info(f"Performance metrics: {result.generation_time:.2f}s, "
f"{result.tokens_used} tokens, {tokens_per_second:.1f} tokens/s")
finally:
os.unlink(temp_path)
@pytest.mark.asyncio
async def test_retry_mechanism(self, parser, ast_extractor, llm_generator):
"""测试重试机制"""
test_code = """
int retry_test(int x) {
return x * 2;
}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.c', delete=False) as f:
f.write(test_code)
temp_path = f.name
try:
ast = parser.parse_file(temp_path)
metadata = ast_extractor.extract_metadata(ast)
function_name = list(metadata.functions.keys())[0]
function_info = metadata.functions[function_name].to_dict()
# 使用较少的重试次数来测试
request = GenerationRequest(
function_name=function_name,
function_info=function_info,
verification_goals=['functional_correctness'],
max_retries=1,
validate=True,
store=False
)
result = await llm_generator.generate_specification(request)
assert result.specification is not None
finally:
os.unlink(temp_path)
def test_ast_validation(self, parser):
"""测试AST验证功能"""
# 创建有效的C代码
valid_code = """
int valid_function(int a, int b) {
return a + b;
}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.c', delete=False) as f:
f.write(valid_code)
temp_path = f.name
try:
ast = parser.parse_file(temp_path)
assert parser.validate_ast(ast)
# 测试AST信息获取
ast_info = parser.get_ast_info(ast)
assert 'node_count' in ast_info
assert 'functions' in ast_info
assert len(ast_info['functions']) > 0
finally:
os.unlink(temp_path)
if __name__ == "__main__":
# 运行集成测试
pytest.main([__file__, "-v", "--tb=short"])