You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

296 lines
9.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "sem/SymbolTable.h"
#include <antlr4-runtime.h> // 用于访问父节点
// ---------- 构造函数 ----------
SymbolTable::SymbolTable() {
scopes_.emplace_back(); // 初始化全局作用域
registerBuiltinFunctions(); // 注册内置库函数
}
// ---------- 作用域管理 ----------
void SymbolTable::enterScope() {
scopes_.emplace_back();
}
void SymbolTable::exitScope() {
if (scopes_.size() > 1) {
scopes_.pop_back();
}
// 不能退出全局作用域
}
// ---------- 符号添加与查找 ----------
bool SymbolTable::addSymbol(const Symbol& sym) {
auto& current_scope = scopes_.back();
if (current_scope.find(sym.name) != current_scope.end()) {
return false; // 重复定义
}
current_scope[sym.name] = sym;
return true;
}
Symbol* SymbolTable::lookup(const std::string& name) {
// 从当前作用域向外层查找
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
auto& scope = *it;
auto found = scope.find(name);
if (found != scope.end()) {
return &found->second;
}
}
return nullptr;
}
Symbol* SymbolTable::lookupCurrent(const std::string& name) {
auto& current_scope = scopes_.back();
auto it = current_scope.find(name);
if (it != current_scope.end()) {
return &it->second;
}
return nullptr;
}
// ---------- 兼容原接口 ----------
void SymbolTable::Add(const std::string& name, SysYParser::VarDefContext* decl) {
Symbol sym;
sym.name = name;
sym.kind = SymbolKind::Variable;
sym.type = getTypeFromVarDef(decl);
sym.var_def_ctx = decl;
sym.scope_level = currentScopeLevel();
addSymbol(sym);
}
bool SymbolTable::Contains(const std::string& name) const {
// const 方法不能修改 scopes_我们模拟查找
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
if (it->find(name) != it->end()) {
return true;
}
}
return false;
}
SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const {
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
auto found = it->find(name);
if (found != it->end()) {
// 只返回变量定义的上下文(函数等其他符号返回 nullptr
if (found->second.kind == SymbolKind::Variable) {
return found->second.var_def_ctx;
}
return nullptr;
}
}
return nullptr;
}
// ---------- 辅助函数:从 VarDefContext 获取外层 VarDeclContext ----------
static SysYParser::VarDeclContext* getOuterVarDecl(SysYParser::VarDefContext* varDef) {
auto parent = varDef->parent;
while (parent) {
if (auto varDecl = dynamic_cast<SysYParser::VarDeclContext*>(parent)) {
return varDecl;
}
parent = parent->parent;
}
return nullptr;
}
// ---------- 辅助函数:从 VarDefContext 获取外层 ConstDeclContext常量定义----------
static SysYParser::ConstDeclContext* getOuterConstDecl(SysYParser::VarDefContext* varDef) {
auto parent = varDef->parent;
while (parent) {
if (auto constDecl = dynamic_cast<SysYParser::ConstDeclContext*>(parent)) {
return constDecl;
}
parent = parent->parent;
}
return nullptr;
}
// 常量表达式求值(占位,需实现真正的常量折叠)
static int evaluateConstExp(SysYParser::ConstExpContext* ctx) {
// TODO: 实现常量折叠目前返回0
return 0;
}
// 从 VarDefContext 构造类型
std::shared_ptr<ir::Type> SymbolTable::getTypeFromVarDef(SysYParser::VarDefContext* ctx) {
// 1. 获取基本类型int/float
std::shared_ptr<ir::Type> base_type = nullptr;
auto varDecl = getOuterVarDecl(ctx);
if (varDecl) {
auto bType = varDecl->bType();
if (bType->Int()) {
base_type = ir::Type::GetInt32Type();
} else if (bType->Float()) {
base_type = ir::Type::GetFloatType();
}
} else {
auto constDecl = getOuterConstDecl(ctx);
if (constDecl) {
auto bType = constDecl->bType();
if (bType->Int()) {
base_type = ir::Type::GetInt32Type();
} else if (bType->Float()) {
base_type = ir::Type::GetFloatType();
}
}
}
if (!base_type) {
base_type = ir::Type::GetInt32Type(); // 默认 int
}
// 2. 解析数组维度(从 varDef 的 constExp 列表获取)
std::vector<int> dims;
for (auto constExp : ctx->constExp()) {
int dimVal = evaluateConstExp(constExp);
dims.push_back(dimVal);
}
if (!dims.empty()) {
return ir::Type::GetArrayType(base_type, dims);
}
return base_type;
}
// 从 FuncDefContext 构造函数类型
std::shared_ptr<ir::Type> SymbolTable::getTypeFromFuncDef(SysYParser::FuncDefContext* ctx) {
// 1. 返回类型
std::shared_ptr<ir::Type> ret_type;
auto funcType = ctx->funcType();
if (funcType->Void()) {
ret_type = ir::Type::GetVoidType();
} else if (funcType->Int()) {
ret_type = ir::Type::GetInt32Type();
} else if (funcType->Float()) {
ret_type = ir::Type::GetFloatType();
} else {
ret_type = ir::Type::GetInt32Type(); // fallback
}
// 2. 参数类型
std::vector<std::shared_ptr<ir::Type>> param_types;
auto fParams = ctx->funcFParams();
if (fParams) {
for (auto param : fParams->funcFParam()) {
std::shared_ptr<ir::Type> param_type;
auto bType = param->bType();
if (bType->Int()) {
param_type = ir::Type::GetInt32Type();
} else if (bType->Float()) {
param_type = ir::Type::GetFloatType();
} else {
param_type = ir::Type::GetInt32Type();
}
// 处理数组参数:如果存在 [ ] 或 [ exp ],退化为指针
if (param->L_BRACK().size() > 0) {
if (param_type->IsInt32()) {
param_type = ir::Type::GetPtrInt32Type();
} else if (param_type->IsFloat()) {
param_type = ir::Type::GetPtrFloatType();
}
}
param_types.push_back(param_type);
}
}
return ir::Type::GetFunctionType(ret_type, param_types);
}
// ----- 注册内置库函数-----
void SymbolTable::registerBuiltinFunctions() {
// 确保当前处于全局作用域scopes_ 只有一层)
// 1. getint: int getint()
Symbol getint;
getint.name = "getint";
getint.kind = SymbolKind::Function;
getint.type = ir::Type::GetFunctionType(ir::Type::GetInt32Type(), {}); // 无参数
getint.param_types = {};
getint.scope_level = 0;
getint.is_builtin = true;
addSymbol(getint);
// 2. getfloat: float getfloat()
Symbol getfloat;
getfloat.name = "getfloat";
getfloat.kind = SymbolKind::Function;
getfloat.type = ir::Type::GetFunctionType(ir::Type::GetFloatType(), {});
getfloat.param_types = {};
getfloat.scope_level = 0;
getfloat.is_builtin = true;
addSymbol(getfloat);
// 3. getch: int getch()
Symbol getch;
getch.name = "getch";
getch.kind = SymbolKind::Function;
getch.type = ir::Type::GetFunctionType(ir::Type::GetInt32Type(), {});
getch.param_types = {};
getch.scope_level = 0;
getch.is_builtin = true;
addSymbol(getch);
// 4. putint: void putint(int)
std::vector<std::shared_ptr<ir::Type>> putint_params = { ir::Type::GetInt32Type() };
Symbol putint;
putint.name = "putint";
putint.kind = SymbolKind::Function;
putint.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), putint_params);
putint.param_types = putint_params;
putint.scope_level = 0;
putint.is_builtin = true;
addSymbol(putint);
// 5. putfloat: void putfloat(float)
std::vector<std::shared_ptr<ir::Type>> putfloat_params = { ir::Type::GetFloatType() };
Symbol putfloat;
putfloat.name = "putfloat";
putfloat.kind = SymbolKind::Function;
putfloat.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), putfloat_params);
putfloat.param_types = putfloat_params;
putfloat.scope_level = 0;
putfloat.is_builtin = true;
addSymbol(putfloat);
// 6. putch: void putch(int)
std::vector<std::shared_ptr<ir::Type>> putch_params = { ir::Type::GetInt32Type() };
Symbol putch;
putch.name = "putch";
putch.kind = SymbolKind::Function;
putch.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), putch_params);
putch.param_types = putch_params;
putch.scope_level = 0;
putch.is_builtin = true;
addSymbol(putch);
// 7. getarray: int getarray(int a[])
// 参数类型: int a[] 退化为 int* 即 PtrInt32
std::vector<std::shared_ptr<ir::Type>> getarray_params = { ir::Type::GetPtrInt32Type() };
Symbol getarray;
getarray.name = "getarray";
getarray.kind = SymbolKind::Function;
getarray.type = ir::Type::GetFunctionType(ir::Type::GetInt32Type(), getarray_params);
getarray.param_types = getarray_params;
getarray.scope_level = 0;
getarray.is_builtin = true;
addSymbol(getarray);
// 8. putarray: void putarray(int n, int a[])
// 参数: int n, int a[] -> 实际类型: int, int*
std::vector<std::shared_ptr<ir::Type>> putarray_params = { ir::Type::GetInt32Type(), ir::Type::GetPtrInt32Type() };
Symbol putarray;
putarray.name = "putarray";
putarray.kind = SymbolKind::Function;
putarray.type = ir::Type::GetFunctionType(ir::Type::GetVoidType(), putarray_params);
putarray.param_types = putarray_params;
putarray.scope_level = 0;
putarray.is_builtin = true;
addSymbol(putarray);
// 9. putf: void putf(char fmt[], ...) —— 可选,但为了完整性
// 参数: char fmt[] 退化为 char*,但 SysY 中没有 char 类型,可能使用 int 数组或特殊处理,此处略过
}