You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
cbmc/codedetect/tests/deepseek/scripts/simple_deepseek_test.py

289 lines
9.3 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python3
"""
简化的DeepSeek v3.1代码规约化测试脚本
直接使用aiohttp调用DeepSeek API无需复杂的依赖
"""
import os
import asyncio
import aiohttp
import json
import time
class SimpleDeepSeekTest:
"""简化的DeepSeek测试类"""
def __init__(self, api_key=None, base_url=None):
self.api_key = api_key or os.getenv('SILICONFLOW_API_KEY')
self.base_url = base_url or 'https://api.siliconflow.cn/v1'
self.model = 'deepseek-ai/DeepSeek-V3.1'
if not self.api_key:
raise ValueError("需要设置 SILICONFLOW_API_KEY 环境变量")
async def generate_specification(self, function_info, verification_goals):
"""生成形式化规约"""
# 构建提示词
prompt = self._build_prompt(function_info, verification_goals)
# 准备请求数据
messages = [
{
"role": "system",
"content": "你是一个专业的形式化验证专家专门为C/C++代码生成CBMC格式的形式化规约。请生成准确、完整的CBMC规约只包含CBMC条款不要包含其他解释。"
},
{
"role": "user",
"content": prompt
}
]
data = {
"model": self.model,
"messages": messages,
"temperature": 0.3,
"max_tokens": 1024,
"stream": False
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
start_time = time.time()
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/chat/completions",
headers=headers,
json=data,
timeout=aiohttp.ClientTimeout(total=60)
) as response:
if response.status == 200:
result = await response.json()
content = result['choices'][0]['message']['content']
tokens_used = result.get('usage', {}).get('total_tokens', 0)
# 后处理提取CBMC规约
specification = self._extract_cbmc_spec(content)
generation_time = time.time() - start_time
return {
'specification': specification,
'raw_content': content,
'generation_time': generation_time,
'tokens_used': tokens_used,
'success': True
}
else:
error_text = await response.text()
return {
'success': False,
'error': f"API错误: {response.status} - {error_text}"
}
except Exception as e:
return {
'success': False,
'error': f"请求失败: {str(e)}"
}
def _build_prompt(self, function_info, verification_goals):
"""构建提示词"""
func_name = function_info.get('name', 'unknown')
return_type = function_info.get('return_type', 'void')
parameters = function_info.get('parameters', [])
source_code = function_info.get('source_code', '')
# 构建函数签名
param_str = ', '.join([f"{p['type']} {p['name']}" for p in parameters])
function_signature = f"{return_type} {func_name}({param_str})"
# 构建验证目标描述
goal_descriptions = {
'functional_correctness': '功能正确性验证',
'memory_safety': '内存安全性验证',
'pointer_validity': '指针有效性验证',
'integer_overflow': '整数溢出检查',
'array_bounds': '数组边界检查',
'error_handling': '错误处理验证'
}
goals_text = []
for goal in verification_goals:
description = goal_descriptions.get(goal, goal)
goals_text.append(f"- {description}")
prompt = f"""
请为以下C函数生成CBMC形式化验证规约
函数签名:{function_signature}
源代码:
```c
{source_code}
```
验证目标:
{chr(10).join(goals_text)}
请生成完整的CBMC规约包含
1. \\requires 子句:前置条件
2. \\ensures 子句:后置条件
3. \\assigns 子句:赋值说明(如果需要)
只输出CBMC规约不要包含其他解释。
"""
return prompt
def _extract_cbmc_spec(self, content):
"""从生成内容中提取CBMC规约"""
lines = content.split('\n')
cbmc_lines = []
for line in lines:
line = line.strip()
# 保留CBMC条款
if (line.startswith(r'\requires') or
line.startswith(r'\ensures') or
line.startswith(r'\assigns') or
line.startswith(r'\valid') or
line.startswith(r'\forall') or
line.startswith(r'\exists')):
cbmc_lines.append(line)
return '\n'.join(cbmc_lines) if cbmc_lines else content.strip()
async def health_check(self):
"""健康检查"""
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/chat/completions",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json={
"model": self.model,
"messages": [{"role": "user", "content": "test"}],
"max_tokens": 5
},
timeout=aiohttp.ClientTimeout(total=10)
) as response:
return {
'status': 'healthy' if response.status == 200 else 'unhealthy',
'response_time': 0.0, # 简化版本
'status_code': response.status
}
except Exception as e:
return {
'status': 'unhealthy',
'error': str(e)
}
def create_test_functions():
"""创建测试函数"""
return [
{
'name': 'add_numbers',
'return_type': 'int',
'parameters': [
{'name': 'a', 'type': 'int'},
{'name': 'b', 'type': 'int'}
],
'source_code': 'int add_numbers(int a, int b) { return a + b; }',
'verification_goals': ['functional_correctness']
},
{
'name': 'safe_divide',
'return_type': 'bool',
'parameters': [
{'name': 'a', 'type': 'int'},
{'name': 'b', 'type': 'int'},
{'name': 'result', 'type': 'int*'}
],
'source_code': '''
bool safe_divide(int a, int b, int *result) {
if (b == 0 || result == NULL) {
return false;
}
*result = a / b;
return true;
}''',
'verification_goals': ['functional_correctness', 'memory_safety', 'error_handling']
}
]
async def main():
"""主测试函数"""
print("🚀 DeepSeek v3.1 代码规约化简化测试")
print("=" * 50)
# 检查API密钥
if not os.getenv('SILICONFLOW_API_KEY'):
print("❌ 错误: 未设置 SILICONFLOW_API_KEY 环境变量")
print("请设置: export SILICONFLOW_API_KEY=your_api_key")
return
try:
# 创建测试实例
tester = SimpleDeepSeekTest()
print(f"📡 连接模型: {tester.model}")
# 健康检查
print("\n🔍 检查API状态...")
health = await tester.health_check()
if health['status'] == 'healthy':
print("✅ API健康")
else:
print(f"❌ API不健康: {health.get('error', '未知错误')}")
return
# 测试函数
test_functions = create_test_functions()
for i, func in enumerate(test_functions, 1):
print(f"\n{'='*20} 测试 {i} {'='*20}")
print(f"📝 函数: {func['name']}")
print(f"🎯 目标: {', '.join(func['verification_goals'])}")
print("\n⏳ 正在生成规约...")
result = await tester.generate_specification(func, func['verification_goals'])
if result['success']:
print(f"✅ 生成成功!")
print(f"⏱️ 时间: {result['generation_time']:.2f}s")
print(f"🔤 Token: {result['tokens_used']}")
print(f"\n📋 生成的CBMC规约:")
print("-" * 40)
print(result['specification'])
print("-" * 40)
print(f"\n📄 原始输出:")
print("-" * 40)
print(result['raw_content'])
print("-" * 40)
else:
print(f"❌ 生成失败: {result['error']}")
print(f"\n🎉 测试完成!")
except Exception as e:
print(f"❌ 测试失败: {str(e)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(main())