""" Prompt Builder单元测试 本模块针对PromptBuilder类进行单元测试,验证模板选择、上下文构建、 领域特定指导和提示格式化等功能。 """ import pytest from unittest.mock import Mock, patch from typing import Dict, Any, List from src.spec.prompt_builder import PromptBuilder from src.parser.metadata import FunctionInfo, VariableInfo, SourceLocation class TestPromptBuilder: """Prompt Builder单元测试类""" @pytest.fixture def prompt_builder(self): """创建PromptBuilder实例""" return PromptBuilder() @pytest.fixture def simple_function_info(self): """简单函数信息""" return FunctionInfo( name='add', return_type='int', parameters=[ VariableInfo( name='a', type='int', location=SourceLocation(line=1, column=10), is_pointer=False, is_array=False, array_size=None ), VariableInfo( name='b', type='int', location=SourceLocation(line=1, column=20), is_pointer=False, is_array=False, array_size=None ) ], variables=[], function_calls=[], location=SourceLocation(line=1, column=1), is_static=False, is_inline=False ) @pytest.fixture def pointer_function_info(self): """指针函数信息""" return FunctionInfo( name='swap', return_type='void', parameters=[ VariableInfo( name='a', type='int*', location=SourceLocation(line=1, column=15), is_pointer=True, is_array=False, array_size=None ), VariableInfo( name='b', type='int*', location=SourceLocation(line=1, column=25), is_pointer=True, is_array=False, array_size=None ) ], variables=[], function_calls=[], location=SourceLocation(line=1, column=1), is_static=False, is_inline=False ) @pytest.fixture def array_function_info(self): """数组函数信息""" return FunctionInfo( name='array_sum', return_type='int', parameters=[ VariableInfo( name='array', type='int*', location=SourceLocation(line=1, column=15), is_pointer=True, is_array=True, array_size=None ), VariableInfo( name='size', type='int', location=SourceLocation(line=1, column=30), is_pointer=False, is_array=False, array_size=None ) ], variables=[], function_calls=[], location=SourceLocation(line=1, column=1), is_static=False, is_inline=False ) @pytest.fixture def freertos_function_info(self): """FreeRTOS函数信息""" return FunctionInfo( name='vTaskFunction', return_type='void', parameters=[ VariableInfo( name='pvParameters', type='void*', location=SourceLocation(line=1, column=20), is_pointer=True, is_array=False, array_size=None ) ], variables=[], function_calls=[], location=SourceLocation(line=1, column=1), is_static=False, is_inline=False ) def test_prompt_builder_initialization(self, prompt_builder): """测试PromptBuilder初始化""" assert prompt_builder is not None assert hasattr(prompt_builder, 'templates') assert hasattr(prompt_builder, 'system_role') def test_build_basic_prompt(self, prompt_builder, simple_function_info): """测试构建基本提示""" prompt = prompt_builder.build_prompt( function_info=simple_function_info.to_dict(), verification_goals=['functional_correctness', 'memory_safety'] ) assert prompt is not None assert len(prompt) > 0 assert 'add' in prompt assert 'int' in prompt assert 'functional_correctness' in prompt.lower() def test_build_pointer_prompt(self, prompt_builder, pointer_function_info): """测试构建指针函数提示""" prompt = prompt_builder.build_prompt( function_info=pointer_function_info.to_dict(), verification_goals=['memory_safety', 'pointer_validity'] ) assert prompt is not None assert 'swap' in prompt assert 'pointer' in prompt.lower() assert 'memory_safety' in prompt.lower() def test_build_array_prompt(self, prompt_builder, array_function_info): """测试构建数组函数提示""" prompt = prompt_builder.build_prompt( function_info=array_function_info.to_dict(), verification_goals=['array_bounds', 'memory_safety'] ) assert prompt is not None assert 'array_sum' in prompt assert 'array' in prompt.lower() assert 'bounds' in prompt.lower() def test_build_freertos_prompt(self, prompt_builder, freertos_function_info): """测试构建FreeRTOS函数提示""" prompt = prompt_builder.build_prompt( function_info=freertos_function_info.to_dict(), verification_goals=['task_safety', 'concurrency'] ) assert prompt is not None assert 'vTaskFunction' in prompt assert 'freertos' in prompt.lower() assert 'task' in prompt.lower() def test_build_prompt_with_context(self, prompt_builder, simple_function_info): """测试带上下文的提示构建""" context = { 'project_type': 'embedded_system', 'safety_critical': True, 'coding_standards': ['MISRA', 'ISO26262'] } prompt = prompt_builder.build_prompt( function_info=simple_function_info.to_dict(), verification_goals=['functional_correctness'], context=context ) assert prompt is not None assert 'embedded' in prompt.lower() assert 'safety' in prompt.lower() def test_build_prompt_with_hints(self, prompt_builder, simple_function_info): """测试带提示的提示构建""" hints = [ "Check for integer overflow", "Validate input parameters", "Consider edge cases" ] prompt = prompt_builder.build_prompt( function_info=simple_function_info.to_dict(), verification_goals=['functional_correctness'], hints=hints ) assert prompt is not None assert 'integer overflow' in prompt.lower() assert 'input parameters' in prompt.lower() assert 'edge cases' in prompt.lower() def test_template_selection(self, prompt_builder): """测试模板选择功能""" # 测试基本函数模板 basic_template = prompt_builder._select_template('basic') assert basic_template is not None assert 'function' in basic_template.lower() # 测试指针函数模板 pointer_template = prompt_builder._select_template('pointer') assert pointer_template is not None assert 'pointer' in pointer_template.lower() # 测试数组函数模板 array_template = prompt_builder._select_template('array') assert array_template is not None assert 'array' in array_template.lower() # 测试默认模板 default_template = prompt_builder._select_template('unknown') assert default_template is not None def test_context_building(self, prompt_builder, simple_function_info): """测试上下文构建功能""" context = prompt_builder._build_context( function_info=simple_function_info.to_dict(), verification_goals=['functional_correctness', 'memory_safety'], context={'project_type': 'safety_critical'}, hints=['Check bounds'] ) assert context is not None assert 'function_name' in context assert 'verification_goals' in context assert 'project_type' in context assert 'hints' in context def test_system_role_inclusion(self, prompt_builder, simple_function_info): """测试系统角色包含""" with patch.object(prompt_builder, 'system_role', 'You are a formal verification expert.'): prompt = prompt_builder.build_prompt( function_info=simple_function_info.to_dict(), verification_goals=['functional_correctness'] ) assert 'formal verification expert' in prompt.lower() def test_multiple_verification_goals(self, prompt_builder, simple_function_info): """测试多个验证目标""" goals = [ 'functional_correctness', 'memory_safety', 'integer_overflow', 'null_pointer', 'array_bounds' ] prompt = prompt_builder.build_prompt( function_info=simple_function_info.to_dict(), verification_goals=goals ) assert prompt is not None for goal in goals: assert goal.replace('_', ' ') in prompt.lower() def test_complex_function_parameters(self, prompt_builder): """测试复杂函数参数""" complex_function = FunctionInfo( name='complex_function', return_type='struct Result*', parameters=[ VariableInfo( name='input_array', type='const int**', location=SourceLocation(line=1, column=20), is_pointer=True, is_array=True, array_size=None ), VariableInfo( name='callback', type='int (*)(int, void*)', location=SourceLocation(line=1, column=40), is_pointer=True, is_array=False, array_size=None ) ], variables=[], function_calls=[], location=SourceLocation(line=1, column=1), is_static=True, is_inline=False ) prompt = prompt_builder.build_prompt( function_info=complex_function.to_dict(), verification_goals=['memory_safety', 'functional_correctness'] ) assert prompt is not None assert 'complex_function' in prompt assert 'callback' in prompt.lower() def test_empty_verification_goals(self, prompt_builder, simple_function_info): """测试空验证目标""" prompt = prompt_builder.build_prompt( function_info=simple_function_info.to_dict(), verification_goals=[] ) assert prompt is not None assert 'add' in prompt def test_special_characters_in_prompt(self, prompt_builder): """测试提示中的特殊字符""" special_function = FunctionInfo( name='special_chars', return_type='int', parameters=[ VariableInfo( name='str', type='const char*', location=SourceLocation(line=1, column=15), is_pointer=True, is_array=False, array_size=None ) ], variables=[], function_calls=[], location=SourceLocation(line=1, column=1), is_static=False, is_inline=False ) prompt = prompt_builder.build_prompt( function_info=special_function.to_dict(), verification_goals=['string_safety', 'buffer_overflow'] ) assert prompt is not None assert 'special_chars' in prompt assert 'string' in prompt.lower() def test_prompt_validation(self, prompt_builder): """测试提示验证""" # 测试有效的提示 valid_prompt = "Generate CBMC specification for function test" is_valid = prompt_builder._validate_prompt(valid_prompt) assert is_valid is True # 测试无效的提示(过短) invalid_prompt = "" is_valid = prompt_builder._validate_prompt(invalid_prompt) assert is_valid is False # 测试无效的提示(仅包含空白字符) invalid_prompt = " \\n \\t " is_valid = prompt_builder._validate_prompt(invalid_prompt) assert is_valid is False def test_refinement_prompt_building(self, prompt_builder, simple_function_info): """测试细化提示构建""" original_spec = "\\requires a >= 0;\\n\\ensures return == a + b;" issues = ["Missing overflow check", "Insufficient parameter validation"] suggestions = ["Add overflow check", "Add parameter validation"] prompt = prompt_builder.build_refinement_prompt( original_spec, simple_function_info.to_dict(), issues, suggestions ) assert prompt is not None assert 'refinement' in prompt.lower() assert 'overflow check' in prompt.lower() assert 'parameter validation' in prompt.lower() def test_prompt_length_optimization(self, prompt_builder, simple_function_info): """测试提示长度优化""" # 构建一个可能很长的提示 long_hints = [ f"Check for issue {i}" for i in range(100) ] prompt = prompt_builder.build_prompt( function_info=simple_function_info.to_dict(), verification_goals=['functional_correctness'], hints=long_hints ) assert prompt is not None # 提示应该仍然有效,但可能被截断或优化 def test_error_handling(self, prompt_builder): """测试错误处理""" # 测试None函数信息 with pytest.raises(Exception): prompt_builder.build_prompt(None, ['functional_correctness']) # 测试无效的函数信息 invalid_function_info = {} prompt = prompt_builder.build_prompt(invalid_function_info, ['functional_correctness']) # 应该优雅地处理,而不抛出异常 assert prompt is not None def test_template_customization(self, prompt_builder): """测试模板自定义""" # 测试自定义模板 custom_template = "Custom template for {function_name}" prompt_builder.templates['custom'] = custom_template prompt = prompt_builder._select_template('custom') assert prompt == custom_template def test_domain_specific_knowledge(self, prompt_builder, freertos_function_info): """测试领域特定知识""" prompt = prompt_builder.build_prompt( function_info=freertos_function_info.to_dict(), verification_goals=['task_safety', 'concurrency'] ) # 应该包含FreeRTOS特定的知识 assert 'freertos' in prompt.lower() assert 'task' in prompt.lower() assert 'concurrency' in prompt.lower() def test_prompt_consistency(self, prompt_builder, simple_function_info): """测试提示一致性""" # 多次构建相同函数的提示应该一致 prompt1 = prompt_builder.build_prompt( function_info=simple_function_info.to_dict(), verification_goals=['functional_correctness'] ) prompt2 = prompt_builder.build_prompt( function_info=simple_function_info.to_dict(), verification_goals=['functional_correctness'] ) # 提示应该基本相同(可能因时间戳等不同) assert 'add' in prompt1 assert 'add' in prompt2 assert 'functional_correctness' in prompt1.lower() assert 'functional_correctness' in prompt2.lower() def test_prompt_for_void_functions(self, prompt_builder): """测试void函数的提示构建""" void_function = FunctionInfo( name='print_message', return_type='void', parameters=[ VariableInfo( name='message', type='const char*', location=SourceLocation(line=1, column=20), is_pointer=True, is_array=False, array_size=None ) ], variables=[], function_calls=[], location=SourceLocation(line=1, column=1), is_static=False, is_inline=False ) prompt = prompt_builder.build_prompt( function_info=void_function.to_dict(), verification_goals=['memory_safety'] ) assert prompt is not None assert 'print_message' in prompt assert 'void' in prompt assert 'memory_safety' in prompt.lower() if __name__ == "__main__": pytest.main([__file__, "-v"])