""" AST信息提取器模块 本模块实现了从解析后的C/C++ AST中提取函数签名、全局变量、调用关系、 数据结构等元数据的功能。支持CBMC特定的分析,如指针使用模式、内存访问分析等。 """ import os import time from typing import Dict, List, Optional, Set, Union, Any, Callable from pathlib import Path from collections import defaultdict try: from pycparser import c_ast except ImportError: raise ImportError("pycparser is required. Install with: pip install pycparser==2.21") from .metadata import ( CodeMetadata, FunctionInfo, VariableInfo, CallRelation, DataStructure, ParameterInfo, StructField, MemoryAccessPattern, SourceLocation, VariableScope, MemoryAccessType, VerificationHint ) from .c_parser import CParser from ..utils.logger import get_logger class ASTVisitor(c_ast.NodeVisitor): """基础AST访问器""" def __init__(self, extractor: 'ASTExtractor'): self.extractor = extractor self.logger = extractor.logger self.current_function = None self.current_scope = [] self.source_file = None def visit(self, node: c_ast.Node) -> Any: """访问节点""" if hasattr(node, 'coord') and node.coord: self.source_file = node.coord.file return super().visit(node) def _stringify_expression(self, expr_node: Optional[c_ast.Node]) -> Optional[str]: """安全地将表达式节点转换为字符串""" if expr_node is None: return None try: if isinstance(expr_node, c_ast.Constant): return str(expr_node.value) elif isinstance(expr_node, c_ast.ID): return expr_node.name elif isinstance(expr_node, c_ast.UnaryOp): return f"{expr_node.op}{self._stringify_expression(expr_node.expr)}" elif isinstance(expr_node, c_ast.BinaryOp): left = self._stringify_expression(expr_node.left) right = self._stringify_expression(expr_node.right) return f"{left} {expr_node.op} {right}" if left and right else None elif isinstance(expr_node, c_ast.TernaryOp): cond = self._stringify_expression(expr_node.cond) true_val = self._stringify_expression(expr_node.iftrue) false_val = self._stringify_expression(expr_node.iffalse) if cond and true_val and false_val: return f"{cond} ? {true_val} : {false_val}" return None else: # 对于其他类型的表达式,返回None而不是尝试字符串化 self.logger.debug(f"Unsupported expression type: {type(expr_node).__name__}") return None except Exception as e: self.logger.debug(f"Error stringifying expression: {e}") return None class FunctionExtractor(ASTVisitor): """函数提取器""" def __init__(self, extractor: 'ASTExtractor'): super().__init__(extractor) self.functions = {} def visit_FuncDef(self, node: c_ast.FuncDef) -> None: """访问函数定义""" function_info = self._extract_function_info(node) if function_info: self.functions[function_info.name] = function_info self.extractor.metadata.add_function(function_info) # 继续访问函数体 self.current_function = function_info self.generic_visit(node) self.current_function = None def visit_Decl(self, node: c_ast.Decl) -> None: """访问声明(包括函数声明)""" if isinstance(node.type, c_ast.FuncDecl) and not hasattr(node, 'init'): # 这是函数声明,不是定义 function_info = self._extract_function_decl_info(node) if function_info: self.functions[function_info.name] = function_info self.extractor.metadata.add_function(function_info) self.generic_visit(node) def _extract_function_info(self, node: c_ast.FuncDef) -> Optional[FunctionInfo]: """从函数定义提取信息""" try: decl = node.decl return_type = self._extract_type(decl.type.type) location = self._extract_location(decl.coord) # 提取参数 parameters = [] if hasattr(decl.type, 'args') and decl.type.args: for param in decl.type.args.params: param_info = self._extract_parameter_info(param) if param_info: parameters.append(param_info) # 检查函数属性 is_static = 'static' in (decl.storage or []) is_inline = 'inline' in (getattr(decl, 'funcspec', []) or []) # 提取函数体位置 body_location = self._extract_location(node.body.coord) if node.body.coord else None return FunctionInfo( name=decl.name, return_type=return_type, parameters=parameters, is_static=is_static, is_inline=is_inline, is_declaration_only=False, location=location, body_location=body_location ) except Exception as e: self.logger.error(f"Error extracting function info: {e}") return None def _extract_function_decl_info(self, node: c_ast.Decl) -> Optional[FunctionInfo]: """从函数声明提取信息""" try: return_type = self._extract_type(node.type.type) location = self._extract_location(node.coord) # 提取参数 parameters = [] if hasattr(node.type, 'args') and node.type.args: for param in node.type.args.params: param_info = self._extract_parameter_info(param) if param_info: parameters.append(param_info) return FunctionInfo( name=node.name, return_type=return_type, parameters=parameters, is_declaration_only=True, location=location ) except Exception as e: self.logger.error(f"Error extracting function declaration info: {e}") return None def _extract_parameter_info(self, node: c_ast.Decl) -> Optional[ParameterInfo]: """提取参数信息""" try: type_name = self._extract_type(node.type) is_pointer = self._is_pointer_type(node.type) is_array = isinstance(node.type, c_ast.ArrayDecl) array_size = self._stringify_expression(node.type.dim) if is_array and hasattr(node.type, 'dim') else None location = self._extract_location(node.coord) # 处理未命名参数 param_name = getattr(node, 'name', None) if param_name is None: # 生成稳定的占位符名称 param_name = f"" return ParameterInfo( name=param_name, type_name=type_name, is_pointer=is_pointer, is_array=is_array, array_size=array_size, location=location ) except Exception as e: self.logger.error(f"Error extracting parameter info: {e}") return None def _extract_type(self, node: c_ast.Node) -> str: """提取类型名称""" if isinstance(node, c_ast.IdentifierType): return ' '.join(node.names) elif isinstance(node, c_ast.PtrDecl): return self._extract_type(node.type) + '*' elif isinstance(node, c_ast.ArrayDecl): return self._extract_type(node.type) + '[]' elif isinstance(node, c_ast.TypeDecl): return self._extract_type(node.type) else: return str(type(node).__name__) def _is_pointer_type(self, node: c_ast.Node) -> bool: """检查是否为指针类型""" return isinstance(node, c_ast.PtrDecl) def _extract_location(self, coord) -> Optional[SourceLocation]: """提取源代码位置""" if not coord: return None return SourceLocation( file=coord.file or self.source_file or "", line=coord.line, column=coord.column ) class VariableExtractor(ASTVisitor): """变量提取器""" def __init__(self, extractor: 'ASTExtractor'): super().__init__(extractor) self.variables = {} self.global_scope = True self.current_params: set[str] = set() def visit_FuncDef(self, node: c_ast.FuncDef) -> None: """访问函数定义""" self.global_scope = False self.current_scope.append(node.decl.name) self.current_function = node.decl.name # 提取参数名称集合 self.current_params = set() if getattr(node.decl.type, 'args', None) and node.decl.type.args: for p in node.decl.type.args.params: if isinstance(p, c_ast.Decl) and getattr(p, 'name', None): self.current_params.add(p.name) self.generic_visit(node) self.current_function = None self.current_scope.pop() self.current_params = set() self.global_scope = True def visit_Decl(self, node: c_ast.Decl) -> None: """访问变量声明""" if not isinstance(node.type, c_ast.FuncDecl): variable_info = self._extract_variable_info(node) if variable_info: self.variables[variable_info.name] = variable_info self.extractor.metadata.add_variable(variable_info) self.generic_visit(node) def _extract_variable_info(self, node: c_ast.Decl) -> Optional[VariableInfo]: """提取变量信息""" try: type_name = self._extract_type(node.type) is_pointer = self._is_pointer_type(node.type) is_array = isinstance(node.type, c_ast.ArrayDecl) array_size = self._stringify_expression(node.type.dim) if is_array and hasattr(node.type, 'dim') else None is_const = any(n.name == 'const' for n in getattr(node, 'quals', [])) is_volatile = any(n.name == 'volatile' for n in getattr(node, 'quals', [])) is_static = 'static' in (getattr(node, 'storage', []) or []) initial_value = None if hasattr(node, 'init') and node.init: initial_value = self._extract_initializer(node.init) location = self._extract_location(node.coord) # 确定作用域 scope = VariableScope.GLOBAL if self.global_scope else VariableScope.LOCAL if node.name in self.current_params: scope = VariableScope.PARAMETER elif is_static: scope = VariableScope.STATIC return VariableInfo( name=node.name, type_name=type_name, scope=scope, is_pointer=is_pointer, is_array=is_array, array_size=array_size, is_const=is_const, is_volatile=is_volatile, is_static=is_static, initial_value=initial_value, location=location ) except Exception as e: self.logger.error(f"Error extracting variable info: {e}") return None def _extract_type(self, node: c_ast.Node) -> str: """提取类型名称""" if isinstance(node, c_ast.IdentifierType): return ' '.join(node.names) elif isinstance(node, c_ast.PtrDecl): return self._extract_type(node.type) + '*' elif isinstance(node, c_ast.ArrayDecl): return self._extract_type(node.type) + '[]' elif isinstance(node, c_ast.TypeDecl): return self._extract_type(node.type) else: return str(type(node).__name__) def _is_pointer_type(self, node: c_ast.Node) -> bool: """检查是否为指针类型""" return isinstance(node, c_ast.PtrDecl) def _extract_initializer(self, node: c_ast.Node) -> Optional[str]: """提取初始值""" try: if isinstance(node, c_ast.Constant): return node.value elif isinstance(node, c_ast.InitList): return "{...}" # 简化的初始化列表表示 elif isinstance(node, c_ast.UnaryOp) and node.op == '&': return f"&{self._extract_initializer(node.expr)}" else: return str(node) except Exception: return None def _extract_location(self, coord) -> Optional[SourceLocation]: """提取源代码位置""" if not coord: return None return SourceLocation( file=coord.file or self.source_file or "", line=coord.line, column=coord.column ) class CallGraphExtractor(ASTVisitor): """调用图提取器""" def __init__(self, extractor: 'ASTExtractor'): super().__init__(extractor) self.call_relations = [] self.function_calls = defaultdict(list) def visit_FuncDef(self, node: c_ast.FuncDef) -> None: """访问函数定义""" self.current_function = node.decl.name self.generic_visit(node) self.current_function = None def visit_FuncCall(self, node: c_ast.FuncCall) -> None: """访问函数调用""" if self.current_function: callee_name = self._extract_callee_name(node.name) if callee_name: location = self._extract_location(node.coord) # 确定调用类型 call_type = 'indirect' if callee_name.startswith('') else 'direct' call_relation = CallRelation( caller=self.current_function, callee=callee_name, call_type=call_type, call_sites=[location] if location else [] ) self.call_relations.append(call_relation) self.extractor.metadata.add_call_relation(call_relation) self.generic_visit(node) def _extract_callee_name(self, node: c_ast.Node) -> Optional[str]: """提取被调用函数名称""" if isinstance(node, c_ast.ID): return node.name elif isinstance(node, c_ast.UnaryOp) and node.op == '&': return self._extract_callee_name(node.expr) elif isinstance(node, c_ast.ArrayRef): # 函数指针数组调用 return self._extract_callee_name(node.name) elif isinstance(node, c_ast.StructRef): # 结构体/联合体中的函数指针调用 field_name = self._extract_callee_name(node.field) if field_name: # 标记为间接调用 self.logger.debug(f"StructRef function call: {node.name.name}.{field_name}") return f":{field_name}" return None elif isinstance(node, c_ast.PtrDecl): # 函数指针调用 return "" else: # 无法解析的调用类型 self.logger.debug(f"Unresolvable call type: {type(node).__name__}") return "" def _extract_location(self, coord) -> Optional[SourceLocation]: """提取源代码位置""" if not coord: return None return SourceLocation( file=coord.file or self.source_file or "", line=coord.line, column=coord.column ) class DataStructureExtractor(ASTVisitor): """数据结构提取器""" def __init__(self, extractor: 'ASTExtractor'): super().__init__(extractor) self.data_structures = {} def visit_Struct(self, node: c_ast.Struct) -> None: """访问结构体定义""" if node.decls: structure_info = self._extract_structure_info(node) if structure_info: self.data_structures[structure_info.name] = structure_info self.extractor.metadata.add_data_structure(structure_info) elif node.name: # 处理不透明结构体(前向声明) structure_info = DataStructure( name=node.name, kind='struct', fields=[], is_opaque=True, is_anonymous=False, location=self._extract_location(node.coord) ) self.data_structures[structure_info.name] = structure_info self.extractor.metadata.add_data_structure(structure_info) self.logger.debug(f"Added opaque struct: {node.name}") self.generic_visit(node) def visit_Union(self, node: c_ast.Union) -> None: """访问联合体定义""" if node.decls: structure_info = self._extract_structure_info(node) if structure_info: self.data_structures[structure_info.name] = structure_info self.extractor.metadata.add_data_structure(structure_info) elif node.name: # 处理不透明联合体(前向声明) structure_info = DataStructure( name=node.name, kind='union', fields=[], is_opaque=True, is_anonymous=False, location=self._extract_location(node.coord) ) self.data_structures[structure_info.name] = structure_info self.extractor.metadata.add_data_structure(structure_info) self.logger.debug(f"Added opaque union: {node.name}") self.generic_visit(node) def _extract_structure_info(self, node) -> Optional[DataStructure]: """提取数据结构信息""" try: kind = 'struct' if isinstance(node, c_ast.Struct) else 'union' name = node.name or f"anonymous_{id(node)}" is_anonymous = node.name is None location = self._extract_location(node.coord) # 提取字段 fields = [] if node.decls: for decl in node.decls: field_info = self._extract_field_info(decl) if field_info: fields.append(field_info) return DataStructure( name=name, kind=kind, fields=fields, is_anonymous=is_anonymous, location=location ) except Exception as e: self.logger.error(f"Error extracting structure info: {e}") return None def _extract_field_info(self, node: c_ast.Decl) -> Optional[StructField]: """提取字段信息""" try: type_name = self._extract_type(node.type) is_pointer = self._is_pointer_type(node.type) is_array = isinstance(node.type, c_ast.ArrayDecl) array_size = self._stringify_expression(node.type.dim) if is_array and hasattr(node.type, 'dim') else None location = self._extract_location(node.coord) # 检查位域 bit_field = None if hasattr(node, 'bitsize') and node.bitsize: if isinstance(node.bitsize, c_ast.Constant): bit_field = int(node.bitsize.value) return StructField( name=node.name, type_name=type_name, is_pointer=is_pointer, is_array=is_array, array_size=array_size, bit_field=bit_field, location=location ) except Exception as e: self.logger.error(f"Error extracting field info: {e}") return None def _extract_type(self, node: c_ast.Node) -> str: """提取类型名称""" if isinstance(node, c_ast.IdentifierType): return ' '.join(node.names) elif isinstance(node, c_ast.PtrDecl): return self._extract_type(node.type) + '*' elif isinstance(node, c_ast.ArrayDecl): return self._extract_type(node.type) + '[]' elif isinstance(node, c_ast.TypeDecl): return self._extract_type(node.type) else: return str(type(node).__name__) def _is_pointer_type(self, node: c_ast.Node) -> bool: """检查是否为指针类型""" return isinstance(node, c_ast.PtrDecl) def _extract_location(self, coord) -> Optional[SourceLocation]: """提取源代码位置""" if not coord: return None return SourceLocation( file=coord.file or self.source_file or "", line=coord.line, column=coord.column ) class PointerAnalyzer(ASTVisitor): """指针分析器""" def __init__(self, extractor: 'ASTExtractor'): super().__init__(extractor) self.memory_patterns = [] def visit_UnaryOp(self, node: c_ast.UnaryOp) -> None: """访问一元操作符(指针解引用、取地址等)""" if node.op in ['*', '&']: self._analyze_pointer_operation(node) self.generic_visit(node) def visit_ArrayRef(self, node: c_ast.ArrayRef) -> None: """访问数组引用(指针算术)""" self._analyze_array_access(node) self.generic_visit(node) def visit_Assignment(self, node: c_ast.Assignment) -> None: """访问赋值操作(指针赋值)""" if isinstance(node.lvalue, c_ast.UnaryOp) and node.lvalue.op == '*': # 指针解引用赋值 self._analyze_pointer_dereference(node.lvalue, MemoryAccessType.WRITE) self.generic_visit(node) def _analyze_pointer_operation(self, node: c_ast.UnaryOp) -> None: """分析指针操作""" try: if node.op == '*': # 指针解引用 variable_name = self._extract_variable_name(node.expr) if variable_name: location = self._extract_location(node.coord) pattern = MemoryAccessPattern( variable_name=variable_name, access_type=MemoryAccessType.READ, access_locations=[location] if location else [], is_dereferenced=True ) self.memory_patterns.append(pattern) self.extractor.metadata.add_memory_pattern(pattern) elif node.op == '&': # 取地址操作 variable_name = self._extract_variable_name(node.expr) if variable_name: location = self._extract_location(node.coord) pattern = MemoryAccessPattern( variable_name=variable_name, access_type=MemoryAccessType.READ, access_locations=[location] if location else [] ) self.memory_patterns.append(pattern) self.extractor.metadata.add_memory_pattern(pattern) except Exception as e: self.logger.error(f"Error analyzing pointer operation: {e}") def _analyze_array_access(self, node: c_ast.ArrayRef) -> None: """分析数组访问""" try: array_name = self._extract_variable_name(node.name) if array_name: location = self._extract_location(node.coord) index_expr = self._extract_expression(node.subscript) pattern = MemoryAccessPattern( variable_name=array_name, access_type=MemoryAccessType.READ, access_locations=[location] if location else [], is_indexed=True, indices=[index_expr] if index_expr else [] ) self.memory_patterns.append(pattern) self.extractor.metadata.add_memory_pattern(pattern) except Exception as e: self.logger.error(f"Error analyzing array access: {e}") def _analyze_pointer_dereference(self, node: c_ast.UnaryOp, access_type: MemoryAccessType) -> None: """分析指针解引用""" try: variable_name = self._extract_variable_name(node.expr) if variable_name: location = self._extract_location(node.coord) pattern = MemoryAccessPattern( variable_name=variable_name, access_type=access_type, access_locations=[location] if location else [], is_dereferenced=True ) self.memory_patterns.append(pattern) self.extractor.metadata.add_memory_pattern(pattern) except Exception as e: self.logger.error(f"Error analyzing pointer dereference: {e}") def _extract_variable_name(self, node: c_ast.Node) -> Optional[str]: """提取变量名称""" if isinstance(node, c_ast.ID): return node.name elif isinstance(node, c_ast.ArrayRef): return self._extract_variable_name(node.name) elif isinstance(node, c_ast.UnaryOp) and node.op == '*': return self._extract_variable_name(node.expr) else: return None def _extract_expression(self, node: c_ast.Node) -> Optional[str]: """提取表达式字符串""" try: if isinstance(node, c_ast.Constant): return node.value elif isinstance(node, c_ast.ID): return node.name elif isinstance(node, c_ast.BinaryOp): return f"{self._extract_expression(node.left)} {node.op} {self._extract_expression(node.right)}" else: return str(node) except Exception: return None def _extract_location(self, coord) -> Optional[SourceLocation]: """提取源代码位置""" if not coord: return None return SourceLocation( file=coord.file or self.source_file or "", line=coord.line, column=coord.column ) class ASTExtractor: """AST信息提取器主类""" def __init__(self, config: Optional[Dict[str, Any]] = None): self.config = config or {} self.logger = get_logger('ast_extractor') self.metadata = None self.extractors = [] def extract_metadata(self, ast: c_ast.FileAST, file_path: Union[str, Path]) -> CodeMetadata: """从AST提取元数据""" start_time = time.time() file_path = Path(file_path) self.logger.info(f"Extracting metadata from: {file_path}") # 初始化元数据 stat = file_path.stat() if file_path.exists() else None self.metadata = CodeMetadata( file_path=str(file_path), file_size=stat.st_size if stat else 0, modification_time=str(stat.st_mtime) if stat else "" ) # 创建并运行提取器 extractors = [ FunctionExtractor(self), VariableExtractor(self), CallGraphExtractor(self), DataStructureExtractor(self), PointerAnalyzer(self) ] for extractor in extractors: try: extractor.visit(ast) self.logger.debug(f"Completed extraction with {extractor.__class__.__name__}") except Exception as e: self.logger.error(f"Error in {extractor.__class__.__name__}: {e}") # 后处理和验证 self._post_process_metadata() duration = time.time() - start_time self.logger.info(f"Metadata extraction completed in {duration:.2f}s") self.logger.info(f"Extracted {len(self.metadata.functions)} functions, " f"{len(self.metadata.variables)} variables, " f"{len(self.metadata.call_relations)} call relations") return self.metadata def _post_process_metadata(self) -> None: """后处理元数据""" try: # 识别递归函数 self._identify_recursive_functions() # 计算函数复杂度 self._calculate_function_complexity() # 添加CBMC特定的验证提示 self._add_verification_hints() # 验证元数据完整性 self._validate_metadata() except Exception as e: self.logger.error(f"Error in post-processing: {e}") def _identify_recursive_functions(self) -> None: """识别递归函数""" for relation in self.metadata.call_relations: if relation.caller == relation.callee: relation.is_recursive = True if relation.caller in self.metadata.functions: self.metadata.functions[relation.caller].verification_hints.add( VerificationHint.REENTRANT ) def _calculate_function_complexity(self) -> None: """计算函数复杂度指标""" for function_name, function_info in self.metadata.functions.items(): # 简单的复杂度计算:基于参数数量和调用关系 complexity_metrics = { 'parameter_count': len(function_info.parameters), 'call_count': len([ rel for rel in self.metadata.call_relations if rel.caller == function_name ]), 'is_called': len([ rel for rel in self.metadata.call_relations if rel.callee == function_name ]) > 0 } function_info.complexity_metrics.update(complexity_metrics) def _add_verification_hints(self) -> None: """添加CBMC特定的验证提示""" for function_name, function_info in self.metadata.functions.items(): # 检查指针参数 has_pointer_params = any(param.is_pointer for param in function_info.parameters) if has_pointer_params: function_info.verification_hints.add(VerificationHint.NO_NULL_DEREFERENCE) function_info.verification_hints.add(VerificationHint.BOUNDS_CHECK) # 检查内存分配模式 if 'alloc' in function_name.lower() or 'malloc' in function_name.lower(): function_info.verification_hints.add(VerificationHint.NO_LEAK) def _validate_metadata(self) -> None: """验证元数据完整性""" # 检查函数定义的完整性 for function_name, function_info in self.metadata.functions.items(): if not function_info.is_declaration_only and not function_info.body_location: self.logger.warning(f"Function {function_name} has no body location") # 检查调用关系的完整性 for relation in self.metadata.call_relations: if relation.caller not in self.metadata.functions: self.logger.warning(f"Unknown caller in call relation: {relation.caller}") if relation.callee not in self.metadata.functions: self.logger.warning(f"Unknown callee in call relation: {relation.callee}") def get_extraction_stats(self) -> Dict[str, Any]: """获取提取统计信息""" if not self.metadata: return {} return { 'functions': len(self.metadata.functions), 'variables': len(self.metadata.variables), 'call_relations': len(self.metadata.call_relations), 'data_structures': len(self.metadata.data_structures), 'memory_patterns': len(self.metadata.memory_patterns), 'includes': len(self.metadata.includes), 'macros': len(self.metadata.macros) } def filter_functions(self, function_names: Set[str]) -> CodeMetadata: """过滤指定函数的元数据""" if not self.metadata: return CodeMetadata(file_path="", file_size=0, modification_time="") filtered_metadata = CodeMetadata( file_path=self.metadata.file_path, file_size=self.metadata.file_size, modification_time=self.metadata.modification_time ) # 添加指定的函数 for func_name in function_names: if func_name in self.metadata.functions: filtered_metadata.add_function(self.metadata.functions[func_name]) # 添加相关的变量和调用关系 related_functions = set(function_names) for relation in self.metadata.call_relations: if relation.caller in function_names or relation.callee in function_names: filtered_metadata.add_call_relation(relation) related_functions.add(relation.caller) related_functions.add(relation.callee) # 添加相关的变量 for var_name, var_info in self.metadata.variables.items(): if var_info.scope == VariableScope.GLOBAL: filtered_metadata.add_variable(var_info) return filtered_metadata