You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
295 lines
9.7 KiB
295 lines
9.7 KiB
#!/usr/bin/env python3
|
|
"""
|
|
简单的DeepSeek v3.1代码规约化测试脚本
|
|
|
|
该脚本演示如何使用DeepSeek v3.1模型对C/C++代码进行形式化规约生成。
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import asyncio
|
|
import json
|
|
from pathlib import Path
|
|
|
|
# 添加项目根目录到路径
|
|
project_root = Path(__file__).parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from src.spec.llm_generator import LLMGenerator, GenerationRequest
|
|
|
|
|
|
def create_sample_function_info():
|
|
"""创建示例函数信息"""
|
|
return {
|
|
'name': 'safe_add',
|
|
'return_type': 'int',
|
|
'parameters': [
|
|
{'name': 'a', 'type': 'int'},
|
|
{'name': 'b', 'type': 'int'},
|
|
{'name': 'overflow', 'type': 'bool*'}
|
|
],
|
|
'source_code': '''
|
|
bool safe_add(int a, int b, int *result) {
|
|
if (result == NULL) {
|
|
return false;
|
|
}
|
|
|
|
if (a > 0 && b > INT_MAX - a) {
|
|
return false; // 正数溢出
|
|
}
|
|
if (a < 0 && b < INT_MIN - a) {
|
|
return false; // 负数溢出
|
|
}
|
|
|
|
*result = a + b;
|
|
return true;
|
|
}
|
|
''',
|
|
'includes': ['#include <limits.h>'],
|
|
'complexity': 'intermediate',
|
|
'features': ['overflow_check', 'error_handling', 'pointer_validation']
|
|
}
|
|
|
|
|
|
def create_sample_function_2():
|
|
"""创建第二个示例函数信息"""
|
|
return {
|
|
'name': 'array_copy',
|
|
'return_type': 'void',
|
|
'parameters': [
|
|
{'name': 'dest', 'type': 'int*'},
|
|
{'name': 'src', 'type': 'const int*'},
|
|
{'name': 'size', 'type': 'size_t'}
|
|
],
|
|
'source_code': '''
|
|
void array_copy(int *dest, const int *src, size_t size) {
|
|
if (dest == NULL || src == NULL) {
|
|
return;
|
|
}
|
|
|
|
for (size_t i = 0; i < size; i++) {
|
|
dest[i] = src[i];
|
|
}
|
|
}
|
|
''',
|
|
'includes': ['#include <stddef.h>'],
|
|
'complexity': 'basic',
|
|
'features': ['array_operation', 'pointer_validation', 'loop']
|
|
}
|
|
|
|
|
|
async def test_basic_normalization():
|
|
"""测试基本的代码规约化功能"""
|
|
print("=" * 60)
|
|
print("DeepSeek v3.1 代码规约化测试")
|
|
print("=" * 60)
|
|
|
|
# 检查API密钥
|
|
api_key = os.getenv('SILICONFLOW_API_KEY')
|
|
if not api_key:
|
|
print("错误: 未设置 SILICONFLOW_API_KEY 环境变量")
|
|
print("请设置API密钥: export SILICONFLOW_API_KEY=your_api_key")
|
|
return
|
|
|
|
try:
|
|
# 创建LLM生成器
|
|
async with LLMGenerator() as generator:
|
|
print(f"使用模型: {generator.model}")
|
|
print(f"API地址: {generator.base_url}")
|
|
|
|
# 执行健康检查
|
|
print("\n执行API健康检查...")
|
|
health = await generator.health_check()
|
|
if health['status'] == 'healthy':
|
|
print(f"✅ API健康 (响应时间: {health['api_response_time']:.2f}s)")
|
|
else:
|
|
print(f"❌ API不健康: {health.get('error', '未知错误')}")
|
|
return
|
|
|
|
# 测试1: 基本函数规约化
|
|
print("\n" + "=" * 40)
|
|
print("测试1: 安全加法函数规约化")
|
|
print("=" * 40)
|
|
|
|
function_info = create_sample_function_info()
|
|
request = GenerationRequest(
|
|
function_name=function_info['name'],
|
|
function_info=function_info,
|
|
verification_goals=[
|
|
'functional_correctness',
|
|
'memory_safety',
|
|
'pointer_validity',
|
|
'integer_overflow',
|
|
'error_handling'
|
|
],
|
|
hints=[
|
|
"检查指针参数有效性",
|
|
"验证整数溢出条件",
|
|
"确保错误返回值正确"
|
|
]
|
|
)
|
|
|
|
print("函数签名:")
|
|
param_str = ', '.join([f'{p["type"]} {p["name"]}' for p in function_info['parameters']])
|
|
print(f" {function_info['return_type']} {function_info['name']}({param_str})")
|
|
print(f"验证目标: {', '.join(request.verification_goals)}")
|
|
|
|
# 生成规约
|
|
print("\n正在生成规约...")
|
|
result = await generator.generate_specification(request)
|
|
|
|
print(f"✅ 规约生成成功 (质量评分: {result.quality_score:.2f})")
|
|
print(f"生成时间: {result.generation_time:.2f}s")
|
|
print(f"使用token数: {result.tokens_used}")
|
|
|
|
print("\n生成的CBMC规约:")
|
|
print("-" * 30)
|
|
print(result.specification)
|
|
print("-" * 30)
|
|
|
|
# 测试2: 数组拷贝函数规约化
|
|
print("\n" + "=" * 40)
|
|
print("测试2: 数组拷贝函数规约化")
|
|
print("=" * 40)
|
|
|
|
function_info2 = create_sample_function_2()
|
|
request2 = GenerationRequest(
|
|
function_name=function_info2['name'],
|
|
function_info=function_info2,
|
|
verification_goals=[
|
|
'functional_correctness',
|
|
'memory_safety',
|
|
'array_bounds',
|
|
'pointer_validity'
|
|
],
|
|
hints=[
|
|
"验证源和目标指针有效性",
|
|
"确保数组边界安全",
|
|
"检查size参数有效性"
|
|
]
|
|
)
|
|
|
|
print("函数签名:")
|
|
param_str2 = ', '.join([f'{p["type"]} {p["name"]}' for p in function_info2['parameters']])
|
|
print(f" {function_info2['return_type']} {function_info2['name']}({param_str2})")
|
|
|
|
print("\n正在生成规约...")
|
|
result2 = await generator.generate_specification(request2)
|
|
|
|
print(f"✅ 规约生成成功 (质量评分: {result2.quality_score:.2f})")
|
|
print(f"生成时间: {result2.generation_time:.2f}s")
|
|
|
|
print("\n生成的CBMC规约:")
|
|
print("-" * 30)
|
|
print(result2.specification)
|
|
print("-" * 30)
|
|
|
|
# 测试3: 批量规约化
|
|
print("\n" + "=" * 40)
|
|
print("测试3: 批量规约化")
|
|
print("=" * 40)
|
|
|
|
batch_requests = [request, request2]
|
|
print(f"正在批量生成 {len(batch_requests)} 个规约...")
|
|
|
|
batch_results = await generator.generate_batch_specifications(batch_requests)
|
|
|
|
print(f"✅ 批量生成完成: {len(batch_results)} 个规约")
|
|
for i, result in enumerate(batch_results):
|
|
print(f" {i+1}. {result.metadata['function_name']}: 质量 {result.quality_score:.2f}")
|
|
|
|
# 显示统计信息
|
|
print("\n" + "=" * 40)
|
|
print("生成统计")
|
|
print("=" * 40)
|
|
|
|
stats = generator.get_generation_stats()
|
|
print(f"总规约数: {stats['total_specifications']}")
|
|
print(f"平均质量: {stats['average_quality']:.2f}")
|
|
if 'quality_distribution' in stats:
|
|
print("质量分布:")
|
|
for level, count in stats['quality_distribution'].items():
|
|
print(f" {level}: {count}")
|
|
|
|
except Exception as e:
|
|
print(f"❌ 测试失败: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
|
|
async def test_streaming_normalization():
|
|
"""测试流式规约化功能"""
|
|
print("\n" + "=" * 60)
|
|
print("流式规约化测试")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
async with LLMGenerator() as generator:
|
|
function_info = create_sample_function_info()
|
|
request = GenerationRequest(
|
|
function_name=function_info['name'],
|
|
function_info=function_info,
|
|
verification_goals=['functional_correctness', 'memory_safety']
|
|
)
|
|
|
|
print("正在流式生成规约...")
|
|
print("-" * 30)
|
|
|
|
full_spec = ""
|
|
async for chunk in generator.stream_specification(request):
|
|
print(chunk, end='', flush=True)
|
|
full_spec += chunk
|
|
|
|
print("\n" + "-" * 30)
|
|
print(f"✅ 流式生成完成,总长度: {len(full_spec)} 字符")
|
|
|
|
except Exception as e:
|
|
print(f"❌ 流式测试失败: {str(e)}")
|
|
|
|
|
|
def save_results_to_file(results, filename="normalization_results.json"):
|
|
"""保存结果到文件"""
|
|
try:
|
|
data = []
|
|
for result in results:
|
|
data.append({
|
|
'function_name': result.metadata['function_name'],
|
|
'specification': result.specification,
|
|
'quality_score': result.quality_score,
|
|
'generation_time': result.generation_time,
|
|
'tokens_used': result.tokens_used,
|
|
'validation_result': {
|
|
'is_valid': result.validation_result.is_valid if result.validation_result else False,
|
|
'errors': result.validation_result.errors if result.validation_result else [],
|
|
'warnings': result.validation_result.warnings if result.validation_result else []
|
|
} if result.validation_result else None
|
|
})
|
|
|
|
with open(filename, 'w', encoding='utf-8') as f:
|
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
|
|
print(f"✅ 结果已保存到: {filename}")
|
|
|
|
except Exception as e:
|
|
print(f"❌ 保存结果失败: {str(e)}")
|
|
|
|
|
|
async def main():
|
|
"""主函数"""
|
|
print("DeepSeek v3.1 代码规约化测试工具")
|
|
print("此工具将演示如何使用DeepSeek v3.1模型对C/C++代码进行形式化规约生成\n")
|
|
|
|
# 执行基本测试
|
|
await test_basic_normalization()
|
|
|
|
# 执行流式测试
|
|
await test_streaming_normalization()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("测试完成")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 运行测试
|
|
asyncio.run(main()) |