#!/usr/bin/env python3 """ 简单的DeepSeek v3.1代码规约化测试脚本 该脚本演示如何使用DeepSeek v3.1模型对C/C++代码进行形式化规约生成。 """ import os import sys import asyncio import json from pathlib import Path # 添加项目根目录到路径 project_root = Path(__file__).parent sys.path.insert(0, str(project_root)) from src.spec.llm_generator import LLMGenerator, GenerationRequest def create_sample_function_info(): """创建示例函数信息""" return { 'name': 'safe_add', 'return_type': 'int', 'parameters': [ {'name': 'a', 'type': 'int'}, {'name': 'b', 'type': 'int'}, {'name': 'overflow', 'type': 'bool*'} ], 'source_code': ''' bool safe_add(int a, int b, int *result) { if (result == NULL) { return false; } if (a > 0 && b > INT_MAX - a) { return false; // 正数溢出 } if (a < 0 && b < INT_MIN - a) { return false; // 负数溢出 } *result = a + b; return true; } ''', 'includes': ['#include '], 'complexity': 'intermediate', 'features': ['overflow_check', 'error_handling', 'pointer_validation'] } def create_sample_function_2(): """创建第二个示例函数信息""" return { 'name': 'array_copy', 'return_type': 'void', 'parameters': [ {'name': 'dest', 'type': 'int*'}, {'name': 'src', 'type': 'const int*'}, {'name': 'size', 'type': 'size_t'} ], 'source_code': ''' void array_copy(int *dest, const int *src, size_t size) { if (dest == NULL || src == NULL) { return; } for (size_t i = 0; i < size; i++) { dest[i] = src[i]; } } ''', 'includes': ['#include '], 'complexity': 'basic', 'features': ['array_operation', 'pointer_validation', 'loop'] } async def test_basic_normalization(): """测试基本的代码规约化功能""" print("=" * 60) print("DeepSeek v3.1 代码规约化测试") print("=" * 60) # 检查API密钥 api_key = os.getenv('SILICONFLOW_API_KEY') if not api_key: print("错误: 未设置 SILICONFLOW_API_KEY 环境变量") print("请设置API密钥: export SILICONFLOW_API_KEY=your_api_key") return try: # 创建LLM生成器 async with LLMGenerator() as generator: print(f"使用模型: {generator.model}") print(f"API地址: {generator.base_url}") # 执行健康检查 print("\n执行API健康检查...") health = await generator.health_check() if health['status'] == 'healthy': print(f"✅ API健康 (响应时间: {health['api_response_time']:.2f}s)") else: print(f"❌ API不健康: {health.get('error', '未知错误')}") return # 测试1: 基本函数规约化 print("\n" + "=" * 40) print("测试1: 安全加法函数规约化") print("=" * 40) function_info = create_sample_function_info() request = GenerationRequest( function_name=function_info['name'], function_info=function_info, verification_goals=[ 'functional_correctness', 'memory_safety', 'pointer_validity', 'integer_overflow', 'error_handling' ], hints=[ "检查指针参数有效性", "验证整数溢出条件", "确保错误返回值正确" ] ) print("函数签名:") param_str = ', '.join([f'{p["type"]} {p["name"]}' for p in function_info['parameters']]) print(f" {function_info['return_type']} {function_info['name']}({param_str})") print(f"验证目标: {', '.join(request.verification_goals)}") # 生成规约 print("\n正在生成规约...") result = await generator.generate_specification(request) print(f"✅ 规约生成成功 (质量评分: {result.quality_score:.2f})") print(f"生成时间: {result.generation_time:.2f}s") print(f"使用token数: {result.tokens_used}") print("\n生成的CBMC规约:") print("-" * 30) print(result.specification) print("-" * 30) # 测试2: 数组拷贝函数规约化 print("\n" + "=" * 40) print("测试2: 数组拷贝函数规约化") print("=" * 40) function_info2 = create_sample_function_2() request2 = GenerationRequest( function_name=function_info2['name'], function_info=function_info2, verification_goals=[ 'functional_correctness', 'memory_safety', 'array_bounds', 'pointer_validity' ], hints=[ "验证源和目标指针有效性", "确保数组边界安全", "检查size参数有效性" ] ) print("函数签名:") param_str2 = ', '.join([f'{p["type"]} {p["name"]}' for p in function_info2['parameters']]) print(f" {function_info2['return_type']} {function_info2['name']}({param_str2})") print("\n正在生成规约...") result2 = await generator.generate_specification(request2) print(f"✅ 规约生成成功 (质量评分: {result2.quality_score:.2f})") print(f"生成时间: {result2.generation_time:.2f}s") print("\n生成的CBMC规约:") print("-" * 30) print(result2.specification) print("-" * 30) # 测试3: 批量规约化 print("\n" + "=" * 40) print("测试3: 批量规约化") print("=" * 40) batch_requests = [request, request2] print(f"正在批量生成 {len(batch_requests)} 个规约...") batch_results = await generator.generate_batch_specifications(batch_requests) print(f"✅ 批量生成完成: {len(batch_results)} 个规约") for i, result in enumerate(batch_results): print(f" {i+1}. {result.metadata['function_name']}: 质量 {result.quality_score:.2f}") # 显示统计信息 print("\n" + "=" * 40) print("生成统计") print("=" * 40) stats = generator.get_generation_stats() print(f"总规约数: {stats['total_specifications']}") print(f"平均质量: {stats['average_quality']:.2f}") if 'quality_distribution' in stats: print("质量分布:") for level, count in stats['quality_distribution'].items(): print(f" {level}: {count}") except Exception as e: print(f"❌ 测试失败: {str(e)}") import traceback traceback.print_exc() async def test_streaming_normalization(): """测试流式规约化功能""" print("\n" + "=" * 60) print("流式规约化测试") print("=" * 60) try: async with LLMGenerator() as generator: function_info = create_sample_function_info() request = GenerationRequest( function_name=function_info['name'], function_info=function_info, verification_goals=['functional_correctness', 'memory_safety'] ) print("正在流式生成规约...") print("-" * 30) full_spec = "" async for chunk in generator.stream_specification(request): print(chunk, end='', flush=True) full_spec += chunk print("\n" + "-" * 30) print(f"✅ 流式生成完成,总长度: {len(full_spec)} 字符") except Exception as e: print(f"❌ 流式测试失败: {str(e)}") def save_results_to_file(results, filename="normalization_results.json"): """保存结果到文件""" try: data = [] for result in results: data.append({ 'function_name': result.metadata['function_name'], 'specification': result.specification, 'quality_score': result.quality_score, 'generation_time': result.generation_time, 'tokens_used': result.tokens_used, 'validation_result': { 'is_valid': result.validation_result.is_valid if result.validation_result else False, 'errors': result.validation_result.errors if result.validation_result else [], 'warnings': result.validation_result.warnings if result.validation_result else [] } if result.validation_result else None }) with open(filename, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) print(f"✅ 结果已保存到: {filename}") except Exception as e: print(f"❌ 保存结果失败: {str(e)}") async def main(): """主函数""" print("DeepSeek v3.1 代码规约化测试工具") print("此工具将演示如何使用DeepSeek v3.1模型对C/C++代码进行形式化规约生成\n") # 执行基本测试 await test_basic_normalization() # 执行流式测试 await test_streaming_normalization() print("\n" + "=" * 60) print("测试完成") print("=" * 60) if __name__ == "__main__": # 运行测试 asyncio.run(main())