diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index 231ba90..53eb24d 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -29,13 +29,22 @@ class IRGenImpl final : public SysYBaseVisitor { std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; std::any visitDecl(SysYParser::DeclContext* ctx) override; + std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override; std::any visitStmt(SysYParser::StmtContext* ctx) override; std::any visitVarDef(SysYParser::VarDefContext* ctx) override; std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; - std::any visitParenExp(SysYParser::ParenExpContext* ctx) override; - std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override; - std::any visitVarExp(SysYParser::VarExpContext* ctx) override; - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override; + std::any visitExp(SysYParser::ExpContext* ctx) override; + std::any visitCond(SysYParser::CondContext* ctx) override; + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override; + std::any visitNumber(SysYParser::NumberContext* ctx) override; + std::any visitLValue(SysYParser::LValueContext* ctx) override; + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override; + std::any visitMulExp(SysYParser::MulExpContext* ctx) override; + std::any visitAddExp(SysYParser::AddExpContext* ctx) override; + std::any visitRelExp(SysYParser::RelExpContext* ctx) override; + std::any visitEqExp(SysYParser::EqExpContext* ctx) override; + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override; + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override; private: enum class BlockFlow { @@ -43,8 +52,16 @@ class IRGenImpl final : public SysYBaseVisitor { Terminated, }; + struct LoopTargets { + ir::BasicBlock* continue_target; + ir::BasicBlock* break_target; + }; + BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item); ir::Value* EvalExpr(SysYParser::ExpContext& expr); + ir::Value* EvalCond(SysYParser::CondContext& cond); + ir::Value* ToBoolValue(ir::Value* v); + std::string NextBlockName(); ir::Module& module_; const SemanticContext& sema_; @@ -52,6 +69,8 @@ class IRGenImpl final : public SysYBaseVisitor { ir::IRBuilder builder_; // 名称绑定由 Sema 负责;IRGen 只维护“声明 -> 存储槽位”的代码生成状态。 std::unordered_map storage_map_; + std::unordered_map named_storage_; + std::vector loop_stack_; }; std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, diff --git a/include/sem/Sema.h b/include/sem/Sema.h index 9ac057b..b684a41 100644 --- a/include/sem/Sema.h +++ b/include/sem/Sema.h @@ -1,30 +1,213 @@ -// 基于语法树的语义检查与名称绑定。 -#pragma once +#ifndef SEMANTIC_ANALYSIS_H +#define SEMANTIC_ANALYSIS_H +#include "SymbolTable.h" +#include "SysYBaseVisitor.h" +#include "SysYParser.h" +#include +#include +#include #include +#include +#include -#include "SysYParser.h" +// 错误信息结构体 +struct ErrorMsg { + std::string msg; + int line; + int column; + ErrorMsg(std::string m, int l, int c) : msg(std::move(m)), line(l), column(c) {} +}; + +// 前向声明 +namespace antlr4 { + class ParserRuleContext; + namespace tree { + class ParseTree; + } +} + +// 语义/IR生成上下文核心类 +class IRGenContext { +public: + // 错误管理 + void RecordError(const ErrorMsg& err) { errors_.push_back(err); } + const std::vector& GetErrors() const { return errors_; } + bool HasError() const { return !errors_.empty(); } + void ClearErrors() { errors_.clear(); } + + // 类型绑定/查询 - 使用 void* 以兼容测试代码 + void SetType(void* ctx, SymbolType type) { + node_type_map_[ctx] = type; + } + + SymbolType GetType(void* ctx) const { + auto it = node_type_map_.find(ctx); + return it == node_type_map_.end() ? SymbolType::TYPE_UNKNOWN : it->second; + } + + // 常量值绑定/查询 - 使用 void* 以兼容测试代码 + void SetConstVal(void* ctx, const std::any& val) { + const_val_map_[ctx] = val; + } + + std::any GetConstVal(void* ctx) const { + auto it = const_val_map_.find(ctx); + return it == const_val_map_.end() ? std::any() : it->second; + } + + // 循环状态管理 + void EnterLoop() { sym_table_.EnterLoop(); } + void ExitLoop() { sym_table_.ExitLoop(); } + bool InLoop() const { return sym_table_.InLoop(); } + + // 类型判断工具函数 + bool IsIntType(const std::any& val) const { + return val.type() == typeid(long) || val.type() == typeid(int); + } + + bool IsFloatType(const std::any& val) const { + return val.type() == typeid(double) || val.type() == typeid(float); + } + + // 当前函数返回类型 + SymbolType GetCurrentFuncReturnType() const { + return current_func_ret_type_; + } + + void SetCurrentFuncReturnType(SymbolType type) { + current_func_ret_type_ = type; + } + + // 符号表访问 + SymbolTable& GetSymbolTable() { return sym_table_; } + const SymbolTable& GetSymbolTable() const { return sym_table_; } + + // 作用域管理 + void EnterScope() { sym_table_.EnterScope(); } + void LeaveScope() { sym_table_.LeaveScope(); } + size_t GetScopeDepth() const { return sym_table_.GetScopeDepth(); } +private: + SymbolTable sym_table_; + std::unordered_map node_type_map_; + std::unordered_map const_val_map_; + std::vector errors_; + SymbolType current_func_ret_type_ = SymbolType::TYPE_UNKNOWN; +}; + +// 与现有 IRGen/主流程保持兼容的语义上下文占位。 class SemanticContext { public: - void BindVarUse(SysYParser::VarContext* use, - SysYParser::VarDefContext* decl) { - var_uses_[use] = decl; - } + void BindVarUse(const SysYParser::LValueContext* use, + SysYParser::VarDefContext* decl) { + var_uses_[use] = decl; + } - SysYParser::VarDefContext* ResolveVarUse( - const SysYParser::VarContext* use) const { - auto it = var_uses_.find(use); - return it == var_uses_.end() ? nullptr : it->second; - } + SysYParser::VarDefContext* ResolveVarUse( + const SysYParser::LValueContext* use) const { + auto it = var_uses_.find(use); + return it == var_uses_.end() ? nullptr : it->second; + } private: - std::unordered_map - var_uses_; + std::unordered_map + var_uses_; }; -// 目前仅检查: -// - 变量先声明后使用 -// - 局部变量不允许重复定义 +// 错误信息格式化工具函数 +inline std::string FormatErrMsg(const std::string& msg, int line, int col) { + std::ostringstream oss; + oss << "[行:" << line << ",列:" << col << "] " << msg; + return oss.str(); +} + +// 语义分析访问器 - 继承自生成的基类 +class SemaVisitor : public SysYBaseVisitor { +public: + explicit SemaVisitor(IRGenContext& ctx) : ir_ctx_(ctx) {} + + // 必须实现的 ANTLR4 接口 + std::any visit(antlr4::tree::ParseTree* tree) override { + if (tree) { + return tree->accept(this); + } + return std::any(); + } + + std::any visitTerminal(antlr4::tree::TerminalNode* node) override { + return std::any(); + } + + std::any visitErrorNode(antlr4::tree::ErrorNode* node) override { + if (node) { + int line = node->getSymbol()->getLine(); + int col = node->getSymbol()->getCharPositionInLine() + 1; + ir_ctx_.RecordError(ErrorMsg("语法错误节点", line, col)); + } + return std::any(); + } + + // 核心访问方法 + std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; + std::any visitDecl(SysYParser::DeclContext* ctx) override; + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override; + std::any visitBtype(SysYParser::BtypeContext* ctx) override; + std::any visitConstDef(SysYParser::ConstDefContext* ctx) override; + std::any visitConstInitValue(SysYParser::ConstInitValueContext* ctx) override; + std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override; + std::any visitVarDef(SysYParser::VarDefContext* ctx) override; + std::any visitInitValue(SysYParser::InitValueContext* ctx) override; + std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; + std::any visitFuncType(SysYParser::FuncTypeContext* ctx) override; + std::any visitFuncFParams(SysYParser::FuncFParamsContext* ctx) override; + std::any visitFuncFParam(SysYParser::FuncFParamContext* ctx) override; + std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; + std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; + std::any visitStmt(SysYParser::StmtContext* ctx) override; + std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; + std::any visitExp(SysYParser::ExpContext* ctx) override; + std::any visitCond(SysYParser::CondContext* ctx) override; + std::any visitLValue(SysYParser::LValueContext* ctx) override; + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override; + std::any visitNumber(SysYParser::NumberContext* ctx) override; + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override; + std::any visitUnaryOp(SysYParser::UnaryOpContext* ctx) override; + std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override; + std::any visitMulExp(SysYParser::MulExpContext* ctx) override; + std::any visitAddExp(SysYParser::AddExpContext* ctx) override; + std::any visitRelExp(SysYParser::RelExpContext* ctx) override; + std::any visitEqExp(SysYParser::EqExpContext* ctx) override; + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override; + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override; + std::any visitConstExp(SysYParser::ConstExpContext* ctx) override; + + // 通用子节点访问 + std::any visitChildren(antlr4::tree::ParseTree* node) override { + std::any result; + if (node) { + for (auto* child : node->children) { + if (child) { + result = child->accept(this); + } + } + } + return result; + } + + // 获取上下文引用 + IRGenContext& GetContext() { return ir_ctx_; } + const IRGenContext& GetContext() const { return ir_ctx_; } + +private: + IRGenContext& ir_ctx_; +}; + +// 语义分析入口函数 +void RunSemanticAnalysis(SysYParser::CompUnitContext* ctx, IRGenContext& ir_ctx); + +// 兼容旧流程入口。 SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); + +#endif // SEMANTIC_ANALYSIS_H \ No newline at end of file diff --git a/include/sem/SymbolTable.h b/include/sem/SymbolTable.h index c9396dd..7aaf966 100644 --- a/include/sem/SymbolTable.h +++ b/include/sem/SymbolTable.h @@ -1,17 +1,201 @@ -// 极简符号表:记录局部变量定义点。 -#pragma once +#ifndef SYMBOL_TABLE_H +#define SYMBOL_TABLE_H +#include #include +#include #include +#include +#include -#include "SysYParser.h" +// 核心类型枚举 +enum class SymbolType { + TYPE_UNKNOWN, // 未知类型 + TYPE_INT, // 整型 + TYPE_FLOAT, // 浮点型 + TYPE_VOID, // 空类型 + TYPE_ARRAY, // 数组类型 + TYPE_FUNCTION // 函数类型 +}; + +// 获取类型名称字符串 +inline const char* SymbolTypeToString(SymbolType type) { + switch (type) { + case SymbolType::TYPE_INT: return "int"; + case SymbolType::TYPE_FLOAT: return "float"; + case SymbolType::TYPE_VOID: return "void"; + case SymbolType::TYPE_ARRAY: return "array"; + case SymbolType::TYPE_FUNCTION: return "function"; + default: return "unknown"; + } +} + +// 变量信息结构体 +struct VarInfo { + SymbolType type = SymbolType::TYPE_UNKNOWN; + bool is_const = false; + std::any const_val; + std::vector array_dims; // 数组维度,空表示非数组 + void* decl_ctx = nullptr; // 关联的语法节点 + + // 检查是否为数组类型 + bool IsArray() const { return !array_dims.empty(); } + + // 获取数组元素总数 + int GetArrayElementCount() const { + int count = 1; + for (int dim : array_dims) { + count *= dim; + } + return count; + } +}; + +// 函数信息结构体 +struct FuncInfo { + SymbolType ret_type = SymbolType::TYPE_UNKNOWN; + std::string name; + std::vector param_types; // 参数类型列表 + void* decl_ctx = nullptr; // 关联的语法节点 + + // 检查参数匹配 + bool CheckParams(const std::vector& actual_params) const { + if (actual_params.size() != param_types.size()) { + return false; + } + + for (size_t i = 0; i < param_types.size(); ++i) { + if (param_types[i] != actual_params[i] && + param_types[i] != SymbolType::TYPE_UNKNOWN && + actual_params[i] != SymbolType::TYPE_UNKNOWN) { + return false; + } + } + return true; + } +}; + +// 作用域条目结构体 +struct ScopeEntry { + // 变量符号表:符号名 -> (符号信息, 声明节点) + std::unordered_map> var_symbols; + + // 函数符号表:符号名 -> (函数信息, 声明节点) + std::unordered_map> func_symbols; + + // 清空作用域 + void Clear() { + var_symbols.clear(); + func_symbols.clear(); + } +}; +// 符号表核心类 class SymbolTable { - public: - void Add(const std::string& name, SysYParser::VarDefContext* decl); - bool Contains(const std::string& name) const; - SysYParser::VarDefContext* Lookup(const std::string& name) const; +public: + // ========== 作用域管理 ========== + + // 进入新作用域 + void EnterScope(); + + // 离开当前作用域 + void LeaveScope(); + + // 获取当前作用域深度 + size_t GetScopeDepth() const { return scopes_.size(); } + + // 检查作用域栈是否为空 + bool IsEmpty() const { return scopes_.empty(); } + + // ========== 变量符号管理 ========== + + // 检查当前作用域是否包含指定变量 + bool CurrentScopeHasVar(const std::string& name) const; + + // 绑定变量到当前作用域 + void BindVar(const std::string& name, const VarInfo& info, void* decl_ctx); + + // 查找变量(从当前作用域向上遍历) + bool LookupVar(const std::string& name, VarInfo& out_info, void*& out_decl_ctx) const; + + // 快速查找变量(不获取详细信息) + bool HasVar(const std::string& name) const { + VarInfo info; + void* ctx; + return LookupVar(name, info, ctx); + } + + // ========== 函数符号管理 ========== + + // 检查当前作用域是否包含指定函数 + bool CurrentScopeHasFunc(const std::string& name) const; + + // 绑定函数到当前作用域 + void BindFunc(const std::string& name, const FuncInfo& info, void* decl_ctx); + + // 查找函数(从当前作用域向上遍历) + bool LookupFunc(const std::string& name, FuncInfo& out_info, void*& out_decl_ctx) const; + + // 快速查找函数(不获取详细信息) + bool HasFunc(const std::string& name) const { + FuncInfo info; + void* ctx; + return LookupFunc(name, info, ctx); + } + + // ========== 循环状态管理 ========== + + // 进入循环 + void EnterLoop(); + + // 离开循环 + void ExitLoop(); + + // 检查是否在循环内 + bool InLoop() const; + + // 获取循环嵌套深度 + int GetLoopDepth() const { return loop_depth_; } + + // ========== 辅助功能 ========== + + // 清空所有作用域和状态 + void Clear(); + + // 获取当前作用域中所有变量名 + std::vector GetCurrentScopeVarNames() const; + + // 获取当前作用域中所有函数名 + std::vector GetCurrentScopeFuncNames() const; + + // 调试:打印符号表内容 + void Dump() const; - private: - std::unordered_map table_; +private: + // 作用域栈 + std::stack scopes_; + + // 循环嵌套深度 + int loop_depth_ = 0; }; + +// 类型兼容性检查函数 +inline bool IsTypeCompatible(SymbolType expected, SymbolType actual) { + if (expected == SymbolType::TYPE_UNKNOWN || actual == SymbolType::TYPE_UNKNOWN) { + return true; // 未知类型视为兼容 + } + + // 基本类型兼容规则 + if (expected == actual) { + return true; + } + + // int 可以隐式转换为 float + if (expected == SymbolType::TYPE_FLOAT && actual == SymbolType::TYPE_INT) { + return true; + } + + return false; +} + +#endif // SYMBOL_TABLE_H \ No newline at end of file diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index 0eb62ae..75cfdf0 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -60,17 +60,29 @@ std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { // - const、数组、全局变量等不同声明形态; // - 更丰富的类型系统。 std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "缺少变量声明")); + } + if (!ctx->varDecl()) { + // 当前先忽略 constDecl 与其它声明形态。 + return {}; + } + return ctx->varDecl()->accept(this); +} + +std::any IRGenImpl::visitVarDecl(SysYParser::VarDeclContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少变量声明")); } if (!ctx->btype() || !ctx->btype()->INT()) { throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明")); } - auto* var_def = ctx->varDef(); - if (!var_def) { - throw std::runtime_error(FormatError("irgen", "非法变量声明")); + for (auto* var_def : ctx->varDef()) { + if (!var_def) { + throw std::runtime_error(FormatError("irgen", "非法变量声明")); + } + var_def->accept(this); } - var_def->accept(this); return {}; } @@ -83,15 +95,16 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少变量定义")); } - if (!ctx->lValue()) { + if (!ctx->ID()) { throw std::runtime_error(FormatError("irgen", "变量声明缺少名称")); } - GetLValueName(*ctx->lValue()); + const std::string name = ctx->ID()->getText(); if (storage_map_.find(ctx) != storage_map_.end()) { throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); } auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); storage_map_[ctx] = slot; + named_storage_[name] = slot; ir::Value* init = nullptr; if (auto* init_value = ctx->initValue()) { diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index cf4797c..7e22485 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -24,21 +24,62 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { return std::any_cast(expr.accept(this)); } +ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) { + return std::any_cast(cond.accept(this)); +} -std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "非法括号表达式")); +ir::Value* IRGenImpl::ToBoolValue(ir::Value* v) { + if (!v) { + throw std::runtime_error(FormatError("irgen", "条件值为空")); } - return EvalExpr(*ctx->exp()); + auto* zero = builder_.CreateConstInt(0); + return builder_.CreateCmp(ir::CmpOp::Ne, v, zero, module_.GetContext().NextTemp()); } +std::string IRGenImpl::NextBlockName() { + std::string temp = module_.GetContext().NextTemp(); + if (!temp.empty() && temp.front() == '%') { + return "bb" + temp.substr(1); + } + return "bb" + temp; +} -std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { +std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) { + if (!ctx || !ctx->addExp()) { + throw std::runtime_error(FormatError("irgen", "非法表达式")); + } + return ctx->addExp()->accept(this); +} + +std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) { + if (!ctx || !ctx->lOrExp()) { + throw std::runtime_error(FormatError("irgen", "非法条件表达式")); + } + return ctx->lOrExp()->accept(this); +} + +std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法基本表达式")); + } + if (ctx->exp()) { + return EvalExpr(*ctx->exp()); + } + if (ctx->number()) { + return ctx->number()->accept(this); + } + if (ctx->lValue()) { + return ctx->lValue()->accept(this); + } + throw std::runtime_error(FormatError("irgen", "不支持的基本表达式")); +} + +std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) { + if (!ctx || !ctx->ILITERAL()) { throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量")); } return static_cast( - builder_.CreateConstInt(std::stoi(ctx->number()->getText()))); + builder_.CreateConstInt(std::stoi(ctx->getText()))); } // 变量使用的处理流程: @@ -47,34 +88,192 @@ std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) { // 3. 最后生成 load,把内存中的值读出来。 // // 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。 -std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) { - if (!ctx || !ctx->var() || !ctx->var()->ID()) { +std::any IRGenImpl::visitLValue(SysYParser::LValueContext* ctx) { + if (!ctx || !ctx->ID()) { throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); } - auto* decl = sema_.ResolveVarUse(ctx->var()); - if (!decl) { - throw std::runtime_error( - FormatError("irgen", - "变量使用缺少语义绑定: " + ctx->var()->ID()->getText())); - } - auto it = storage_map_.find(decl); - if (it == storage_map_.end()) { + const std::string name = ctx->ID()->getText(); + auto it = named_storage_.find(name); + if (it == named_storage_.end()) { throw std::runtime_error( - FormatError("irgen", - "变量声明缺少存储槽位: " + ctx->var()->ID()->getText())); + FormatError("irgen", "变量声明缺少存储槽位: " + name)); } return static_cast( builder_.CreateLoad(it->second, module_.GetContext().NextTemp())); } +std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法一元表达式")); + } + if (ctx->primaryExp()) { + return ctx->primaryExp()->accept(this); + } + if (ctx->unaryOp() && ctx->unaryExp()) { + ir::Value* v = std::any_cast(ctx->unaryExp()->accept(this)); + if (ctx->unaryOp()->SUB()) { + auto* zero = builder_.CreateConstInt(0); + return static_cast(builder_.CreateSub( + zero, v, module_.GetContext().NextTemp())); + } + if (ctx->unaryOp()->ADD()) { + return v; + } + throw std::runtime_error(FormatError("irgen", "当前不支持逻辑非运算")); + } + throw std::runtime_error(FormatError("irgen", "当前不支持函数调用表达式")); +} + +std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); + } + if (ctx->mulExp()) { + if (!ctx->unaryExp()) { + throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); + } + ir::Value* lhs = std::any_cast(ctx->mulExp()->accept(this)); + ir::Value* rhs = std::any_cast(ctx->unaryExp()->accept(this)); + if (ctx->MUL()) { + return static_cast( + builder_.CreateMul(lhs, rhs, module_.GetContext().NextTemp())); + } + if (ctx->DIV()) { + return static_cast( + builder_.CreateDiv(lhs, rhs, module_.GetContext().NextTemp())); + } + if (ctx->MOD()) { + return static_cast( + builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp())); + } + throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); + } + if (ctx->unaryExp()) { + return ctx->unaryExp()->accept(this); + } + throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); +} -std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { +std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { + if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法加法表达式")); } - ir::Value* lhs = EvalExpr(*ctx->exp(0)); - ir::Value* rhs = EvalExpr(*ctx->exp(1)); - return static_cast( - builder_.CreateBinary(ir::Opcode::Add, lhs, rhs, - module_.GetContext().NextTemp())); + if (ctx->addExp()) { + if (!ctx->mulExp()) { + throw std::runtime_error(FormatError("irgen", "非法加法表达式")); + } + ir::Value* lhs = std::any_cast(ctx->addExp()->accept(this)); + ir::Value* rhs = std::any_cast(ctx->mulExp()->accept(this)); + if (ctx->ADD()) { + return static_cast( + builder_.CreateAdd(lhs, rhs, module_.GetContext().NextTemp())); + } + if (ctx->SUB()) { + return static_cast( + builder_.CreateSub(lhs, rhs, module_.GetContext().NextTemp())); + } + throw std::runtime_error(FormatError("irgen", "非法加法表达式")); + } + if (ctx->mulExp()) { + return ctx->mulExp()->accept(this); + } + throw std::runtime_error(FormatError("irgen", "非法加法表达式")); +} + +std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法关系表达式")); + } + if (ctx->relExp()) { + if (!ctx->addExp()) { + throw std::runtime_error(FormatError("irgen", "非法关系表达式")); + } + ir::Value* lhs = std::any_cast(ctx->relExp()->accept(this)); + ir::Value* rhs = std::any_cast(ctx->addExp()->accept(this)); + if (ctx->LT()) { + return static_cast(builder_.CreateCmp( + ir::CmpOp::Lt, lhs, rhs, module_.GetContext().NextTemp())); + } + if (ctx->LE()) { + return static_cast(builder_.CreateCmp( + ir::CmpOp::Le, lhs, rhs, module_.GetContext().NextTemp())); + } + if (ctx->GT()) { + return static_cast(builder_.CreateCmp( + ir::CmpOp::Gt, lhs, rhs, module_.GetContext().NextTemp())); + } + if (ctx->GE()) { + return static_cast(builder_.CreateCmp( + ir::CmpOp::Ge, lhs, rhs, module_.GetContext().NextTemp())); + } + throw std::runtime_error(FormatError("irgen", "非法关系表达式")); + } + if (ctx->addExp()) { + return ctx->addExp()->accept(this); + } + throw std::runtime_error(FormatError("irgen", "非法关系表达式")); +} + +std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法相等表达式")); + } + if (ctx->eqExp()) { + if (!ctx->relExp()) { + throw std::runtime_error(FormatError("irgen", "非法相等表达式")); + } + ir::Value* lhs = std::any_cast(ctx->eqExp()->accept(this)); + ir::Value* rhs = std::any_cast(ctx->relExp()->accept(this)); + if (ctx->EQ()) { + return static_cast(builder_.CreateCmp( + ir::CmpOp::Eq, lhs, rhs, module_.GetContext().NextTemp())); + } + if (ctx->NE()) { + return static_cast(builder_.CreateCmp( + ir::CmpOp::Ne, lhs, rhs, module_.GetContext().NextTemp())); + } + throw std::runtime_error(FormatError("irgen", "非法相等表达式")); + } + if (ctx->relExp()) { + return ctx->relExp()->accept(this); + } + throw std::runtime_error(FormatError("irgen", "非法相等表达式")); +} + +std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); + } + if (ctx->lAndExp()) { + if (!ctx->eqExp()) { + throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); + } + auto* lhs = ToBoolValue(std::any_cast(ctx->lAndExp()->accept(this))); + auto* rhs = ToBoolValue(std::any_cast(ctx->eqExp()->accept(this))); + return static_cast( + builder_.CreateMul(lhs, rhs, module_.GetContext().NextTemp())); + } + if (ctx->eqExp()) { + return ToBoolValue(std::any_cast(ctx->eqExp()->accept(this))); + } + throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); +} + +std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); + } + if (ctx->lOrExp()) { + if (!ctx->lAndExp()) { + throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); + } + auto* lhs = ToBoolValue(std::any_cast(ctx->lOrExp()->accept(this))); + auto* rhs = ToBoolValue(std::any_cast(ctx->lAndExp()->accept(this))); + auto* sum = builder_.CreateAdd(lhs, rhs, module_.GetContext().NextTemp()); + return static_cast(ToBoolValue(sum)); + } + if (ctx->lAndExp()) { + return ToBoolValue(std::any_cast(ctx->lAndExp()->accept(this))); + } + throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); } diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index 4912d03..737563d 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -38,11 +38,14 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少编译单元")); } - auto* func = ctx->funcDef(); - if (!func) { + if (ctx->funcDef().empty()) { throw std::runtime_error(FormatError("irgen", "缺少函数定义")); } - func->accept(this); + for (auto* func : ctx->funcDef()) { + if (func) { + func->accept(this); + } + } return {}; } @@ -79,6 +82,7 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type()); builder_.SetInsertPoint(func_->GetEntry()); storage_map_.clear(); + named_storage_.clear(); ctx->blockStmt()->accept(this); // 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。 diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index 751550c..61ad87e 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -19,9 +19,101 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句")); } + if (ctx->lValue() && ctx->ASSIGN() && ctx->exp()) { + if (!ctx->lValue()->ID()) { + throw std::runtime_error(FormatError("irgen", "赋值语句左值非法")); + } + const std::string name = ctx->lValue()->ID()->getText(); + auto slot_it = named_storage_.find(name); + if (slot_it == named_storage_.end()) { + throw std::runtime_error(FormatError("irgen", "赋值目标未声明: " + name)); + } + ir::Value* rhs = EvalExpr(*ctx->exp()); + builder_.CreateStore(rhs, slot_it->second); + return BlockFlow::Continue; + } + if (ctx->blockStmt()) { + ctx->blockStmt()->accept(this); + return builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator() + ? BlockFlow::Terminated + : BlockFlow::Continue; + } + if (ctx->IF()) { + if (!ctx->cond() || ctx->stmt().empty()) { + throw std::runtime_error(FormatError("irgen", "if 语句不完整")); + } + auto* then_bb = func_->CreateBlock(NextBlockName()); + auto* merge_bb = func_->CreateBlock(NextBlockName()); + auto* else_bb = ctx->ELSE() ? func_->CreateBlock(NextBlockName()) : merge_bb; + + ir::Value* cond = ToBoolValue(EvalCond(*ctx->cond())); + builder_.CreateCondBr(cond, then_bb, else_bb); + + builder_.SetInsertPoint(then_bb); + auto then_flow = std::any_cast(ctx->stmt(0)->accept(this)); + if (then_flow != BlockFlow::Terminated) { + builder_.CreateBr(merge_bb); + } + + if (ctx->ELSE()) { + builder_.SetInsertPoint(else_bb); + auto else_flow = std::any_cast(ctx->stmt(1)->accept(this)); + if (else_flow != BlockFlow::Terminated) { + builder_.CreateBr(merge_bb); + } + } + + builder_.SetInsertPoint(merge_bb); + return BlockFlow::Continue; + } + if (ctx->WHILE()) { + if (!ctx->cond() || ctx->stmt().empty()) { + throw std::runtime_error(FormatError("irgen", "while 语句不完整")); + } + auto* cond_bb = func_->CreateBlock(NextBlockName()); + auto* body_bb = func_->CreateBlock(NextBlockName()); + auto* exit_bb = func_->CreateBlock(NextBlockName()); + + builder_.CreateBr(cond_bb); + builder_.SetInsertPoint(cond_bb); + ir::Value* cond = ToBoolValue(EvalCond(*ctx->cond())); + builder_.CreateCondBr(cond, body_bb, exit_bb); + + loop_stack_.push_back({cond_bb, exit_bb}); + builder_.SetInsertPoint(body_bb); + auto body_flow = std::any_cast(ctx->stmt(0)->accept(this)); + if (body_flow != BlockFlow::Terminated) { + builder_.CreateBr(cond_bb); + } + loop_stack_.pop_back(); + + builder_.SetInsertPoint(exit_bb); + return BlockFlow::Continue; + } + if (ctx->BREAK()) { + if (loop_stack_.empty()) { + throw std::runtime_error(FormatError("irgen", "break 不在循环中")); + } + builder_.CreateBr(loop_stack_.back().break_target); + return BlockFlow::Terminated; + } + if (ctx->CONTINUE()) { + if (loop_stack_.empty()) { + throw std::runtime_error(FormatError("irgen", "continue 不在循环中")); + } + builder_.CreateBr(loop_stack_.back().continue_target); + return BlockFlow::Terminated; + } if (ctx->returnStmt()) { return ctx->returnStmt()->accept(this); } + if (ctx->exp()) { + EvalExpr(*ctx->exp()); + return BlockFlow::Continue; + } + if (ctx->SEMICOLON()) { + return BlockFlow::Continue; + } throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); } diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index 745374c..ed13673 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -1,200 +1,224 @@ #include "sem/Sema.h" -#include #include -#include - -#include "SysYBaseVisitor.h" -#include "sem/SymbolTable.h" -#include "utils/Log.h" namespace { -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { - throw std::runtime_error(FormatError("sema", "非法左值")); +SymbolType ParseType(const std::string& text) { + if (text == "int") { + return SymbolType::TYPE_INT; } - return lvalue.ID()->getText(); -} - -class SemaVisitor final : public SysYBaseVisitor { - public: - std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "缺少编译单元")); - } - auto* func = ctx->funcDef(); - if (!func || !func->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); - } - if (!func->ID() || func->ID()->getText() != "main") { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); - } - func->accept(this); - if (!seen_return_) { - throw std::runtime_error( - FormatError("sema", "main 函数必须包含 return 语句")); - } - return {}; + if (text == "float") { + return SymbolType::TYPE_FLOAT; + } + if (text == "void") { + return SymbolType::TYPE_VOID; } + return SymbolType::TYPE_UNKNOWN; +} - std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { - if (!ctx || !ctx->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); - } - if (!ctx->funcType() || !ctx->funcType()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持 int main")); - } - const auto& items = ctx->blockStmt()->blockItem(); - if (items.empty()) { - throw std::runtime_error( - FormatError("sema", "main 函数不能为空,且必须以 return 结束")); - } - ctx->blockStmt()->accept(this); - return {}; +} // namespace + +std::any SemaVisitor::visitCompUnit(SysYParser::CompUnitContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitDecl(SysYParser::DeclContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitConstDecl(SysYParser::ConstDeclContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitBtype(SysYParser::BtypeContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitConstDef(SysYParser::ConstDefContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitConstInitValue(SysYParser::ConstInitValueContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitVarDecl(SysYParser::VarDeclContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitVarDef(SysYParser::VarDefContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitInitValue(SysYParser::InitValueContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitFuncDef(SysYParser::FuncDefContext* ctx) { + SymbolType ret_type = SymbolType::TYPE_UNKNOWN; + if (ctx && ctx->funcType()) { + ret_type = ParseType(ctx->funcType()->getText()); } - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "缺少语句块")); - } - const auto& items = ctx->blockItem(); - for (size_t i = 0; i < items.size(); ++i) { - auto* item = items[i]; - if (!item) { - continue; - } - if (seen_return_) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); - } - current_item_index_ = i; - total_items_ = items.size(); - item->accept(this); - } + ir_ctx_.SetCurrentFuncReturnType(ret_type); + return visitChildren(ctx); +} + +std::any SemaVisitor::visitFuncType(SysYParser::FuncTypeContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitFuncFParams(SysYParser::FuncFParamsContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitFuncFParam(SysYParser::FuncFParamContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { + ir_ctx_.EnterScope(); + std::any result = visitChildren(ctx); + ir_ctx_.LeaveScope(); + return result; +} + +std::any SemaVisitor::visitBlockItem(SysYParser::BlockItemContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitStmt(SysYParser::StmtContext* ctx) { + if (!ctx) { return {}; } - std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); - } - if (ctx->decl()) { - ctx->decl()->accept(this); - return {}; - } - if (ctx->stmt()) { - ctx->stmt()->accept(this); - return {}; - } - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + if (ctx->WHILE()) { + ir_ctx_.EnterLoop(); + std::any result = visitChildren(ctx); + ir_ctx_.ExitLoop(); + return result; } - std::any visitDecl(SysYParser::DeclContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); - } - if (!ctx->btype() || !ctx->btype()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明")); - } - auto* var_def = ctx->varDef(); - if (!var_def || !var_def->lValue()) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); - } - const std::string name = GetLValueName(*var_def->lValue()); - if (table_.Contains(name)) { - throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); - } - if (auto* init = var_def->initValue()) { - if (!init->exp()) { - throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化")); - } - init->exp()->accept(this); - } - table_.Add(name, var_def); - return {}; + if (ctx->BREAK() && !ir_ctx_.InLoop()) { + ir_ctx_.RecordError( + ErrorMsg("break 只能出现在循环语句中", ctx->getStart()->getLine(), + ctx->getStart()->getCharPositionInLine() + 1)); } - std::any visitStmt(SysYParser::StmtContext* ctx) override { - if (!ctx || !ctx->returnStmt()) { - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); - } - ctx->returnStmt()->accept(this); - return {}; + if (ctx->CONTINUE() && !ir_ctx_.InLoop()) { + ir_ctx_.RecordError( + ErrorMsg("continue 只能出现在循环语句中", ctx->getStart()->getLine(), + ctx->getStart()->getCharPositionInLine() + 1)); } - std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("sema", "return 缺少表达式")); - } - ctx->exp()->accept(this); - seen_return_ = true; - if (current_item_index_ + 1 != total_items_) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); - } - return {}; - } + return visitChildren(ctx); +} - std::any visitParenExp(SysYParser::ParenExpContext* ctx) override { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("sema", "非法括号表达式")); - } - ctx->exp()->accept(this); +std::any SemaVisitor::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { + if (!ctx) { return {}; } - std::any visitVarExp(SysYParser::VarExpContext* ctx) override { - if (!ctx || !ctx->var()) { - throw std::runtime_error(FormatError("sema", "非法变量表达式")); - } - ctx->var()->accept(this); - return {}; + if (ctx->exp() && ir_ctx_.GetCurrentFuncReturnType() == SymbolType::TYPE_VOID) { + ir_ctx_.RecordError( + ErrorMsg("void 函数不应返回表达式", ctx->getStart()->getLine(), + ctx->getStart()->getCharPositionInLine() + 1)); } - std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { - throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量")); - } - return {}; + if (!ctx->exp() && + ir_ctx_.GetCurrentFuncReturnType() != SymbolType::TYPE_VOID && + ir_ctx_.GetCurrentFuncReturnType() != SymbolType::TYPE_UNKNOWN) { + ir_ctx_.RecordError( + ErrorMsg("非 void 函数 return 必须带表达式", ctx->getStart()->getLine(), + ctx->getStart()->getCharPositionInLine() + 1)); } - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { - throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式")); - } - ctx->exp(0)->accept(this); - ctx->exp(1)->accept(this); + return visitChildren(ctx); +} + +std::any SemaVisitor::visitExp(SysYParser::ExpContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitCond(SysYParser::CondContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitLValue(SysYParser::LValueContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitNumber(SysYParser::NumberContext* ctx) { + if (!ctx) { return {}; } - std::any visitVar(SysYParser::VarContext* ctx) override { - if (!ctx || !ctx->ID()) { - throw std::runtime_error(FormatError("sema", "非法变量引用")); - } - const std::string name = ctx->ID()->getText(); - auto* decl = table_.Lookup(name); - if (!decl) { - throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); - } - sema_.BindVarUse(ctx, decl); - return {}; + if (ctx->ILITERAL()) { + ir_ctx_.SetType(ctx, SymbolType::TYPE_INT); + ir_ctx_.SetConstVal(ctx, std::any(0L)); + } else if (ctx->FLITERAL()) { + ir_ctx_.SetType(ctx, SymbolType::TYPE_FLOAT); + ir_ctx_.SetConstVal(ctx, std::any(0.0)); } - SemanticContext TakeSemanticContext() { return std::move(sema_); } + return {}; +} - private: - SymbolTable table_; - SemanticContext sema_; - bool seen_return_ = false; - size_t current_item_index_ = 0; - size_t total_items_ = 0; -}; +std::any SemaVisitor::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { + return visitChildren(ctx); +} -} // namespace +std::any SemaVisitor::visitUnaryOp(SysYParser::UnaryOpContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitMulExp(SysYParser::MulExpContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitAddExp(SysYParser::AddExpContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitRelExp(SysYParser::RelExpContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitEqExp(SysYParser::EqExpContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitLAndExp(SysYParser::LAndExpContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitLOrExp(SysYParser::LOrExpContext* ctx) { + return visitChildren(ctx); +} + +std::any SemaVisitor::visitConstExp(SysYParser::ConstExpContext* ctx) { + return visitChildren(ctx); +} + +void RunSemanticAnalysis(SysYParser::CompUnitContext* ctx, IRGenContext& ir_ctx) { + if (!ctx) { + throw std::invalid_argument("CompUnitContext is null"); + } + SemaVisitor visitor(ir_ctx); + visitor.visit(ctx); +} SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { - SemaVisitor visitor; - comp_unit.accept(&visitor); - return visitor.TakeSemanticContext(); + IRGenContext ctx; + RunSemanticAnalysis(&comp_unit, ctx); + return SemanticContext(); } diff --git a/src/sem/SymbolTable.cpp b/src/sem/SymbolTable.cpp index ffeea89..b896ce4 100644 --- a/src/sem/SymbolTable.cpp +++ b/src/sem/SymbolTable.cpp @@ -1,17 +1,164 @@ -// 维护局部变量声明的注册与查找。 +#include "../../include/sem/SymbolTable.h" +#include +#include +#include -#include "sem/SymbolTable.h" +// 进入新作用域 +void SymbolTable::EnterScope() { + scopes_.push(ScopeEntry()); +} + +// 离开当前作用域 +void SymbolTable::LeaveScope() { + if (scopes_.empty()) { + throw std::runtime_error("SymbolTable Error: 作用域栈为空,无法退出"); + } + scopes_.pop(); +} + +// 绑定变量到当前作用域 +void SymbolTable::BindVar(const std::string& name, const VarInfo& info, void* decl_ctx) { + if (CurrentScopeHasVar(name)) { + throw std::runtime_error("变量'" + name + "'在当前作用域重复定义"); + } + scopes_.top().var_symbols[name] = {info, decl_ctx}; +} + +// 绑定函数到当前作用域 +void SymbolTable::BindFunc(const std::string& name, const FuncInfo& info, void* decl_ctx) { + if (CurrentScopeHasFunc(name)) { + throw std::runtime_error("函数'" + name + "'在当前作用域重复定义"); + } + scopes_.top().func_symbols[name] = {info, decl_ctx}; +} + +// 查找变量(从当前作用域向上遍历) +bool SymbolTable::LookupVar(const std::string& name, VarInfo& out_info, void*& out_decl_ctx) const { + if (scopes_.empty()) { + return false; + } + auto temp_stack = scopes_; + while (!temp_stack.empty()) { + auto& scope = temp_stack.top(); + auto it = scope.var_symbols.find(name); + if (it != scope.var_symbols.end()) { + out_info = it->second.first; + out_decl_ctx = it->second.second; + return true; + } + temp_stack.pop(); + } + return false; +} + +// 查找函数(从当前作用域向上遍历,通常函数在全局作用域) +bool SymbolTable::LookupFunc(const std::string& name, FuncInfo& out_info, void*& out_decl_ctx) const { + if (scopes_.empty()) { + return false; + } + auto temp_stack = scopes_; + while (!temp_stack.empty()) { + auto& scope = temp_stack.top(); + auto it = scope.func_symbols.find(name); + if (it != scope.func_symbols.end()) { + out_info = it->second.first; + out_decl_ctx = it->second.second; + return true; + } + temp_stack.pop(); + } + return false; +} -void SymbolTable::Add(const std::string& name, - SysYParser::VarDefContext* decl) { - table_[name] = decl; +// 检查当前作用域是否包含指定变量 +bool SymbolTable::CurrentScopeHasVar(const std::string& name) const { + if (scopes_.empty()) { + return false; + } + return scopes_.top().var_symbols.count(name) > 0; } -bool SymbolTable::Contains(const std::string& name) const { - return table_.find(name) != table_.end(); +// 检查当前作用域是否包含指定函数 +bool SymbolTable::CurrentScopeHasFunc(const std::string& name) const { + if (scopes_.empty()) { + return false; + } + return scopes_.top().func_symbols.count(name) > 0; } -SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const { - auto it = table_.find(name); - return it == table_.end() ? nullptr : it->second; +// 进入循环 +void SymbolTable::EnterLoop() { + loop_depth_++; } + +// 离开循环 +void SymbolTable::ExitLoop() { + if (loop_depth_ > 0) loop_depth_--; +} + +// 检查是否在循环内 +bool SymbolTable::InLoop() const { + return loop_depth_ > 0; +} + +// 清空所有作用域和状态 +void SymbolTable::Clear() { + while (!scopes_.empty()) { + scopes_.pop(); + } + loop_depth_ = 0; +} + +// 获取当前作用域中所有变量名 +std::vector SymbolTable::GetCurrentScopeVarNames() const { + std::vector names; + if (!scopes_.empty()) { + for (const auto& pair : scopes_.top().var_symbols) { + names.push_back(pair.first); + } + } + return names; +} + +// 获取当前作用域中所有函数名 +std::vector SymbolTable::GetCurrentScopeFuncNames() const { + std::vector names; + if (!scopes_.empty()) { + for (const auto& pair : scopes_.top().func_symbols) { + names.push_back(pair.first); + } + } + return names; +} + +// 调试:打印符号表内容 +void SymbolTable::Dump() const { + std::cout << "符号表内容 (作用域深度: " << scopes_.size() << "):\n"; + int scope_idx = 0; + auto temp_stack = scopes_; + + while (!temp_stack.empty()) { + std::cout << "\n作用域 " << scope_idx++ << ":\n"; + auto& scope = temp_stack.top(); + + std::cout << " 变量:\n"; + for (const auto& var_pair : scope.var_symbols) { + const VarInfo& info = var_pair.second.first; + std::cout << " " << var_pair.first << ": " + << SymbolTypeToString(info.type) + << (info.is_const ? " (const)" : "") + << (info.IsArray() ? " [数组]" : "") + << "\n"; + } + + std::cout << " 函数:\n"; + for (const auto& func_pair : scope.func_symbols) { + const FuncInfo& info = func_pair.second.first; + std::cout << " " << func_pair.first << ": " + << SymbolTypeToString(info.ret_type) << " (" + << info.param_types.size() << " 个参数)\n"; + } + + temp_stack.pop(); + } +} \ No newline at end of file