You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
cbmc/codedetect/tests/tools/test_llm_cli.py

601 lines
22 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python3
"""
LLM生成CLI测试工具
本工具提供交互式命令行界面用于测试LLM生成管道的各个方面。
支持单个函数测试、批量文件处理、模板比较和详细分析功能。
"""
import os
import sys
import asyncio
import json
import time
import argparse
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, asdict
import click
from colorama import init, Fore, Style
# 初始化colorama
init()
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from src.spec.llm_generator import LLMGenerator, GenerationRequest, GenerationResult
from src.parser.c_parser import CParser, CParserFactory
from src.parser.ast_extractor import ASTExtractor
from src.utils.logger import get_logger, setup_logging
from test_llm_generation import TestRunner, TestCase
from test_data.simple_c_examples import (
get_basic_test_suite,
get_test_cases_by_complexity,
get_test_cases_by_category
)
from utils.cbmc_spec_validator import CBMCSpecificationValidator, validate_cbmc_specification
# 设置日志
setup_logging(level=logging.INFO)
logger = get_logger(__name__)
@dataclass
class TestConfig:
"""测试配置"""
api_key: str
model: str = "deepseek-ai/DeepSeek-V3.1"
base_url: str = "https://api.siliconflow.cn/v1"
timeout: int = 120
max_retries: int = 3
parallel_workers: int = 3
quality_threshold: float = 0.7
output_dir: str = "test_results"
class LLMTestCLI:
"""LLM测试CLI类"""
def __init__(self, config: TestConfig):
self.config = config
self.logger = logger
self.parser = CParserFactory.create_parser()
self.ast_extractor = ASTExtractor()
self.validator = CBMCSpecificationValidator()
self.test_runner = TestRunner()
# 确保输出目录存在
Path(config.output_dir).mkdir(parents=True, exist_ok=True)
def print_header(self):
"""打印头部信息"""
print(f"{Fore.CYAN}{'='*60}")
print(f"{Fore.CYAN}LLM Generation Test CLI")
print(f"{Fore.CYAN}{'='*60}{Style.RESET_ALL}")
def print_success(self, message: str):
"""打印成功消息"""
print(f"{Fore.GREEN}{message}{Style.RESET_ALL}")
def print_error(self, message: str):
"""打印错误消息"""
print(f"{Fore.RED}{message}{Style.RESET_ALL}")
def print_warning(self, message: str):
"""打印警告消息"""
print(f"{Fore.YELLOW}{message}{Style.RESET_ALL}")
def print_info(self, message: str):
"""打印信息消息"""
print(f"{Fore.BLUE} {message}{Style.RESET_ALL}")
async def test_single_function(self, c_file: str, function_name: Optional[str] = None) -> GenerationResult:
"""测试单个函数"""
self.print_info(f"Testing function from {c_file}")
try:
# 解析文件
ast = self.parser.parse_file(c_file)
metadata = self.ast_extractor.extract_metadata(ast)
# 选择函数
if function_name and function_name in metadata.functions:
func_info = metadata.functions[function_name]
else:
# 使用第一个函数
if not metadata.functions:
raise ValueError("No functions found in file")
func_name = list(metadata.functions.keys())[0]
func_info = metadata.functions[func_name]
function_name = func_name
self.print_info(f"Selected function: {function_name}")
# 创建生成请求
request = GenerationRequest(
function_name=function_name,
function_info=func_info.to_dict(),
verification_goals=['functional_correctness', 'memory_safety'],
max_retries=self.config.max_retries,
validate=True,
store=False
)
# 生成规范
async with LLMGenerator(
api_key=self.config.api_key,
base_url=self.config.base_url
) as generator:
result = await generator.generate_specification(request)
# 显示结果
self._display_generation_result(result)
# 验证规范
validation_result = validate_cbmc_specification(
result.specification,
func_info.to_dict()
)
self._display_validation_result(validation_result)
return result
except Exception as e:
self.print_error(f"Test failed: {str(e)}")
raise
async def test_batch_files(self, file_paths: List[str]) -> List[GenerationResult]:
"""批量测试文件"""
self.print_info(f"Testing {len(file_paths)} files")
results = []
for file_path in file_paths:
try:
self.print_info(f"Processing {file_path}")
result = await self.test_single_function(file_path)
results.append(result)
except Exception as e:
self.print_error(f"Failed to process {file_path}: {str(e)}")
# 显示汇总信息
self._display_batch_summary(results)
return results
async def run_preset_tests(self, test_type: str = "basic") -> List[GenerationResult]:
"""运行预设测试"""
self.print_info(f"Running {test_type} test suite")
# 获取测试用例
if test_type == "basic":
test_cases = get_basic_test_suite()
elif test_type == "intermediate":
test_cases = get_test_cases_by_complexity("intermediate")
elif test_type == "advanced":
test_cases = get_test_cases_by_complexity("advanced")
elif test_type == "comprehensive":
test_cases = get_test_cases_by_complexity("basic") + \
get_test_cases_by_complexity("intermediate") + \
get_test_cases_by_complexity("advanced")
else:
raise ValueError(f"Unknown test type: {test_type}")
self.print_info(f"Loaded {len(test_cases)} test cases")
# 使用测试运行器
results = await self.test_runner.run_batch_tests(test_cases, parallel=False)
# 显示结果
self._display_test_suite_results(results)
return results
async def compare_templates(self, c_file: str, function_name: Optional[str] = None) -> Dict[str, GenerationResult]:
"""比较不同模板的生成结果"""
self.print_info(f"Comparing templates for {c_file}")
try:
# 解析文件
ast = self.parser.parse_file(c_file)
metadata = self.ast_extractor.extract_metadata(ast)
# 选择函数
if function_name and function_name in metadata.functions:
func_info = metadata.functions[function_name]
else:
if not metadata.functions:
raise ValueError("No functions found in file")
func_name = list(metadata.functions.keys())[0]
func_info = metadata.functions[func_name]
function_name = func_name
# 不同的验证目标组合
goal_combinations = [
['functional_correctness'],
['memory_safety'],
['functional_correctness', 'memory_safety'],
['functional_correctness', 'memory_safety', 'buffer_overflow']
]
results = {}
for i, goals in enumerate(goal_combinations):
self.print_info(f"Testing with goals: {', '.join(goals)}")
request = GenerationRequest(
function_name=function_name,
function_info=func_info.to_dict(),
verification_goals=goals,
max_retries=self.config.max_retries,
validate=True,
store=False
)
async with LLMGenerator(
api_key=self.config.api_key,
base_url=self.config.base_url
) as generator:
result = await generator.generate_specification(request)
results[f"template_{i+1}"] = result
# 显示基本结果
self.print_info(f" Quality: {result.quality_score:.2f}, "
f"Tokens: {result.tokens_used}, "
f"Time: {result.generation_time:.2f}s")
# 显示比较结果
self._display_template_comparison(results)
return results
except Exception as e:
self.print_error(f"Template comparison failed: {str(e)}")
raise
async def analyze_specification(self, spec_file: str) -> Dict[str, Any]:
"""分析现有规范文件"""
self.print_info(f"Analyzing specification file: {spec_file}")
try:
# 读取规范
with open(spec_file, 'r', encoding='utf-8') as f:
spec_content = f.read()
# 模拟函数信息(需要用户提供)
function_info = {
'name': 'analyzed_function',
'return_type': 'int',
'parameters': [
{'name': 'param1', 'type': 'int'},
{'name': 'param2', 'type': 'int'}
]
}
# 验证规范
validation_result = validate_cbmc_specification(spec_content, function_info)
self._display_validation_result(validation_result)
# 分析规范特征
analysis = self._analyze_spec_features(spec_content)
self._display_spec_analysis(analysis)
return {
'validation': validation_result.to_dict(),
'analysis': analysis
}
except Exception as e:
self.print_error(f"Specification analysis failed: {str(e)}")
raise
def _display_generation_result(self, result: GenerationResult):
"""显示生成结果"""
print(f"\\n{Fore.CYAN}Generation Result:{Style.RESET_ALL}")
print(f" Function: {result.metadata['function_name']}")
print(f" Quality Score: {result.quality_score:.2f}")
print(f" Generation Time: {result.generation_time:.2f}s")
print(f" Tokens Used: {result.tokens_used}")
print(f" Token Rate: {result.tokens_used/result.generation_time:.1f} tokens/s")
if result.specification:
print(f"\\n{Fore.YELLOW}Generated Specification:{Style.RESET_ALL}")
print(result.specification)
def _display_validation_result(self, validation_result):
"""显示验证结果"""
print(f"\\n{Fore.CYAN}Validation Result:{Style.RESET_ALL}")
print(f" Valid: {validation_result.is_valid}")
print(f" Quality Score: {validation_result.quality_score:.2f}")
if validation_result.errors:
print(f"\\n{Fore.RED}Errors:{Style.RESET_ALL}")
for error in validation_result.errors:
print(f" - {error}")
if validation_result.warnings:
print(f"\\n{Fore.YELLOW}Warnings:{Style.RESET_ALL}")
for warning in validation_result.warnings:
print(f" - {warning}")
if validation_result.suggestions:
print(f"\\n{Fore.BLUE}Suggestions:{Style.RESET_ALL}")
for suggestion in validation_result.suggestions:
print(f" - {suggestion}")
def _display_batch_summary(self, results: List[GenerationResult]):
"""显示批量测试汇总"""
if not results:
return
total = len(results)
successful = sum(1 for r in results if r.quality_score > 0)
avg_quality = sum(r.quality_score for r in results) / total
total_tokens = sum(r.tokens_used for r in results)
total_time = sum(r.generation_time for r in results)
print(f"\\n{Fore.CYAN}Batch Test Summary:{Style.RESET_ALL}")
print(f" Total Tests: {total}")
print(f" Successful: {successful}")
print(f" Average Quality: {avg_quality:.2f}")
print(f" Total Tokens: {total_tokens}")
print(f" Total Time: {total_time:.2f}s")
print(f" Average Token Rate: {total_tokens/total_time:.1f} tokens/s")
def _display_test_suite_results(self, results):
"""显示测试套件结果"""
if not results:
return
successful = [r for r in results if r.success]
failed = [r for r in results if not r.success]
print(f"\\n{Fore.CYAN}Test Suite Results:{Style.RESET_ALL}")
print(f" Total: {len(results)}")
print(f" Passed: {len(successful)}")
print(f" Failed: {len(failed)}")
if successful:
avg_quality = sum(r.quality_score for r in successful) / len(successful)
print(f" Average Quality (passed): {avg_quality:.2f}")
if failed:
print(f"\\n{Fore.RED}Failed Tests:{Style.RESET_ALL}")
for result in failed:
print(f" - {result.test_case.name}: {result.error_message}")
def _display_template_comparison(self, results: Dict[str, GenerationResult]):
"""显示模板比较结果"""
print(f"\\n{Fore.CYAN}Template Comparison:{Style.RESET_ALL}")
for template_name, result in results.items():
print(f" {template_name}:")
print(f" Quality: {result.quality_score:.2f}")
print(f" Tokens: {result.tokens_used}")
print(f" Time: {result.generation_time:.2f}s")
# 找出最佳结果
best_result = max(results.values(), key=lambda r: r.quality_score)
best_template = [k for k, v in results.items() if v == best_result][0]
print(f"\\n{Fore.GREEN}Best Template: {best_template} (Quality: {best_result.quality_score:.2f}){Style.RESET_ALL}")
def _analyze_spec_features(self, spec: str) -> Dict[str, Any]:
"""分析规范特征"""
lines = spec.split('\\n')
clauses = [line.strip() for line in lines if line.strip().startswith('\\\\')]
return {
'total_lines': len(lines),
'total_clauses': len(clauses),
'clause_types': list(set(clause.split()[0][2:] for clause in clauses if clause.split())),
'has_requires': any('\\\\requires' in clause for clause in clauses),
'has_ensures': any('\\\\ensures' in clause for clause in clauses),
'has_assigns': any('\\\\assigns' in clause for clause in clauses),
'has_valid': any('\\\\valid' in clause for clause in clauses),
'complexity_indicators': {
'logical_operators': spec.count('&&') + spec.count('\\|\\|'),
'quantifiers': spec.count('\\\\forall') + spec.count('\\\\exists'),
'arithmetic_ops': sum(spec.count(op) for op in ['+', '-', '*', '/'])
}
}
def _display_spec_analysis(self, analysis: Dict[str, Any]):
"""显示规范分析结果"""
print(f"\\n{Fore.CYAN}Specification Analysis:{Style.RESET_ALL}")
print(f" Total Lines: {analysis['total_lines']}")
print(f" Total Clauses: {analysis['total_clauses']}")
print(f" Clause Types: {', '.join(analysis['clause_types'])}")
print(f"\\n Contains:")
print(f" Requires: {analysis['has_requires']}")
print(f" Ensures: {analysis['has_ensures']}")
print(f" Assigns: {analysis['has_assigns']}")
print(f" Valid: {analysis['has_valid']}")
indicators = analysis['complexity_indicators']
print(f"\\n Complexity Indicators:")
print(f" Logical Operators: {indicators['logical_operators']}")
print(f" Quantifiers: {indicators['quantifiers']}")
print(f" Arithmetic Operations: {indicators['arithmetic_ops']}")
async def run_health_check(self):
"""运行API健康检查"""
self.print_info("Running API health check...")
try:
async with LLMGenerator(
api_key=self.config.api_key,
base_url=self.config.base_url
) as generator:
health = await generator.health_check()
if health['status'] == 'healthy':
self.print_success(f"API is healthy (response time: {health['api_response_time']:.2f}s)")
else:
self.print_error(f"API is unhealthy: {health.get('error', 'Unknown error')}")
except Exception as e:
self.print_error(f"Health check failed: {str(e)}")
# CLI命令
@click.group()
@click.option('--api-key', envvar='SILICONFLOW_API_KEY', required=True, help='SiliconFlow API key')
@click.option('--model', default='deepseek-ai/DeepSeek-V3.1', help='Model to use')
@click.option('--base-url', default='https://api.siliconflow.cn/v1', help='API base URL')
@click.option('--timeout', default=120, help='Request timeout in seconds')
@click.option('--max-retries', default=3, help='Maximum retry attempts')
@click.option('--output-dir', default='test_results', help='Output directory')
@click.pass_context
def cli(ctx, api_key, model, base_url, timeout, max_retries, output_dir):
"""LLM Generation Test CLI"""
ctx.ensure_object(dict)
ctx.obj['config'] = TestConfig(
api_key=api_key,
model=model,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
output_dir=output_dir
)
@cli.command()
@click.argument('c_file', type=click.Path(exists=True))
@click.option('--function', '-f', help='Specific function to test')
@click.pass_context
def test_function(ctx, c_file, function):
"""Test a single function from a C file"""
config = ctx.obj['config']
cli_tool = LLMTestCLI(config)
cli_tool.print_header()
asyncio.run(cli_tool.test_single_function(c_file, function))
@cli.command()
@click.argument('files', nargs=-1, type=click.Path(exists=True))
@click.pass_context
def test_batch(ctx, files):
"""Test multiple C files"""
if not files:
click.echo("Please provide at least one C file")
return
config = ctx.obj['config']
cli_tool = LLMTestCLI(config)
cli_tool.print_header()
asyncio.run(cli_tool.test_batch_files(list(files)))
@cli.command()
@click.option('--type', '-t', type=click.Choice(['basic', 'intermediate', 'advanced', 'comprehensive']),
default='basic', help='Test suite type')
@click.pass_context
def test_suite(ctx, type):
"""Run a predefined test suite"""
config = ctx.obj['config']
cli_tool = LLMTestCLI(config)
cli_tool.print_header()
asyncio.run(cli_tool.run_preset_tests(type))
@cli.command()
@click.argument('c_file', type=click.Path(exists=True))
@click.option('--function', '-f', help='Specific function to test')
@click.pass_context
def compare_templates(ctx, c_file, function):
"""Compare different prompt templates"""
config = ctx.obj['config']
cli_tool = LLMTestCLI(config)
cli_tool.print_header()
asyncio.run(cli_tool.compare_templates(c_file, function))
@cli.command()
@click.argument('spec_file', type=click.Path(exists=True))
@click.pass_context
def analyze_spec(ctx, spec_file):
"""Analyze an existing specification file"""
config = ctx.obj['config']
cli_tool = LLMTestCLI(config)
cli_tool.print_header()
asyncio.run(cli_tool.analyze_specification(spec_file))
@cli.command()
@click.pass_context
def health_check(ctx):
"""Check API health"""
config = ctx.obj['config']
cli_tool = LLMTestCLI(config)
cli_tool.print_header()
async def check_health():
async with LLMGenerator(
api_key=config.api_key,
base_url=config.base_url
) as generator:
health = await generator.health_check()
if health['status'] == 'healthy':
cli_tool.print_success(f"API is healthy (response time: {health['api_response_time']:.2f}s)")
else:
cli_tool.print_error(f"API is unhealthy: {health.get('error', 'Unknown error')}")
asyncio.run(check_health())
@cli.command()
@click.pass_context
def interactive(ctx):
"""Start interactive mode"""
config = ctx.obj['config']
cli_tool = LLMTestCLI(config)
cli_tool.print_header()
cli_tool.print_info("Interactive mode started. Type 'help' for available commands.")
while True:
try:
command = input(f"\\n{Fore.GREEN}llm-test> {Style.RESET_ALL}").strip()
if not command:
continue
if command.lower() in ['exit', 'quit']:
cli_tool.print_info("Goodbye!")
break
elif command.lower() == 'help':
cli_tool.print_info("Available commands:")
cli_tool.print_info(" test <file> - Test a C file")
cli_tool.print_info(" batch <file1> <file2> ... - Test multiple files")
cli_tool.print_info(" suite <type> - Run test suite (basic/intermediate/advanced/comprehensive)")
cli_tool.print_info(" compare <file> - Compare templates for a file")
cli_tool.print_info(" health - Check API health")
cli_tool.print_info(" help - Show this help")
cli_tool.print_info(" exit - Exit interactive mode")
elif command.startswith('test '):
file_path = command[5:].strip()
if os.path.exists(file_path):
asyncio.run(cli_tool.test_single_function(file_path))
else:
cli_tool.print_error(f"File not found: {file_path}")
elif command.startswith('health'):
asyncio.run(cli_tool.run_health_check())
else:
cli_tool.print_error(f"Unknown command: {command}")
except KeyboardInterrupt:
cli_tool.print_info("\\nGoodbye!")
break
except EOFError:
cli_tool.print_info("\\nGoodbye!")
break
if __name__ == '__main__':
cli()