|
|
"""
|
|
|
AST信息提取器模块
|
|
|
|
|
|
本模块实现了从解析后的C/C++ AST中提取函数签名、全局变量、调用关系、
|
|
|
数据结构等元数据的功能。支持CBMC特定的分析,如指针使用模式、内存访问分析等。
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import time
|
|
|
from typing import Dict, List, Optional, Set, Union, Any, Callable
|
|
|
from pathlib import Path
|
|
|
from collections import defaultdict
|
|
|
|
|
|
try:
|
|
|
from pycparser import c_ast
|
|
|
except ImportError:
|
|
|
raise ImportError("pycparser is required. Install with: pip install pycparser==2.21")
|
|
|
|
|
|
from .metadata import (
|
|
|
CodeMetadata, FunctionInfo, VariableInfo, CallRelation, DataStructure,
|
|
|
ParameterInfo, StructField, MemoryAccessPattern, SourceLocation,
|
|
|
VariableScope, MemoryAccessType, VerificationHint
|
|
|
)
|
|
|
from .c_parser import CParser
|
|
|
from ..utils.logger import get_logger
|
|
|
|
|
|
|
|
|
class ASTVisitor(c_ast.NodeVisitor):
|
|
|
"""基础AST访问器"""
|
|
|
|
|
|
def __init__(self, extractor: 'ASTExtractor'):
|
|
|
self.extractor = extractor
|
|
|
self.logger = extractor.logger
|
|
|
self.current_function = None
|
|
|
self.current_scope = []
|
|
|
self.source_file = None
|
|
|
|
|
|
def visit(self, node: c_ast.Node) -> Any:
|
|
|
"""访问节点"""
|
|
|
if hasattr(node, 'coord') and node.coord:
|
|
|
self.source_file = node.coord.file
|
|
|
|
|
|
return super().visit(node)
|
|
|
|
|
|
def _stringify_expression(self, expr_node: Optional[c_ast.Node]) -> Optional[str]:
|
|
|
"""安全地将表达式节点转换为字符串"""
|
|
|
if expr_node is None:
|
|
|
return None
|
|
|
|
|
|
try:
|
|
|
if isinstance(expr_node, c_ast.Constant):
|
|
|
return str(expr_node.value)
|
|
|
elif isinstance(expr_node, c_ast.ID):
|
|
|
return expr_node.name
|
|
|
elif isinstance(expr_node, c_ast.UnaryOp):
|
|
|
return f"{expr_node.op}{self._stringify_expression(expr_node.expr)}"
|
|
|
elif isinstance(expr_node, c_ast.BinaryOp):
|
|
|
left = self._stringify_expression(expr_node.left)
|
|
|
right = self._stringify_expression(expr_node.right)
|
|
|
return f"{left} {expr_node.op} {right}" if left and right else None
|
|
|
elif isinstance(expr_node, c_ast.TernaryOp):
|
|
|
cond = self._stringify_expression(expr_node.cond)
|
|
|
true_val = self._stringify_expression(expr_node.iftrue)
|
|
|
false_val = self._stringify_expression(expr_node.iffalse)
|
|
|
if cond and true_val and false_val:
|
|
|
return f"{cond} ? {true_val} : {false_val}"
|
|
|
return None
|
|
|
else:
|
|
|
# 对于其他类型的表达式,返回None而不是尝试字符串化
|
|
|
self.logger.debug(f"Unsupported expression type: {type(expr_node).__name__}")
|
|
|
return None
|
|
|
except Exception as e:
|
|
|
self.logger.debug(f"Error stringifying expression: {e}")
|
|
|
return None
|
|
|
|
|
|
|
|
|
class FunctionExtractor(ASTVisitor):
|
|
|
"""函数提取器"""
|
|
|
|
|
|
def __init__(self, extractor: 'ASTExtractor'):
|
|
|
super().__init__(extractor)
|
|
|
self.functions = {}
|
|
|
|
|
|
def visit_FuncDef(self, node: c_ast.FuncDef) -> None:
|
|
|
"""访问函数定义"""
|
|
|
function_info = self._extract_function_info(node)
|
|
|
if function_info:
|
|
|
self.functions[function_info.name] = function_info
|
|
|
self.extractor.metadata.add_function(function_info)
|
|
|
|
|
|
# 继续访问函数体
|
|
|
self.current_function = function_info
|
|
|
self.generic_visit(node)
|
|
|
self.current_function = None
|
|
|
|
|
|
def visit_Decl(self, node: c_ast.Decl) -> None:
|
|
|
"""访问声明(包括函数声明)"""
|
|
|
if isinstance(node.type, c_ast.FuncDecl) and not hasattr(node, 'init'):
|
|
|
# 这是函数声明,不是定义
|
|
|
function_info = self._extract_function_decl_info(node)
|
|
|
if function_info:
|
|
|
self.functions[function_info.name] = function_info
|
|
|
self.extractor.metadata.add_function(function_info)
|
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
def _extract_function_info(self, node: c_ast.FuncDef) -> Optional[FunctionInfo]:
|
|
|
"""从函数定义提取信息"""
|
|
|
try:
|
|
|
decl = node.decl
|
|
|
return_type = self._extract_type(decl.type.type)
|
|
|
location = self._extract_location(decl.coord)
|
|
|
|
|
|
# 提取参数
|
|
|
parameters = []
|
|
|
if hasattr(decl.type, 'args') and decl.type.args:
|
|
|
for param in decl.type.args.params:
|
|
|
param_info = self._extract_parameter_info(param)
|
|
|
if param_info:
|
|
|
parameters.append(param_info)
|
|
|
|
|
|
# 检查函数属性
|
|
|
is_static = 'static' in (decl.storage or [])
|
|
|
is_inline = 'inline' in (getattr(decl, 'funcspec', []) or [])
|
|
|
|
|
|
# 提取函数体位置
|
|
|
body_location = self._extract_location(node.body.coord) if node.body.coord else None
|
|
|
|
|
|
return FunctionInfo(
|
|
|
name=decl.name,
|
|
|
return_type=return_type,
|
|
|
parameters=parameters,
|
|
|
is_static=is_static,
|
|
|
is_inline=is_inline,
|
|
|
is_declaration_only=False,
|
|
|
location=location,
|
|
|
body_location=body_location
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error extracting function info: {e}")
|
|
|
return None
|
|
|
|
|
|
def _extract_function_decl_info(self, node: c_ast.Decl) -> Optional[FunctionInfo]:
|
|
|
"""从函数声明提取信息"""
|
|
|
try:
|
|
|
return_type = self._extract_type(node.type.type)
|
|
|
location = self._extract_location(node.coord)
|
|
|
|
|
|
# 提取参数
|
|
|
parameters = []
|
|
|
if hasattr(node.type, 'args') and node.type.args:
|
|
|
for param in node.type.args.params:
|
|
|
param_info = self._extract_parameter_info(param)
|
|
|
if param_info:
|
|
|
parameters.append(param_info)
|
|
|
|
|
|
return FunctionInfo(
|
|
|
name=node.name,
|
|
|
return_type=return_type,
|
|
|
parameters=parameters,
|
|
|
is_declaration_only=True,
|
|
|
location=location
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error extracting function declaration info: {e}")
|
|
|
return None
|
|
|
|
|
|
def _extract_parameter_info(self, node: c_ast.Decl) -> Optional[ParameterInfo]:
|
|
|
"""提取参数信息"""
|
|
|
try:
|
|
|
type_name = self._extract_type(node.type)
|
|
|
is_pointer = self._is_pointer_type(node.type)
|
|
|
is_array = isinstance(node.type, c_ast.ArrayDecl)
|
|
|
array_size = self._stringify_expression(node.type.dim) if is_array and hasattr(node.type, 'dim') else None
|
|
|
location = self._extract_location(node.coord)
|
|
|
|
|
|
# 处理未命名参数
|
|
|
param_name = getattr(node, 'name', None)
|
|
|
if param_name is None:
|
|
|
# 生成稳定的占位符名称
|
|
|
param_name = f"<unnamed_param_{id(node)}>"
|
|
|
|
|
|
return ParameterInfo(
|
|
|
name=param_name,
|
|
|
type_name=type_name,
|
|
|
is_pointer=is_pointer,
|
|
|
is_array=is_array,
|
|
|
array_size=array_size,
|
|
|
location=location
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error extracting parameter info: {e}")
|
|
|
return None
|
|
|
|
|
|
def _extract_type(self, node: c_ast.Node) -> str:
|
|
|
"""提取类型名称"""
|
|
|
if isinstance(node, c_ast.IdentifierType):
|
|
|
return ' '.join(node.names)
|
|
|
elif isinstance(node, c_ast.PtrDecl):
|
|
|
return self._extract_type(node.type) + '*'
|
|
|
elif isinstance(node, c_ast.ArrayDecl):
|
|
|
return self._extract_type(node.type) + '[]'
|
|
|
elif isinstance(node, c_ast.TypeDecl):
|
|
|
return self._extract_type(node.type)
|
|
|
else:
|
|
|
return str(type(node).__name__)
|
|
|
|
|
|
def _is_pointer_type(self, node: c_ast.Node) -> bool:
|
|
|
"""检查是否为指针类型"""
|
|
|
return isinstance(node, c_ast.PtrDecl)
|
|
|
|
|
|
def _extract_location(self, coord) -> Optional[SourceLocation]:
|
|
|
"""提取源代码位置"""
|
|
|
if not coord:
|
|
|
return None
|
|
|
|
|
|
return SourceLocation(
|
|
|
file=coord.file or self.source_file or "",
|
|
|
line=coord.line,
|
|
|
column=coord.column
|
|
|
)
|
|
|
|
|
|
|
|
|
class VariableExtractor(ASTVisitor):
|
|
|
"""变量提取器"""
|
|
|
|
|
|
def __init__(self, extractor: 'ASTExtractor'):
|
|
|
super().__init__(extractor)
|
|
|
self.variables = {}
|
|
|
self.global_scope = True
|
|
|
self.current_params: set[str] = set()
|
|
|
|
|
|
def visit_FuncDef(self, node: c_ast.FuncDef) -> None:
|
|
|
"""访问函数定义"""
|
|
|
self.global_scope = False
|
|
|
self.current_scope.append(node.decl.name)
|
|
|
self.current_function = node.decl.name
|
|
|
|
|
|
# 提取参数名称集合
|
|
|
self.current_params = set()
|
|
|
if getattr(node.decl.type, 'args', None) and node.decl.type.args:
|
|
|
for p in node.decl.type.args.params:
|
|
|
if isinstance(p, c_ast.Decl) and getattr(p, 'name', None):
|
|
|
self.current_params.add(p.name)
|
|
|
|
|
|
self.generic_visit(node)
|
|
|
self.current_function = None
|
|
|
self.current_scope.pop()
|
|
|
self.current_params = set()
|
|
|
self.global_scope = True
|
|
|
|
|
|
def visit_Decl(self, node: c_ast.Decl) -> None:
|
|
|
"""访问变量声明"""
|
|
|
if not isinstance(node.type, c_ast.FuncDecl):
|
|
|
variable_info = self._extract_variable_info(node)
|
|
|
if variable_info:
|
|
|
self.variables[variable_info.name] = variable_info
|
|
|
self.extractor.metadata.add_variable(variable_info)
|
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
def _extract_variable_info(self, node: c_ast.Decl) -> Optional[VariableInfo]:
|
|
|
"""提取变量信息"""
|
|
|
try:
|
|
|
type_name = self._extract_type(node.type)
|
|
|
is_pointer = self._is_pointer_type(node.type)
|
|
|
is_array = isinstance(node.type, c_ast.ArrayDecl)
|
|
|
array_size = self._stringify_expression(node.type.dim) if is_array and hasattr(node.type, 'dim') else None
|
|
|
is_const = any(n.name == 'const' for n in getattr(node, 'quals', []))
|
|
|
is_volatile = any(n.name == 'volatile' for n in getattr(node, 'quals', []))
|
|
|
is_static = 'static' in (getattr(node, 'storage', []) or [])
|
|
|
|
|
|
initial_value = None
|
|
|
if hasattr(node, 'init') and node.init:
|
|
|
initial_value = self._extract_initializer(node.init)
|
|
|
|
|
|
location = self._extract_location(node.coord)
|
|
|
|
|
|
# 确定作用域
|
|
|
scope = VariableScope.GLOBAL if self.global_scope else VariableScope.LOCAL
|
|
|
if node.name in self.current_params:
|
|
|
scope = VariableScope.PARAMETER
|
|
|
elif is_static:
|
|
|
scope = VariableScope.STATIC
|
|
|
|
|
|
return VariableInfo(
|
|
|
name=node.name,
|
|
|
type_name=type_name,
|
|
|
scope=scope,
|
|
|
is_pointer=is_pointer,
|
|
|
is_array=is_array,
|
|
|
array_size=array_size,
|
|
|
is_const=is_const,
|
|
|
is_volatile=is_volatile,
|
|
|
is_static=is_static,
|
|
|
initial_value=initial_value,
|
|
|
location=location
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error extracting variable info: {e}")
|
|
|
return None
|
|
|
|
|
|
def _extract_type(self, node: c_ast.Node) -> str:
|
|
|
"""提取类型名称"""
|
|
|
if isinstance(node, c_ast.IdentifierType):
|
|
|
return ' '.join(node.names)
|
|
|
elif isinstance(node, c_ast.PtrDecl):
|
|
|
return self._extract_type(node.type) + '*'
|
|
|
elif isinstance(node, c_ast.ArrayDecl):
|
|
|
return self._extract_type(node.type) + '[]'
|
|
|
elif isinstance(node, c_ast.TypeDecl):
|
|
|
return self._extract_type(node.type)
|
|
|
else:
|
|
|
return str(type(node).__name__)
|
|
|
|
|
|
def _is_pointer_type(self, node: c_ast.Node) -> bool:
|
|
|
"""检查是否为指针类型"""
|
|
|
return isinstance(node, c_ast.PtrDecl)
|
|
|
|
|
|
def _extract_initializer(self, node: c_ast.Node) -> Optional[str]:
|
|
|
"""提取初始值"""
|
|
|
try:
|
|
|
if isinstance(node, c_ast.Constant):
|
|
|
return node.value
|
|
|
elif isinstance(node, c_ast.InitList):
|
|
|
return "{...}" # 简化的初始化列表表示
|
|
|
elif isinstance(node, c_ast.UnaryOp) and node.op == '&':
|
|
|
return f"&{self._extract_initializer(node.expr)}"
|
|
|
else:
|
|
|
return str(node)
|
|
|
except Exception:
|
|
|
return None
|
|
|
|
|
|
def _extract_location(self, coord) -> Optional[SourceLocation]:
|
|
|
"""提取源代码位置"""
|
|
|
if not coord:
|
|
|
return None
|
|
|
|
|
|
return SourceLocation(
|
|
|
file=coord.file or self.source_file or "",
|
|
|
line=coord.line,
|
|
|
column=coord.column
|
|
|
)
|
|
|
|
|
|
|
|
|
class CallGraphExtractor(ASTVisitor):
|
|
|
"""调用图提取器"""
|
|
|
|
|
|
def __init__(self, extractor: 'ASTExtractor'):
|
|
|
super().__init__(extractor)
|
|
|
self.call_relations = []
|
|
|
self.function_calls = defaultdict(list)
|
|
|
|
|
|
def visit_FuncDef(self, node: c_ast.FuncDef) -> None:
|
|
|
"""访问函数定义"""
|
|
|
self.current_function = node.decl.name
|
|
|
self.generic_visit(node)
|
|
|
self.current_function = None
|
|
|
|
|
|
def visit_FuncCall(self, node: c_ast.FuncCall) -> None:
|
|
|
"""访问函数调用"""
|
|
|
if self.current_function:
|
|
|
callee_name = self._extract_callee_name(node.name)
|
|
|
if callee_name:
|
|
|
location = self._extract_location(node.coord)
|
|
|
|
|
|
# 确定调用类型
|
|
|
call_type = 'indirect' if callee_name.startswith('<indirect>') else 'direct'
|
|
|
|
|
|
call_relation = CallRelation(
|
|
|
caller=self.current_function,
|
|
|
callee=callee_name,
|
|
|
call_type=call_type,
|
|
|
call_sites=[location] if location else []
|
|
|
)
|
|
|
|
|
|
self.call_relations.append(call_relation)
|
|
|
self.extractor.metadata.add_call_relation(call_relation)
|
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
def _extract_callee_name(self, node: c_ast.Node) -> Optional[str]:
|
|
|
"""提取被调用函数名称"""
|
|
|
if isinstance(node, c_ast.ID):
|
|
|
return node.name
|
|
|
elif isinstance(node, c_ast.UnaryOp) and node.op == '&':
|
|
|
return self._extract_callee_name(node.expr)
|
|
|
elif isinstance(node, c_ast.ArrayRef):
|
|
|
# 函数指针数组调用
|
|
|
return self._extract_callee_name(node.name)
|
|
|
elif isinstance(node, c_ast.StructRef):
|
|
|
# 结构体/联合体中的函数指针调用
|
|
|
field_name = self._extract_callee_name(node.field)
|
|
|
if field_name:
|
|
|
# 标记为间接调用
|
|
|
self.logger.debug(f"StructRef function call: {node.name.name}.{field_name}")
|
|
|
return f"<indirect>:{field_name}"
|
|
|
return None
|
|
|
elif isinstance(node, c_ast.PtrDecl):
|
|
|
# 函数指针调用
|
|
|
return "<indirect>"
|
|
|
else:
|
|
|
# 无法解析的调用类型
|
|
|
self.logger.debug(f"Unresolvable call type: {type(node).__name__}")
|
|
|
return "<indirect>"
|
|
|
|
|
|
def _extract_location(self, coord) -> Optional[SourceLocation]:
|
|
|
"""提取源代码位置"""
|
|
|
if not coord:
|
|
|
return None
|
|
|
|
|
|
return SourceLocation(
|
|
|
file=coord.file or self.source_file or "",
|
|
|
line=coord.line,
|
|
|
column=coord.column
|
|
|
)
|
|
|
|
|
|
|
|
|
class DataStructureExtractor(ASTVisitor):
|
|
|
"""数据结构提取器"""
|
|
|
|
|
|
def __init__(self, extractor: 'ASTExtractor'):
|
|
|
super().__init__(extractor)
|
|
|
self.data_structures = {}
|
|
|
|
|
|
def visit_Struct(self, node: c_ast.Struct) -> None:
|
|
|
"""访问结构体定义"""
|
|
|
if node.decls:
|
|
|
structure_info = self._extract_structure_info(node)
|
|
|
if structure_info:
|
|
|
self.data_structures[structure_info.name] = structure_info
|
|
|
self.extractor.metadata.add_data_structure(structure_info)
|
|
|
elif node.name:
|
|
|
# 处理不透明结构体(前向声明)
|
|
|
structure_info = DataStructure(
|
|
|
name=node.name,
|
|
|
kind='struct',
|
|
|
fields=[],
|
|
|
is_opaque=True,
|
|
|
is_anonymous=False,
|
|
|
location=self._extract_location(node.coord)
|
|
|
)
|
|
|
self.data_structures[structure_info.name] = structure_info
|
|
|
self.extractor.metadata.add_data_structure(structure_info)
|
|
|
self.logger.debug(f"Added opaque struct: {node.name}")
|
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
def visit_Union(self, node: c_ast.Union) -> None:
|
|
|
"""访问联合体定义"""
|
|
|
if node.decls:
|
|
|
structure_info = self._extract_structure_info(node)
|
|
|
if structure_info:
|
|
|
self.data_structures[structure_info.name] = structure_info
|
|
|
self.extractor.metadata.add_data_structure(structure_info)
|
|
|
elif node.name:
|
|
|
# 处理不透明联合体(前向声明)
|
|
|
structure_info = DataStructure(
|
|
|
name=node.name,
|
|
|
kind='union',
|
|
|
fields=[],
|
|
|
is_opaque=True,
|
|
|
is_anonymous=False,
|
|
|
location=self._extract_location(node.coord)
|
|
|
)
|
|
|
self.data_structures[structure_info.name] = structure_info
|
|
|
self.extractor.metadata.add_data_structure(structure_info)
|
|
|
self.logger.debug(f"Added opaque union: {node.name}")
|
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
def _extract_structure_info(self, node) -> Optional[DataStructure]:
|
|
|
"""提取数据结构信息"""
|
|
|
try:
|
|
|
kind = 'struct' if isinstance(node, c_ast.Struct) else 'union'
|
|
|
name = node.name or f"anonymous_{id(node)}"
|
|
|
is_anonymous = node.name is None
|
|
|
location = self._extract_location(node.coord)
|
|
|
|
|
|
# 提取字段
|
|
|
fields = []
|
|
|
if node.decls:
|
|
|
for decl in node.decls:
|
|
|
field_info = self._extract_field_info(decl)
|
|
|
if field_info:
|
|
|
fields.append(field_info)
|
|
|
|
|
|
return DataStructure(
|
|
|
name=name,
|
|
|
kind=kind,
|
|
|
fields=fields,
|
|
|
is_anonymous=is_anonymous,
|
|
|
location=location
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error extracting structure info: {e}")
|
|
|
return None
|
|
|
|
|
|
def _extract_field_info(self, node: c_ast.Decl) -> Optional[StructField]:
|
|
|
"""提取字段信息"""
|
|
|
try:
|
|
|
type_name = self._extract_type(node.type)
|
|
|
is_pointer = self._is_pointer_type(node.type)
|
|
|
is_array = isinstance(node.type, c_ast.ArrayDecl)
|
|
|
array_size = self._stringify_expression(node.type.dim) if is_array and hasattr(node.type, 'dim') else None
|
|
|
location = self._extract_location(node.coord)
|
|
|
|
|
|
# 检查位域
|
|
|
bit_field = None
|
|
|
if hasattr(node, 'bitsize') and node.bitsize:
|
|
|
if isinstance(node.bitsize, c_ast.Constant):
|
|
|
bit_field = int(node.bitsize.value)
|
|
|
|
|
|
return StructField(
|
|
|
name=node.name,
|
|
|
type_name=type_name,
|
|
|
is_pointer=is_pointer,
|
|
|
is_array=is_array,
|
|
|
array_size=array_size,
|
|
|
bit_field=bit_field,
|
|
|
location=location
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error extracting field info: {e}")
|
|
|
return None
|
|
|
|
|
|
def _extract_type(self, node: c_ast.Node) -> str:
|
|
|
"""提取类型名称"""
|
|
|
if isinstance(node, c_ast.IdentifierType):
|
|
|
return ' '.join(node.names)
|
|
|
elif isinstance(node, c_ast.PtrDecl):
|
|
|
return self._extract_type(node.type) + '*'
|
|
|
elif isinstance(node, c_ast.ArrayDecl):
|
|
|
return self._extract_type(node.type) + '[]'
|
|
|
elif isinstance(node, c_ast.TypeDecl):
|
|
|
return self._extract_type(node.type)
|
|
|
else:
|
|
|
return str(type(node).__name__)
|
|
|
|
|
|
def _is_pointer_type(self, node: c_ast.Node) -> bool:
|
|
|
"""检查是否为指针类型"""
|
|
|
return isinstance(node, c_ast.PtrDecl)
|
|
|
|
|
|
def _extract_location(self, coord) -> Optional[SourceLocation]:
|
|
|
"""提取源代码位置"""
|
|
|
if not coord:
|
|
|
return None
|
|
|
|
|
|
return SourceLocation(
|
|
|
file=coord.file or self.source_file or "",
|
|
|
line=coord.line,
|
|
|
column=coord.column
|
|
|
)
|
|
|
|
|
|
|
|
|
class PointerAnalyzer(ASTVisitor):
|
|
|
"""指针分析器"""
|
|
|
|
|
|
def __init__(self, extractor: 'ASTExtractor'):
|
|
|
super().__init__(extractor)
|
|
|
self.memory_patterns = []
|
|
|
|
|
|
def visit_UnaryOp(self, node: c_ast.UnaryOp) -> None:
|
|
|
"""访问一元操作符(指针解引用、取地址等)"""
|
|
|
if node.op in ['*', '&']:
|
|
|
self._analyze_pointer_operation(node)
|
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
def visit_ArrayRef(self, node: c_ast.ArrayRef) -> None:
|
|
|
"""访问数组引用(指针算术)"""
|
|
|
self._analyze_array_access(node)
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
def visit_Assignment(self, node: c_ast.Assignment) -> None:
|
|
|
"""访问赋值操作(指针赋值)"""
|
|
|
if isinstance(node.lvalue, c_ast.UnaryOp) and node.lvalue.op == '*':
|
|
|
# 指针解引用赋值
|
|
|
self._analyze_pointer_dereference(node.lvalue, MemoryAccessType.WRITE)
|
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
def _analyze_pointer_operation(self, node: c_ast.UnaryOp) -> None:
|
|
|
"""分析指针操作"""
|
|
|
try:
|
|
|
if node.op == '*':
|
|
|
# 指针解引用
|
|
|
variable_name = self._extract_variable_name(node.expr)
|
|
|
if variable_name:
|
|
|
location = self._extract_location(node.coord)
|
|
|
pattern = MemoryAccessPattern(
|
|
|
variable_name=variable_name,
|
|
|
access_type=MemoryAccessType.READ,
|
|
|
access_locations=[location] if location else [],
|
|
|
is_dereferenced=True
|
|
|
)
|
|
|
self.memory_patterns.append(pattern)
|
|
|
self.extractor.metadata.add_memory_pattern(pattern)
|
|
|
|
|
|
elif node.op == '&':
|
|
|
# 取地址操作
|
|
|
variable_name = self._extract_variable_name(node.expr)
|
|
|
if variable_name:
|
|
|
location = self._extract_location(node.coord)
|
|
|
pattern = MemoryAccessPattern(
|
|
|
variable_name=variable_name,
|
|
|
access_type=MemoryAccessType.READ,
|
|
|
access_locations=[location] if location else []
|
|
|
)
|
|
|
self.memory_patterns.append(pattern)
|
|
|
self.extractor.metadata.add_memory_pattern(pattern)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error analyzing pointer operation: {e}")
|
|
|
|
|
|
def _analyze_array_access(self, node: c_ast.ArrayRef) -> None:
|
|
|
"""分析数组访问"""
|
|
|
try:
|
|
|
array_name = self._extract_variable_name(node.name)
|
|
|
if array_name:
|
|
|
location = self._extract_location(node.coord)
|
|
|
index_expr = self._extract_expression(node.subscript)
|
|
|
|
|
|
pattern = MemoryAccessPattern(
|
|
|
variable_name=array_name,
|
|
|
access_type=MemoryAccessType.READ,
|
|
|
access_locations=[location] if location else [],
|
|
|
is_indexed=True,
|
|
|
indices=[index_expr] if index_expr else []
|
|
|
)
|
|
|
self.memory_patterns.append(pattern)
|
|
|
self.extractor.metadata.add_memory_pattern(pattern)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error analyzing array access: {e}")
|
|
|
|
|
|
def _analyze_pointer_dereference(self, node: c_ast.UnaryOp, access_type: MemoryAccessType) -> None:
|
|
|
"""分析指针解引用"""
|
|
|
try:
|
|
|
variable_name = self._extract_variable_name(node.expr)
|
|
|
if variable_name:
|
|
|
location = self._extract_location(node.coord)
|
|
|
pattern = MemoryAccessPattern(
|
|
|
variable_name=variable_name,
|
|
|
access_type=access_type,
|
|
|
access_locations=[location] if location else [],
|
|
|
is_dereferenced=True
|
|
|
)
|
|
|
self.memory_patterns.append(pattern)
|
|
|
self.extractor.metadata.add_memory_pattern(pattern)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error analyzing pointer dereference: {e}")
|
|
|
|
|
|
def _extract_variable_name(self, node: c_ast.Node) -> Optional[str]:
|
|
|
"""提取变量名称"""
|
|
|
if isinstance(node, c_ast.ID):
|
|
|
return node.name
|
|
|
elif isinstance(node, c_ast.ArrayRef):
|
|
|
return self._extract_variable_name(node.name)
|
|
|
elif isinstance(node, c_ast.UnaryOp) and node.op == '*':
|
|
|
return self._extract_variable_name(node.expr)
|
|
|
else:
|
|
|
return None
|
|
|
|
|
|
def _extract_expression(self, node: c_ast.Node) -> Optional[str]:
|
|
|
"""提取表达式字符串"""
|
|
|
try:
|
|
|
if isinstance(node, c_ast.Constant):
|
|
|
return node.value
|
|
|
elif isinstance(node, c_ast.ID):
|
|
|
return node.name
|
|
|
elif isinstance(node, c_ast.BinaryOp):
|
|
|
return f"{self._extract_expression(node.left)} {node.op} {self._extract_expression(node.right)}"
|
|
|
else:
|
|
|
return str(node)
|
|
|
except Exception:
|
|
|
return None
|
|
|
|
|
|
def _extract_location(self, coord) -> Optional[SourceLocation]:
|
|
|
"""提取源代码位置"""
|
|
|
if not coord:
|
|
|
return None
|
|
|
|
|
|
return SourceLocation(
|
|
|
file=coord.file or self.source_file or "",
|
|
|
line=coord.line,
|
|
|
column=coord.column
|
|
|
)
|
|
|
|
|
|
|
|
|
class ASTExtractor:
|
|
|
"""AST信息提取器主类"""
|
|
|
|
|
|
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
|
|
self.config = config or {}
|
|
|
self.logger = get_logger('ast_extractor')
|
|
|
self.metadata = None
|
|
|
self.extractors = []
|
|
|
|
|
|
def extract_metadata(self, ast: c_ast.FileAST, file_path: Union[str, Path]) -> CodeMetadata:
|
|
|
"""从AST提取元数据"""
|
|
|
start_time = time.time()
|
|
|
file_path = Path(file_path)
|
|
|
|
|
|
self.logger.info(f"Extracting metadata from: {file_path}")
|
|
|
|
|
|
# 初始化元数据
|
|
|
stat = file_path.stat() if file_path.exists() else None
|
|
|
self.metadata = CodeMetadata(
|
|
|
file_path=str(file_path),
|
|
|
file_size=stat.st_size if stat else 0,
|
|
|
modification_time=str(stat.st_mtime) if stat else ""
|
|
|
)
|
|
|
|
|
|
# 创建并运行提取器
|
|
|
extractors = [
|
|
|
FunctionExtractor(self),
|
|
|
VariableExtractor(self),
|
|
|
CallGraphExtractor(self),
|
|
|
DataStructureExtractor(self),
|
|
|
PointerAnalyzer(self)
|
|
|
]
|
|
|
|
|
|
for extractor in extractors:
|
|
|
try:
|
|
|
extractor.visit(ast)
|
|
|
self.logger.debug(f"Completed extraction with {extractor.__class__.__name__}")
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error in {extractor.__class__.__name__}: {e}")
|
|
|
|
|
|
# 后处理和验证
|
|
|
self._post_process_metadata()
|
|
|
|
|
|
duration = time.time() - start_time
|
|
|
self.logger.info(f"Metadata extraction completed in {duration:.2f}s")
|
|
|
self.logger.info(f"Extracted {len(self.metadata.functions)} functions, "
|
|
|
f"{len(self.metadata.variables)} variables, "
|
|
|
f"{len(self.metadata.call_relations)} call relations")
|
|
|
|
|
|
return self.metadata
|
|
|
|
|
|
def _post_process_metadata(self) -> None:
|
|
|
"""后处理元数据"""
|
|
|
try:
|
|
|
# 识别递归函数
|
|
|
self._identify_recursive_functions()
|
|
|
|
|
|
# 计算函数复杂度
|
|
|
self._calculate_function_complexity()
|
|
|
|
|
|
# 添加CBMC特定的验证提示
|
|
|
self._add_verification_hints()
|
|
|
|
|
|
# 验证元数据完整性
|
|
|
self._validate_metadata()
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Error in post-processing: {e}")
|
|
|
|
|
|
def _identify_recursive_functions(self) -> None:
|
|
|
"""识别递归函数"""
|
|
|
for relation in self.metadata.call_relations:
|
|
|
if relation.caller == relation.callee:
|
|
|
relation.is_recursive = True
|
|
|
if relation.caller in self.metadata.functions:
|
|
|
self.metadata.functions[relation.caller].verification_hints.add(
|
|
|
VerificationHint.REENTRANT
|
|
|
)
|
|
|
|
|
|
def _calculate_function_complexity(self) -> None:
|
|
|
"""计算函数复杂度指标"""
|
|
|
for function_name, function_info in self.metadata.functions.items():
|
|
|
# 简单的复杂度计算:基于参数数量和调用关系
|
|
|
complexity_metrics = {
|
|
|
'parameter_count': len(function_info.parameters),
|
|
|
'call_count': len([
|
|
|
rel for rel in self.metadata.call_relations
|
|
|
if rel.caller == function_name
|
|
|
]),
|
|
|
'is_called': len([
|
|
|
rel for rel in self.metadata.call_relations
|
|
|
if rel.callee == function_name
|
|
|
]) > 0
|
|
|
}
|
|
|
function_info.complexity_metrics.update(complexity_metrics)
|
|
|
|
|
|
def _add_verification_hints(self) -> None:
|
|
|
"""添加CBMC特定的验证提示"""
|
|
|
for function_name, function_info in self.metadata.functions.items():
|
|
|
# 检查指针参数
|
|
|
has_pointer_params = any(param.is_pointer for param in function_info.parameters)
|
|
|
if has_pointer_params:
|
|
|
function_info.verification_hints.add(VerificationHint.NO_NULL_DEREFERENCE)
|
|
|
function_info.verification_hints.add(VerificationHint.BOUNDS_CHECK)
|
|
|
|
|
|
# 检查内存分配模式
|
|
|
if 'alloc' in function_name.lower() or 'malloc' in function_name.lower():
|
|
|
function_info.verification_hints.add(VerificationHint.NO_LEAK)
|
|
|
|
|
|
def _validate_metadata(self) -> None:
|
|
|
"""验证元数据完整性"""
|
|
|
# 检查函数定义的完整性
|
|
|
for function_name, function_info in self.metadata.functions.items():
|
|
|
if not function_info.is_declaration_only and not function_info.body_location:
|
|
|
self.logger.warning(f"Function {function_name} has no body location")
|
|
|
|
|
|
# 检查调用关系的完整性
|
|
|
for relation in self.metadata.call_relations:
|
|
|
if relation.caller not in self.metadata.functions:
|
|
|
self.logger.warning(f"Unknown caller in call relation: {relation.caller}")
|
|
|
if relation.callee not in self.metadata.functions:
|
|
|
self.logger.warning(f"Unknown callee in call relation: {relation.callee}")
|
|
|
|
|
|
def get_extraction_stats(self) -> Dict[str, Any]:
|
|
|
"""获取提取统计信息"""
|
|
|
if not self.metadata:
|
|
|
return {}
|
|
|
|
|
|
return {
|
|
|
'functions': len(self.metadata.functions),
|
|
|
'variables': len(self.metadata.variables),
|
|
|
'call_relations': len(self.metadata.call_relations),
|
|
|
'data_structures': len(self.metadata.data_structures),
|
|
|
'memory_patterns': len(self.metadata.memory_patterns),
|
|
|
'includes': len(self.metadata.includes),
|
|
|
'macros': len(self.metadata.macros)
|
|
|
}
|
|
|
|
|
|
def filter_functions(self, function_names: Set[str]) -> CodeMetadata:
|
|
|
"""过滤指定函数的元数据"""
|
|
|
if not self.metadata:
|
|
|
return CodeMetadata(file_path="", file_size=0, modification_time="")
|
|
|
|
|
|
filtered_metadata = CodeMetadata(
|
|
|
file_path=self.metadata.file_path,
|
|
|
file_size=self.metadata.file_size,
|
|
|
modification_time=self.metadata.modification_time
|
|
|
)
|
|
|
|
|
|
# 添加指定的函数
|
|
|
for func_name in function_names:
|
|
|
if func_name in self.metadata.functions:
|
|
|
filtered_metadata.add_function(self.metadata.functions[func_name])
|
|
|
|
|
|
# 添加相关的变量和调用关系
|
|
|
related_functions = set(function_names)
|
|
|
for relation in self.metadata.call_relations:
|
|
|
if relation.caller in function_names or relation.callee in function_names:
|
|
|
filtered_metadata.add_call_relation(relation)
|
|
|
related_functions.add(relation.caller)
|
|
|
related_functions.add(relation.callee)
|
|
|
|
|
|
# 添加相关的变量
|
|
|
for var_name, var_info in self.metadata.variables.items():
|
|
|
if var_info.scope == VariableScope.GLOBAL:
|
|
|
filtered_metadata.add_variable(var_info)
|
|
|
|
|
|
return filtered_metadata |