""" LLM Generation测试框架 本模块为LLM生成阶段提供全面的测试框架,包括测试用例管理、结果验证、 性能测量和报告功能。支持自动化和交互式测试模式。 """ import os import json import time import asyncio import logging from pathlib import Path from typing import Dict, List, Any, Optional, Union, Tuple, Callable from dataclasses import dataclass, field from datetime import datetime import tempfile import shutil import pytest from pytest import TestReport import aiohttp import yaml from src.spec.llm_generator import LLMGenerator, GenerationRequest, GenerationResult from src.parser.c_parser import CParser, CParserFactory from src.parser.ast_extractor import ASTExtractor from src.utils.logger import get_logger from src.utils.config import get_config @dataclass class TestCase: """测试用例数据类""" name: str description: str source_code: str file_path: Optional[str] = None expected_metadata: Dict[str, Any] = field(default_factory=dict) verification_goals: List[str] = field(default_factory=list) hints: List[str] = field(default_factory=list) expected_spec_patterns: List[str] = field(default_factory=list) complexity_level: str = "basic" # basic, intermediate, advanced timeout: int = 120 retry_count: int = 3 category: str = "function" tags: List[str] = field(default_factory=list) @dataclass class TestResult: """测试结果数据类""" test_case: TestCase start_time: float end_time: float success: bool specification: Optional[str] = None generation_result: Optional[GenerationResult] = None metadata: Dict[str, Any] = field(default_factory=dict) validation_result: Dict[str, Any] = field(default_factory=dict) error_message: Optional[str] = None performance_metrics: Dict[str, Any] = field(default_factory=dict) quality_score: float = 0.0 tokens_used: int = 0 retry_attempts: int = 0 class TestRunner: """LLM生成测试运行器""" def __init__(self, config_path: Optional[str] = None): self.config = self._load_config(config_path) self.logger = get_logger('test_runner') self.test_results: List[TestResult] = [] self.temp_dir = tempfile.mkdtemp(prefix='llm_test_') # 初始化组件 self.parser = CParserFactory.create_parser(self.config.get('parser', {})) self.ast_extractor = ASTExtractor() self.llm_generator = None self.logger.info(f"TestRunner initialized with temp dir: {self.temp_dir}") def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]: """加载测试配置""" if config_path and os.path.exists(config_path): with open(config_path, 'r', encoding='utf-8') as f: return yaml.safe_load(f) else: # 使用默认配置 base_config = get_config() test_config = { 'test': { 'timeout': 120, 'max_retries': 3, 'parallel_workers': 3, 'save_intermediate': True, 'validate_specs': True, 'quality_threshold': 0.7 }, 'parser': base_config.get('parser', {}), 'llm': base_config.get('llm', {}) } return test_config async def run_single_test(self, test_case: TestCase) -> TestResult: """运行单个测试用例""" start_time = time.time() result = TestResult(test_case=test_case, start_time=start_time, end_time=0, success=False) self.logger.info(f"Running test: {test_case.name}") try: # 准备源代码文件 source_file = self._prepare_source_file(test_case) # 解析源代码 ast = self.parser.parse_file(source_file) metadata = self.ast_extractor.extract_metadata(ast) result.metadata = { 'ast_info': self.parser.get_ast_info(ast), 'functions_count': len(metadata.functions), 'variables_count': len(metadata.variables), 'total_calls': sum(len(func.function_calls) for func in metadata.functions.values()) } # 验证元数据 if test_case.expected_metadata: self._validate_metadata(metadata, test_case.expected_metadata) # 选择第一个函数进行测试 if not metadata.functions: raise ValueError("No functions found in source code") function_name = list(metadata.functions.keys())[0] function_info = metadata.functions[function_name].to_dict() # 创建生成请求 request = GenerationRequest( function_name=function_name, function_info=function_info, verification_goals=test_case.verification_goals or [ 'memory_safety', 'functional_correctness', 'buffer_overflow' ], hints=test_case.hints, max_retries=test_case.retry_count, validate=True, store=False ) # 生成规范 if not self.llm_generator: self.llm_generator = LLMGenerator() generation_result = await self.llm_generator.generate_specification(request) result.generation_result = generation_result result.specification = generation_result.specification result.quality_score = generation_result.quality_score result.tokens_used = generation_result.tokens_used # 验证生成的规范 if test_case.expected_spec_patterns: self._validate_specification( generation_result.specification, test_case.expected_spec_patterns ) # 验证质量分数 if generation_result.quality_score < self.config['test']['quality_threshold']: raise ValueError( f"Quality score {generation_result.quality_score} below threshold " f"{self.config['test']['quality_threshold']}" ) result.success = True self.logger.info(f"Test passed: {test_case.name} (quality: {generation_result.quality_score:.2f})") except Exception as e: result.success = False result.error_message = str(e) self.logger.error(f"Test failed: {test_case.name} - {str(e)}") finally: result.end_time = time.time() result.performance_metrics = { 'duration': result.end_time - result.start_time, 'tokens_per_second': result.tokens_used / max(result.end_time - result.start_time, 0.001) } # 清理临时文件(只删除在临时目录中创建的文件) if 'source_file' in locals() and source_file.startswith(self.temp_dir): try: os.unlink(source_file) except: pass self.test_results.append(result) return result async def run_batch_tests(self, test_cases: List[TestCase], parallel: bool = True) -> List[TestResult]: """批量运行测试用例""" self.logger.info(f"Starting batch test run with {len(test_cases)} cases") if parallel: # 并行执行 semaphore = asyncio.Semaphore(self.config['test']['parallel_workers']) async def run_with_semaphore(test_case: TestCase): async with semaphore: return await self.run_single_test(test_case) tasks = [run_with_semaphore(tc) for tc in test_cases] results = await asyncio.gather(*tasks, return_exceptions=True) # 处理异常结果 test_results = [] for result in results: if isinstance(result, Exception): self.logger.error(f"Batch test exception: {str(result)}") else: test_results.append(result) return test_results else: # 串行执行 results = [] for test_case in test_cases: result = await self.run_single_test(test_case) results.append(result) return results def _prepare_source_file(self, test_case: TestCase) -> str: """准备源代码文件""" if test_case.file_path and os.path.exists(test_case.file_path): return test_case.file_path # 创建临时文件 temp_file = os.path.join(self.temp_dir, f"{test_case.name}.c") with open(temp_file, 'w', encoding='utf-8') as f: f.write(test_case.source_code) return temp_file def _validate_metadata(self, metadata, expected_metadata: Dict[str, Any]): """验证提取的元数据""" errors = [] for key, expected_value in expected_metadata.items(): actual_value = getattr(metadata, key, None) if key == 'functions_count': actual_value = len(metadata.functions) elif key == 'variables_count': actual_value = len(metadata.variables) if actual_value != expected_value: errors.append(f"Metadata validation failed: {key} expected {expected_value}, got {actual_value}") if errors: raise ValueError("; ".join(errors)) def _validate_specification(self, specification: str, expected_patterns: List[str]): """验证生成的规范包含预期模式""" errors = [] for pattern in expected_patterns: if pattern not in specification: errors.append(f"Missing expected pattern: {pattern}") if errors: raise ValueError("Specification validation failed: " + "; ".join(errors)) def generate_report(self) -> Dict[str, Any]: """生成测试报告""" if not self.test_results: return {'error': 'No test results available'} total_tests = len(self.test_results) passed_tests = sum(1 for r in self.test_results if r.success) failed_tests = total_tests - passed_tests # 性能统计 durations = [r.performance_metrics['duration'] for r in self.test_results] avg_duration = sum(durations) / len(durations) if durations else 0 # 质量统计 quality_scores = [r.quality_score for r in self.test_results if r.success] avg_quality = sum(quality_scores) / len(quality_scores) if quality_scores else 0 # Token统计 total_tokens = sum(r.tokens_used for r in self.test_results) # 按复杂度分组 complexity_stats = {} for result in self.test_results: complexity = result.test_case.complexity_level if complexity not in complexity_stats: complexity_stats[complexity] = {'total': 0, 'passed': 0} complexity_stats[complexity]['total'] += 1 if result.success: complexity_stats[complexity]['passed'] += 1 # 失败用例详情 failed_cases = [] for result in self.test_results: if not result.success: failed_cases.append({ 'name': result.test_case.name, 'error': result.error_message, 'duration': result.performance_metrics['duration'] }) report = { 'summary': { 'total_tests': total_tests, 'passed_tests': passed_tests, 'failed_tests': failed_tests, 'success_rate': passed_tests / total_tests if total_tests > 0 else 0, 'average_duration': avg_duration, 'average_quality': avg_quality, 'total_tokens': total_tokens, 'timestamp': datetime.now().isoformat() }, 'complexity_breakdown': complexity_stats, 'failed_cases': failed_cases, 'performance_metrics': { 'min_duration': min(durations) if durations else 0, 'max_duration': max(durations) if durations else 0, 'total_duration': sum(durations) }, 'quality_metrics': { 'min_quality': min(quality_scores) if quality_scores else 0, 'max_quality': max(quality_scores) if quality_scores else 0, 'quality_distribution': { 'excellent': len([s for s in quality_scores if s >= 0.9]), 'good': len([s for s in quality_scores if 0.7 <= s < 0.9]), 'fair': len([s for s in quality_scores if 0.5 <= s < 0.7]), 'poor': len([s for s in quality_scores if s < 0.5]) } } } return report def save_report(self, report: Dict[str, Any], output_path: str): """保存测试报告""" if not output_path or not output_path.strip(): self.logger.error("Output path cannot be empty") return if not report: self.logger.error("Report data cannot be empty") return # 确保路径是绝对路径 output_path = os.path.abspath(output_path) try: # 创建父目录 parent_dir = os.path.dirname(output_path) if parent_dir: # 只有当有父目录时才创建 os.makedirs(parent_dir, exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: json.dump(report, f, indent=2, ensure_ascii=False) self.logger.info(f"Test report saved to: {output_path}") except Exception as e: self.logger.error(f"Failed to save report to {output_path}: {e}") def save_detailed_results(self, output_dir: str): """保存详细的测试结果""" if not output_dir or not output_dir.strip(): self.logger.error("Output directory cannot be empty") return # 确保路径是绝对路径 output_dir = os.path.abspath(output_dir) try: os.makedirs(output_dir, exist_ok=True) except Exception as e: self.logger.error(f"Failed to create output directory {output_dir}: {e}") return for i, result in enumerate(self.test_results): result_file = os.path.join(output_dir, f"test_{i:03d}_{result.test_case.name}.json") result_data = { 'test_case': { 'name': result.test_case.name, 'description': result.test_case.description, 'complexity_level': result.test_case.complexity_level, 'category': result.test_case.category, 'tags': result.test_case.tags }, 'execution': { 'success': result.success, 'start_time': result.start_time, 'end_time': result.end_time, 'duration': result.end_time - result.start_time, 'error_message': result.error_message }, 'metadata': result.metadata, 'validation_result': result.validation_result, 'performance_metrics': result.performance_metrics, 'quality_score': result.quality_score, 'tokens_used': result.tokens_used } if result.specification: result_data['specification'] = result.specification try: with open(result_file, 'w', encoding='utf-8') as f: json.dump(result_data, f, indent=2, ensure_ascii=False) except Exception as e: self.logger.error(f"Failed to save result file {result_file}: {e}") continue self.logger.info(f"Detailed results saved to: {output_dir} ({len(self.test_results)} files)") def cleanup(self): """清理临时文件""" if os.path.exists(self.temp_dir): shutil.rmtree(self.temp_dir) self.logger.info(f"Cleaned up temp directory: {self.temp_dir}") class ResultValidator: """测试结果验证器""" @staticmethod def validate_specification_syntax(spec: str) -> Dict[str, Any]: """验证CBMC规范语法""" errors = [] warnings = [] # 基本的CBMC语法检查 required_clauses = ['requires', 'ensures'] found_clauses = [] lines = spec.split('\n') for line in lines: line = line.strip() if line.startswith('\\requires'): found_clauses.append('requires') elif line.startswith('\\ensures'): found_clauses.append('ensures') elif line.startswith('\\assigns'): found_clauses.append('assigns') # 检查必需的子句 for clause in required_clauses: if clause not in found_clauses: errors.append(f"Missing required clause: {clause}") # 检查语法格式 for i, line in enumerate(lines): line = line.strip() if line and not line.startswith('\\') and not line.startswith('//'): warnings.append(f"Line {i+1}: Invalid CBMC clause format") return { 'is_valid': len(errors) == 0, 'errors': errors, 'warnings': warnings, 'found_clauses': found_clauses } @staticmethod def validate_quality_threshold(quality_score: float, threshold: float) -> bool: """验证质量分数是否达到阈值""" return quality_score >= threshold @staticmethod def validate_performance_metrics(metrics: Dict[str, Any], max_duration: float = 300.0) -> Dict[str, Any]: """验证性能指标""" errors = [] if metrics.get('duration', 0) > max_duration: errors.append(f"Duration {metrics['duration']}s exceeds maximum {max_duration}s") if metrics.get('tokens_per_second', 0) < 1.0: warnings = ["Low token generation rate"] else: warnings = [] return { 'performance_ok': len(errors) == 0, 'errors': errors, 'warnings': warnings } class TestReporter: """测试报告生成器""" def __init__(self, output_dir: str = "test_reports"): self.output_dir = Path(output_dir) self.output_dir.mkdir(exist_ok=True) def generate_html_report(self, report: Dict[str, Any], test_results: List[TestResult]) -> str: """生成HTML格式报告""" html_content = self._build_html_template(report, test_results) report_path = self.output_dir / f"test_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.html" with open(report_path, 'w', encoding='utf-8') as f: f.write(html_content) return str(report_path) def _build_html_template(self, report: Dict[str, Any], test_results: List[TestResult]) -> str: """构建HTML报告模板""" summary = report['summary'] html = f""" LLM Generation Test Report

LLM Generation Test Report

Generated: {summary['timestamp']}

Summary

Total Tests: {summary['total_tests']}
Passed: {summary['passed_tests']}
Failed: {summary['failed_tests']}
Success Rate: {summary['success_rate']:.1%}
Avg Quality: {summary['average_quality']:.2f}
Total Tokens: {summary['total_tokens']}
""" # 添加失败的测试用例 if summary['failed_tests'] > 0: html += """

Failed Test Cases

""" for failed_case in report['failed_cases']: html += f"""

{failed_case['name']}

Error: {failed_case['error']}

Duration: {failed_case['duration']:.2f}s

""" html += "
" html += """ """ return html # 便捷函数 async def run_llm_tests(test_cases: List[TestCase], config_path: Optional[str] = None, parallel: bool = True) -> Tuple[TestRunner, List[TestResult], Dict[str, Any]]: """运行LLM生成测试的便捷函数""" runner = TestRunner(config_path) try: results = await runner.run_batch_tests(test_cases, parallel) report = runner.generate_report() return runner, results, report finally: # 注意:不在这里清理runner,因为调用者需要使用它来保存详细结果 pass