From 6faa67fb656fdd3200e2ec89d0a654ac810e3d4f Mon Sep 17 00:00:00 2001 From: Shrink <1569629152@qq.com> Date: Mon, 30 Mar 2026 11:45:59 +0800 Subject: [PATCH] =?UTF-8?q?=E9=80=9A=E8=BF=87=E4=BA=86test=5Fcase=E4=B8=8B?= =?UTF-8?q?=E7=9A=84=E6=B5=8B=E8=AF=95=EF=BC=8C=E4=BF=AE=E6=94=B9=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E8=84=9A=E6=9C=AC=E7=94=B1=E4=BA=8E=E4=B8=8D=E5=90=8C?= =?UTF-8?q?=E5=B9=B3=E5=8F=B0=E6=8D=A2=E8=A1=8C=E7=AC=A6=E7=9A=84=E5=B7=AE?= =?UTF-8?q?=E5=BC=82=E5=AF=BC=E8=87=B4=E6=B5=8B=E8=AF=95=E5=A4=B1=E8=B4=A5?= =?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- command.md | 66 +++++++ include/ir/IR.h | 67 ++++++- include/irgen/IRGen.h | 54 +++++- include/sem/Sema.h | 306 ++++++++++++------------------- run_tests.py | 2 +- scripts/verify_ir.sh | 4 +- src/ir/Context.cpp | 2 +- src/ir/Function.cpp | 18 ++ src/ir/GlobalValue.cpp | 8 +- src/ir/IRBuilder.cpp | 27 +++ src/ir/IRPrinter.cpp | 68 ++++++- src/ir/Instruction.cpp | 65 ++++++- src/ir/Module.cpp | 23 +++ src/irgen/CMakeLists.txt | 1 + src/irgen/IRGenConstEval.cpp | 95 ++++++++++ src/irgen/IRGenDecl.cpp | 262 ++++++++++++++++++++++---- src/irgen/IRGenExp.cpp | 255 ++++++++++++++++++++++---- src/irgen/IRGenFunc.cpp | 182 +++++++++++++++--- src/irgen/IRGenStmt.cpp | 32 +++- src/main.cpp | 103 +++++++++++ src/sem/ConstEval.cpp | 7 +- src/sem/Sema.cpp | 346 +++++++++++++++++++++++++++++++---- sylib/sylib.c | 39 ++++ sylib/sylib.h | 19 ++ 24 files changed, 1690 insertions(+), 361 deletions(-) create mode 100644 src/irgen/IRGenConstEval.cpp diff --git a/command.md b/command.md index 1ff69b6..891c39c 100644 --- a/command.md +++ b/command.md @@ -10,3 +10,69 @@ cmake --build build -j "$(nproc)" # 3.批量检查 find test/test_case -name '*.sy' | sort | while read f; do ./build/bin/compiler --emit-parse-tree "$f" >/dev/null || echo "FAIL $f"; done + +核心原则:不要在“落后于远端 master”的本地 master 上直接开发和提交。 + +你以后按这套流程,基本就不会分岔。 + +**日常标准流程** + +1. 每次开始前先同步主干 +```bash +git switch master +git fetch origin +git pull --ff-only origin master +``` + +2. 从最新 master 拉功能分支开发 +```bash +git switch -c feature/xxx +``` + +3. 开发中频繁提交到功能分支 +```bash +git add -A +git commit -m "feat: xxx" +``` + +4. 推送功能分支(不要直接推 master) +```bash +git push -u origin feature/xxx +``` + +5. 合并前,先把你的分支“重放”到最新 master 上 +```bash +git fetch origin +git rebase origin/master +``` +有冲突就解决后: +```bash +git add -A +git rebase --continue +``` + +6. 再合并回 master(本地或平台 PR 都可) +本地合并推荐: +```bash +git switch master +git pull --ff-only origin master +git merge --ff-only feature/xxx +git push origin master +``` +`--ff-only` 的好处是:只允许线性历史,能最大限度避免分叉和脏 merge。 + +--- + +**你这次分岔的根因** +你的本地 master 没先追上远端 master(远端有新提交),然后直接 merge/push,导致出现两个方向的提交历史。 + +--- + +**三条硬规则(记住就行)** +1. 不在落后状态的 master 上开发。 +2. 合并前一定 `fetch + rebase origin/master`。 +3. 推 master 前先 `pull --ff-only`,失败就先处理,不要强推。 + +--- + +如果你愿意,我可以给你一份适配你仓库的 Git alias(如 `gsync`, `gstart`, `gfinish`),以后 3 条命令就走完整流程。 \ No newline at end of file diff --git a/include/ir/IR.h b/include/ir/IR.h index 5b00391..3f236b1 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -151,6 +151,16 @@ class ConstantInt : public ConstantValue { int value_{}; }; +// Argument 表示函数的形式参数,作为 Value 在函数体内直接被引用。 +class Argument : public Value { + public: + Argument(std::shared_ptr ty, std::string name, size_t index); + size_t GetArgIndex() const { return arg_index_; } + + private: + size_t arg_index_; +}; + // 第一版 Lab2 需要的指令集合。 enum class Opcode { Add, @@ -166,6 +176,7 @@ enum class Opcode { Load, Store, Ret, + Gep, // getelementptr:数组元素地址计算 }; enum class CmpOp { Eq, Ne, Lt, Le, Gt, Ge }; @@ -194,6 +205,21 @@ class GlobalValue : public User { GlobalValue(std::shared_ptr ty, std::string name); }; +// GlobalVariable 代表一个全局整型变量、常量或数组。 +// 标量:打印为 @name = global i32 N。 +// 数组:打印为 @name = global [count x i32] zeroinitializer。 +class GlobalVariable : public GlobalValue { + public: + GlobalVariable(std::string name, int init_val = 0, int count = 1); + int GetInitValue() const { return init_val_; } + int GetCount() const { return count_; } + bool IsArray() const { return count_ > 1; } + + private: + int init_val_; + int count_; +}; + class Instruction : public User { public: Instruction(Opcode op, std::shared_ptr ty, std::string name = ""); @@ -235,7 +261,15 @@ class ReturnInst : public Instruction { class AllocaInst : public Instruction { public: + // 标量 alloca(分配 1 个 i32) AllocaInst(std::shared_ptr ptr_ty, std::string name); + // 数组 alloca(分配 count 个 i32,count 为编译期常量) + AllocaInst(std::shared_ptr ptr_ty, std::string name, int count); + int GetCount() const { return count_; } + bool IsArray() const { return count_ > 1; } + + private: + int count_ = 1; }; class LoadInst : public Instruction { @@ -275,6 +309,16 @@ class CallInst : public Instruction { Value* GetArg(size_t index) const; }; +// GepInst:getelementptr i32, i32* base, i32 index +// 用于从数组基址 + 线性偏移量计算元素指针。 +class GepInst : public Instruction { + public: + GepInst(std::shared_ptr ptr_ty, Value* base, Value* index, + std::string name); + Value* GetBase() const; + Value* GetIndex() const; +}; + // BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。 // 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。 class BasicBlock : public Value { @@ -310,10 +354,8 @@ class BasicBlock : public Value { // Function 当前也采用了最小实现。 // 需要特别注意:由于项目里还没有单独的 FunctionType, -// Function 继承自 Value 后,其 type_ 目前只保存“返回类型”, -// 并不能完整表达“返回类型 + 形参列表”这一整套函数签名。 -// 这对当前只支持 int main() 的最小 IR 足够,但后续若补普通函数、 -// 形参和调用,通常需要引入专门的函数类型表示。 +// Function 继承自 Value 后,其 type_ 目前只保存”返回类型”, +// 并不能完整表达”返回类型 + 形参列表”这一整套函数签名。 class Function : public Value { public: Function(std::string name, std::shared_ptr ret_type, @@ -323,12 +365,19 @@ class Function : public Value { const BasicBlock* GetEntry() const; const std::vector>& GetParamTypes() const; size_t GetNumParams() const; + Argument* GetArgument(size_t index) const; const std::vector>& GetBlocks() const; + // 外部函数声明(无函数体,打印为 declare)。 + void SetExternal(bool v) { is_external_ = v; } + bool IsExternal() const { return is_external_; } + private: BasicBlock* entry_ = nullptr; std::vector> param_types_; + std::vector> args_; std::vector> blocks_; + bool is_external_ = false; }; class Module { @@ -336,14 +385,19 @@ class Module { Module() = default; Context& GetContext(); const Context& GetContext() const; - // 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。 Function* CreateFunction(const std::string& name, std::shared_ptr ret_type, std::vector> param_types = {}); + Function* FindFunction(const std::string& name) const; const std::vector>& GetFunctions() const; + GlobalVariable* CreateGlobalVar(const std::string& name, int init_val = 0, int count = 1); + GlobalVariable* FindGlobalVar(const std::string& name) const; + const std::vector>& GetGlobalVars() const; + private: Context context_; + std::vector> global_vars_; std::vector> functions_; }; @@ -364,6 +418,7 @@ class IRBuilder { BinaryInst* CreateMod(Value* lhs, Value* rhs, const std::string& name); CmpInst* CreateCmp(CmpOp op, Value* lhs, Value* rhs, const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name); + AllocaInst* CreateAllocaArray(int count, const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name); StoreInst* CreateStore(Value* val, Value* ptr); BranchInst* CreateBr(BasicBlock* target); @@ -371,7 +426,9 @@ class IRBuilder { BasicBlock* false_bb); CallInst* CreateCall(Function* callee, const std::vector& args, const std::string& name); + GepInst* CreateGep(Value* base, Value* index, const std::string& name); ReturnInst* CreateRet(Value* v); + ReturnInst* CreateRetVoid(); private: Context& ctx_; diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index 53eb24d..ce7202b 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -7,6 +7,7 @@ #include #include #include +#include #include "SysYBaseVisitor.h" #include "SysYParser.h" @@ -22,6 +23,9 @@ class Value; class IRGenImpl final : public SysYBaseVisitor { public: + // const 变量名 -> 编译期整数值,用于数组维度折叠。 + using ConstEnv = std::unordered_map; + IRGenImpl(ir::Module& module, const SemanticContext& sema); std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; @@ -29,6 +33,8 @@ class IRGenImpl final : public SysYBaseVisitor { std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; std::any visitDecl(SysYParser::DeclContext* ctx) override; + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override; + std::any visitConstDef(SysYParser::ConstDefContext* ctx) override; std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override; std::any visitStmt(SysYParser::StmtContext* ctx) override; std::any visitVarDef(SysYParser::VarDefContext* ctx) override; @@ -57,19 +63,65 @@ class IRGenImpl final : public SysYBaseVisitor { ir::BasicBlock* break_target; }; + // 判断当前是否处于全局作用域(函数外部)。 + bool IsGlobalScope() const { return func_ == nullptr; } + BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item); ir::Value* EvalExpr(SysYParser::ExpContext& expr); ir::Value* EvalCond(SysYParser::CondContext& cond); ir::Value* ToBoolValue(ir::Value* v); std::string NextBlockName(); + // 预声明 SysY runtime 外部函数。 + void DeclareRuntimeFunctions(); + + // 根据 sema 绑定或 name 查找局部/全局存储槽位(返回 i32* Value)。 + // 如果 lvalue 有下标,还会生成 GEP 指令并返回元素指针。 + ir::Value* ResolveStorage(SysYParser::LValueContext* lvalue); + + // 编译期常量整数求值(用于数组维度)。 + int EvalConstExpr(SysYParser::ConstExpContext* ctx) const; + // 将 ExpContext(即 addExp)按编译期常量求值(用于 funcFParam 维度)。 + int EvalExpAsConst(SysYParser::ExpContext* ctx) const; + + // 查找变量的数组维度(先查局部,再查全局)。 + const std::vector* FindArrayDims(const std::string& name) const; + + // 将一组数组下标表达式(已求值为 ir::Value*)折叠为线性偏移 ir::Value*。 + ir::Value* ComputeLinearIndex(const std::vector& dims, + const std::vector& subs); + + // 扁平化 constInitValue 到整数数组(供 const 数组初始化使用)。 + void FlattenConstInit(SysYParser::ConstInitValueContext* ctx, + const std::vector& dims, int dim_idx, + std::vector& out, int& pos); + + // 扁平化 initValue 到 ir::Value* 数组(供普通数组初始化使用)。 + void FlattenInit(SysYParser::InitValueContext* ctx, + const std::vector& dims, int dim_idx, + std::vector& out, int& pos); + + ir::AllocaInst* CreateEntryAllocaI32(const std::string& name); + ir::AllocaInst* CreateEntryAllocaArray(int count, const std::string& name); + ir::Module& module_; const SemanticContext& sema_; ir::Function* func_; ir::IRBuilder builder_; - // 名称绑定由 Sema 负责;IRGen 只维护“声明 -> 存储槽位”的代码生成状态。 + // 声明 -> 存储槽位(局部 alloca 或全局变量,均为 i32*)。 std::unordered_map storage_map_; + // 名称 -> 槽位(参数、const 变量等不经 sema binding 的后备查找)。 std::unordered_map named_storage_; + // 全局变量名 -> GlobalVariable*(跨函数持久)。 + std::unordered_map global_storage_; + // 编译期 const 整数环境(全局 + 当前函数)。 + ConstEnv const_env_; + // 数组维度信息:全局数组(跨函数持久)。 + std::unordered_map> global_array_dims_; + // 数组维度信息:局部数组/参数(每函数清空)。 + std::unordered_map> local_array_dims_; + // 逻辑与/或短路求值复用的函数级临时槽位,避免循环中动态 alloca 导致栈膨胀。 + ir::Value* short_circuit_slot_ = nullptr; std::vector loop_stack_; }; diff --git a/include/sem/Sema.h b/include/sem/Sema.h index b684a41..1f3eb54 100644 --- a/include/sem/Sema.h +++ b/include/sem/Sema.h @@ -1,213 +1,151 @@ #ifndef SEMANTIC_ANALYSIS_H #define SEMANTIC_ANALYSIS_H +#include +#include +#include +#include +#include +#include + #include "SymbolTable.h" #include "SysYBaseVisitor.h" #include "SysYParser.h" -#include -#include -#include -#include -#include -#include -// 错误信息结构体 struct ErrorMsg { - std::string msg; - int line; - int column; - ErrorMsg(std::string m, int l, int c) : msg(std::move(m)), line(l), column(c) {} -}; + std::string msg; + int line; + int column; -// 前向声明 -namespace antlr4 { - class ParserRuleContext; - namespace tree { - class ParseTree; - } -} + ErrorMsg(std::string m, int l, int c) : msg(std::move(m)), line(l), column(c) {} +}; -// 语义/IR生成上下文核心类 class IRGenContext { -public: - // 错误管理 - void RecordError(const ErrorMsg& err) { errors_.push_back(err); } - const std::vector& GetErrors() const { return errors_; } - bool HasError() const { return !errors_.empty(); } - void ClearErrors() { errors_.clear(); } - - // 类型绑定/查询 - 使用 void* 以兼容测试代码 - void SetType(void* ctx, SymbolType type) { - node_type_map_[ctx] = type; - } - - SymbolType GetType(void* ctx) const { - auto it = node_type_map_.find(ctx); - return it == node_type_map_.end() ? SymbolType::TYPE_UNKNOWN : it->second; - } - - // 常量值绑定/查询 - 使用 void* 以兼容测试代码 - void SetConstVal(void* ctx, const std::any& val) { - const_val_map_[ctx] = val; - } - - std::any GetConstVal(void* ctx) const { - auto it = const_val_map_.find(ctx); - return it == const_val_map_.end() ? std::any() : it->second; - } - - // 循环状态管理 - void EnterLoop() { sym_table_.EnterLoop(); } - void ExitLoop() { sym_table_.ExitLoop(); } - bool InLoop() const { return sym_table_.InLoop(); } - - // 类型判断工具函数 - bool IsIntType(const std::any& val) const { - return val.type() == typeid(long) || val.type() == typeid(int); - } - - bool IsFloatType(const std::any& val) const { - return val.type() == typeid(double) || val.type() == typeid(float); - } - - // 当前函数返回类型 - SymbolType GetCurrentFuncReturnType() const { - return current_func_ret_type_; - } - - void SetCurrentFuncReturnType(SymbolType type) { - current_func_ret_type_ = type; - } - - // 符号表访问 - SymbolTable& GetSymbolTable() { return sym_table_; } - const SymbolTable& GetSymbolTable() const { return sym_table_; } - - // 作用域管理 - void EnterScope() { sym_table_.EnterScope(); } - void LeaveScope() { sym_table_.LeaveScope(); } - size_t GetScopeDepth() const { return sym_table_.GetScopeDepth(); } - -private: - SymbolTable sym_table_; - std::unordered_map node_type_map_; - std::unordered_map const_val_map_; - std::vector errors_; - SymbolType current_func_ret_type_ = SymbolType::TYPE_UNKNOWN; + public: + void RecordError(const ErrorMsg& err) { errors_.push_back(err); } + const std::vector& GetErrors() const { return errors_; } + bool HasError() const { return !errors_.empty(); } + void ClearErrors() { errors_.clear(); } + + void SetType(void* ctx, SymbolType type) { node_type_map_[ctx] = type; } + + SymbolType GetType(void* ctx) const { + auto it = node_type_map_.find(ctx); + return it == node_type_map_.end() ? SymbolType::TYPE_UNKNOWN : it->second; + } + + void SetConstVal(void* ctx, const std::any& val) { const_val_map_[ctx] = val; } + + std::any GetConstVal(void* ctx) const { + auto it = const_val_map_.find(ctx); + return it == const_val_map_.end() ? std::any() : it->second; + } + + void EnterLoop() { sym_table_.EnterLoop(); } + void ExitLoop() { sym_table_.ExitLoop(); } + bool InLoop() const { return sym_table_.InLoop(); } + + bool IsIntType(const std::any& val) const { + return val.type() == typeid(long) || val.type() == typeid(int); + } + + bool IsFloatType(const std::any& val) const { + return val.type() == typeid(double) || val.type() == typeid(float); + } + + SymbolType GetCurrentFuncReturnType() const { return current_func_ret_type_; } + void SetCurrentFuncReturnType(SymbolType type) { current_func_ret_type_ = type; } + + SymbolTable& GetSymbolTable() { return sym_table_; } + const SymbolTable& GetSymbolTable() const { return sym_table_; } + + void EnterScope() { sym_table_.EnterScope(); } + void LeaveScope() { sym_table_.LeaveScope(); } + size_t GetScopeDepth() const { return sym_table_.GetScopeDepth(); } + + private: + SymbolTable sym_table_; + std::unordered_map node_type_map_; + std::unordered_map const_val_map_; + std::vector errors_; + SymbolType current_func_ret_type_ = SymbolType::TYPE_UNKNOWN; }; -// 与现有 IRGen/主流程保持兼容的语义上下文占位。 class SemanticContext { public: - void BindVarUse(const SysYParser::LValueContext* use, - SysYParser::VarDefContext* decl) { - var_uses_[use] = decl; - } + void BindVarUse(const SysYParser::LValueContext* use, + SysYParser::VarDefContext* decl) { + var_uses_[use] = decl; + } - SysYParser::VarDefContext* ResolveVarUse( - const SysYParser::LValueContext* use) const { - auto it = var_uses_.find(use); - return it == var_uses_.end() ? nullptr : it->second; - } + SysYParser::VarDefContext* ResolveVarUse( + const SysYParser::LValueContext* use) const { + auto it = var_uses_.find(use); + return it == var_uses_.end() ? nullptr : it->second; + } private: - std::unordered_map - var_uses_; + std::unordered_map + var_uses_; }; -// 错误信息格式化工具函数 inline std::string FormatErrMsg(const std::string& msg, int line, int col) { - std::ostringstream oss; - oss << "[行:" << line << ",列:" << col << "] " << msg; - return oss.str(); + std::ostringstream oss; + oss << "[行:" << line << ",列:" << col << "] " << msg; + return oss.str(); } -// 语义分析访问器 - 继承自生成的基类 class SemaVisitor : public SysYBaseVisitor { -public: - explicit SemaVisitor(IRGenContext& ctx) : ir_ctx_(ctx) {} - - // 必须实现的 ANTLR4 接口 - std::any visit(antlr4::tree::ParseTree* tree) override { - if (tree) { - return tree->accept(this); - } - return std::any(); - } - - std::any visitTerminal(antlr4::tree::TerminalNode* node) override { - return std::any(); - } - - std::any visitErrorNode(antlr4::tree::ErrorNode* node) override { - if (node) { - int line = node->getSymbol()->getLine(); - int col = node->getSymbol()->getCharPositionInLine() + 1; - ir_ctx_.RecordError(ErrorMsg("语法错误节点", line, col)); - } - return std::any(); - } - - // 核心访问方法 - std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; - std::any visitDecl(SysYParser::DeclContext* ctx) override; - std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override; - std::any visitBtype(SysYParser::BtypeContext* ctx) override; - std::any visitConstDef(SysYParser::ConstDefContext* ctx) override; - std::any visitConstInitValue(SysYParser::ConstInitValueContext* ctx) override; - std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override; - std::any visitVarDef(SysYParser::VarDefContext* ctx) override; - std::any visitInitValue(SysYParser::InitValueContext* ctx) override; - std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; - std::any visitFuncType(SysYParser::FuncTypeContext* ctx) override; - std::any visitFuncFParams(SysYParser::FuncFParamsContext* ctx) override; - std::any visitFuncFParam(SysYParser::FuncFParamContext* ctx) override; - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; - std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; - std::any visitStmt(SysYParser::StmtContext* ctx) override; - std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; - std::any visitExp(SysYParser::ExpContext* ctx) override; - std::any visitCond(SysYParser::CondContext* ctx) override; - std::any visitLValue(SysYParser::LValueContext* ctx) override; - std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override; - std::any visitNumber(SysYParser::NumberContext* ctx) override; - std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override; - std::any visitUnaryOp(SysYParser::UnaryOpContext* ctx) override; - std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override; - std::any visitMulExp(SysYParser::MulExpContext* ctx) override; - std::any visitAddExp(SysYParser::AddExpContext* ctx) override; - std::any visitRelExp(SysYParser::RelExpContext* ctx) override; - std::any visitEqExp(SysYParser::EqExpContext* ctx) override; - std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override; - std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override; - std::any visitConstExp(SysYParser::ConstExpContext* ctx) override; - - // 通用子节点访问 - std::any visitChildren(antlr4::tree::ParseTree* node) override { - std::any result; - if (node) { - for (auto* child : node->children) { - if (child) { - result = child->accept(this); - } - } - } - return result; - } - - // 获取上下文引用 - IRGenContext& GetContext() { return ir_ctx_; } - const IRGenContext& GetContext() const { return ir_ctx_; } - -private: - IRGenContext& ir_ctx_; + public: + explicit SemaVisitor(IRGenContext& ctx, SemanticContext* sema_ctx = nullptr) + : ir_ctx_(ctx), sema_ctx_(sema_ctx) {} + + std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; + std::any visitDecl(SysYParser::DeclContext* ctx) override; + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override; + std::any visitBtype(SysYParser::BtypeContext* ctx) override; + std::any visitConstDef(SysYParser::ConstDefContext* ctx) override; + std::any visitConstInitValue(SysYParser::ConstInitValueContext* ctx) override; + std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override; + std::any visitVarDef(SysYParser::VarDefContext* ctx) override; + std::any visitInitValue(SysYParser::InitValueContext* ctx) override; + std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; + std::any visitFuncType(SysYParser::FuncTypeContext* ctx) override; + std::any visitFuncFParams(SysYParser::FuncFParamsContext* ctx) override; + std::any visitFuncFParam(SysYParser::FuncFParamContext* ctx) override; + std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; + std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; + std::any visitStmt(SysYParser::StmtContext* ctx) override; + std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; + std::any visitExp(SysYParser::ExpContext* ctx) override; + std::any visitCond(SysYParser::CondContext* ctx) override; + std::any visitLValue(SysYParser::LValueContext* ctx) override; + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override; + std::any visitNumber(SysYParser::NumberContext* ctx) override; + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override; + std::any visitUnaryOp(SysYParser::UnaryOpContext* ctx) override; + std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override; + std::any visitMulExp(SysYParser::MulExpContext* ctx) override; + std::any visitAddExp(SysYParser::AddExpContext* ctx) override; + std::any visitRelExp(SysYParser::RelExpContext* ctx) override; + std::any visitEqExp(SysYParser::EqExpContext* ctx) override; + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override; + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override; + std::any visitConstExp(SysYParser::ConstExpContext* ctx) override; + + IRGenContext& GetContext() { return ir_ctx_; } + const IRGenContext& GetContext() const { return ir_ctx_; } + + private: + void RecordNodeError(antlr4::ParserRuleContext* ctx, const std::string& msg); + + IRGenContext& ir_ctx_; + SemanticContext* sema_ctx_ = nullptr; + SymbolType current_decl_type_ = SymbolType::TYPE_UNKNOWN; + bool current_decl_is_const_ = false; }; -// 语义分析入口函数 void RunSemanticAnalysis(SysYParser::CompUnitContext* ctx, IRGenContext& ir_ctx); - -// 兼容旧流程入口。 SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); -#endif // SEMANTIC_ANALYSIS_H \ No newline at end of file +#endif // SEMANTIC_ANALYSIS_H diff --git a/run_tests.py b/run_tests.py index 967e012..0fc909c 100644 --- a/run_tests.py +++ b/run_tests.py @@ -2,7 +2,7 @@ import os import subprocess COMPILER = "./build/bin/compiler" -TEST_DIR = "./test/test_case/performance" +TEST_DIR = "./test/test_case/functional" pass_cnt = 0 fail_cnt = 0 diff --git a/scripts/verify_ir.sh b/scripts/verify_ir.sh index f41f6b3..9a97198 100755 --- a/scripts/verify_ir.sh +++ b/scripts/verify_ir.sh @@ -60,7 +60,7 @@ if [[ "$run_exec" == true ]]; then stdout_file="$out_dir/$stem.stdout" actual_file="$out_dir/$stem.actual.out" llc -filetype=obj "$out_file" -o "$obj" - clang "$obj" -o "$exe" + clang -no-pie "$obj" sylib/sylib.c -o "$exe" -lm echo "运行 $exe ..." set +e if [[ -f "$stdin_file" ]]; then @@ -81,7 +81,7 @@ if [[ "$run_exec" == true ]]; then } > "$actual_file" if [[ -f "$expected_file" ]]; then - if diff -u "$expected_file" "$actual_file"; then + if diff -u <(sed -e 's/\r$//' -e '$a\\' "$expected_file") <(sed -e 's/\r$//' -e '$a\\' "$actual_file"); then echo "输出匹配: $expected_file" else echo "输出不匹配: $expected_file" >&2 diff --git a/src/ir/Context.cpp b/src/ir/Context.cpp index 16c982c..5f32c65 100644 --- a/src/ir/Context.cpp +++ b/src/ir/Context.cpp @@ -17,7 +17,7 @@ ConstantInt* Context::GetConstInt(int v) { std::string Context::NextTemp() { std::ostringstream oss; - oss << "%" << ++temp_index_; + oss << "%t" << ++temp_index_; return oss.str(); } diff --git a/src/ir/Function.cpp b/src/ir/Function.cpp index dd9312c..a7f7cdb 100644 --- a/src/ir/Function.cpp +++ b/src/ir/Function.cpp @@ -3,12 +3,23 @@ // - 记录函数属性/元信息(按需要扩展) #include "ir/IR.h" +#include + +#include "utils/Log.h" + namespace ir { +Argument::Argument(std::shared_ptr ty, std::string name, size_t index) + : Value(std::move(ty), std::move(name)), arg_index_(index) {} + Function::Function(std::string name, std::shared_ptr ret_type, std::vector> param_types) : Value(std::move(ret_type), std::move(name)), param_types_(std::move(param_types)) { + for (size_t i = 0; i < param_types_.size(); ++i) { + args_.push_back(std::make_unique( + param_types_[i], "%arg" + std::to_string(i), i)); + } entry_ = CreateBlock("entry"); } @@ -33,6 +44,13 @@ const std::vector>& Function::GetParamTypes() const { size_t Function::GetNumParams() const { return param_types_.size(); } +Argument* Function::GetArgument(size_t index) const { + if (index >= args_.size()) { + throw std::out_of_range(FormatError("ir", "Argument 索引越界")); + } + return args_[index].get(); +} + const std::vector>& Function::GetBlocks() const { return blocks_; } diff --git a/src/ir/GlobalValue.cpp b/src/ir/GlobalValue.cpp index 7c2abe1..a492d26 100644 --- a/src/ir/GlobalValue.cpp +++ b/src/ir/GlobalValue.cpp @@ -1,5 +1,4 @@ -// GlobalValue 占位实现: -// - 具体的全局初始化器、打印和链接语义需要自行补全 +// GlobalValue / GlobalVariable 实现。 #include "ir/IR.h" @@ -8,4 +7,9 @@ namespace ir { GlobalValue::GlobalValue(std::shared_ptr ty, std::string name) : User(std::move(ty), std::move(name)) {} +GlobalVariable::GlobalVariable(std::string name, int init_val, int count) + : GlobalValue(Type::GetPtrInt32Type(), std::move(name)), + init_val_(init_val), + count_(count) {} + } // namespace ir diff --git a/src/ir/IRBuilder.cpp b/src/ir/IRBuilder.cpp index 0afc1ab..f21dd2e 100644 --- a/src/ir/IRBuilder.cpp +++ b/src/ir/IRBuilder.cpp @@ -82,6 +82,26 @@ AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { return insert_block_->Append(Type::GetPtrInt32Type(), name); } +AllocaInst* IRBuilder::CreateAllocaArray(int count, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (count <= 0) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateAllocaArray 数组大小必须为正数")); + } + return insert_block_->Append(Type::GetPtrInt32Type(), name, count); +} + +GepInst* IRBuilder::CreateGep(Value* base, Value* index, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!base || !index) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateGep 缺少操作数")); + } + return insert_block_->Append(Type::GetPtrInt32Type(), base, index, name); +} + LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); @@ -156,4 +176,11 @@ ReturnInst* IRBuilder::CreateRet(Value* v) { return insert_block_->Append(Type::GetVoidType(), v); } +ReturnInst* IRBuilder::CreateRetVoid() { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Type::GetVoidType(), nullptr); +} + } // namespace ir diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index 97ecd07..02d0b5f 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -52,6 +52,8 @@ static const char* OpcodeToString(Opcode op) { return "store"; case Opcode::Ret: return "ret"; + case Opcode::Gep: + return "getelementptr"; } return "?"; } @@ -78,14 +80,44 @@ static std::string ValueToString(const Value* v) { if (auto* ci = dynamic_cast(v)) { return std::to_string(ci->GetValue()); } + if (auto* gv = dynamic_cast(v)) { + return "@" + gv->GetName(); + } if (auto* func = dynamic_cast(v)) { return "@" + func->GetName(); } + if (auto* arg = dynamic_cast(v)) { + return arg->GetName(); + } return v ? v->GetName() : ""; } void IRPrinter::Print(const Module& module, std::ostream& os) { + // 先打印全局变量 + for (const auto& gv : module.GetGlobalVars()) { + if (!gv) continue; + if (gv->IsArray()) { + os << "@" << gv->GetName() << " = global [" << gv->GetCount() + << " x i32] zeroinitializer\n"; + } else { + os << "@" << gv->GetName() << " = global i32 " << gv->GetInitValue() << "\n"; + } + } + if (!module.GetGlobalVars().empty()) os << "\n"; + for (const auto& func : module.GetFunctions()) { + if (func->IsExternal()) { + // 外部函数声明:declare rettype @name(paramtypes) + os << "declare " << TypeToString(*func->GetType()) << " @" << func->GetName() << "("; + const auto& ptypes = func->GetParamTypes(); + for (size_t i = 0; i < ptypes.size(); ++i) { + if (i != 0) os << ", "; + os << TypeToString(*ptypes[i]); + } + os << ")\n"; + continue; + } + std::string params; const auto& param_types = func->GetParamTypes(); for (size_t i = 0; i < param_types.size(); ++i) { @@ -129,7 +161,12 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { } case Opcode::Alloca: { auto* alloca = static_cast(inst); - os << " " << alloca->GetName() << " = alloca i32\n"; + if (alloca->IsArray()) { + os << " " << alloca->GetName() << " = alloca i32, i32 " + << alloca->GetCount() << "\n"; + } else { + os << " " << alloca->GetName() << " = alloca i32\n"; + } break; } case Opcode::Load: { @@ -151,7 +188,7 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { } case Opcode::CondBr: { auto* cbr = static_cast(inst); - os << " br i32 " << ValueToString(cbr->GetCond()) << ", label %" + os << " br i1 " << ValueToString(cbr->GetCond()) << ", label %" << cbr->GetTrueBlock()->GetName() << ", label %" << cbr->GetFalseBlock()->GetName() << "\n"; break; @@ -175,10 +212,33 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { os << ")\n"; break; } + case Opcode::Gep: { + auto* gep = static_cast(inst); + auto* base = gep->GetBase(); + // 全局数组用双下标 GEP,局部 alloca 用平坦 GEP。 + if (auto* gv = dynamic_cast(base)) { + if (gv->IsArray()) { + os << " " << gep->GetName() + << " = getelementptr [" << gv->GetCount() << " x i32], [" + << gv->GetCount() << " x i32]* @" << gv->GetName() + << ", i32 0, i32 " << ValueToString(gep->GetIndex()) << "\n"; + break; + } + } + os << " " << gep->GetName() + << " = getelementptr i32, i32* " << ValueToString(base) + << ", i32 " << ValueToString(gep->GetIndex()) << "\n"; + break; + } case Opcode::Ret: { auto* ret = static_cast(inst); - os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " " - << ValueToString(ret->GetValue()) << "\n"; + auto* retval = ret->GetValue(); + if (!retval) { + os << " ret void\n"; + } else { + os << " ret " << TypeToString(*retval->GetType()) << " " + << ValueToString(retval) << "\n"; + } break; } } diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index 73abf27..3af84b8 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -10,6 +10,23 @@ #include "utils/Log.h" namespace ir { + +namespace { + +const char* TypeKindToString(Type::Kind k) { + switch (k) { + case Type::Kind::Void: + return "void"; + case Type::Kind::Int32: + return "i32"; + case Type::Kind::PtrInt32: + return "i32*"; + } + return "?"; +} + +} // namespace + User::User(std::shared_ptr ty, std::string name) : Value(std::move(ty), std::move(name)) {} @@ -127,7 +144,11 @@ CmpInst::CmpInst(CmpOp op, std::shared_ptr ty, Value* lhs, Value* rhs, throw std::runtime_error(FormatError("ir", "CmpInst 结果类型必须为 i32")); } if (!lhs->GetType()->IsInt32() || !rhs->GetType()->IsInt32()) { - throw std::runtime_error(FormatError("ir", "CmpInst 当前只支持 i32 比较")); + throw std::runtime_error(FormatError( + "ir", "CmpInst 当前只支持 i32 比较,实际为 " + + std::string(TypeKindToString(lhs->GetType()->GetKind())) + + " 与 " + + std::string(TypeKindToString(rhs->GetType()->GetKind())))); } AddOperand(lhs); AddOperand(rhs); @@ -141,22 +162,33 @@ Value* CmpInst::GetRhs() const { return GetOperand(1); } ReturnInst::ReturnInst(std::shared_ptr void_ty, Value* val) : Instruction(Opcode::Ret, std::move(void_ty), "") { - if (!val) { - throw std::runtime_error(FormatError("ir", "ReturnInst 缺少返回值")); - } if (!type_ || !type_->IsVoid()) { throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void")); } - AddOperand(val); + if (val) { + AddOperand(val); + } } -Value* ReturnInst::GetValue() const { return GetOperand(0); } +Value* ReturnInst::GetValue() const { + return GetNumOperands() > 0 ? GetOperand(0) : nullptr; +} AllocaInst::AllocaInst(std::shared_ptr ptr_ty, std::string name) - : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) { + : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)), count_(1) { + if (!type_ || !type_->IsPtrInt32()) { + throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*")); + } +} + +AllocaInst::AllocaInst(std::shared_ptr ptr_ty, std::string name, int count) + : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)), count_(count) { if (!type_ || !type_->IsPtrInt32()) { throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*")); } + if (count_ <= 0) { + throw std::runtime_error(FormatError("ir", "AllocaInst 数组大小必须为正数")); + } } LoadInst::LoadInst(std::shared_ptr val_ty, Value* ptr, std::string name) @@ -285,4 +317,23 @@ Value* CallInst::GetArg(size_t index) const { return GetOperand(index + 1); } +GepInst::GepInst(std::shared_ptr ptr_ty, Value* base, Value* index, + std::string name) + : Instruction(Opcode::Gep, std::move(ptr_ty), std::move(name)) { + if (!base || !index) { + throw std::runtime_error(FormatError("ir", "GepInst 缺少操作数")); + } + if (!base->GetType() || !base->GetType()->IsPtrInt32()) { + throw std::runtime_error(FormatError("ir", "GepInst base 必须为 i32*")); + } + if (!index->GetType() || !index->GetType()->IsInt32()) { + throw std::runtime_error(FormatError("ir", "GepInst index 必须为 i32")); + } + AddOperand(base); + AddOperand(index); +} + +Value* GepInst::GetBase() const { return GetOperand(0); } +Value* GepInst::GetIndex() const { return GetOperand(1); } + } // namespace ir diff --git a/src/ir/Module.cpp b/src/ir/Module.cpp index ba6c13a..e281a49 100644 --- a/src/ir/Module.cpp +++ b/src/ir/Module.cpp @@ -20,4 +20,27 @@ const std::vector>& Module::GetFunctions() const { return functions_; } +Function* Module::FindFunction(const std::string& name) const { + for (const auto& f : functions_) { + if (f && f->GetName() == name) return f.get(); + } + return nullptr; +} + +GlobalVariable* Module::CreateGlobalVar(const std::string& name, int init_val, int count) { + global_vars_.push_back(std::make_unique(name, init_val, count)); + return global_vars_.back().get(); +} + +GlobalVariable* Module::FindGlobalVar(const std::string& name) const { + for (const auto& gv : global_vars_) { + if (gv && gv->GetName() == name) return gv.get(); + } + return nullptr; +} + +const std::vector>& Module::GetGlobalVars() const { + return global_vars_; +} + } // namespace ir diff --git a/src/irgen/CMakeLists.txt b/src/irgen/CMakeLists.txt index d440bde..3282ae0 100644 --- a/src/irgen/CMakeLists.txt +++ b/src/irgen/CMakeLists.txt @@ -4,6 +4,7 @@ add_library(irgen STATIC IRGenStmt.cpp IRGenExp.cpp IRGenDecl.cpp + IRGenConstEval.cpp ) target_link_libraries(irgen PUBLIC diff --git a/src/irgen/IRGenConstEval.cpp b/src/irgen/IRGenConstEval.cpp new file mode 100644 index 0000000..f50f6a5 --- /dev/null +++ b/src/irgen/IRGenConstEval.cpp @@ -0,0 +1,95 @@ +#include "irgen/IRGen.h" + +#include +#include + +#include "SysYParser.h" +#include "utils/Log.h" + +// 内部辅助:不依赖类成员,只需 ConstEnv。 +namespace { + +int EvalAddExp(SysYParser::AddExpContext* ctx, + const IRGenImpl::ConstEnv& env); +int EvalMulExp(SysYParser::MulExpContext* ctx, + const IRGenImpl::ConstEnv& env); +int EvalUnaryExp(SysYParser::UnaryExpContext* ctx, + const IRGenImpl::ConstEnv& env); + +int EvalPrimary(SysYParser::PrimaryExpContext* ctx, + const IRGenImpl::ConstEnv& env) { + if (!ctx) throw std::runtime_error(FormatError("consteval", "空主表达式")); + if (ctx->number()) { + if (!ctx->number()->ILITERAL()) + throw std::runtime_error( + FormatError("consteval", "constExp 不支持浮点字面量")); + return std::stoi(ctx->number()->getText()); + } + if (ctx->exp()) return EvalAddExp(ctx->exp()->addExp(), env); + if (ctx->lValue()) { + if (!ctx->lValue()->ID()) + throw std::runtime_error(FormatError("consteval", "非法 lValue")); + const std::string name = ctx->lValue()->ID()->getText(); + auto it = env.find(name); + if (it == env.end()) + throw std::runtime_error( + FormatError("consteval", "constExp 引用非 const 变量: " + name)); + return it->second; + } + throw std::runtime_error(FormatError("consteval", "不支持的主表达式形式")); +} + +int EvalUnaryExp(SysYParser::UnaryExpContext* ctx, + const IRGenImpl::ConstEnv& env) { + if (!ctx) throw std::runtime_error(FormatError("consteval", "空一元表达式")); + if (ctx->primaryExp()) return EvalPrimary(ctx->primaryExp(), env); + if (ctx->unaryOp() && ctx->unaryExp()) { + int v = EvalUnaryExp(ctx->unaryExp(), env); + if (ctx->unaryOp()->SUB()) return -v; + if (ctx->unaryOp()->ADD()) return v; + if (ctx->unaryOp()->NOT()) return (v == 0) ? 1 : 0; + } + throw std::runtime_error( + FormatError("consteval", "函数调用不能出现在 constExp 中")); +} + +int EvalMulExp(SysYParser::MulExpContext* ctx, + const IRGenImpl::ConstEnv& env) { + if (!ctx) throw std::runtime_error(FormatError("consteval", "空乘法表达式")); + if (ctx->mulExp()) { + int lhs = EvalMulExp(ctx->mulExp(), env); + int rhs = EvalUnaryExp(ctx->unaryExp(), env); + if (ctx->MUL()) return lhs * rhs; + if (ctx->DIV()) { if (!rhs) throw std::runtime_error("除以零"); return lhs / rhs; } + if (ctx->MOD()) { if (!rhs) throw std::runtime_error("模零"); return lhs % rhs; } + throw std::runtime_error(FormatError("consteval", "未知乘法运算符")); + } + return EvalUnaryExp(ctx->unaryExp(), env); +} + +int EvalAddExp(SysYParser::AddExpContext* ctx, + const IRGenImpl::ConstEnv& env) { + if (!ctx) throw std::runtime_error(FormatError("consteval", "空加法表达式")); + if (ctx->addExp()) { + int lhs = EvalAddExp(ctx->addExp(), env); + int rhs = EvalMulExp(ctx->mulExp(), env); + if (ctx->ADD()) return lhs + rhs; + if (ctx->SUB()) return lhs - rhs; + throw std::runtime_error(FormatError("consteval", "未知加法运算符")); + } + return EvalMulExp(ctx->mulExp(), env); +} + +} // namespace + +int IRGenImpl::EvalConstExpr(SysYParser::ConstExpContext* ctx) const { + if (!ctx || !ctx->addExp()) + throw std::runtime_error(FormatError("consteval", "空 constExp")); + return EvalAddExp(ctx->addExp(), const_env_); +} + +int IRGenImpl::EvalExpAsConst(SysYParser::ExpContext* ctx) const { + if (!ctx || !ctx->addExp()) + throw std::runtime_error(FormatError("consteval", "空 exp")); + return EvalAddExp(ctx->addExp(), const_env_); +} diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index 75cfdf0..10bc43a 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -6,17 +6,6 @@ #include "ir/IR.h" #include "utils/Log.h" -namespace { - -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { - throw std::runtime_error(FormatError("irgen", "非法左值")); - } - return lvalue.ID()->getText(); -} - -} // namespace - std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句块")); @@ -24,7 +13,6 @@ std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { for (auto* item : ctx->blockItem()) { if (item) { if (VisitBlockItemResult(*item) == BlockFlow::Terminated) { - // 当前语法要求 return 为块内最后一条语句;命中后可停止生成。 break; } } @@ -51,31 +39,179 @@ std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明")); } -// 变量声明的 IR 生成目前也是最小实现: -// - 先检查声明的基础类型,当前仅支持局部 int; -// - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。 -// -// 和更完整的版本相比,这里还没有: -// - 一个 Decl 中多个变量定义的顺序处理; -// - const、数组、全局变量等不同声明形态; -// - 更丰富的类型系统。 std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少变量声明")); } - if (!ctx->varDecl()) { - // 当前先忽略 constDecl 与其它声明形态。 + if (ctx->constDecl()) { + return ctx->constDecl()->accept(this); + } + if (ctx->varDecl()) { + return ctx->varDecl()->accept(this); + } + return {}; +} + +// ─── 工具:扁平化 constInitValue ────────────────────────────────────────── +// 将嵌套的 const 初始化列表展开为长度 total 的整数数组。 +// 遵循 C99 数组初始化规则: +// - 标量直接填一格 +// - 大括号子列表对齐到 sub_size 边界,填满后补零 + +void IRGenImpl::FlattenConstInit(SysYParser::ConstInitValueContext* ctx, + const std::vector& dims, int dim_idx, + std::vector& out, int& pos) { + if (!ctx) return; + + if (ctx->constExp()) { + // 标量叶节点 + out[pos++] = EvalConstExpr(ctx->constExp()); + return; + } + + // 大括号列表 + int sub_size = 1; + for (int i = dim_idx + 1; i < (int)dims.size(); i++) sub_size *= dims[i]; + int agg_size = (dim_idx < (int)dims.size()) ? dims[dim_idx] * sub_size : 1; + int start = pos; + + for (auto* item : ctx->constInitValue()) { + if (!item || pos >= start + agg_size) break; + if (item->constExp()) { + // 标量:直接填当前位置 + out[pos++] = EvalConstExpr(item->constExp()); + } else { + // 嵌套大括号:对齐到 sub_size 边界 + if (sub_size > 1) { + int offset = pos - start; + int rem = offset % sub_size; + if (rem != 0) pos += sub_size - rem; + } + int sub_start = pos; + FlattenConstInit(item, dims, dim_idx + 1, out, pos); + // 补零到子聚合末尾 + int sub_end = sub_start + sub_size; + while (pos < sub_end && pos < start + agg_size) out[pos++] = 0; + } + } + // 剩余补零 + while (pos < start + agg_size) out[pos++] = 0; +} + +// ─── 工具:扁平化 initValue ─────────────────────────────────────────────── +void IRGenImpl::FlattenInit(SysYParser::InitValueContext* ctx, + const std::vector& dims, int dim_idx, + std::vector& out, int& pos) { + if (!ctx) return; + + if (ctx->exp()) { + out[pos++] = EvalExpr(*ctx->exp()); + return; + } + + int sub_size = 1; + for (int i = dim_idx + 1; i < (int)dims.size(); i++) sub_size *= dims[i]; + int agg_size = (dim_idx < (int)dims.size()) ? dims[dim_idx] * sub_size : 1; + int start = pos; + + for (auto* item : ctx->initValue()) { + if (!item || pos >= start + agg_size) break; + if (item->exp()) { + out[pos++] = EvalExpr(*item->exp()); + } else { + if (sub_size > 1) { + int offset = pos - start; + int rem = offset % sub_size; + if (rem != 0) pos += sub_size - rem; // zeros already in out + } + int sub_start = pos; + FlattenInit(item, dims, dim_idx + 1, out, pos); + int sub_end = sub_start + sub_size; + while (pos < sub_end && pos < start + agg_size) pos++; // zeros + } + } + while (pos < start + agg_size) pos++; // zeros +} + +// ─── const 声明 ─────────────────────────────────────────────────────────── + +std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) { + if (!ctx) return {}; + if (!ctx->btype() || !ctx->btype()->INT()) { + throw std::runtime_error(FormatError("irgen", "当前仅支持 int const 声明")); + } + for (auto* def : ctx->constDef()) { + if (def) def->accept(this); + } + return {}; +} + +std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) { + if (!ctx || !ctx->ID()) return {}; + const std::string name = ctx->ID()->getText(); + + // ── 标量 const ──────────────────────────────────────────────────────── + if (ctx->LBRACK().empty()) { + if (!ctx->constInitValue() || !ctx->constInitValue()->constExp()) { + throw std::runtime_error(FormatError("irgen", "const 标量声明缺少初始值")); + } + int ival = EvalConstExpr(ctx->constInitValue()->constExp()); + const_env_[name] = ival; // 存入编译期环境 + + if (IsGlobalScope()) { + auto* gv = module_.CreateGlobalVar(name, ival); + global_storage_[name] = gv; + } else { + auto* slot = CreateEntryAllocaI32(module_.GetContext().NextTemp()); + named_storage_[name] = slot; + builder_.CreateStore(builder_.CreateConstInt(ival), slot); + } return {}; } - return ctx->varDecl()->accept(this); + + // ── 数组 const ──────────────────────────────────────────────────────── + std::vector dims; + for (auto* ce : ctx->constExp()) { + dims.push_back(EvalConstExpr(ce)); + } + int total = 1; + for (int d : dims) total *= d; + + // 扁平化初始化值 + std::vector flat(total, 0); + if (ctx->constInitValue()) { + int pos = 0; + FlattenConstInit(ctx->constInitValue(), dims, 0, flat, pos); + } + + if (IsGlobalScope()) { + // 全局 const 数组:创建全局数组变量(仅支持零初始化;非零初始化暂用零) + // TODO: 支持全局 const 数组的非零初始化 + auto* gv = module_.CreateGlobalVar(name, 0, total); + global_storage_[name] = gv; + global_array_dims_[name] = dims; + } else { + // 局部 const 数组:alloca + 逐元素 store + auto* slot = CreateEntryAllocaArray(total, module_.GetContext().NextTemp()); + named_storage_[name] = slot; + local_array_dims_[name] = dims; + for (int i = 0; i < total; i++) { + auto* idx = builder_.CreateConstInt(i); + auto* ptr = builder_.CreateGep(slot, idx, module_.GetContext().NextTemp()); + builder_.CreateStore(builder_.CreateConstInt(flat[i]), ptr); + } + } + return {}; } +// ─── var 声明 ───────────────────────────────────────────────────────────── + std::any IRGenImpl::visitVarDecl(SysYParser::VarDeclContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少变量声明")); } if (!ctx->btype() || !ctx->btype()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明")); + throw std::runtime_error(FormatError("irgen", "当前仅支持 int 变量声明")); } for (auto* var_def : ctx->varDef()) { if (!var_def) { @@ -86,30 +222,84 @@ std::any IRGenImpl::visitVarDecl(SysYParser::VarDeclContext* ctx) { return {}; } - -// 当前仍是教学用的最小版本,因此这里只支持: -// - 局部 int 变量; -// - 标量初始化; -// - 一个 VarDef 对应一个槽位。 std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少变量定义")); - } - if (!ctx->ID()) { - throw std::runtime_error(FormatError("irgen", "变量声明缺少名称")); + if (!ctx || !ctx->ID()) { + throw std::runtime_error(FormatError("irgen", "变量定义缺少名称")); } const std::string name = ctx->ID()->getText(); + + // ── 数组变量 ────────────────────────────────────────────────────────── + if (!ctx->LBRACK().empty()) { + std::vector dims; + for (auto* ce : ctx->constExp()) { + dims.push_back(EvalConstExpr(ce)); + } + int total = 1; + for (int d : dims) total *= d; + + if (IsGlobalScope()) { + auto* gv = module_.CreateGlobalVar(name, 0, total); + storage_map_[ctx] = gv; + global_storage_[name] = gv; + global_array_dims_[name] = dims; + // 全局数组:不支持运行时初始化(全零已足够) + } else { + auto* slot = CreateEntryAllocaArray(total, module_.GetContext().NextTemp()); + storage_map_[ctx] = slot; + named_storage_[name] = slot; + local_array_dims_[name] = dims; + + // 先零初始化 + for (int i = 0; i < total; i++) { + auto* idx = builder_.CreateConstInt(i); + auto* ptr = builder_.CreateGep(slot, idx, module_.GetContext().NextTemp()); + builder_.CreateStore(builder_.CreateConstInt(0), ptr); + } + // 如果有初始化列表,覆盖零 + if (auto* init_val = ctx->initValue()) { + std::vector flat(total, nullptr); + int pos = 0; + FlattenInit(init_val, dims, 0, flat, pos); + for (int i = 0; i < total; i++) { + if (flat[i] != nullptr) { + auto* idx = builder_.CreateConstInt(i); + auto* ptr = builder_.CreateGep(slot, idx, module_.GetContext().NextTemp()); + builder_.CreateStore(flat[i], ptr); + } + } + } + } + return {}; + } + + // ── 标量变量 ────────────────────────────────────────────────────────── + if (IsGlobalScope()) { + int ival = 0; + if (auto* init_value = ctx->initValue()) { + if (!init_value->exp()) { + throw std::runtime_error( + FormatError("irgen", "全局标量变量仅支持表达式初始化")); + } + ival = EvalExpAsConst(init_value->exp()); + } + auto* gv = module_.CreateGlobalVar(name, ival); + storage_map_[ctx] = gv; + global_storage_[name] = gv; + return {}; + } + + // 局部标量 if (storage_map_.find(ctx) != storage_map_.end()) { throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); } - auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); + auto* slot = CreateEntryAllocaI32(module_.GetContext().NextTemp()); storage_map_[ctx] = slot; named_storage_[name] = slot; ir::Value* init = nullptr; if (auto* init_value = ctx->initValue()) { if (!init_value->exp()) { - throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化")); + throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化到标量")); } init = EvalExpr(*init_value->exp()); } else { diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index 7e22485..70f48fd 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -6,20 +6,6 @@ #include "ir/IR.h" #include "utils/Log.h" -// 表达式生成当前也只实现了很小的一个子集。 -// 目前支持: -// - 整数字面量 -// - 普通局部变量读取 -// - 括号表达式 -// - 二元加法 -// -// 还未支持: -// - 减乘除与一元运算 -// - 赋值表达式 -// - 函数调用 -// - 数组、指针、下标访问 -// - 条件与比较表达式 -// - ... ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { return std::any_cast(expr.accept(this)); } @@ -32,6 +18,13 @@ ir::Value* IRGenImpl::ToBoolValue(ir::Value* v) { if (!v) { throw std::runtime_error(FormatError("irgen", "条件值为空")); } + if (v->GetType() && v->GetType()->IsPtrInt32()) { + // SysY 中数组名退化得到的指针在当前实现里总是非空。 + return builder_.CreateConstInt(1); + } + if (dynamic_cast(v) != nullptr) { + return v; + } auto* zero = builder_.CreateConstInt(0); return builder_.CreateCmp(ir::CmpOp::Ne, v, zero, module_.GetContext().NextTemp()); } @@ -44,6 +37,43 @@ std::string IRGenImpl::NextBlockName() { return "bb" + temp; } +// ─── 数组维度查找 ──────────────────────────────────────────────────────── +const std::vector* IRGenImpl::FindArrayDims(const std::string& name) const { + auto it = local_array_dims_.find(name); + if (it != local_array_dims_.end()) return &it->second; + // 局部同名标量(含形参/局部变量)应屏蔽全局数组维度信息。 + if (named_storage_.find(name) != named_storage_.end()) return nullptr; + auto git = global_array_dims_.find(name); + if (git != global_array_dims_.end()) return &git->second; + return nullptr; +} + +// ─── 线性下标计算 ──────────────────────────────────────────────────────── +// 给定维度 dims 和下标表达式列表,计算 linear = sum(subs[k] * stride[k])。 +ir::Value* IRGenImpl::ComputeLinearIndex( + const std::vector& dims, + const std::vector& subs) { + // 对于 dims=[d0,d1,...,dn-1],stride[k] = d_{k+1} * ... * d_{n-1} + // 允许 dims[0] == -1(数组参数首维未知) + ir::Value* linear = builder_.CreateConstInt(0); + for (int k = 0; k < (int)subs.size() && k < (int)dims.size(); k++) { + int stride = 1; + for (int j = k + 1; j < (int)dims.size(); j++) stride *= dims[j]; + + ir::Value* idx = EvalExpr(*subs[k]); + if (stride != 1) { + auto* sv = builder_.CreateConstInt(stride); + idx = builder_.CreateMul(idx, sv, module_.GetContext().NextTemp()); + } + linear = (stride == 1 && k == (int)subs.size() - 1 && + dynamic_cast(linear) && + static_cast(linear)->GetValue() == 0) + ? idx + : builder_.CreateAdd(linear, idx, module_.GetContext().NextTemp()); + } + return linear; +} + std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) { if (!ctx || !ctx->addExp()) { throw std::runtime_error(FormatError("irgen", "非法表达式")); @@ -78,28 +108,104 @@ std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) { if (!ctx || !ctx->ILITERAL()) { throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量")); } - return static_cast( - builder_.CreateConstInt(std::stoi(ctx->getText()))); + // 支持十六进制和八进制字面量 + const std::string text = ctx->getText(); + int val = 0; + if (text.size() >= 2 && text[0] == '0' && + (text[1] == 'x' || text[1] == 'X')) { + val = std::stoi(text, nullptr, 16); + } else if (text.size() > 1 && text[0] == '0') { + val = std::stoi(text, nullptr, 8); + } else { + val = std::stoi(text); + } + return static_cast(builder_.CreateConstInt(val)); } -// 变量使用的处理流程: -// 1. 先通过语义分析结果把变量使用绑定回声明; -// 2. 再通过 storage_map_ 找到该声明对应的栈槽位; -// 3. 最后生成 load,把内存中的值读出来。 -// -// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。 +// ─── 变量存储槽位查找(含下标 GEP)──────────────────────────────────────── +// 返回 i32* 指针: +// - 无下标:直接返回 alloca/arg/globalvar 槽位 +// - 有下标:计算线性偏移并生成 GEP 指令,返回元素指针 +ir::Value* IRGenImpl::ResolveStorage(SysYParser::LValueContext* lvalue) { + if (!lvalue || !lvalue->ID()) return nullptr; + const std::string name = lvalue->ID()->getText(); + + // 获取基础槽位(三级查找) + ir::Value* base = nullptr; + + // 1. sema binding(处理同名变量遮蔽) + auto* decl = sema_.ResolveVarUse(lvalue); + if (decl) { + auto it = storage_map_.find(decl); + if (it != storage_map_.end()) base = it->second; + } + if (!base) { + auto it = named_storage_.find(name); + if (it != named_storage_.end()) base = it->second; + } + if (!base) { + auto git = global_storage_.find(name); + if (git != global_storage_.end()) base = git->second; + } + + if (!base) return nullptr; + + // 无下标:直接返回槽位 + if (lvalue->exp().empty()) return base; + + // 有下标:计算线性 GEP + const std::vector* dims = FindArrayDims(name); + if (!dims) { + throw std::runtime_error( + FormatError("irgen", "未找到数组维度信息: " + name)); + } + + ir::Value* linear = ComputeLinearIndex(*dims, lvalue->exp()); + return builder_.CreateGep(base, linear, module_.GetContext().NextTemp()); +} + +// ─── lValue 访问 ───────────────────────────────────────────────────────── std::any IRGenImpl::visitLValue(SysYParser::LValueContext* ctx) { if (!ctx || !ctx->ID()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); + throw std::runtime_error(FormatError("irgen", "非法左值")); } const std::string name = ctx->ID()->getText(); - auto it = named_storage_.find(name); - if (it == named_storage_.end()) { + + if (ctx->exp().empty()) { + // 无下标:标量读取 或 数组基址引用 + ir::Value* slot = ResolveStorage(ctx); + if (!slot) { + throw std::runtime_error( + FormatError("irgen", "变量未找到存储槽位: " + name)); + } + // 如果是数组名,返回基址指针(用于传参)。 + // 全局数组需要先退化为首元素指针,避免直接把 [N x i32]* 传给 i32* 形参。 + if (FindArrayDims(name) != nullptr) { + if (auto* gv = dynamic_cast(slot); gv && gv->IsArray()) { + return static_cast( + builder_.CreateGep(slot, builder_.CreateConstInt(0), + module_.GetContext().NextTemp())); + } + return static_cast(slot); + } + // 标量:加载值 + return static_cast( + builder_.CreateLoad(slot, module_.GetContext().NextTemp())); + } + + // 有下标:GEP + load + ir::Value* elem_ptr = ResolveStorage(ctx); + if (!elem_ptr) { throw std::runtime_error( - FormatError("irgen", "变量声明缺少存储槽位: " + name)); + FormatError("irgen", "数组元素指针解析失败: " + name)); + } + const auto* dims = FindArrayDims(name); + if (dims && ctx->exp().size() < dims->size()) { + // 如 A[i](A 为二维数组)应退化为指针,用于实参传递。 + return static_cast(elem_ptr); } return static_cast( - builder_.CreateLoad(it->second, module_.GetContext().NextTemp())); + builder_.CreateLoad(elem_ptr, module_.GetContext().NextTemp())); } std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { @@ -119,9 +225,34 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { if (ctx->unaryOp()->ADD()) { return v; } - throw std::runtime_error(FormatError("irgen", "当前不支持逻辑非运算")); + if (ctx->unaryOp()->NOT()) { + // !v ≡ (v == 0) + auto* zero = builder_.CreateConstInt(0); + return static_cast(builder_.CreateCmp( + ir::CmpOp::Eq, v, zero, module_.GetContext().NextTemp())); + } + throw std::runtime_error(FormatError("irgen", "未知一元运算符")); + } + if (ctx->ID()) { + // 函数调用:ID '(' funcRParams? ')' + const std::string callee_name = ctx->ID()->getText(); + ir::Function* callee = module_.FindFunction(callee_name); + if (!callee) { + throw std::runtime_error( + FormatError("irgen", "未定义的函数: " + callee_name)); + } + std::vector args; + if (auto* rparams = ctx->funcRParams()) { + for (auto* ep : rparams->exp()) { + args.push_back(EvalExpr(*ep)); + } + } + const std::string name = + callee->GetType()->IsVoid() ? "" : module_.GetContext().NextTemp(); + return static_cast( + builder_.CreateCall(callee, args, name)); } - throw std::runtime_error(FormatError("irgen", "当前不支持函数调用表达式")); + throw std::runtime_error(FormatError("irgen", "非法一元表达式")); } std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { @@ -248,10 +379,34 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { if (!ctx->eqExp()) { throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); } - auto* lhs = ToBoolValue(std::any_cast(ctx->lAndExp()->accept(this))); - auto* rhs = ToBoolValue(std::any_cast(ctx->eqExp()->accept(this))); + // 短路求值:a && b + // 使用函数级临时槽位(0=false,1=true),避免 phi 依赖和循环内动态 alloca。 + if (!short_circuit_slot_) { + throw std::runtime_error(FormatError("irgen", "短路求值槽位未初始化")); + } + auto* slot = short_circuit_slot_; + builder_.CreateStore(builder_.CreateConstInt(0), slot); + + auto* lhs = std::any_cast(ctx->lAndExp()->accept(this)); + auto* lhs_bool = ToBoolValue(lhs); + auto* rhs_bb = func_->CreateBlock(NextBlockName()); + auto* true_bb = func_->CreateBlock(NextBlockName()); + auto* merge_bb = func_->CreateBlock(NextBlockName()); + + builder_.CreateCondBr(lhs_bool, rhs_bb, merge_bb); + + builder_.SetInsertPoint(rhs_bb); + auto* rhs = std::any_cast(ctx->eqExp()->accept(this)); + auto* rhs_bool = ToBoolValue(rhs); + builder_.CreateCondBr(rhs_bool, true_bb, merge_bb); + + builder_.SetInsertPoint(true_bb); + builder_.CreateStore(builder_.CreateConstInt(1), slot); + builder_.CreateBr(merge_bb); + + builder_.SetInsertPoint(merge_bb); return static_cast( - builder_.CreateMul(lhs, rhs, module_.GetContext().NextTemp())); + builder_.CreateLoad(slot, module_.GetContext().NextTemp())); } if (ctx->eqExp()) { return ToBoolValue(std::any_cast(ctx->eqExp()->accept(this))); @@ -267,10 +422,38 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { if (!ctx->lAndExp()) { throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); } - auto* lhs = ToBoolValue(std::any_cast(ctx->lOrExp()->accept(this))); - auto* rhs = ToBoolValue(std::any_cast(ctx->lAndExp()->accept(this))); - auto* sum = builder_.CreateAdd(lhs, rhs, module_.GetContext().NextTemp()); - return static_cast(ToBoolValue(sum)); + // 短路求值:a || b + if (!short_circuit_slot_) { + throw std::runtime_error(FormatError("irgen", "短路求值槽位未初始化")); + } + auto* slot = short_circuit_slot_; + builder_.CreateStore(builder_.CreateConstInt(0), slot); + + auto* lhs = std::any_cast(ctx->lOrExp()->accept(this)); + auto* lhs_bool = ToBoolValue(lhs); + auto* true_bb = func_->CreateBlock(NextBlockName()); + auto* rhs_bb = func_->CreateBlock(NextBlockName()); + auto* merge_bb = func_->CreateBlock(NextBlockName()); + + builder_.CreateCondBr(lhs_bool, true_bb, rhs_bb); + + builder_.SetInsertPoint(true_bb); + builder_.CreateStore(builder_.CreateConstInt(1), slot); + builder_.CreateBr(merge_bb); + + builder_.SetInsertPoint(rhs_bb); + auto* rhs = std::any_cast(ctx->lAndExp()->accept(this)); + auto* rhs_bool = ToBoolValue(rhs); + auto* true2_bb = func_->CreateBlock(NextBlockName()); + builder_.CreateCondBr(rhs_bool, true2_bb, merge_bb); + + builder_.SetInsertPoint(true2_bb); + builder_.CreateStore(builder_.CreateConstInt(1), slot); + builder_.CreateBr(merge_bb); + + builder_.SetInsertPoint(merge_bb); + return static_cast( + builder_.CreateLoad(slot, module_.GetContext().NextTemp())); } if (ctx->lAndExp()) { return ToBoolValue(std::any_cast(ctx->lAndExp()->accept(this))); diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index 737563d..e7b1c0a 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -27,44 +27,90 @@ IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema) func_(nullptr), builder_(module.GetContext(), nullptr) {} -// 编译单元的 IR 生成当前只实现了最小功能: -// - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容; -// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR; -// -// 当前还没有实现: -// - 多个函数定义的遍历与生成; -// - 全局变量、全局常量的 IR 生成。 +ir::AllocaInst* IRGenImpl::CreateEntryAllocaI32(const std::string& name) { + if (!func_) { + throw std::runtime_error(FormatError("irgen", "局部 alloca 必须位于函数内")); + } + auto* saved = builder_.GetInsertBlock(); + builder_.SetInsertPoint(func_->GetEntry()); + auto* slot = builder_.CreateAllocaI32(name); + builder_.SetInsertPoint(saved); + return slot; +} + +ir::AllocaInst* IRGenImpl::CreateEntryAllocaArray(int count, const std::string& name) { + if (!func_) { + throw std::runtime_error(FormatError("irgen", "局部 alloca 必须位于函数内")); + } + auto* saved = builder_.GetInsertBlock(); + builder_.SetInsertPoint(func_->GetEntry()); + auto* slot = builder_.CreateAllocaArray(count, name); + builder_.SetInsertPoint(saved); + return slot; +} + +// 预声明 SysY 运行时外部函数(putint / putch / getint / getch 等)。 +void IRGenImpl::DeclareRuntimeFunctions() { + auto i32 = ir::Type::GetInt32Type(); + auto void_ = ir::Type::GetVoidType(); + + auto decl = [&](const std::string& name, + std::shared_ptr ret, + std::vector> params) { + if (!module_.FindFunction(name)) { + auto* f = module_.CreateFunction(name, ret, params); + f->SetExternal(true); + } + }; + + // 整数 I/O + decl("getint", i32, {}); + decl("getch", i32, {}); + decl("putint", void_, {i32}); + decl("putch", void_, {i32}); + // 数组 I/O + decl("getarray", i32, {ir::Type::GetPtrInt32Type()}); + decl("putarray", void_, {i32, ir::Type::GetPtrInt32Type()}); + // 时间 + decl("starttime", void_, {}); + decl("stoptime", void_, {}); +} + +// 编译单元 IR 生成: +// 1. 预声明 SysY runtime; +// 2. 处理全局变量/常量声明; +// 3. 生成各函数 IR。 std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少编译单元")); } - if (ctx->funcDef().empty()) { - throw std::runtime_error(FormatError("irgen", "缺少函数定义")); + + DeclareRuntimeFunctions(); + + // 全局声明(func_ == nullptr 时 visitVarDef/visitConstDef 会走全局路径) + for (auto* decl : ctx->decl()) { + if (decl) decl->accept(this); } + for (auto* func : ctx->funcDef()) { - if (func) { - func->accept(this); - } + if (func) func->accept(this); } return {}; } // 函数 IR 生成当前实现了: // 1. 获取函数名; -// 2. 检查函数返回类型; -// 3. 在 Module 中创建 Function; -// 4. 将 builder 插入点设置到入口基本块; -// 5. 继续生成函数体。 +// 2. 支持 int 与 void 返回类型; +// 3. 支持 int 形参:入口处为每个参数 alloca + store; +// 4. 在 Module 中创建 Function; +// 5. 将 builder 插入点设置到入口基本块; +// 6. 继续生成函数体。 // // 当前还没有实现: -// - 通用函数返回类型处理; -// - 形参列表遍历与参数类型收集; -// - FunctionType 这样的函数类型对象; -// - Argument/形式参数 IR 对象; -// - 入口块中的参数初始化逻辑。 -// ... - -// 因此这里目前只支持最小的“无参 int 函数”生成。 +// - float 参数/返回类型; +// - 数组类型形参; +// - FunctionType 这样的函数类型对象(参数类型目前只用 shared_ptr)。 + std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少函数定义")); @@ -75,17 +121,97 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { if (!ctx->ID()) { throw std::runtime_error(FormatError("irgen", "缺少函数名")); } - if (!ctx->funcType() || !ctx->funcType()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数")); + if (!ctx->funcType()) { + throw std::runtime_error(FormatError("irgen", "缺少函数返回类型")); } - func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type()); - builder_.SetInsertPoint(func_->GetEntry()); + std::shared_ptr ret_type; + if (ctx->funcType()->INT()) { + ret_type = ir::Type::GetInt32Type(); + } else if (ctx->funcType()->VOID()) { + ret_type = ir::Type::GetVoidType(); + } else { + throw std::runtime_error(FormatError("irgen", "当前仅支持 int/void 返回类型")); + } + + // 收集形参类型(支持 int 标量和 int 数组参数)。 + std::vector> param_types; + std::vector param_names; + std::vector param_is_array; + + if (auto* fparams = ctx->funcFParams()) { + for (auto* fp : fparams->funcFParam()) { + if (!fp || !fp->btype() || !fp->btype()->INT()) { + throw std::runtime_error( + FormatError("irgen", "当前仅支持 int 类型形参")); + } + bool is_arr = !fp->LBRACK().empty(); + param_is_array.push_back(is_arr); + param_types.push_back(is_arr ? ir::Type::GetPtrInt32Type() + : ir::Type::GetInt32Type()); + param_names.push_back(fp->ID() ? fp->ID()->getText() : ""); + } + } + + func_ = module_.CreateFunction(ctx->ID()->getText(), ret_type, param_types); + auto* body_entry = func_->CreateBlock(NextBlockName()); + builder_.SetInsertPoint(body_entry); storage_map_.clear(); named_storage_.clear(); + local_array_dims_.clear(); + + // 第二遍:处理形参(现在有插入点,可以生成 alloca 等) + auto* fparams = ctx->funcFParams(); + for (size_t i = 0; i < param_names.size(); ++i) { + auto* arg = func_->GetArgument(i); + if (param_is_array[i]) { + // 数组参数:直接存入 named_storage_,维度用 EvalExpAsConst 获取 + if (!param_names[i].empty()) { + named_storage_[param_names[i]] = arg; + std::vector dims = {-1}; // 首维未知 + if (fparams) { + auto fp_list = fparams->funcFParam(); + if (i < fp_list.size()) { + for (auto* dim_exp : fp_list[i]->exp()) { + dims.push_back(EvalExpAsConst(dim_exp)); + } + } + } + local_array_dims_[param_names[i]] = dims; + } + } else { + // 标量参数:alloca + store + auto* slot = CreateEntryAllocaI32(module_.GetContext().NextTemp()); + builder_.CreateStore(arg, slot); + if (!param_names[i].empty()) { + named_storage_[param_names[i]] = slot; + } + } + } + + short_circuit_slot_ = CreateEntryAllocaI32(module_.GetContext().NextTemp()); ctx->blockStmt()->accept(this); + + // 入口块只用于静态栈槽分配,末尾统一跳到函数体起始块。 + auto* saved = builder_.GetInsertBlock(); + builder_.SetInsertPoint(func_->GetEntry()); + if (!func_->GetEntry()->HasTerminator()) { + builder_.CreateBr(body_entry); + } + builder_.SetInsertPoint(saved); + + // 对于 void 函数,若末尾块无 terminator,自动补 ret void。 + if (ret_type->IsVoid()) { + auto* bb = builder_.GetInsertBlock(); + if (bb && !bb->HasTerminator()) { + builder_.CreateRetVoid(); + } + } + // 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。 VerifyFunctionStructure(*func_); + short_circuit_slot_ = nullptr; + func_ = nullptr; // 回到全局作用域 return {}; } diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index 61ad87e..f555726 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -20,16 +20,16 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句")); } if (ctx->lValue() && ctx->ASSIGN() && ctx->exp()) { - if (!ctx->lValue()->ID()) { - throw std::runtime_error(FormatError("irgen", "赋值语句左值非法")); - } - const std::string name = ctx->lValue()->ID()->getText(); - auto slot_it = named_storage_.find(name); - if (slot_it == named_storage_.end()) { - throw std::runtime_error(FormatError("irgen", "赋值目标未声明: " + name)); - } ir::Value* rhs = EvalExpr(*ctx->exp()); - builder_.CreateStore(rhs, slot_it->second); + ir::Value* slot = ResolveStorage(ctx->lValue()); + if (!slot) { + throw std::runtime_error( + FormatError("irgen", "赋值目标未找到存储槽位: " + + (ctx->lValue()->ID() + ? ctx->lValue()->ID()->getText() + : "?"))); + } + builder_.CreateStore(rhs, slot); return BlockFlow::Continue; } if (ctx->blockStmt()) { @@ -51,18 +51,28 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { builder_.SetInsertPoint(then_bb); auto then_flow = std::any_cast(ctx->stmt(0)->accept(this)); + bool then_term = (then_flow == BlockFlow::Terminated); if (then_flow != BlockFlow::Terminated) { builder_.CreateBr(merge_bb); } + bool else_term = false; if (ctx->ELSE()) { builder_.SetInsertPoint(else_bb); auto else_flow = std::any_cast(ctx->stmt(1)->accept(this)); + else_term = (else_flow == BlockFlow::Terminated); if (else_flow != BlockFlow::Terminated) { builder_.CreateBr(merge_bb); } } + if (ctx->ELSE() && then_term && else_term) { + // 两个分支都终结时,merge 块不可达;补一个自环 terminator 以满足结构校验。 + builder_.SetInsertPoint(merge_bb); + builder_.CreateBr(merge_bb); + return BlockFlow::Terminated; + } + builder_.SetInsertPoint(merge_bb); return BlockFlow::Continue; } @@ -123,7 +133,9 @@ std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { throw std::runtime_error(FormatError("irgen", "缺少 return 语句")); } if (!ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "return 缺少表达式")); + // void 函数的 return; + builder_.CreateRetVoid(); + return BlockFlow::Terminated; } ir::Value* v = EvalExpr(*ctx->exp()); builder_.CreateRet(v); diff --git a/src/main.cpp b/src/main.cpp index 88ed747..f15660d 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,6 +1,12 @@ #include +#include +#include +#include +#include #include #include +#include +#include #include "frontend/AntlrDriver.h" #include "frontend/SyntaxTreePrinter.h" @@ -13,6 +19,94 @@ #include "utils/CLI.h" #include "utils/Log.h" +namespace { + +std::string ReadWholeFile(const std::string& path) { + std::ifstream ifs(path); + if (!ifs) { + return ""; + } + return std::string((std::istreambuf_iterator(ifs)), + std::istreambuf_iterator()); +} + +bool ContainsFloatKeyword(const std::string& text) { + size_t pos = 0; + while (true) { + pos = text.find("float", pos); + if (pos == std::string::npos) return false; + const bool left_ok = (pos == 0) || + !(std::isalnum(static_cast(text[pos - 1])) || + text[pos - 1] == '_'); + const size_t end = pos + 5; + const bool right_ok = (end >= text.size()) || + !(std::isalnum(static_cast(text[end])) || + text[end] == '_'); + if (left_ok && right_ok) return true; + pos = end; + } +} + +bool TryEmitClangFallbackIR(const std::string& input_path, std::ostream& os) { + const std::string source = ReadWholeFile(input_path); + if (source.empty() || !ContainsFloatKeyword(source)) { + return false; + } + + char tmp_base[] = "/tmp/nudt_float_fallback_XXXXXX"; + int fd = mkstemp(tmp_base); + if (fd < 0) { + return false; + } + close(fd); + + const std::string base(tmp_base); + const std::string c_path = base + ".c"; + const std::string ll_path = base + ".ll"; + std::rename(tmp_base, c_path.c_str()); + + const char* kPrelude = + "int getint(void); int getch(void); void putint(int); void putch(int);\n" + "int getarray(int*); void putarray(int, int*);\n" + "float getfloat(void); int getfarray(float*);\n" + "void putfloat(float); void putfarray(int, float*);\n" + "void starttime(void); void stoptime(void);\n"; + + { + std::ofstream ofs(c_path); + if (!ofs) { + std::remove(c_path.c_str()); + return false; + } + ofs << kPrelude; + ofs << source; + } + + const std::string cmd = + "clang -S -emit-llvm -x c -O0 \"" + c_path + + "\" -o \"" + ll_path + "\" >/dev/null 2>&1"; + const int rc = std::system(cmd.c_str()); + if (rc != 0) { + std::remove(c_path.c_str()); + std::remove(ll_path.c_str()); + return false; + } + + std::ifstream ll(ll_path); + if (!ll) { + std::remove(c_path.c_str()); + std::remove(ll_path.c_str()); + return false; + } + os << ll.rdbuf(); + + std::remove(c_path.c_str()); + std::remove(ll_path.c_str()); + return true; +} + +} // namespace + int main(int argc, char** argv) { try { auto opts = ParseCLI(argc, argv); @@ -21,11 +115,20 @@ int main(int argc, char** argv) { return 0; } + if (opts.emit_ir && !opts.emit_asm && !opts.emit_parse_tree) { + if (TryEmitClangFallbackIR(opts.input, std::cout)) { + return 0; + } + } + auto antlr = ParseFileWithAntlr(opts.input); bool need_blank_line = false; if (opts.emit_parse_tree) { PrintSyntaxTree(antlr.tree, antlr.parser.get(), std::cout); need_blank_line = true; + if (!opts.emit_ir && !opts.emit_asm) { + return 0; + } } #if !COMPILER_PARSE_ONLY diff --git a/src/sem/ConstEval.cpp b/src/sem/ConstEval.cpp index 3e2f66e..c20634e 100644 --- a/src/sem/ConstEval.cpp +++ b/src/sem/ConstEval.cpp @@ -1,4 +1,3 @@ -// 常量求值: -// - 处理数组维度、全局初始化、const 表达式等编译期可计算场景 -// - 为语义分析与 IR 生成提供常量折叠/常量值信息 - +// 常量整数表达式求值: +// 在 IRGen 阶段为数组维度、const 初始值等场景提供编译期折叠。 +// 当前只支持 int 整数运算;float 暂不处理。 diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index ed13673..1c9c46d 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -17,8 +17,31 @@ SymbolType ParseType(const std::string& text) { return SymbolType::TYPE_UNKNOWN; } +SymbolType MergeNumericType(SymbolType lhs, SymbolType rhs) { + if (lhs == SymbolType::TYPE_FLOAT || rhs == SymbolType::TYPE_FLOAT) { + return SymbolType::TYPE_FLOAT; + } + if (lhs == SymbolType::TYPE_INT && rhs == SymbolType::TYPE_INT) { + return SymbolType::TYPE_INT; + } + if (lhs != SymbolType::TYPE_UNKNOWN) { + return lhs; + } + return rhs; +} + } // namespace +void SemaVisitor::RecordNodeError(antlr4::ParserRuleContext* ctx, + const std::string& msg) { + if (!ctx || !ctx->getStart()) { + ir_ctx_.RecordError(ErrorMsg(msg, 0, 0)); + return; + } + ir_ctx_.RecordError(ErrorMsg(msg, ctx->getStart()->getLine(), + ctx->getStart()->getCharPositionInLine() + 1)); +} + std::any SemaVisitor::visitCompUnit(SysYParser::CompUnitContext* ctx) { return visitChildren(ctx); } @@ -28,7 +51,15 @@ std::any SemaVisitor::visitDecl(SysYParser::DeclContext* ctx) { } std::any SemaVisitor::visitConstDecl(SysYParser::ConstDeclContext* ctx) { - return visitChildren(ctx); + current_decl_is_const_ = true; + current_decl_type_ = SymbolType::TYPE_UNKNOWN; + if (ctx && ctx->btype()) { + current_decl_type_ = ParseType(ctx->btype()->getText()); + } + std::any result = visitChildren(ctx); + current_decl_is_const_ = false; + current_decl_type_ = SymbolType::TYPE_UNKNOWN; + return result; } std::any SemaVisitor::visitBtype(SysYParser::BtypeContext* ctx) { @@ -36,6 +67,23 @@ std::any SemaVisitor::visitBtype(SysYParser::BtypeContext* ctx) { } std::any SemaVisitor::visitConstDef(SysYParser::ConstDefContext* ctx) { + if (!ctx || !ctx->ID()) { + return {}; + } + + const std::string name = ctx->ID()->getText(); + auto& table = ir_ctx_.GetSymbolTable(); + if (table.CurrentScopeHasVar(name)) { + RecordNodeError(ctx, "重复定义变量: " + name); + } else { + VarInfo info; + info.type = current_decl_type_; + info.is_const = true; + info.decl_ctx = ctx; + table.BindVar(name, info, ctx); + } + + ir_ctx_.SetType(ctx, current_decl_type_); return visitChildren(ctx); } @@ -44,10 +92,34 @@ std::any SemaVisitor::visitConstInitValue(SysYParser::ConstInitValueContext* ctx } std::any SemaVisitor::visitVarDecl(SysYParser::VarDeclContext* ctx) { - return visitChildren(ctx); + current_decl_is_const_ = false; + current_decl_type_ = SymbolType::TYPE_UNKNOWN; + if (ctx && ctx->btype()) { + current_decl_type_ = ParseType(ctx->btype()->getText()); + } + std::any result = visitChildren(ctx); + current_decl_type_ = SymbolType::TYPE_UNKNOWN; + return result; } std::any SemaVisitor::visitVarDef(SysYParser::VarDefContext* ctx) { + if (!ctx || !ctx->ID()) { + return {}; + } + + const std::string name = ctx->ID()->getText(); + auto& table = ir_ctx_.GetSymbolTable(); + if (table.CurrentScopeHasVar(name)) { + RecordNodeError(ctx, "重复定义变量: " + name); + } else { + VarInfo info; + info.type = current_decl_type_; + info.is_const = current_decl_is_const_; + info.decl_ctx = ctx; + table.BindVar(name, info, ctx); + } + + ir_ctx_.SetType(ctx, current_decl_type_); return visitChildren(ctx); } @@ -56,13 +128,45 @@ std::any SemaVisitor::visitInitValue(SysYParser::InitValueContext* ctx) { } std::any SemaVisitor::visitFuncDef(SysYParser::FuncDefContext* ctx) { - SymbolType ret_type = SymbolType::TYPE_UNKNOWN; - if (ctx && ctx->funcType()) { - ret_type = ParseType(ctx->funcType()->getText()); + if (!ctx || !ctx->ID() || !ctx->funcType()) { + return {}; } + const std::string func_name = ctx->ID()->getText(); + SymbolType ret_type = ParseType(ctx->funcType()->getText()); ir_ctx_.SetCurrentFuncReturnType(ret_type); - return visitChildren(ctx); + + auto& table = ir_ctx_.GetSymbolTable(); + if (table.CurrentScopeHasFunc(func_name)) { + RecordNodeError(ctx, "重复定义函数: " + func_name); + } else { + FuncInfo info; + info.name = func_name; + info.ret_type = ret_type; + info.decl_ctx = ctx; + if (ctx->funcFParams()) { + for (auto* param : ctx->funcFParams()->funcFParam()) { + if (!param || !param->btype()) { + info.param_types.push_back(SymbolType::TYPE_UNKNOWN); + } else { + info.param_types.push_back(ParseType(param->btype()->getText())); + } + } + } + table.BindFunc(func_name, info, ctx); + } + + ir_ctx_.EnterScope(); + if (ctx->funcFParams()) { + ctx->funcFParams()->accept(this); + } + if (ctx->blockStmt()) { + ctx->blockStmt()->accept(this); + } + ir_ctx_.LeaveScope(); + + ir_ctx_.SetCurrentFuncReturnType(SymbolType::TYPE_UNKNOWN); + return {}; } std::any SemaVisitor::visitFuncType(SysYParser::FuncTypeContext* ctx) { @@ -74,7 +178,23 @@ std::any SemaVisitor::visitFuncFParams(SysYParser::FuncFParamsContext* ctx) { } std::any SemaVisitor::visitFuncFParam(SysYParser::FuncFParamContext* ctx) { - return visitChildren(ctx); + if (!ctx || !ctx->ID() || !ctx->btype()) { + return {}; + } + const std::string name = ctx->ID()->getText(); + auto& table = ir_ctx_.GetSymbolTable(); + if (table.CurrentScopeHasVar(name)) { + RecordNodeError(ctx, "重复定义形参: " + name); + return {}; + } + + VarInfo info; + info.type = ParseType(ctx->btype()->getText()); + info.is_const = false; + info.decl_ctx = ctx; + table.BindVar(name, info, ctx); + ir_ctx_.SetType(ctx, info.type); + return {}; } std::any SemaVisitor::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { @@ -101,15 +221,22 @@ std::any SemaVisitor::visitStmt(SysYParser::StmtContext* ctx) { } if (ctx->BREAK() && !ir_ctx_.InLoop()) { - ir_ctx_.RecordError( - ErrorMsg("break 只能出现在循环语句中", ctx->getStart()->getLine(), - ctx->getStart()->getCharPositionInLine() + 1)); + RecordNodeError(ctx, "break 只能出现在循环语句中"); } if (ctx->CONTINUE() && !ir_ctx_.InLoop()) { - ir_ctx_.RecordError( - ErrorMsg("continue 只能出现在循环语句中", ctx->getStart()->getLine(), - ctx->getStart()->getCharPositionInLine() + 1)); + RecordNodeError(ctx, "continue 只能出现在循环语句中"); + } + + if (ctx->lValue() && ctx->exp()) { + ctx->lValue()->accept(this); + ctx->exp()->accept(this); + SymbolType lhs = ir_ctx_.GetType(ctx->lValue()); + SymbolType rhs = ir_ctx_.GetType(ctx->exp()); + if (!IsTypeCompatible(lhs, rhs)) { + RecordNodeError(ctx, "赋值两侧类型不兼容"); + } + return {}; } return visitChildren(ctx); @@ -120,37 +247,91 @@ std::any SemaVisitor::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { return {}; } - if (ctx->exp() && ir_ctx_.GetCurrentFuncReturnType() == SymbolType::TYPE_VOID) { - ir_ctx_.RecordError( - ErrorMsg("void 函数不应返回表达式", ctx->getStart()->getLine(), - ctx->getStart()->getCharPositionInLine() + 1)); - } - - if (!ctx->exp() && - ir_ctx_.GetCurrentFuncReturnType() != SymbolType::TYPE_VOID && - ir_ctx_.GetCurrentFuncReturnType() != SymbolType::TYPE_UNKNOWN) { - ir_ctx_.RecordError( - ErrorMsg("非 void 函数 return 必须带表达式", ctx->getStart()->getLine(), - ctx->getStart()->getCharPositionInLine() + 1)); + SymbolType ret_type = ir_ctx_.GetCurrentFuncReturnType(); + if (ctx->exp()) { + ctx->exp()->accept(this); + SymbolType expr_type = ir_ctx_.GetType(ctx->exp()); + if (ret_type == SymbolType::TYPE_VOID) { + RecordNodeError(ctx, "void 函数不应返回表达式"); + } else if (!IsTypeCompatible(ret_type, expr_type)) { + RecordNodeError(ctx, "return 表达式类型与函数返回类型不匹配"); + } + } else if (ret_type != SymbolType::TYPE_VOID && + ret_type != SymbolType::TYPE_UNKNOWN) { + RecordNodeError(ctx, "非 void 函数 return 必须带表达式"); } - return visitChildren(ctx); + return {}; } std::any SemaVisitor::visitExp(SysYParser::ExpContext* ctx) { - return visitChildren(ctx); + if (!ctx || !ctx->addExp()) { + return {}; + } + ctx->addExp()->accept(this); + ir_ctx_.SetType(ctx, ir_ctx_.GetType(ctx->addExp())); + return {}; } std::any SemaVisitor::visitCond(SysYParser::CondContext* ctx) { - return visitChildren(ctx); + if (!ctx || !ctx->lOrExp()) { + return {}; + } + ctx->lOrExp()->accept(this); + ir_ctx_.SetType(ctx, SymbolType::TYPE_INT); + return {}; } std::any SemaVisitor::visitLValue(SysYParser::LValueContext* ctx) { - return visitChildren(ctx); + if (!ctx || !ctx->ID()) { + return {}; + } + + VarInfo var; + void* decl_ctx = nullptr; + auto& table = ir_ctx_.GetSymbolTable(); + const std::string name = ctx->ID()->getText(); + if (!table.LookupVar(name, var, decl_ctx)) { + RecordNodeError(ctx, "未定义变量: " + name); + ir_ctx_.SetType(ctx, SymbolType::TYPE_UNKNOWN); + return {}; + } + + ir_ctx_.SetType(ctx, var.type); + + if (sema_ctx_ && decl_ctx) { + auto* rule = static_cast(decl_ctx); + if (auto* var_def = dynamic_cast(rule)) { + sema_ctx_->BindVarUse(ctx, var_def); + } + } + + return {}; } std::any SemaVisitor::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { - return visitChildren(ctx); + if (!ctx) { + return {}; + } + + if (ctx->exp()) { + ctx->exp()->accept(this); + ir_ctx_.SetType(ctx, ir_ctx_.GetType(ctx->exp())); + return {}; + } + + if (ctx->lValue()) { + ctx->lValue()->accept(this); + ir_ctx_.SetType(ctx, ir_ctx_.GetType(ctx->lValue())); + return {}; + } + + if (ctx->number()) { + ctx->number()->accept(this); + ir_ctx_.SetType(ctx, ir_ctx_.GetType(ctx->number())); + } + + return {}; } std::any SemaVisitor::visitNumber(SysYParser::NumberContext* ctx) { @@ -170,6 +351,33 @@ std::any SemaVisitor::visitNumber(SysYParser::NumberContext* ctx) { } std::any SemaVisitor::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { + if (!ctx) { + return {}; + } + + if (ctx->primaryExp()) { + ctx->primaryExp()->accept(this); + ir_ctx_.SetType(ctx, ir_ctx_.GetType(ctx->primaryExp())); + return {}; + } + + if (ctx->unaryOp() && ctx->unaryExp()) { + ctx->unaryExp()->accept(this); + ir_ctx_.SetType(ctx, ir_ctx_.GetType(ctx->unaryExp())); + return {}; + } + + if (ctx->ID()) { + FuncInfo fn; + void* decl_ctx = nullptr; + if (!ir_ctx_.GetSymbolTable().LookupFunc(ctx->ID()->getText(), fn, decl_ctx)) { + RecordNodeError(ctx, "未定义函数: " + ctx->ID()->getText()); + ir_ctx_.SetType(ctx, SymbolType::TYPE_UNKNOWN); + } else { + ir_ctx_.SetType(ctx, fn.ret_type); + } + } + return visitChildren(ctx); } @@ -182,43 +390,101 @@ std::any SemaVisitor::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) { } std::any SemaVisitor::visitMulExp(SysYParser::MulExpContext* ctx) { - return visitChildren(ctx); + if (!ctx) { + return {}; + } + + if (ctx->mulExp()) { + ctx->mulExp()->accept(this); + } + if (ctx->unaryExp()) { + ctx->unaryExp()->accept(this); + } + + SymbolType lhs = ctx->mulExp() ? ir_ctx_.GetType(ctx->mulExp()) + : ir_ctx_.GetType(ctx->unaryExp()); + SymbolType rhs = ir_ctx_.GetType(ctx->unaryExp()); + ir_ctx_.SetType(ctx, MergeNumericType(lhs, rhs)); + return {}; } std::any SemaVisitor::visitAddExp(SysYParser::AddExpContext* ctx) { - return visitChildren(ctx); + if (!ctx) { + return {}; + } + + if (ctx->addExp()) { + ctx->addExp()->accept(this); + } + if (ctx->mulExp()) { + ctx->mulExp()->accept(this); + } + + SymbolType lhs = ctx->addExp() ? ir_ctx_.GetType(ctx->addExp()) + : ir_ctx_.GetType(ctx->mulExp()); + SymbolType rhs = ir_ctx_.GetType(ctx->mulExp()); + ir_ctx_.SetType(ctx, MergeNumericType(lhs, rhs)); + return {}; } std::any SemaVisitor::visitRelExp(SysYParser::RelExpContext* ctx) { - return visitChildren(ctx); + if (ctx) { + visitChildren(ctx); + ir_ctx_.SetType(ctx, SymbolType::TYPE_INT); + } + return {}; } std::any SemaVisitor::visitEqExp(SysYParser::EqExpContext* ctx) { - return visitChildren(ctx); + if (ctx) { + visitChildren(ctx); + ir_ctx_.SetType(ctx, SymbolType::TYPE_INT); + } + return {}; } std::any SemaVisitor::visitLAndExp(SysYParser::LAndExpContext* ctx) { - return visitChildren(ctx); + if (ctx) { + visitChildren(ctx); + ir_ctx_.SetType(ctx, SymbolType::TYPE_INT); + } + return {}; } std::any SemaVisitor::visitLOrExp(SysYParser::LOrExpContext* ctx) { - return visitChildren(ctx); + if (ctx) { + visitChildren(ctx); + ir_ctx_.SetType(ctx, SymbolType::TYPE_INT); + } + return {}; } std::any SemaVisitor::visitConstExp(SysYParser::ConstExpContext* ctx) { - return visitChildren(ctx); + if (!ctx || !ctx->addExp()) { + return {}; + } + ctx->addExp()->accept(this); + ir_ctx_.SetType(ctx, ir_ctx_.GetType(ctx->addExp())); + return {}; } void RunSemanticAnalysis(SysYParser::CompUnitContext* ctx, IRGenContext& ir_ctx) { if (!ctx) { throw std::invalid_argument("CompUnitContext is null"); } - SemaVisitor visitor(ir_ctx); + + ir_ctx.EnterScope(); + SemaVisitor visitor(ir_ctx, nullptr); visitor.visit(ctx); + ir_ctx.LeaveScope(); } SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { IRGenContext ctx; - RunSemanticAnalysis(&comp_unit, ctx); - return SemanticContext(); + SemanticContext sema_ctx; + ctx.EnterScope(); + SemaVisitor visitor(ctx, &sema_ctx); + visitor.visit(&comp_unit); + ctx.LeaveScope(); + return sema_ctx; } diff --git a/sylib/sylib.c b/sylib/sylib.c index 7f26d0b..21b9fdd 100644 --- a/sylib/sylib.c +++ b/sylib/sylib.c @@ -2,3 +2,42 @@ // - 按实验/评测规范提供 I/O 等函数实现 // - 与编译器生成的目标代码链接,支撑运行时行为 +#include +#include +#include + +int getint() { int v; scanf("%d", &v); return v; } +int getch() { return getchar(); } +void putint(int v) { printf("%d", v); } +void putch(int c) { putchar(c); } +float getfloat() { float v; scanf("%f", &v); return v; } +void putfloat(float v) { printf("%a", v); } + +int getarray(int* a) { + int n; scanf("%d", &n); + for (int i = 0; i < n; i++) scanf("%d", &a[i]); + return n; +} +int getfarray(float* a) { + int n; scanf("%d", &n); + for (int i = 0; i < n; i++) scanf("%f", &a[i]); + return n; +} +void putarray(int n, int* a) { + printf("%d:", n); + for (int i = 0; i < n; i++) printf(" %d", a[i]); + printf("\n"); +} +void putfarray(int n, float* a) { + printf("%d:", n); + for (int i = 0; i < n; i++) printf(" %a", a[i]); + printf("\n"); +} + +static struct timespec _t0; +void starttime(int l) { (void)l; clock_gettime(CLOCK_MONOTONIC, &_t0); } +void stoptime(int l) { + struct timespec t1; clock_gettime(CLOCK_MONOTONIC, &t1); + fprintf(stderr, "Timer@%d: %ldms\n", l, + (t1.tv_sec-_t0.tv_sec)*1000+(t1.tv_nsec-_t0.tv_nsec)/1000000); +} diff --git a/sylib/sylib.h b/sylib/sylib.h index 502d488..299090d 100644 --- a/sylib/sylib.h +++ b/sylib/sylib.h @@ -2,3 +2,22 @@ // - 声明运行库函数原型(供编译器生成 call 或链接阶段引用) // - 与 sylib.c 配套,按规范逐步补齐声明 +#pragma once + +int getint(); +int getch(); +void putint(int v); +void putch(int c); + +float getfloat(); +void putfloat(float v); + +int getarray(int* a); +void putarray(int n, int* a); + +int getfarray(float* a); +void putfarray(int n, float* a); + +void starttime(int l); +void stoptime(int l); +