|
|
"""
|
|
|
Prompt Builder单元测试
|
|
|
|
|
|
本模块针对PromptBuilder类进行单元测试,验证模板选择、上下文构建、
|
|
|
领域特定指导和提示格式化等功能。
|
|
|
"""
|
|
|
|
|
|
import pytest
|
|
|
from unittest.mock import Mock, patch
|
|
|
from typing import Dict, Any, List
|
|
|
|
|
|
from src.spec.prompt_builder import PromptBuilder
|
|
|
from src.parser.metadata import FunctionInfo, VariableInfo, SourceLocation
|
|
|
|
|
|
|
|
|
class TestPromptBuilder:
|
|
|
"""Prompt Builder单元测试类"""
|
|
|
|
|
|
@pytest.fixture
|
|
|
def prompt_builder(self):
|
|
|
"""创建PromptBuilder实例"""
|
|
|
return PromptBuilder()
|
|
|
|
|
|
@pytest.fixture
|
|
|
def simple_function_info(self):
|
|
|
"""简单函数信息"""
|
|
|
return FunctionInfo(
|
|
|
name='add',
|
|
|
return_type='int',
|
|
|
parameters=[
|
|
|
VariableInfo(
|
|
|
name='a',
|
|
|
type='int',
|
|
|
location=SourceLocation(line=1, column=10),
|
|
|
is_pointer=False,
|
|
|
is_array=False,
|
|
|
array_size=None
|
|
|
),
|
|
|
VariableInfo(
|
|
|
name='b',
|
|
|
type='int',
|
|
|
location=SourceLocation(line=1, column=20),
|
|
|
is_pointer=False,
|
|
|
is_array=False,
|
|
|
array_size=None
|
|
|
)
|
|
|
],
|
|
|
variables=[],
|
|
|
function_calls=[],
|
|
|
location=SourceLocation(line=1, column=1),
|
|
|
is_static=False,
|
|
|
is_inline=False
|
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
|
def pointer_function_info(self):
|
|
|
"""指针函数信息"""
|
|
|
return FunctionInfo(
|
|
|
name='swap',
|
|
|
return_type='void',
|
|
|
parameters=[
|
|
|
VariableInfo(
|
|
|
name='a',
|
|
|
type='int*',
|
|
|
location=SourceLocation(line=1, column=15),
|
|
|
is_pointer=True,
|
|
|
is_array=False,
|
|
|
array_size=None
|
|
|
),
|
|
|
VariableInfo(
|
|
|
name='b',
|
|
|
type='int*',
|
|
|
location=SourceLocation(line=1, column=25),
|
|
|
is_pointer=True,
|
|
|
is_array=False,
|
|
|
array_size=None
|
|
|
)
|
|
|
],
|
|
|
variables=[],
|
|
|
function_calls=[],
|
|
|
location=SourceLocation(line=1, column=1),
|
|
|
is_static=False,
|
|
|
is_inline=False
|
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
|
def array_function_info(self):
|
|
|
"""数组函数信息"""
|
|
|
return FunctionInfo(
|
|
|
name='array_sum',
|
|
|
return_type='int',
|
|
|
parameters=[
|
|
|
VariableInfo(
|
|
|
name='array',
|
|
|
type='int*',
|
|
|
location=SourceLocation(line=1, column=15),
|
|
|
is_pointer=True,
|
|
|
is_array=True,
|
|
|
array_size=None
|
|
|
),
|
|
|
VariableInfo(
|
|
|
name='size',
|
|
|
type='int',
|
|
|
location=SourceLocation(line=1, column=30),
|
|
|
is_pointer=False,
|
|
|
is_array=False,
|
|
|
array_size=None
|
|
|
)
|
|
|
],
|
|
|
variables=[],
|
|
|
function_calls=[],
|
|
|
location=SourceLocation(line=1, column=1),
|
|
|
is_static=False,
|
|
|
is_inline=False
|
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
|
def freertos_function_info(self):
|
|
|
"""FreeRTOS函数信息"""
|
|
|
return FunctionInfo(
|
|
|
name='vTaskFunction',
|
|
|
return_type='void',
|
|
|
parameters=[
|
|
|
VariableInfo(
|
|
|
name='pvParameters',
|
|
|
type='void*',
|
|
|
location=SourceLocation(line=1, column=20),
|
|
|
is_pointer=True,
|
|
|
is_array=False,
|
|
|
array_size=None
|
|
|
)
|
|
|
],
|
|
|
variables=[],
|
|
|
function_calls=[],
|
|
|
location=SourceLocation(line=1, column=1),
|
|
|
is_static=False,
|
|
|
is_inline=False
|
|
|
)
|
|
|
|
|
|
def test_prompt_builder_initialization(self, prompt_builder):
|
|
|
"""测试PromptBuilder初始化"""
|
|
|
assert prompt_builder is not None
|
|
|
assert hasattr(prompt_builder, 'templates')
|
|
|
assert hasattr(prompt_builder, 'system_role')
|
|
|
|
|
|
def test_build_basic_prompt(self, prompt_builder, simple_function_info):
|
|
|
"""测试构建基本提示"""
|
|
|
prompt = prompt_builder.build_prompt(
|
|
|
function_info=simple_function_info.to_dict(),
|
|
|
verification_goals=['functional_correctness', 'memory_safety']
|
|
|
)
|
|
|
|
|
|
assert prompt is not None
|
|
|
assert len(prompt) > 0
|
|
|
assert 'add' in prompt
|
|
|
assert 'int' in prompt
|
|
|
assert 'functional_correctness' in prompt.lower()
|
|
|
|
|
|
def test_build_pointer_prompt(self, prompt_builder, pointer_function_info):
|
|
|
"""测试构建指针函数提示"""
|
|
|
prompt = prompt_builder.build_prompt(
|
|
|
function_info=pointer_function_info.to_dict(),
|
|
|
verification_goals=['memory_safety', 'pointer_validity']
|
|
|
)
|
|
|
|
|
|
assert prompt is not None
|
|
|
assert 'swap' in prompt
|
|
|
assert 'pointer' in prompt.lower()
|
|
|
assert 'memory_safety' in prompt.lower()
|
|
|
|
|
|
def test_build_array_prompt(self, prompt_builder, array_function_info):
|
|
|
"""测试构建数组函数提示"""
|
|
|
prompt = prompt_builder.build_prompt(
|
|
|
function_info=array_function_info.to_dict(),
|
|
|
verification_goals=['array_bounds', 'memory_safety']
|
|
|
)
|
|
|
|
|
|
assert prompt is not None
|
|
|
assert 'array_sum' in prompt
|
|
|
assert 'array' in prompt.lower()
|
|
|
assert 'bounds' in prompt.lower()
|
|
|
|
|
|
def test_build_freertos_prompt(self, prompt_builder, freertos_function_info):
|
|
|
"""测试构建FreeRTOS函数提示"""
|
|
|
prompt = prompt_builder.build_prompt(
|
|
|
function_info=freertos_function_info.to_dict(),
|
|
|
verification_goals=['task_safety', 'concurrency']
|
|
|
)
|
|
|
|
|
|
assert prompt is not None
|
|
|
assert 'vTaskFunction' in prompt
|
|
|
assert 'freertos' in prompt.lower()
|
|
|
assert 'task' in prompt.lower()
|
|
|
|
|
|
def test_build_prompt_with_context(self, prompt_builder, simple_function_info):
|
|
|
"""测试带上下文的提示构建"""
|
|
|
context = {
|
|
|
'project_type': 'embedded_system',
|
|
|
'safety_critical': True,
|
|
|
'coding_standards': ['MISRA', 'ISO26262']
|
|
|
}
|
|
|
|
|
|
prompt = prompt_builder.build_prompt(
|
|
|
function_info=simple_function_info.to_dict(),
|
|
|
verification_goals=['functional_correctness'],
|
|
|
context=context
|
|
|
)
|
|
|
|
|
|
assert prompt is not None
|
|
|
assert 'embedded' in prompt.lower()
|
|
|
assert 'safety' in prompt.lower()
|
|
|
|
|
|
def test_build_prompt_with_hints(self, prompt_builder, simple_function_info):
|
|
|
"""测试带提示的提示构建"""
|
|
|
hints = [
|
|
|
"Check for integer overflow",
|
|
|
"Validate input parameters",
|
|
|
"Consider edge cases"
|
|
|
]
|
|
|
|
|
|
prompt = prompt_builder.build_prompt(
|
|
|
function_info=simple_function_info.to_dict(),
|
|
|
verification_goals=['functional_correctness'],
|
|
|
hints=hints
|
|
|
)
|
|
|
|
|
|
assert prompt is not None
|
|
|
assert 'integer overflow' in prompt.lower()
|
|
|
assert 'input parameters' in prompt.lower()
|
|
|
assert 'edge cases' in prompt.lower()
|
|
|
|
|
|
def test_template_selection(self, prompt_builder):
|
|
|
"""测试模板选择功能"""
|
|
|
# 测试基本函数模板
|
|
|
basic_template = prompt_builder._select_template('basic')
|
|
|
assert basic_template is not None
|
|
|
assert 'function' in basic_template.lower()
|
|
|
|
|
|
# 测试指针函数模板
|
|
|
pointer_template = prompt_builder._select_template('pointer')
|
|
|
assert pointer_template is not None
|
|
|
assert 'pointer' in pointer_template.lower()
|
|
|
|
|
|
# 测试数组函数模板
|
|
|
array_template = prompt_builder._select_template('array')
|
|
|
assert array_template is not None
|
|
|
assert 'array' in array_template.lower()
|
|
|
|
|
|
# 测试默认模板
|
|
|
default_template = prompt_builder._select_template('unknown')
|
|
|
assert default_template is not None
|
|
|
|
|
|
def test_context_building(self, prompt_builder, simple_function_info):
|
|
|
"""测试上下文构建功能"""
|
|
|
context = prompt_builder._build_context(
|
|
|
function_info=simple_function_info.to_dict(),
|
|
|
verification_goals=['functional_correctness', 'memory_safety'],
|
|
|
context={'project_type': 'safety_critical'},
|
|
|
hints=['Check bounds']
|
|
|
)
|
|
|
|
|
|
assert context is not None
|
|
|
assert 'function_name' in context
|
|
|
assert 'verification_goals' in context
|
|
|
assert 'project_type' in context
|
|
|
assert 'hints' in context
|
|
|
|
|
|
def test_system_role_inclusion(self, prompt_builder, simple_function_info):
|
|
|
"""测试系统角色包含"""
|
|
|
with patch.object(prompt_builder, 'system_role', 'You are a formal verification expert.'):
|
|
|
prompt = prompt_builder.build_prompt(
|
|
|
function_info=simple_function_info.to_dict(),
|
|
|
verification_goals=['functional_correctness']
|
|
|
)
|
|
|
|
|
|
assert 'formal verification expert' in prompt.lower()
|
|
|
|
|
|
def test_multiple_verification_goals(self, prompt_builder, simple_function_info):
|
|
|
"""测试多个验证目标"""
|
|
|
goals = [
|
|
|
'functional_correctness',
|
|
|
'memory_safety',
|
|
|
'integer_overflow',
|
|
|
'null_pointer',
|
|
|
'array_bounds'
|
|
|
]
|
|
|
|
|
|
prompt = prompt_builder.build_prompt(
|
|
|
function_info=simple_function_info.to_dict(),
|
|
|
verification_goals=goals
|
|
|
)
|
|
|
|
|
|
assert prompt is not None
|
|
|
for goal in goals:
|
|
|
assert goal.replace('_', ' ') in prompt.lower()
|
|
|
|
|
|
def test_complex_function_parameters(self, prompt_builder):
|
|
|
"""测试复杂函数参数"""
|
|
|
complex_function = FunctionInfo(
|
|
|
name='complex_function',
|
|
|
return_type='struct Result*',
|
|
|
parameters=[
|
|
|
VariableInfo(
|
|
|
name='input_array',
|
|
|
type='const int**',
|
|
|
location=SourceLocation(line=1, column=20),
|
|
|
is_pointer=True,
|
|
|
is_array=True,
|
|
|
array_size=None
|
|
|
),
|
|
|
VariableInfo(
|
|
|
name='callback',
|
|
|
type='int (*)(int, void*)',
|
|
|
location=SourceLocation(line=1, column=40),
|
|
|
is_pointer=True,
|
|
|
is_array=False,
|
|
|
array_size=None
|
|
|
)
|
|
|
],
|
|
|
variables=[],
|
|
|
function_calls=[],
|
|
|
location=SourceLocation(line=1, column=1),
|
|
|
is_static=True,
|
|
|
is_inline=False
|
|
|
)
|
|
|
|
|
|
prompt = prompt_builder.build_prompt(
|
|
|
function_info=complex_function.to_dict(),
|
|
|
verification_goals=['memory_safety', 'functional_correctness']
|
|
|
)
|
|
|
|
|
|
assert prompt is not None
|
|
|
assert 'complex_function' in prompt
|
|
|
assert 'callback' in prompt.lower()
|
|
|
|
|
|
def test_empty_verification_goals(self, prompt_builder, simple_function_info):
|
|
|
"""测试空验证目标"""
|
|
|
prompt = prompt_builder.build_prompt(
|
|
|
function_info=simple_function_info.to_dict(),
|
|
|
verification_goals=[]
|
|
|
)
|
|
|
|
|
|
assert prompt is not None
|
|
|
assert 'add' in prompt
|
|
|
|
|
|
def test_special_characters_in_prompt(self, prompt_builder):
|
|
|
"""测试提示中的特殊字符"""
|
|
|
special_function = FunctionInfo(
|
|
|
name='special_chars',
|
|
|
return_type='int',
|
|
|
parameters=[
|
|
|
VariableInfo(
|
|
|
name='str',
|
|
|
type='const char*',
|
|
|
location=SourceLocation(line=1, column=15),
|
|
|
is_pointer=True,
|
|
|
is_array=False,
|
|
|
array_size=None
|
|
|
)
|
|
|
],
|
|
|
variables=[],
|
|
|
function_calls=[],
|
|
|
location=SourceLocation(line=1, column=1),
|
|
|
is_static=False,
|
|
|
is_inline=False
|
|
|
)
|
|
|
|
|
|
prompt = prompt_builder.build_prompt(
|
|
|
function_info=special_function.to_dict(),
|
|
|
verification_goals=['string_safety', 'buffer_overflow']
|
|
|
)
|
|
|
|
|
|
assert prompt is not None
|
|
|
assert 'special_chars' in prompt
|
|
|
assert 'string' in prompt.lower()
|
|
|
|
|
|
def test_prompt_validation(self, prompt_builder):
|
|
|
"""测试提示验证"""
|
|
|
# 测试有效的提示
|
|
|
valid_prompt = "Generate CBMC specification for function test"
|
|
|
is_valid = prompt_builder._validate_prompt(valid_prompt)
|
|
|
assert is_valid is True
|
|
|
|
|
|
# 测试无效的提示(过短)
|
|
|
invalid_prompt = ""
|
|
|
is_valid = prompt_builder._validate_prompt(invalid_prompt)
|
|
|
assert is_valid is False
|
|
|
|
|
|
# 测试无效的提示(仅包含空白字符)
|
|
|
invalid_prompt = " \\n \\t "
|
|
|
is_valid = prompt_builder._validate_prompt(invalid_prompt)
|
|
|
assert is_valid is False
|
|
|
|
|
|
def test_refinement_prompt_building(self, prompt_builder, simple_function_info):
|
|
|
"""测试细化提示构建"""
|
|
|
original_spec = "\\requires a >= 0;\\n\\ensures return == a + b;"
|
|
|
issues = ["Missing overflow check", "Insufficient parameter validation"]
|
|
|
suggestions = ["Add overflow check", "Add parameter validation"]
|
|
|
|
|
|
prompt = prompt_builder.build_refinement_prompt(
|
|
|
original_spec, simple_function_info.to_dict(), issues, suggestions
|
|
|
)
|
|
|
|
|
|
assert prompt is not None
|
|
|
assert 'refinement' in prompt.lower()
|
|
|
assert 'overflow check' in prompt.lower()
|
|
|
assert 'parameter validation' in prompt.lower()
|
|
|
|
|
|
def test_prompt_length_optimization(self, prompt_builder, simple_function_info):
|
|
|
"""测试提示长度优化"""
|
|
|
# 构建一个可能很长的提示
|
|
|
long_hints = [
|
|
|
f"Check for issue {i}" for i in range(100)
|
|
|
]
|
|
|
|
|
|
prompt = prompt_builder.build_prompt(
|
|
|
function_info=simple_function_info.to_dict(),
|
|
|
verification_goals=['functional_correctness'],
|
|
|
hints=long_hints
|
|
|
)
|
|
|
|
|
|
assert prompt is not None
|
|
|
# 提示应该仍然有效,但可能被截断或优化
|
|
|
|
|
|
def test_error_handling(self, prompt_builder):
|
|
|
"""测试错误处理"""
|
|
|
# 测试None函数信息
|
|
|
with pytest.raises(Exception):
|
|
|
prompt_builder.build_prompt(None, ['functional_correctness'])
|
|
|
|
|
|
# 测试无效的函数信息
|
|
|
invalid_function_info = {}
|
|
|
prompt = prompt_builder.build_prompt(invalid_function_info, ['functional_correctness'])
|
|
|
# 应该优雅地处理,而不抛出异常
|
|
|
assert prompt is not None
|
|
|
|
|
|
def test_template_customization(self, prompt_builder):
|
|
|
"""测试模板自定义"""
|
|
|
# 测试自定义模板
|
|
|
custom_template = "Custom template for {function_name}"
|
|
|
prompt_builder.templates['custom'] = custom_template
|
|
|
|
|
|
prompt = prompt_builder._select_template('custom')
|
|
|
assert prompt == custom_template
|
|
|
|
|
|
def test_domain_specific_knowledge(self, prompt_builder, freertos_function_info):
|
|
|
"""测试领域特定知识"""
|
|
|
prompt = prompt_builder.build_prompt(
|
|
|
function_info=freertos_function_info.to_dict(),
|
|
|
verification_goals=['task_safety', 'concurrency']
|
|
|
)
|
|
|
|
|
|
# 应该包含FreeRTOS特定的知识
|
|
|
assert 'freertos' in prompt.lower()
|
|
|
assert 'task' in prompt.lower()
|
|
|
assert 'concurrency' in prompt.lower()
|
|
|
|
|
|
def test_prompt_consistency(self, prompt_builder, simple_function_info):
|
|
|
"""测试提示一致性"""
|
|
|
# 多次构建相同函数的提示应该一致
|
|
|
prompt1 = prompt_builder.build_prompt(
|
|
|
function_info=simple_function_info.to_dict(),
|
|
|
verification_goals=['functional_correctness']
|
|
|
)
|
|
|
|
|
|
prompt2 = prompt_builder.build_prompt(
|
|
|
function_info=simple_function_info.to_dict(),
|
|
|
verification_goals=['functional_correctness']
|
|
|
)
|
|
|
|
|
|
# 提示应该基本相同(可能因时间戳等不同)
|
|
|
assert 'add' in prompt1
|
|
|
assert 'add' in prompt2
|
|
|
assert 'functional_correctness' in prompt1.lower()
|
|
|
assert 'functional_correctness' in prompt2.lower()
|
|
|
|
|
|
def test_prompt_for_void_functions(self, prompt_builder):
|
|
|
"""测试void函数的提示构建"""
|
|
|
void_function = FunctionInfo(
|
|
|
name='print_message',
|
|
|
return_type='void',
|
|
|
parameters=[
|
|
|
VariableInfo(
|
|
|
name='message',
|
|
|
type='const char*',
|
|
|
location=SourceLocation(line=1, column=20),
|
|
|
is_pointer=True,
|
|
|
is_array=False,
|
|
|
array_size=None
|
|
|
)
|
|
|
],
|
|
|
variables=[],
|
|
|
function_calls=[],
|
|
|
location=SourceLocation(line=1, column=1),
|
|
|
is_static=False,
|
|
|
is_inline=False
|
|
|
)
|
|
|
|
|
|
prompt = prompt_builder.build_prompt(
|
|
|
function_info=void_function.to_dict(),
|
|
|
verification_goals=['memory_safety']
|
|
|
)
|
|
|
|
|
|
assert prompt is not None
|
|
|
assert 'print_message' in prompt
|
|
|
assert 'void' in prompt
|
|
|
assert 'memory_safety' in prompt.lower()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
pytest.main([__file__, "-v"]) |