You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
cbmc/codedetect/src/parser/ast_extractor.py

865 lines
33 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
AST信息提取器模块
本模块实现了从解析后的C/C++ AST中提取函数签名、全局变量、调用关系、
数据结构等元数据的功能。支持CBMC特定的分析如指针使用模式、内存访问分析等。
"""
import os
import time
from typing import Dict, List, Optional, Set, Union, Any, Callable
from pathlib import Path
from collections import defaultdict
try:
from pycparser import c_ast
except ImportError:
raise ImportError("pycparser is required. Install with: pip install pycparser==2.21")
from .metadata import (
CodeMetadata, FunctionInfo, VariableInfo, CallRelation, DataStructure,
ParameterInfo, StructField, MemoryAccessPattern, SourceLocation,
VariableScope, MemoryAccessType, VerificationHint
)
from .c_parser import CParser
from ..utils.logger import get_logger
class ASTVisitor(c_ast.NodeVisitor):
"""基础AST访问器"""
def __init__(self, extractor: 'ASTExtractor'):
self.extractor = extractor
self.logger = extractor.logger
self.current_function = None
self.current_scope = []
self.source_file = None
def visit(self, node: c_ast.Node) -> Any:
"""访问节点"""
if hasattr(node, 'coord') and node.coord:
self.source_file = node.coord.file
return super().visit(node)
def _stringify_expression(self, expr_node: Optional[c_ast.Node]) -> Optional[str]:
"""安全地将表达式节点转换为字符串"""
if expr_node is None:
return None
try:
if isinstance(expr_node, c_ast.Constant):
return str(expr_node.value)
elif isinstance(expr_node, c_ast.ID):
return expr_node.name
elif isinstance(expr_node, c_ast.UnaryOp):
return f"{expr_node.op}{self._stringify_expression(expr_node.expr)}"
elif isinstance(expr_node, c_ast.BinaryOp):
left = self._stringify_expression(expr_node.left)
right = self._stringify_expression(expr_node.right)
return f"{left} {expr_node.op} {right}" if left and right else None
elif isinstance(expr_node, c_ast.TernaryOp):
cond = self._stringify_expression(expr_node.cond)
true_val = self._stringify_expression(expr_node.iftrue)
false_val = self._stringify_expression(expr_node.iffalse)
if cond and true_val and false_val:
return f"{cond} ? {true_val} : {false_val}"
return None
else:
# 对于其他类型的表达式返回None而不是尝试字符串化
self.logger.debug(f"Unsupported expression type: {type(expr_node).__name__}")
return None
except Exception as e:
self.logger.debug(f"Error stringifying expression: {e}")
return None
class FunctionExtractor(ASTVisitor):
"""函数提取器"""
def __init__(self, extractor: 'ASTExtractor'):
super().__init__(extractor)
self.functions = {}
def visit_FuncDef(self, node: c_ast.FuncDef) -> None:
"""访问函数定义"""
function_info = self._extract_function_info(node)
if function_info:
self.functions[function_info.name] = function_info
self.extractor.metadata.add_function(function_info)
# 继续访问函数体
self.current_function = function_info
self.generic_visit(node)
self.current_function = None
def visit_Decl(self, node: c_ast.Decl) -> None:
"""访问声明(包括函数声明)"""
if isinstance(node.type, c_ast.FuncDecl) and not hasattr(node, 'init'):
# 这是函数声明,不是定义
function_info = self._extract_function_decl_info(node)
if function_info:
self.functions[function_info.name] = function_info
self.extractor.metadata.add_function(function_info)
self.generic_visit(node)
def _extract_function_info(self, node: c_ast.FuncDef) -> Optional[FunctionInfo]:
"""从函数定义提取信息"""
try:
decl = node.decl
return_type = self._extract_type(decl.type.type)
location = self._extract_location(decl.coord)
# 提取参数
parameters = []
if hasattr(decl.type, 'args') and decl.type.args:
for param in decl.type.args.params:
param_info = self._extract_parameter_info(param)
if param_info:
parameters.append(param_info)
# 检查函数属性
is_static = 'static' in (decl.storage or [])
is_inline = 'inline' in (getattr(decl, 'funcspec', []) or [])
# 提取函数体位置
body_location = self._extract_location(node.body.coord) if node.body.coord else None
return FunctionInfo(
name=decl.name,
return_type=return_type,
parameters=parameters,
is_static=is_static,
is_inline=is_inline,
is_declaration_only=False,
location=location,
body_location=body_location
)
except Exception as e:
self.logger.error(f"Error extracting function info: {e}")
return None
def _extract_function_decl_info(self, node: c_ast.Decl) -> Optional[FunctionInfo]:
"""从函数声明提取信息"""
try:
return_type = self._extract_type(node.type.type)
location = self._extract_location(node.coord)
# 提取参数
parameters = []
if hasattr(node.type, 'args') and node.type.args:
for param in node.type.args.params:
param_info = self._extract_parameter_info(param)
if param_info:
parameters.append(param_info)
return FunctionInfo(
name=node.name,
return_type=return_type,
parameters=parameters,
is_declaration_only=True,
location=location
)
except Exception as e:
self.logger.error(f"Error extracting function declaration info: {e}")
return None
def _extract_parameter_info(self, node: c_ast.Decl) -> Optional[ParameterInfo]:
"""提取参数信息"""
try:
type_name = self._extract_type(node.type)
is_pointer = self._is_pointer_type(node.type)
is_array = isinstance(node.type, c_ast.ArrayDecl)
array_size = self._stringify_expression(node.type.dim) if is_array and hasattr(node.type, 'dim') else None
location = self._extract_location(node.coord)
# 处理未命名参数
param_name = getattr(node, 'name', None)
if param_name is None:
# 生成稳定的占位符名称
param_name = f"<unnamed_param_{id(node)}>"
return ParameterInfo(
name=param_name,
type_name=type_name,
is_pointer=is_pointer,
is_array=is_array,
array_size=array_size,
location=location
)
except Exception as e:
self.logger.error(f"Error extracting parameter info: {e}")
return None
def _extract_type(self, node: c_ast.Node) -> str:
"""提取类型名称"""
if isinstance(node, c_ast.IdentifierType):
return ' '.join(node.names)
elif isinstance(node, c_ast.PtrDecl):
return self._extract_type(node.type) + '*'
elif isinstance(node, c_ast.ArrayDecl):
return self._extract_type(node.type) + '[]'
elif isinstance(node, c_ast.TypeDecl):
return self._extract_type(node.type)
else:
return str(type(node).__name__)
def _is_pointer_type(self, node: c_ast.Node) -> bool:
"""检查是否为指针类型"""
return isinstance(node, c_ast.PtrDecl)
def _extract_location(self, coord) -> Optional[SourceLocation]:
"""提取源代码位置"""
if not coord:
return None
return SourceLocation(
file=coord.file or self.source_file or "",
line=coord.line,
column=coord.column
)
class VariableExtractor(ASTVisitor):
"""变量提取器"""
def __init__(self, extractor: 'ASTExtractor'):
super().__init__(extractor)
self.variables = {}
self.global_scope = True
self.current_params: set[str] = set()
def visit_FuncDef(self, node: c_ast.FuncDef) -> None:
"""访问函数定义"""
self.global_scope = False
self.current_scope.append(node.decl.name)
self.current_function = node.decl.name
# 提取参数名称集合
self.current_params = set()
if getattr(node.decl.type, 'args', None) and node.decl.type.args:
for p in node.decl.type.args.params:
if isinstance(p, c_ast.Decl) and getattr(p, 'name', None):
self.current_params.add(p.name)
self.generic_visit(node)
self.current_function = None
self.current_scope.pop()
self.current_params = set()
self.global_scope = True
def visit_Decl(self, node: c_ast.Decl) -> None:
"""访问变量声明"""
if not isinstance(node.type, c_ast.FuncDecl):
variable_info = self._extract_variable_info(node)
if variable_info:
self.variables[variable_info.name] = variable_info
self.extractor.metadata.add_variable(variable_info)
self.generic_visit(node)
def _extract_variable_info(self, node: c_ast.Decl) -> Optional[VariableInfo]:
"""提取变量信息"""
try:
type_name = self._extract_type(node.type)
is_pointer = self._is_pointer_type(node.type)
is_array = isinstance(node.type, c_ast.ArrayDecl)
array_size = self._stringify_expression(node.type.dim) if is_array and hasattr(node.type, 'dim') else None
is_const = any(n.name == 'const' for n in getattr(node, 'quals', []))
is_volatile = any(n.name == 'volatile' for n in getattr(node, 'quals', []))
is_static = 'static' in (getattr(node, 'storage', []) or [])
initial_value = None
if hasattr(node, 'init') and node.init:
initial_value = self._extract_initializer(node.init)
location = self._extract_location(node.coord)
# 确定作用域
scope = VariableScope.GLOBAL if self.global_scope else VariableScope.LOCAL
if node.name in self.current_params:
scope = VariableScope.PARAMETER
elif is_static:
scope = VariableScope.STATIC
return VariableInfo(
name=node.name,
type_name=type_name,
scope=scope,
is_pointer=is_pointer,
is_array=is_array,
array_size=array_size,
is_const=is_const,
is_volatile=is_volatile,
is_static=is_static,
initial_value=initial_value,
location=location
)
except Exception as e:
self.logger.error(f"Error extracting variable info: {e}")
return None
def _extract_type(self, node: c_ast.Node) -> str:
"""提取类型名称"""
if isinstance(node, c_ast.IdentifierType):
return ' '.join(node.names)
elif isinstance(node, c_ast.PtrDecl):
return self._extract_type(node.type) + '*'
elif isinstance(node, c_ast.ArrayDecl):
return self._extract_type(node.type) + '[]'
elif isinstance(node, c_ast.TypeDecl):
return self._extract_type(node.type)
else:
return str(type(node).__name__)
def _is_pointer_type(self, node: c_ast.Node) -> bool:
"""检查是否为指针类型"""
return isinstance(node, c_ast.PtrDecl)
def _extract_initializer(self, node: c_ast.Node) -> Optional[str]:
"""提取初始值"""
try:
if isinstance(node, c_ast.Constant):
return node.value
elif isinstance(node, c_ast.InitList):
return "{...}" # 简化的初始化列表表示
elif isinstance(node, c_ast.UnaryOp) and node.op == '&':
return f"&{self._extract_initializer(node.expr)}"
else:
return str(node)
except Exception:
return None
def _extract_location(self, coord) -> Optional[SourceLocation]:
"""提取源代码位置"""
if not coord:
return None
return SourceLocation(
file=coord.file or self.source_file or "",
line=coord.line,
column=coord.column
)
class CallGraphExtractor(ASTVisitor):
"""调用图提取器"""
def __init__(self, extractor: 'ASTExtractor'):
super().__init__(extractor)
self.call_relations = []
self.function_calls = defaultdict(list)
def visit_FuncDef(self, node: c_ast.FuncDef) -> None:
"""访问函数定义"""
self.current_function = node.decl.name
self.generic_visit(node)
self.current_function = None
def visit_FuncCall(self, node: c_ast.FuncCall) -> None:
"""访问函数调用"""
if self.current_function:
callee_name = self._extract_callee_name(node.name)
if callee_name:
location = self._extract_location(node.coord)
# 确定调用类型
call_type = 'indirect' if callee_name.startswith('<indirect>') else 'direct'
call_relation = CallRelation(
caller=self.current_function,
callee=callee_name,
call_type=call_type,
call_sites=[location] if location else []
)
self.call_relations.append(call_relation)
self.extractor.metadata.add_call_relation(call_relation)
self.generic_visit(node)
def _extract_callee_name(self, node: c_ast.Node) -> Optional[str]:
"""提取被调用函数名称"""
if isinstance(node, c_ast.ID):
return node.name
elif isinstance(node, c_ast.UnaryOp) and node.op == '&':
return self._extract_callee_name(node.expr)
elif isinstance(node, c_ast.ArrayRef):
# 函数指针数组调用
return self._extract_callee_name(node.name)
elif isinstance(node, c_ast.StructRef):
# 结构体/联合体中的函数指针调用
field_name = self._extract_callee_name(node.field)
if field_name:
# 标记为间接调用
self.logger.debug(f"StructRef function call: {node.name.name}.{field_name}")
return f"<indirect>:{field_name}"
return None
elif isinstance(node, c_ast.PtrDecl):
# 函数指针调用
return "<indirect>"
else:
# 无法解析的调用类型
self.logger.debug(f"Unresolvable call type: {type(node).__name__}")
return "<indirect>"
def _extract_location(self, coord) -> Optional[SourceLocation]:
"""提取源代码位置"""
if not coord:
return None
return SourceLocation(
file=coord.file or self.source_file or "",
line=coord.line,
column=coord.column
)
class DataStructureExtractor(ASTVisitor):
"""数据结构提取器"""
def __init__(self, extractor: 'ASTExtractor'):
super().__init__(extractor)
self.data_structures = {}
def visit_Struct(self, node: c_ast.Struct) -> None:
"""访问结构体定义"""
if node.decls:
structure_info = self._extract_structure_info(node)
if structure_info:
self.data_structures[structure_info.name] = structure_info
self.extractor.metadata.add_data_structure(structure_info)
elif node.name:
# 处理不透明结构体(前向声明)
structure_info = DataStructure(
name=node.name,
kind='struct',
fields=[],
is_opaque=True,
is_anonymous=False,
location=self._extract_location(node.coord)
)
self.data_structures[structure_info.name] = structure_info
self.extractor.metadata.add_data_structure(structure_info)
self.logger.debug(f"Added opaque struct: {node.name}")
self.generic_visit(node)
def visit_Union(self, node: c_ast.Union) -> None:
"""访问联合体定义"""
if node.decls:
structure_info = self._extract_structure_info(node)
if structure_info:
self.data_structures[structure_info.name] = structure_info
self.extractor.metadata.add_data_structure(structure_info)
elif node.name:
# 处理不透明联合体(前向声明)
structure_info = DataStructure(
name=node.name,
kind='union',
fields=[],
is_opaque=True,
is_anonymous=False,
location=self._extract_location(node.coord)
)
self.data_structures[structure_info.name] = structure_info
self.extractor.metadata.add_data_structure(structure_info)
self.logger.debug(f"Added opaque union: {node.name}")
self.generic_visit(node)
def _extract_structure_info(self, node) -> Optional[DataStructure]:
"""提取数据结构信息"""
try:
kind = 'struct' if isinstance(node, c_ast.Struct) else 'union'
name = node.name or f"anonymous_{id(node)}"
is_anonymous = node.name is None
location = self._extract_location(node.coord)
# 提取字段
fields = []
if node.decls:
for decl in node.decls:
field_info = self._extract_field_info(decl)
if field_info:
fields.append(field_info)
return DataStructure(
name=name,
kind=kind,
fields=fields,
is_anonymous=is_anonymous,
location=location
)
except Exception as e:
self.logger.error(f"Error extracting structure info: {e}")
return None
def _extract_field_info(self, node: c_ast.Decl) -> Optional[StructField]:
"""提取字段信息"""
try:
type_name = self._extract_type(node.type)
is_pointer = self._is_pointer_type(node.type)
is_array = isinstance(node.type, c_ast.ArrayDecl)
array_size = self._stringify_expression(node.type.dim) if is_array and hasattr(node.type, 'dim') else None
location = self._extract_location(node.coord)
# 检查位域
bit_field = None
if hasattr(node, 'bitsize') and node.bitsize:
if isinstance(node.bitsize, c_ast.Constant):
bit_field = int(node.bitsize.value)
return StructField(
name=node.name,
type_name=type_name,
is_pointer=is_pointer,
is_array=is_array,
array_size=array_size,
bit_field=bit_field,
location=location
)
except Exception as e:
self.logger.error(f"Error extracting field info: {e}")
return None
def _extract_type(self, node: c_ast.Node) -> str:
"""提取类型名称"""
if isinstance(node, c_ast.IdentifierType):
return ' '.join(node.names)
elif isinstance(node, c_ast.PtrDecl):
return self._extract_type(node.type) + '*'
elif isinstance(node, c_ast.ArrayDecl):
return self._extract_type(node.type) + '[]'
elif isinstance(node, c_ast.TypeDecl):
return self._extract_type(node.type)
else:
return str(type(node).__name__)
def _is_pointer_type(self, node: c_ast.Node) -> bool:
"""检查是否为指针类型"""
return isinstance(node, c_ast.PtrDecl)
def _extract_location(self, coord) -> Optional[SourceLocation]:
"""提取源代码位置"""
if not coord:
return None
return SourceLocation(
file=coord.file or self.source_file or "",
line=coord.line,
column=coord.column
)
class PointerAnalyzer(ASTVisitor):
"""指针分析器"""
def __init__(self, extractor: 'ASTExtractor'):
super().__init__(extractor)
self.memory_patterns = []
def visit_UnaryOp(self, node: c_ast.UnaryOp) -> None:
"""访问一元操作符(指针解引用、取地址等)"""
if node.op in ['*', '&']:
self._analyze_pointer_operation(node)
self.generic_visit(node)
def visit_ArrayRef(self, node: c_ast.ArrayRef) -> None:
"""访问数组引用(指针算术)"""
self._analyze_array_access(node)
self.generic_visit(node)
def visit_Assignment(self, node: c_ast.Assignment) -> None:
"""访问赋值操作(指针赋值)"""
if isinstance(node.lvalue, c_ast.UnaryOp) and node.lvalue.op == '*':
# 指针解引用赋值
self._analyze_pointer_dereference(node.lvalue, MemoryAccessType.WRITE)
self.generic_visit(node)
def _analyze_pointer_operation(self, node: c_ast.UnaryOp) -> None:
"""分析指针操作"""
try:
if node.op == '*':
# 指针解引用
variable_name = self._extract_variable_name(node.expr)
if variable_name:
location = self._extract_location(node.coord)
pattern = MemoryAccessPattern(
variable_name=variable_name,
access_type=MemoryAccessType.READ,
access_locations=[location] if location else [],
is_dereferenced=True
)
self.memory_patterns.append(pattern)
self.extractor.metadata.add_memory_pattern(pattern)
elif node.op == '&':
# 取地址操作
variable_name = self._extract_variable_name(node.expr)
if variable_name:
location = self._extract_location(node.coord)
pattern = MemoryAccessPattern(
variable_name=variable_name,
access_type=MemoryAccessType.READ,
access_locations=[location] if location else []
)
self.memory_patterns.append(pattern)
self.extractor.metadata.add_memory_pattern(pattern)
except Exception as e:
self.logger.error(f"Error analyzing pointer operation: {e}")
def _analyze_array_access(self, node: c_ast.ArrayRef) -> None:
"""分析数组访问"""
try:
array_name = self._extract_variable_name(node.name)
if array_name:
location = self._extract_location(node.coord)
index_expr = self._extract_expression(node.subscript)
pattern = MemoryAccessPattern(
variable_name=array_name,
access_type=MemoryAccessType.READ,
access_locations=[location] if location else [],
is_indexed=True,
indices=[index_expr] if index_expr else []
)
self.memory_patterns.append(pattern)
self.extractor.metadata.add_memory_pattern(pattern)
except Exception as e:
self.logger.error(f"Error analyzing array access: {e}")
def _analyze_pointer_dereference(self, node: c_ast.UnaryOp, access_type: MemoryAccessType) -> None:
"""分析指针解引用"""
try:
variable_name = self._extract_variable_name(node.expr)
if variable_name:
location = self._extract_location(node.coord)
pattern = MemoryAccessPattern(
variable_name=variable_name,
access_type=access_type,
access_locations=[location] if location else [],
is_dereferenced=True
)
self.memory_patterns.append(pattern)
self.extractor.metadata.add_memory_pattern(pattern)
except Exception as e:
self.logger.error(f"Error analyzing pointer dereference: {e}")
def _extract_variable_name(self, node: c_ast.Node) -> Optional[str]:
"""提取变量名称"""
if isinstance(node, c_ast.ID):
return node.name
elif isinstance(node, c_ast.ArrayRef):
return self._extract_variable_name(node.name)
elif isinstance(node, c_ast.UnaryOp) and node.op == '*':
return self._extract_variable_name(node.expr)
else:
return None
def _extract_expression(self, node: c_ast.Node) -> Optional[str]:
"""提取表达式字符串"""
try:
if isinstance(node, c_ast.Constant):
return node.value
elif isinstance(node, c_ast.ID):
return node.name
elif isinstance(node, c_ast.BinaryOp):
return f"{self._extract_expression(node.left)} {node.op} {self._extract_expression(node.right)}"
else:
return str(node)
except Exception:
return None
def _extract_location(self, coord) -> Optional[SourceLocation]:
"""提取源代码位置"""
if not coord:
return None
return SourceLocation(
file=coord.file or self.source_file or "",
line=coord.line,
column=coord.column
)
class ASTExtractor:
"""AST信息提取器主类"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
self.config = config or {}
self.logger = get_logger('ast_extractor')
self.metadata = None
self.extractors = []
def extract_metadata(self, ast: c_ast.FileAST, file_path: Union[str, Path]) -> CodeMetadata:
"""从AST提取元数据"""
start_time = time.time()
file_path = Path(file_path)
self.logger.info(f"Extracting metadata from: {file_path}")
# 初始化元数据
stat = file_path.stat() if file_path.exists() else None
self.metadata = CodeMetadata(
file_path=str(file_path),
file_size=stat.st_size if stat else 0,
modification_time=str(stat.st_mtime) if stat else ""
)
# 创建并运行提取器
extractors = [
FunctionExtractor(self),
VariableExtractor(self),
CallGraphExtractor(self),
DataStructureExtractor(self),
PointerAnalyzer(self)
]
for extractor in extractors:
try:
extractor.visit(ast)
self.logger.debug(f"Completed extraction with {extractor.__class__.__name__}")
except Exception as e:
self.logger.error(f"Error in {extractor.__class__.__name__}: {e}")
# 后处理和验证
self._post_process_metadata()
duration = time.time() - start_time
self.logger.info(f"Metadata extraction completed in {duration:.2f}s")
self.logger.info(f"Extracted {len(self.metadata.functions)} functions, "
f"{len(self.metadata.variables)} variables, "
f"{len(self.metadata.call_relations)} call relations")
return self.metadata
def _post_process_metadata(self) -> None:
"""后处理元数据"""
try:
# 识别递归函数
self._identify_recursive_functions()
# 计算函数复杂度
self._calculate_function_complexity()
# 添加CBMC特定的验证提示
self._add_verification_hints()
# 验证元数据完整性
self._validate_metadata()
except Exception as e:
self.logger.error(f"Error in post-processing: {e}")
def _identify_recursive_functions(self) -> None:
"""识别递归函数"""
for relation in self.metadata.call_relations:
if relation.caller == relation.callee:
relation.is_recursive = True
if relation.caller in self.metadata.functions:
self.metadata.functions[relation.caller].verification_hints.add(
VerificationHint.REENTRANT
)
def _calculate_function_complexity(self) -> None:
"""计算函数复杂度指标"""
for function_name, function_info in self.metadata.functions.items():
# 简单的复杂度计算:基于参数数量和调用关系
complexity_metrics = {
'parameter_count': len(function_info.parameters),
'call_count': len([
rel for rel in self.metadata.call_relations
if rel.caller == function_name
]),
'is_called': len([
rel for rel in self.metadata.call_relations
if rel.callee == function_name
]) > 0
}
function_info.complexity_metrics.update(complexity_metrics)
def _add_verification_hints(self) -> None:
"""添加CBMC特定的验证提示"""
for function_name, function_info in self.metadata.functions.items():
# 检查指针参数
has_pointer_params = any(param.is_pointer for param in function_info.parameters)
if has_pointer_params:
function_info.verification_hints.add(VerificationHint.NO_NULL_DEREFERENCE)
function_info.verification_hints.add(VerificationHint.BOUNDS_CHECK)
# 检查内存分配模式
if 'alloc' in function_name.lower() or 'malloc' in function_name.lower():
function_info.verification_hints.add(VerificationHint.NO_LEAK)
def _validate_metadata(self) -> None:
"""验证元数据完整性"""
# 检查函数定义的完整性
for function_name, function_info in self.metadata.functions.items():
if not function_info.is_declaration_only and not function_info.body_location:
self.logger.warning(f"Function {function_name} has no body location")
# 检查调用关系的完整性
for relation in self.metadata.call_relations:
if relation.caller not in self.metadata.functions:
self.logger.warning(f"Unknown caller in call relation: {relation.caller}")
if relation.callee not in self.metadata.functions:
self.logger.warning(f"Unknown callee in call relation: {relation.callee}")
def get_extraction_stats(self) -> Dict[str, Any]:
"""获取提取统计信息"""
if not self.metadata:
return {}
return {
'functions': len(self.metadata.functions),
'variables': len(self.metadata.variables),
'call_relations': len(self.metadata.call_relations),
'data_structures': len(self.metadata.data_structures),
'memory_patterns': len(self.metadata.memory_patterns),
'includes': len(self.metadata.includes),
'macros': len(self.metadata.macros)
}
def filter_functions(self, function_names: Set[str]) -> CodeMetadata:
"""过滤指定函数的元数据"""
if not self.metadata:
return CodeMetadata(file_path="", file_size=0, modification_time="")
filtered_metadata = CodeMetadata(
file_path=self.metadata.file_path,
file_size=self.metadata.file_size,
modification_time=self.metadata.modification_time
)
# 添加指定的函数
for func_name in function_names:
if func_name in self.metadata.functions:
filtered_metadata.add_function(self.metadata.functions[func_name])
# 添加相关的变量和调用关系
related_functions = set(function_names)
for relation in self.metadata.call_relations:
if relation.caller in function_names or relation.callee in function_names:
filtered_metadata.add_call_relation(relation)
related_functions.add(relation.caller)
related_functions.add(relation.callee)
# 添加相关的变量
for var_name, var_info in self.metadata.variables.items():
if var_info.scope == VariableScope.GLOBAL:
filtered_metadata.add_variable(var_info)
return filtered_metadata