You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
cbmc/codedetect/tests/unit/test_prompt_builder.py

511 lines
18 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
Prompt Builder单元测试
本模块针对PromptBuilder类进行单元测试验证模板选择、上下文构建、
领域特定指导和提示格式化等功能。
"""
import pytest
from unittest.mock import Mock, patch
from typing import Dict, Any, List
from src.spec.prompt_builder import PromptBuilder
from src.parser.metadata import FunctionInfo, VariableInfo, SourceLocation
class TestPromptBuilder:
"""Prompt Builder单元测试类"""
@pytest.fixture
def prompt_builder(self):
"""创建PromptBuilder实例"""
return PromptBuilder()
@pytest.fixture
def simple_function_info(self):
"""简单函数信息"""
return FunctionInfo(
name='add',
return_type='int',
parameters=[
VariableInfo(
name='a',
type='int',
location=SourceLocation(line=1, column=10),
is_pointer=False,
is_array=False,
array_size=None
),
VariableInfo(
name='b',
type='int',
location=SourceLocation(line=1, column=20),
is_pointer=False,
is_array=False,
array_size=None
)
],
variables=[],
function_calls=[],
location=SourceLocation(line=1, column=1),
is_static=False,
is_inline=False
)
@pytest.fixture
def pointer_function_info(self):
"""指针函数信息"""
return FunctionInfo(
name='swap',
return_type='void',
parameters=[
VariableInfo(
name='a',
type='int*',
location=SourceLocation(line=1, column=15),
is_pointer=True,
is_array=False,
array_size=None
),
VariableInfo(
name='b',
type='int*',
location=SourceLocation(line=1, column=25),
is_pointer=True,
is_array=False,
array_size=None
)
],
variables=[],
function_calls=[],
location=SourceLocation(line=1, column=1),
is_static=False,
is_inline=False
)
@pytest.fixture
def array_function_info(self):
"""数组函数信息"""
return FunctionInfo(
name='array_sum',
return_type='int',
parameters=[
VariableInfo(
name='array',
type='int*',
location=SourceLocation(line=1, column=15),
is_pointer=True,
is_array=True,
array_size=None
),
VariableInfo(
name='size',
type='int',
location=SourceLocation(line=1, column=30),
is_pointer=False,
is_array=False,
array_size=None
)
],
variables=[],
function_calls=[],
location=SourceLocation(line=1, column=1),
is_static=False,
is_inline=False
)
@pytest.fixture
def freertos_function_info(self):
"""FreeRTOS函数信息"""
return FunctionInfo(
name='vTaskFunction',
return_type='void',
parameters=[
VariableInfo(
name='pvParameters',
type='void*',
location=SourceLocation(line=1, column=20),
is_pointer=True,
is_array=False,
array_size=None
)
],
variables=[],
function_calls=[],
location=SourceLocation(line=1, column=1),
is_static=False,
is_inline=False
)
def test_prompt_builder_initialization(self, prompt_builder):
"""测试PromptBuilder初始化"""
assert prompt_builder is not None
assert hasattr(prompt_builder, 'templates')
assert hasattr(prompt_builder, 'system_role')
def test_build_basic_prompt(self, prompt_builder, simple_function_info):
"""测试构建基本提示"""
prompt = prompt_builder.build_prompt(
function_info=simple_function_info.to_dict(),
verification_goals=['functional_correctness', 'memory_safety']
)
assert prompt is not None
assert len(prompt) > 0
assert 'add' in prompt
assert 'int' in prompt
assert 'functional_correctness' in prompt.lower()
def test_build_pointer_prompt(self, prompt_builder, pointer_function_info):
"""测试构建指针函数提示"""
prompt = prompt_builder.build_prompt(
function_info=pointer_function_info.to_dict(),
verification_goals=['memory_safety', 'pointer_validity']
)
assert prompt is not None
assert 'swap' in prompt
assert 'pointer' in prompt.lower()
assert 'memory_safety' in prompt.lower()
def test_build_array_prompt(self, prompt_builder, array_function_info):
"""测试构建数组函数提示"""
prompt = prompt_builder.build_prompt(
function_info=array_function_info.to_dict(),
verification_goals=['array_bounds', 'memory_safety']
)
assert prompt is not None
assert 'array_sum' in prompt
assert 'array' in prompt.lower()
assert 'bounds' in prompt.lower()
def test_build_freertos_prompt(self, prompt_builder, freertos_function_info):
"""测试构建FreeRTOS函数提示"""
prompt = prompt_builder.build_prompt(
function_info=freertos_function_info.to_dict(),
verification_goals=['task_safety', 'concurrency']
)
assert prompt is not None
assert 'vTaskFunction' in prompt
assert 'freertos' in prompt.lower()
assert 'task' in prompt.lower()
def test_build_prompt_with_context(self, prompt_builder, simple_function_info):
"""测试带上下文的提示构建"""
context = {
'project_type': 'embedded_system',
'safety_critical': True,
'coding_standards': ['MISRA', 'ISO26262']
}
prompt = prompt_builder.build_prompt(
function_info=simple_function_info.to_dict(),
verification_goals=['functional_correctness'],
context=context
)
assert prompt is not None
assert 'embedded' in prompt.lower()
assert 'safety' in prompt.lower()
def test_build_prompt_with_hints(self, prompt_builder, simple_function_info):
"""测试带提示的提示构建"""
hints = [
"Check for integer overflow",
"Validate input parameters",
"Consider edge cases"
]
prompt = prompt_builder.build_prompt(
function_info=simple_function_info.to_dict(),
verification_goals=['functional_correctness'],
hints=hints
)
assert prompt is not None
assert 'integer overflow' in prompt.lower()
assert 'input parameters' in prompt.lower()
assert 'edge cases' in prompt.lower()
def test_template_selection(self, prompt_builder):
"""测试模板选择功能"""
# 测试基本函数模板
basic_template = prompt_builder._select_template('basic')
assert basic_template is not None
assert 'function' in basic_template.lower()
# 测试指针函数模板
pointer_template = prompt_builder._select_template('pointer')
assert pointer_template is not None
assert 'pointer' in pointer_template.lower()
# 测试数组函数模板
array_template = prompt_builder._select_template('array')
assert array_template is not None
assert 'array' in array_template.lower()
# 测试默认模板
default_template = prompt_builder._select_template('unknown')
assert default_template is not None
def test_context_building(self, prompt_builder, simple_function_info):
"""测试上下文构建功能"""
context = prompt_builder._build_context(
function_info=simple_function_info.to_dict(),
verification_goals=['functional_correctness', 'memory_safety'],
context={'project_type': 'safety_critical'},
hints=['Check bounds']
)
assert context is not None
assert 'function_name' in context
assert 'verification_goals' in context
assert 'project_type' in context
assert 'hints' in context
def test_system_role_inclusion(self, prompt_builder, simple_function_info):
"""测试系统角色包含"""
with patch.object(prompt_builder, 'system_role', 'You are a formal verification expert.'):
prompt = prompt_builder.build_prompt(
function_info=simple_function_info.to_dict(),
verification_goals=['functional_correctness']
)
assert 'formal verification expert' in prompt.lower()
def test_multiple_verification_goals(self, prompt_builder, simple_function_info):
"""测试多个验证目标"""
goals = [
'functional_correctness',
'memory_safety',
'integer_overflow',
'null_pointer',
'array_bounds'
]
prompt = prompt_builder.build_prompt(
function_info=simple_function_info.to_dict(),
verification_goals=goals
)
assert prompt is not None
for goal in goals:
assert goal.replace('_', ' ') in prompt.lower()
def test_complex_function_parameters(self, prompt_builder):
"""测试复杂函数参数"""
complex_function = FunctionInfo(
name='complex_function',
return_type='struct Result*',
parameters=[
VariableInfo(
name='input_array',
type='const int**',
location=SourceLocation(line=1, column=20),
is_pointer=True,
is_array=True,
array_size=None
),
VariableInfo(
name='callback',
type='int (*)(int, void*)',
location=SourceLocation(line=1, column=40),
is_pointer=True,
is_array=False,
array_size=None
)
],
variables=[],
function_calls=[],
location=SourceLocation(line=1, column=1),
is_static=True,
is_inline=False
)
prompt = prompt_builder.build_prompt(
function_info=complex_function.to_dict(),
verification_goals=['memory_safety', 'functional_correctness']
)
assert prompt is not None
assert 'complex_function' in prompt
assert 'callback' in prompt.lower()
def test_empty_verification_goals(self, prompt_builder, simple_function_info):
"""测试空验证目标"""
prompt = prompt_builder.build_prompt(
function_info=simple_function_info.to_dict(),
verification_goals=[]
)
assert prompt is not None
assert 'add' in prompt
def test_special_characters_in_prompt(self, prompt_builder):
"""测试提示中的特殊字符"""
special_function = FunctionInfo(
name='special_chars',
return_type='int',
parameters=[
VariableInfo(
name='str',
type='const char*',
location=SourceLocation(line=1, column=15),
is_pointer=True,
is_array=False,
array_size=None
)
],
variables=[],
function_calls=[],
location=SourceLocation(line=1, column=1),
is_static=False,
is_inline=False
)
prompt = prompt_builder.build_prompt(
function_info=special_function.to_dict(),
verification_goals=['string_safety', 'buffer_overflow']
)
assert prompt is not None
assert 'special_chars' in prompt
assert 'string' in prompt.lower()
def test_prompt_validation(self, prompt_builder):
"""测试提示验证"""
# 测试有效的提示
valid_prompt = "Generate CBMC specification for function test"
is_valid = prompt_builder._validate_prompt(valid_prompt)
assert is_valid is True
# 测试无效的提示(过短)
invalid_prompt = ""
is_valid = prompt_builder._validate_prompt(invalid_prompt)
assert is_valid is False
# 测试无效的提示(仅包含空白字符)
invalid_prompt = " \\n \\t "
is_valid = prompt_builder._validate_prompt(invalid_prompt)
assert is_valid is False
def test_refinement_prompt_building(self, prompt_builder, simple_function_info):
"""测试细化提示构建"""
original_spec = "\\requires a >= 0;\\n\\ensures return == a + b;"
issues = ["Missing overflow check", "Insufficient parameter validation"]
suggestions = ["Add overflow check", "Add parameter validation"]
prompt = prompt_builder.build_refinement_prompt(
original_spec, simple_function_info.to_dict(), issues, suggestions
)
assert prompt is not None
assert 'refinement' in prompt.lower()
assert 'overflow check' in prompt.lower()
assert 'parameter validation' in prompt.lower()
def test_prompt_length_optimization(self, prompt_builder, simple_function_info):
"""测试提示长度优化"""
# 构建一个可能很长的提示
long_hints = [
f"Check for issue {i}" for i in range(100)
]
prompt = prompt_builder.build_prompt(
function_info=simple_function_info.to_dict(),
verification_goals=['functional_correctness'],
hints=long_hints
)
assert prompt is not None
# 提示应该仍然有效,但可能被截断或优化
def test_error_handling(self, prompt_builder):
"""测试错误处理"""
# 测试None函数信息
with pytest.raises(Exception):
prompt_builder.build_prompt(None, ['functional_correctness'])
# 测试无效的函数信息
invalid_function_info = {}
prompt = prompt_builder.build_prompt(invalid_function_info, ['functional_correctness'])
# 应该优雅地处理,而不抛出异常
assert prompt is not None
def test_template_customization(self, prompt_builder):
"""测试模板自定义"""
# 测试自定义模板
custom_template = "Custom template for {function_name}"
prompt_builder.templates['custom'] = custom_template
prompt = prompt_builder._select_template('custom')
assert prompt == custom_template
def test_domain_specific_knowledge(self, prompt_builder, freertos_function_info):
"""测试领域特定知识"""
prompt = prompt_builder.build_prompt(
function_info=freertos_function_info.to_dict(),
verification_goals=['task_safety', 'concurrency']
)
# 应该包含FreeRTOS特定的知识
assert 'freertos' in prompt.lower()
assert 'task' in prompt.lower()
assert 'concurrency' in prompt.lower()
def test_prompt_consistency(self, prompt_builder, simple_function_info):
"""测试提示一致性"""
# 多次构建相同函数的提示应该一致
prompt1 = prompt_builder.build_prompt(
function_info=simple_function_info.to_dict(),
verification_goals=['functional_correctness']
)
prompt2 = prompt_builder.build_prompt(
function_info=simple_function_info.to_dict(),
verification_goals=['functional_correctness']
)
# 提示应该基本相同(可能因时间戳等不同)
assert 'add' in prompt1
assert 'add' in prompt2
assert 'functional_correctness' in prompt1.lower()
assert 'functional_correctness' in prompt2.lower()
def test_prompt_for_void_functions(self, prompt_builder):
"""测试void函数的提示构建"""
void_function = FunctionInfo(
name='print_message',
return_type='void',
parameters=[
VariableInfo(
name='message',
type='const char*',
location=SourceLocation(line=1, column=20),
is_pointer=True,
is_array=False,
array_size=None
)
],
variables=[],
function_calls=[],
location=SourceLocation(line=1, column=1),
is_static=False,
is_inline=False
)
prompt = prompt_builder.build_prompt(
function_info=void_function.to_dict(),
verification_goals=['memory_safety']
)
assert prompt is not None
assert 'print_message' in prompt
assert 'void' in prompt
assert 'memory_safety' in prompt.lower()
if __name__ == "__main__":
pytest.main([__file__, "-v"])