You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
cbmc/codedetect/tests/test_llm_generation.py

611 lines
22 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
LLM Generation测试框架
本模块为LLM生成阶段提供全面的测试框架包括测试用例管理、结果验证、
性能测量和报告功能。支持自动化和交互式测试模式。
"""
import os
import json
import time
import asyncio
import logging
from pathlib import Path
from typing import Dict, List, Any, Optional, Union, Tuple, Callable
from dataclasses import dataclass, field
from datetime import datetime
import tempfile
import shutil
import pytest
from pytest import TestReport
import aiohttp
import yaml
from src.spec.llm_generator import LLMGenerator, GenerationRequest, GenerationResult
from src.parser.c_parser import CParser, CParserFactory
from src.parser.ast_extractor import ASTExtractor
from src.utils.logger import get_logger
from src.utils.config import get_config
@dataclass
class TestCase:
"""测试用例数据类"""
name: str
description: str
source_code: str
file_path: Optional[str] = None
expected_metadata: Dict[str, Any] = field(default_factory=dict)
verification_goals: List[str] = field(default_factory=list)
hints: List[str] = field(default_factory=list)
expected_spec_patterns: List[str] = field(default_factory=list)
complexity_level: str = "basic" # basic, intermediate, advanced
timeout: int = 120
retry_count: int = 3
category: str = "function"
tags: List[str] = field(default_factory=list)
@dataclass
class TestResult:
"""测试结果数据类"""
test_case: TestCase
start_time: float
end_time: float
success: bool
specification: Optional[str] = None
generation_result: Optional[GenerationResult] = None
metadata: Dict[str, Any] = field(default_factory=dict)
validation_result: Dict[str, Any] = field(default_factory=dict)
error_message: Optional[str] = None
performance_metrics: Dict[str, Any] = field(default_factory=dict)
quality_score: float = 0.0
tokens_used: int = 0
retry_attempts: int = 0
class TestRunner:
"""LLM生成测试运行器"""
def __init__(self, config_path: Optional[str] = None):
self.config = self._load_config(config_path)
self.logger = get_logger('test_runner')
self.test_results: List[TestResult] = []
self.temp_dir = tempfile.mkdtemp(prefix='llm_test_')
# 初始化组件
self.parser = CParserFactory.create_parser(self.config.get('parser', {}))
self.ast_extractor = ASTExtractor()
self.llm_generator = None
self.logger.info(f"TestRunner initialized with temp dir: {self.temp_dir}")
def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
"""加载测试配置"""
if config_path and os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
else:
# 使用默认配置
base_config = get_config()
test_config = {
'test': {
'timeout': 120,
'max_retries': 3,
'parallel_workers': 3,
'save_intermediate': True,
'validate_specs': True,
'quality_threshold': 0.7
},
'parser': base_config.get('parser', {}),
'llm': base_config.get('llm', {})
}
return test_config
async def run_single_test(self, test_case: TestCase) -> TestResult:
"""运行单个测试用例"""
start_time = time.time()
result = TestResult(test_case=test_case, start_time=start_time, end_time=0, success=False)
self.logger.info(f"Running test: {test_case.name}")
try:
# 准备源代码文件
source_file = self._prepare_source_file(test_case)
# 解析源代码
ast = self.parser.parse_file(source_file)
metadata = self.ast_extractor.extract_metadata(ast)
result.metadata = {
'ast_info': self.parser.get_ast_info(ast),
'functions_count': len(metadata.functions),
'variables_count': len(metadata.variables),
'total_calls': sum(len(func.function_calls) for func in metadata.functions.values())
}
# 验证元数据
if test_case.expected_metadata:
self._validate_metadata(metadata, test_case.expected_metadata)
# 选择第一个函数进行测试
if not metadata.functions:
raise ValueError("No functions found in source code")
function_name = list(metadata.functions.keys())[0]
function_info = metadata.functions[function_name].to_dict()
# 创建生成请求
request = GenerationRequest(
function_name=function_name,
function_info=function_info,
verification_goals=test_case.verification_goals or [
'memory_safety',
'functional_correctness',
'buffer_overflow'
],
hints=test_case.hints,
max_retries=test_case.retry_count,
validate=True,
store=False
)
# 生成规范
if not self.llm_generator:
self.llm_generator = LLMGenerator()
generation_result = await self.llm_generator.generate_specification(request)
result.generation_result = generation_result
result.specification = generation_result.specification
result.quality_score = generation_result.quality_score
result.tokens_used = generation_result.tokens_used
# 验证生成的规范
if test_case.expected_spec_patterns:
self._validate_specification(
generation_result.specification,
test_case.expected_spec_patterns
)
# 验证质量分数
if generation_result.quality_score < self.config['test']['quality_threshold']:
raise ValueError(
f"Quality score {generation_result.quality_score} below threshold "
f"{self.config['test']['quality_threshold']}"
)
result.success = True
self.logger.info(f"Test passed: {test_case.name} (quality: {generation_result.quality_score:.2f})")
except Exception as e:
result.success = False
result.error_message = str(e)
self.logger.error(f"Test failed: {test_case.name} - {str(e)}")
finally:
result.end_time = time.time()
result.performance_metrics = {
'duration': result.end_time - result.start_time,
'tokens_per_second': result.tokens_used / max(result.end_time - result.start_time, 0.001)
}
# 清理临时文件(只删除在临时目录中创建的文件)
if 'source_file' in locals() and source_file.startswith(self.temp_dir):
try:
os.unlink(source_file)
except:
pass
self.test_results.append(result)
return result
async def run_batch_tests(self, test_cases: List[TestCase],
parallel: bool = True) -> List[TestResult]:
"""批量运行测试用例"""
self.logger.info(f"Starting batch test run with {len(test_cases)} cases")
if parallel:
# 并行执行
semaphore = asyncio.Semaphore(self.config['test']['parallel_workers'])
async def run_with_semaphore(test_case: TestCase):
async with semaphore:
return await self.run_single_test(test_case)
tasks = [run_with_semaphore(tc) for tc in test_cases]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理异常结果
test_results = []
for result in results:
if isinstance(result, Exception):
self.logger.error(f"Batch test exception: {str(result)}")
else:
test_results.append(result)
return test_results
else:
# 串行执行
results = []
for test_case in test_cases:
result = await self.run_single_test(test_case)
results.append(result)
return results
def _prepare_source_file(self, test_case: TestCase) -> str:
"""准备源代码文件"""
if test_case.file_path and os.path.exists(test_case.file_path):
return test_case.file_path
# 创建临时文件
temp_file = os.path.join(self.temp_dir, f"{test_case.name}.c")
with open(temp_file, 'w', encoding='utf-8') as f:
f.write(test_case.source_code)
return temp_file
def _validate_metadata(self, metadata, expected_metadata: Dict[str, Any]):
"""验证提取的元数据"""
errors = []
for key, expected_value in expected_metadata.items():
actual_value = getattr(metadata, key, None)
if key == 'functions_count':
actual_value = len(metadata.functions)
elif key == 'variables_count':
actual_value = len(metadata.variables)
if actual_value != expected_value:
errors.append(f"Metadata validation failed: {key} expected {expected_value}, got {actual_value}")
if errors:
raise ValueError("; ".join(errors))
def _validate_specification(self, specification: str, expected_patterns: List[str]):
"""验证生成的规范包含预期模式"""
errors = []
for pattern in expected_patterns:
if pattern not in specification:
errors.append(f"Missing expected pattern: {pattern}")
if errors:
raise ValueError("Specification validation failed: " + "; ".join(errors))
def generate_report(self) -> Dict[str, Any]:
"""生成测试报告"""
if not self.test_results:
return {'error': 'No test results available'}
total_tests = len(self.test_results)
passed_tests = sum(1 for r in self.test_results if r.success)
failed_tests = total_tests - passed_tests
# 性能统计
durations = [r.performance_metrics['duration'] for r in self.test_results]
avg_duration = sum(durations) / len(durations) if durations else 0
# 质量统计
quality_scores = [r.quality_score for r in self.test_results if r.success]
avg_quality = sum(quality_scores) / len(quality_scores) if quality_scores else 0
# Token统计
total_tokens = sum(r.tokens_used for r in self.test_results)
# 按复杂度分组
complexity_stats = {}
for result in self.test_results:
complexity = result.test_case.complexity_level
if complexity not in complexity_stats:
complexity_stats[complexity] = {'total': 0, 'passed': 0}
complexity_stats[complexity]['total'] += 1
if result.success:
complexity_stats[complexity]['passed'] += 1
# 失败用例详情
failed_cases = []
for result in self.test_results:
if not result.success:
failed_cases.append({
'name': result.test_case.name,
'error': result.error_message,
'duration': result.performance_metrics['duration']
})
report = {
'summary': {
'total_tests': total_tests,
'passed_tests': passed_tests,
'failed_tests': failed_tests,
'success_rate': passed_tests / total_tests if total_tests > 0 else 0,
'average_duration': avg_duration,
'average_quality': avg_quality,
'total_tokens': total_tokens,
'timestamp': datetime.now().isoformat()
},
'complexity_breakdown': complexity_stats,
'failed_cases': failed_cases,
'performance_metrics': {
'min_duration': min(durations) if durations else 0,
'max_duration': max(durations) if durations else 0,
'total_duration': sum(durations)
},
'quality_metrics': {
'min_quality': min(quality_scores) if quality_scores else 0,
'max_quality': max(quality_scores) if quality_scores else 0,
'quality_distribution': {
'excellent': len([s for s in quality_scores if s >= 0.9]),
'good': len([s for s in quality_scores if 0.7 <= s < 0.9]),
'fair': len([s for s in quality_scores if 0.5 <= s < 0.7]),
'poor': len([s for s in quality_scores if s < 0.5])
}
}
}
return report
def save_report(self, report: Dict[str, Any], output_path: str):
"""保存测试报告"""
if not output_path or not output_path.strip():
self.logger.error("Output path cannot be empty")
return
if not report:
self.logger.error("Report data cannot be empty")
return
# 确保路径是绝对路径
output_path = os.path.abspath(output_path)
try:
# 创建父目录
parent_dir = os.path.dirname(output_path)
if parent_dir: # 只有当有父目录时才创建
os.makedirs(parent_dir, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(report, f, indent=2, ensure_ascii=False)
self.logger.info(f"Test report saved to: {output_path}")
except Exception as e:
self.logger.error(f"Failed to save report to {output_path}: {e}")
def save_detailed_results(self, output_dir: str):
"""保存详细的测试结果"""
if not output_dir or not output_dir.strip():
self.logger.error("Output directory cannot be empty")
return
# 确保路径是绝对路径
output_dir = os.path.abspath(output_dir)
try:
os.makedirs(output_dir, exist_ok=True)
except Exception as e:
self.logger.error(f"Failed to create output directory {output_dir}: {e}")
return
for i, result in enumerate(self.test_results):
result_file = os.path.join(output_dir, f"test_{i:03d}_{result.test_case.name}.json")
result_data = {
'test_case': {
'name': result.test_case.name,
'description': result.test_case.description,
'complexity_level': result.test_case.complexity_level,
'category': result.test_case.category,
'tags': result.test_case.tags
},
'execution': {
'success': result.success,
'start_time': result.start_time,
'end_time': result.end_time,
'duration': result.end_time - result.start_time,
'error_message': result.error_message
},
'metadata': result.metadata,
'validation_result': result.validation_result,
'performance_metrics': result.performance_metrics,
'quality_score': result.quality_score,
'tokens_used': result.tokens_used
}
if result.specification:
result_data['specification'] = result.specification
try:
with open(result_file, 'w', encoding='utf-8') as f:
json.dump(result_data, f, indent=2, ensure_ascii=False)
except Exception as e:
self.logger.error(f"Failed to save result file {result_file}: {e}")
continue
self.logger.info(f"Detailed results saved to: {output_dir} ({len(self.test_results)} files)")
def cleanup(self):
"""清理临时文件"""
if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)
self.logger.info(f"Cleaned up temp directory: {self.temp_dir}")
class ResultValidator:
"""测试结果验证器"""
@staticmethod
def validate_specification_syntax(spec: str) -> Dict[str, Any]:
"""验证CBMC规范语法"""
errors = []
warnings = []
# 基本的CBMC语法检查
required_clauses = ['requires', 'ensures']
found_clauses = []
lines = spec.split('\n')
for line in lines:
line = line.strip()
if line.startswith('\\requires'):
found_clauses.append('requires')
elif line.startswith('\\ensures'):
found_clauses.append('ensures')
elif line.startswith('\\assigns'):
found_clauses.append('assigns')
# 检查必需的子句
for clause in required_clauses:
if clause not in found_clauses:
errors.append(f"Missing required clause: {clause}")
# 检查语法格式
for i, line in enumerate(lines):
line = line.strip()
if line and not line.startswith('\\') and not line.startswith('//'):
warnings.append(f"Line {i+1}: Invalid CBMC clause format")
return {
'is_valid': len(errors) == 0,
'errors': errors,
'warnings': warnings,
'found_clauses': found_clauses
}
@staticmethod
def validate_quality_threshold(quality_score: float, threshold: float) -> bool:
"""验证质量分数是否达到阈值"""
return quality_score >= threshold
@staticmethod
def validate_performance_metrics(metrics: Dict[str, Any],
max_duration: float = 300.0) -> Dict[str, Any]:
"""验证性能指标"""
errors = []
if metrics.get('duration', 0) > max_duration:
errors.append(f"Duration {metrics['duration']}s exceeds maximum {max_duration}s")
if metrics.get('tokens_per_second', 0) < 1.0:
warnings = ["Low token generation rate"]
else:
warnings = []
return {
'performance_ok': len(errors) == 0,
'errors': errors,
'warnings': warnings
}
class TestReporter:
"""测试报告生成器"""
def __init__(self, output_dir: str = "test_reports"):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True)
def generate_html_report(self, report: Dict[str, Any], test_results: List[TestResult]) -> str:
"""生成HTML格式报告"""
html_content = self._build_html_template(report, test_results)
report_path = self.output_dir / f"test_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.html"
with open(report_path, 'w', encoding='utf-8') as f:
f.write(html_content)
return str(report_path)
def _build_html_template(self, report: Dict[str, Any], test_results: List[TestResult]) -> str:
"""构建HTML报告模板"""
summary = report['summary']
html = f"""
<!DOCTYPE html>
<html>
<head>
<title>LLM Generation Test Report</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
.header {{ background-color: #f0f0f0; padding: 20px; border-radius: 5px; }}
.summary {{ margin: 20px 0; }}
.metric {{ display: inline-block; margin: 10px; padding: 10px;
background-color: #e8f4f8; border-radius: 5px; }}
.success {{ color: #28a745; }}
.failure {{ color: #dc3545; }}
.test-case {{ margin: 10px 0; padding: 10px; border: 1px solid #ddd; border-radius: 5px; }}
.failed {{ background-color: #f8d7da; }}
.passed {{ background-color: #d4edda; }}
</style>
</head>
<body>
<div class="header">
<h1>LLM Generation Test Report</h1>
<p>Generated: {summary['timestamp']}</p>
</div>
<div class="summary">
<h2>Summary</h2>
<div class="metric">
<strong>Total Tests:</strong> {summary['total_tests']}
</div>
<div class="metric">
<strong>Passed:</strong> <span class="success">{summary['passed_tests']}</span>
</div>
<div class="metric">
<strong>Failed:</strong> <span class="failure">{summary['failed_tests']}</span>
</div>
<div class="metric">
<strong>Success Rate:</strong> {summary['success_rate']:.1%}
</div>
<div class="metric">
<strong>Avg Quality:</strong> {summary['average_quality']:.2f}
</div>
<div class="metric">
<strong>Total Tokens:</strong> {summary['total_tokens']}
</div>
</div>
"""
# 添加失败的测试用例
if summary['failed_tests'] > 0:
html += """
<div class="failed-cases">
<h2>Failed Test Cases</h2>
"""
for failed_case in report['failed_cases']:
html += f"""
<div class="test-case failed">
<h3>{failed_case['name']}</h3>
<p><strong>Error:</strong> {failed_case['error']}</p>
<p><strong>Duration:</strong> {failed_case['duration']:.2f}s</p>
</div>
"""
html += "</div>"
html += """
</body>
</html>
"""
return html
# 便捷函数
async def run_llm_tests(test_cases: List[TestCase],
config_path: Optional[str] = None,
parallel: bool = True) -> Tuple[TestRunner, List[TestResult], Dict[str, Any]]:
"""运行LLM生成测试的便捷函数"""
runner = TestRunner(config_path)
try:
results = await runner.run_batch_tests(test_cases, parallel)
report = runner.generate_report()
return runner, results, report
finally:
# 注意不在这里清理runner因为调用者需要使用它来保存详细结果
pass