You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
cbmc/codedetect/tests/deepseek/test_deepseek_normalization.py

295 lines
9.7 KiB

#!/usr/bin/env python3
"""
简单的DeepSeek v3.1代码规约化测试脚本
该脚本演示如何使用DeepSeek v3.1模型对C/C++代码进行形式化规约生成。
"""
import os
import sys
import asyncio
import json
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from src.spec.llm_generator import LLMGenerator, GenerationRequest
def create_sample_function_info():
"""创建示例函数信息"""
return {
'name': 'safe_add',
'return_type': 'int',
'parameters': [
{'name': 'a', 'type': 'int'},
{'name': 'b', 'type': 'int'},
{'name': 'overflow', 'type': 'bool*'}
],
'source_code': '''
bool safe_add(int a, int b, int *result) {
if (result == NULL) {
return false;
}
if (a > 0 && b > INT_MAX - a) {
return false; // 正数溢出
}
if (a < 0 && b < INT_MIN - a) {
return false; // 负数溢出
}
*result = a + b;
return true;
}
''',
'includes': ['#include <limits.h>'],
'complexity': 'intermediate',
'features': ['overflow_check', 'error_handling', 'pointer_validation']
}
def create_sample_function_2():
"""创建第二个示例函数信息"""
return {
'name': 'array_copy',
'return_type': 'void',
'parameters': [
{'name': 'dest', 'type': 'int*'},
{'name': 'src', 'type': 'const int*'},
{'name': 'size', 'type': 'size_t'}
],
'source_code': '''
void array_copy(int *dest, const int *src, size_t size) {
if (dest == NULL || src == NULL) {
return;
}
for (size_t i = 0; i < size; i++) {
dest[i] = src[i];
}
}
''',
'includes': ['#include <stddef.h>'],
'complexity': 'basic',
'features': ['array_operation', 'pointer_validation', 'loop']
}
async def test_basic_normalization():
"""测试基本的代码规约化功能"""
print("=" * 60)
print("DeepSeek v3.1 代码规约化测试")
print("=" * 60)
# 检查API密钥
api_key = os.getenv('SILICONFLOW_API_KEY')
if not api_key:
print("错误: 未设置 SILICONFLOW_API_KEY 环境变量")
print("请设置API密钥: export SILICONFLOW_API_KEY=your_api_key")
return
try:
# 创建LLM生成器
async with LLMGenerator() as generator:
print(f"使用模型: {generator.model}")
print(f"API地址: {generator.base_url}")
# 执行健康检查
print("\n执行API健康检查...")
health = await generator.health_check()
if health['status'] == 'healthy':
print(f"✅ API健康 (响应时间: {health['api_response_time']:.2f}s)")
else:
print(f"❌ API不健康: {health.get('error', '未知错误')}")
return
# 测试1: 基本函数规约化
print("\n" + "=" * 40)
print("测试1: 安全加法函数规约化")
print("=" * 40)
function_info = create_sample_function_info()
request = GenerationRequest(
function_name=function_info['name'],
function_info=function_info,
verification_goals=[
'functional_correctness',
'memory_safety',
'pointer_validity',
'integer_overflow',
'error_handling'
],
hints=[
"检查指针参数有效性",
"验证整数溢出条件",
"确保错误返回值正确"
]
)
print("函数签名:")
param_str = ', '.join([f'{p["type"]} {p["name"]}' for p in function_info['parameters']])
print(f" {function_info['return_type']} {function_info['name']}({param_str})")
print(f"验证目标: {', '.join(request.verification_goals)}")
# 生成规约
print("\n正在生成规约...")
result = await generator.generate_specification(request)
print(f"✅ 规约生成成功 (质量评分: {result.quality_score:.2f})")
print(f"生成时间: {result.generation_time:.2f}s")
print(f"使用token数: {result.tokens_used}")
print("\n生成的CBMC规约:")
print("-" * 30)
print(result.specification)
print("-" * 30)
# 测试2: 数组拷贝函数规约化
print("\n" + "=" * 40)
print("测试2: 数组拷贝函数规约化")
print("=" * 40)
function_info2 = create_sample_function_2()
request2 = GenerationRequest(
function_name=function_info2['name'],
function_info=function_info2,
verification_goals=[
'functional_correctness',
'memory_safety',
'array_bounds',
'pointer_validity'
],
hints=[
"验证源和目标指针有效性",
"确保数组边界安全",
"检查size参数有效性"
]
)
print("函数签名:")
param_str2 = ', '.join([f'{p["type"]} {p["name"]}' for p in function_info2['parameters']])
print(f" {function_info2['return_type']} {function_info2['name']}({param_str2})")
print("\n正在生成规约...")
result2 = await generator.generate_specification(request2)
print(f"✅ 规约生成成功 (质量评分: {result2.quality_score:.2f})")
print(f"生成时间: {result2.generation_time:.2f}s")
print("\n生成的CBMC规约:")
print("-" * 30)
print(result2.specification)
print("-" * 30)
# 测试3: 批量规约化
print("\n" + "=" * 40)
print("测试3: 批量规约化")
print("=" * 40)
batch_requests = [request, request2]
print(f"正在批量生成 {len(batch_requests)} 个规约...")
batch_results = await generator.generate_batch_specifications(batch_requests)
print(f"✅ 批量生成完成: {len(batch_results)} 个规约")
for i, result in enumerate(batch_results):
print(f" {i+1}. {result.metadata['function_name']}: 质量 {result.quality_score:.2f}")
# 显示统计信息
print("\n" + "=" * 40)
print("生成统计")
print("=" * 40)
stats = generator.get_generation_stats()
print(f"总规约数: {stats['total_specifications']}")
print(f"平均质量: {stats['average_quality']:.2f}")
if 'quality_distribution' in stats:
print("质量分布:")
for level, count in stats['quality_distribution'].items():
print(f" {level}: {count}")
except Exception as e:
print(f"❌ 测试失败: {str(e)}")
import traceback
traceback.print_exc()
async def test_streaming_normalization():
"""测试流式规约化功能"""
print("\n" + "=" * 60)
print("流式规约化测试")
print("=" * 60)
try:
async with LLMGenerator() as generator:
function_info = create_sample_function_info()
request = GenerationRequest(
function_name=function_info['name'],
function_info=function_info,
verification_goals=['functional_correctness', 'memory_safety']
)
print("正在流式生成规约...")
print("-" * 30)
full_spec = ""
async for chunk in generator.stream_specification(request):
print(chunk, end='', flush=True)
full_spec += chunk
print("\n" + "-" * 30)
print(f"✅ 流式生成完成,总长度: {len(full_spec)} 字符")
except Exception as e:
print(f"❌ 流式测试失败: {str(e)}")
def save_results_to_file(results, filename="normalization_results.json"):
"""保存结果到文件"""
try:
data = []
for result in results:
data.append({
'function_name': result.metadata['function_name'],
'specification': result.specification,
'quality_score': result.quality_score,
'generation_time': result.generation_time,
'tokens_used': result.tokens_used,
'validation_result': {
'is_valid': result.validation_result.is_valid if result.validation_result else False,
'errors': result.validation_result.errors if result.validation_result else [],
'warnings': result.validation_result.warnings if result.validation_result else []
} if result.validation_result else None
})
with open(filename, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"✅ 结果已保存到: {filename}")
except Exception as e:
print(f"❌ 保存结果失败: {str(e)}")
async def main():
"""主函数"""
print("DeepSeek v3.1 代码规约化测试工具")
print("此工具将演示如何使用DeepSeek v3.1模型对C/C++代码进行形式化规约生成\n")
# 执行基本测试
await test_basic_normalization()
# 执行流式测试
await test_streaming_normalization()
print("\n" + "=" * 60)
print("测试完成")
print("=" * 60)
if __name__ == "__main__":
# 运行测试
asyncio.run(main())