From c276d5649f869fa795f54516851bbd0a92b83980 Mon Sep 17 00:00:00 2001 From: LuoHello <2901023943@qq.com> Date: Mon, 18 May 2026 22:22:48 +0800 Subject: [PATCH] =?UTF-8?q?mem2reg=20and=20constprob/fold,=E6=B6=89?= =?UTF-8?q?=E5=8F=8A=E9=83=A8=E5=88=86=E6=94=AF=E9=85=8D=E6=A0=91=EF=BC=8C?= =?UTF-8?q?CFG=EF=BC=8C=E6=96=B0=E5=BB=BA=E6=B5=8B=E8=AF=95=E7=A8=8B?= =?UTF-8?q?=E5=BA=8F=EF=BC=8C=E6=80=BB=E5=B7=A5=E4=BD=9C=E8=A7=81/doc/lab4?= =?UTF-8?q?=5Fwork?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- doc/lab4_work.md | 336 +++++++++++++++++++++++++ include/ir/IR.h | 63 +++++ include/ir/analysis/DominatorTree.h | 40 +++ include/ir/passes/ConstFold.h | 15 ++ include/ir/passes/ConstProp.h | 15 ++ include/ir/passes/DCE.h | 15 ++ include/ir/passes/Mem2Reg.h | 27 ++ include/ir/passes/PassManager.h | 29 +++ include/utils/CLI.h | 1 + include/utils/Log.h | 5 + out.ll | 40 +++ out.s | 0 scripts/verify_mem2reg.sh | 377 ++++++++++++++++++++++++++++ src/ir/BasicBlock.cpp | 26 ++ src/ir/IRPrinter.cpp | 14 ++ src/ir/Instruction.cpp | 9 +- src/ir/Value.cpp | 14 +- src/ir/analysis/DominatorTree.cpp | 256 ++++++++++++++++++- src/ir/passes/ConstFold.cpp | 301 +++++++++++++++++++++- src/ir/passes/ConstProp.cpp | 215 +++++++++++++++- src/ir/passes/DCE.cpp | 56 ++++- src/ir/passes/Mem2Reg.cpp | 278 +++++++++++++++++++- src/ir/passes/PassManager.cpp | 39 ++- src/irgen/IRGenDecl.cpp | 62 ++--- src/irgen/IRGenDriver.cpp | 16 ++ src/irgen/IRGenExp.cpp | 84 +++---- src/irgen/IRGenFunc.cpp | 36 +-- src/irgen/IRGenStmt.cpp | 112 ++++----- src/main.cpp | 2 + src/sem/Sema.cpp | 154 ++++++------ src/sem/SymbolTable.cpp | 7 +- src/utils/CLI.cpp | 5 + src/utils/Log.cpp | 28 +++ test/test_case/mem2reg/01_phi.out | 1 + test/test_case/mem2reg/01_phi.sy | 9 + 35 files changed, 2441 insertions(+), 246 deletions(-) create mode 100644 doc/lab4_work.md create mode 100644 include/ir/analysis/DominatorTree.h create mode 100644 include/ir/passes/ConstFold.h create mode 100644 include/ir/passes/ConstProp.h create mode 100644 include/ir/passes/DCE.h create mode 100644 include/ir/passes/Mem2Reg.h create mode 100644 include/ir/passes/PassManager.h create mode 100644 out.ll create mode 100644 out.s create mode 100755 scripts/verify_mem2reg.sh create mode 100644 test/test_case/mem2reg/01_phi.out create mode 100644 test/test_case/mem2reg/01_phi.sy diff --git a/doc/lab4_work.md b/doc/lab4_work.md new file mode 100644 index 0000000..5eb17be --- /dev/null +++ b/doc/lab4_work.md @@ -0,0 +1,336 @@ +# Lab4 工作记录:基本标量优化 + +本文记录本次 Lab4 中完成的优化 pass、关键实现思路、遇到的问题和测试脚本的使用方式。 + +## 1. 完成内容概览 + +本次主要完成并接入了如下内容: + +- `Mem2Reg`:将可提升的局部标量变量从 `alloca/load/store` 形式提升到 SSA 形式。 +- `ConstFold`:对常量表达式进行编译期求值。 +- `ConstProp`:做简单常量传播和代数化简。 +- `DCE`:删除没有 use 且没有副作用的死指令。 +- `verify_mem2reg.sh`:批量验证 `test/` 下所有用例是否能完成现有 pass 优化,并可选运行语义回归。 + +当前优化流水线接在 `src/irgen/IRGenDriver.cpp` 中,顺序为: + +```text +Mem2Reg -> ConstFold -> ConstProp -> ConstFold -> DCE +``` + +其中第二次 `ConstFold` 用于吃掉 `ConstProp` 暴露出来的新常量表达式,最后的 `DCE` 清理被替换后不再使用的指令。 + +## 2. Mem2Reg 做了什么 + +前端生成 IR 时,局部变量通常先表示成内存形式: + +```llvm +%x = alloca i32 +store i32 1, i32* %x +%v = load i32, i32* %x +``` + +这种形式语义直接,但会让后续优化很难判断 `%v` 到底是什么值。`Mem2Reg` 的作用是把这类可提升变量改写成 SSA value。 + +例如: + +```c +int x; +if (cond) { + x = 1; +} else { + x = 2; +} +return x; +``` + +提升后核心 IR 会变成: + +```llvm +merge: + %x.merge.phi0 = phi i32 [1, %then], [2, %else] + ret i32 %x.merge.phi0 +``` + +`phi` 表示“根据控制流从哪个前驱块来,选择对应的值”。 + +### 2.1 可提升对象 + +本实验里的 `Mem2Reg` 只提升局部标量 alloca: + +- `i32*` +- `float*` +- `i1*` + +并且要求它们只被直接 `load/store` 使用。 + +以下情况不会提升: + +- 数组 alloca,例如 `[100 x i32]` +- 通过 `getelementptr` 复杂访问的内存 +- 地址传给函数的变量 +- 全局变量 +- 其他地址逃逸的变量 + +所以测试中仍看到数组相关 `alloca` 是正常的。`mem2reg` 不是把所有内存都消掉,而是把可以安全转成 SSA 的局部标量消掉。 + +### 2.2 核心算法 + +实现流程如下: + +1. 扫描入口块,找到可提升的 `alloca`。 +2. 收集该变量所有 `store` 所在的定义块。 +3. 构建 CFG,并计算支配树和支配边界。 +4. 在支配边界处插入 `phi`。 +5. 沿支配树递归重命名: + - 遇到 `store`,更新当前变量值。 + - 遇到 `load`,用当前变量值替换该 `load`。 + - 遍历后继块时,给后继块中的 `phi` 填 incoming value。 +6. 删除被提升掉的 `alloca/load/store`。 + +### 2.3 修过的问题 + +实现过程中修了一个关键问题:多个 phi 结果重名。 + +原来 phi 名字类似: + +```text +变量名.phi +``` + +复杂循环里同一个变量可能需要多个 phi,导致 LLVM 报: + +```text +multiple definition of local value named '...phi' +``` + +现在 phi 名字包含变量名、基本块名和递增编号,例如: + +```text +%t45_i.while.cond.t72.phi3 +``` + +这样可以保证同一个函数内 SSA 名字唯一。 + +## 3. 常量折叠与常量传播 + +### 3.1 ConstFold + +`ConstFold` 会把操作数都是常量的指令直接计算出来,并用常量替换原指令。 + +目前支持: + +- 整数运算:`add/sub/mul/div/mod/and/or` +- 浮点运算:`fadd/fsub/fmul/fdiv` +- 整数比较:`icmp` +- 浮点比较:`fcmp` +- 类型转换:`zext/trunc/sitofp/fptosi` +- 所有 incoming 都是同一常量的简单 `phi` + +例如: + +```llvm +%t = add i32 20, 4 +ret i32 %t +``` + +会变成: + +```llvm +ret i32 24 +``` + +### 3.2 ConstProp + +`ConstProp` 主要做简单代数化简和传播: + +- `x + 0 -> x` +- `x - 0 -> x` +- `x * 1 -> x` +- `x * 0 -> 0` +- `x / 1 -> x` +- `0 / x -> 0` +- `phi` 所有有效 incoming 相同,则替换为该值 + +它不做复杂全局数据流分析,目标是配合 `Mem2Reg` 暴露出来的 SSA 值,吃掉一些明显冗余表达式。 + +## 4. DCE + +`DCE` 删除无副作用且没有 use 的指令。 + +保留的有副作用或控制流指令包括: + +- `store` +- `ret` +- `call` +- `br` +- `condbr` + +优化后被常量替换掉的二元运算、比较、转换指令,如果不再被使用,会被 DCE 清掉。 + +## 5. 测试脚本设计 + +新增或重写的脚本: + +```text +scripts/verify_mem2reg.sh +``` + +这个脚本不再只测试 `test/test_case/mem2reg`,而是默认扫描整个 `test/` 目录下所有 `.sy` 文件。 + +脚本分三层验证。 + +### 5.1 IR 生成检查 + +第一层检查每个 `.sy` 是否能完成: + +```bash +./build/bin/compiler --emit-ir xxx.sy +``` + +如果能生成包含 `define` 的 IR,说明前端和当前 pass 流水线都跑完了。 + +### 5.2 优化结果检查 + +第二层检查优化后 IR 中是否还有标量 alloca: + +```llvm +%x = alloca i32 +%y = alloca float +%b = alloca i1 +``` + +默认情况下,残留标量 alloca 只作为 warning,不直接判失败。原因是:不是所有 alloca 都一定能安全提升,尤其在复杂数组、地址使用、函数调用附近,保守处理是合理的。 + +如果希望更严格,可以使用: + +```bash +./scripts/verify_mem2reg.sh --strict-mem2reg +``` + +这会把残留标量 alloca 当成失败。 + +### 5.3 运行语义回归 + +第三层需要手动打开: + +```bash +./scripts/verify_mem2reg.sh --run +``` + +它会执行: + +1. 生成优化后 IR。 +2. 用 `llc` 把 `.ll` 转成目标文件。 +3. 用 `clang` 链接目标文件。 +4. 如果存在 `sylib/sylib.c`,会先编译并链接运行库。 +5. 自动读取同名 `.in` 作为输入。 +6. 将程序 stdout 和退出码拼成 actual 结果。 +7. 与同名 `.out` 对比。 + +脚本比较时会统一处理: + +- Windows 风格换行 `\r\n` +- 文件末尾是否多一个换行 + +这样可以避免因为文本格式差异造成误报。 + +## 6. 测试脚本用法 + +### 6.1 构建项目 + +```bash +cmake -S . -B build -DCMAKE_BUILD_TYPE=Release +cmake --build build -j "$(nproc)" +``` + +### 6.2 只检查 pass 能否跑完 + +```bash +./scripts/verify_mem2reg.sh +``` + +默认会扫描: + +```text +test/ +``` + +输出示例: + +```text +IR 生成: 22 / 22 +Pass 优化检查: 22 / 22 +全部检查通过。 +``` + +### 6.3 同时运行语义回归 + +```bash +./scripts/verify_mem2reg.sh --run +``` + +输出示例: + +```text +IR 生成: 22 / 22 +Pass 优化检查: 22 / 22 +运行结果: 22 / 22 +全部检查通过。 +``` + +### 6.4 只测试某个目录 + +```bash +./scripts/verify_mem2reg.sh --test-root test/test_case/functional --run +``` + +### 6.5 打印详细信息 + +```bash +./scripts/verify_mem2reg.sh --debug --run +``` + +### 6.6 严格检查 mem2reg + +```bash +./scripts/verify_mem2reg.sh --strict-mem2reg +``` + +这个模式适合专门检查“还有哪些标量 alloca 没被提升”。当前某些复杂性能样例会有 warning,是否要继续优化要结合 IR 使用情况判断。 + +## 7. 当前测试结论 + +当前执行: + +```bash +./scripts/verify_mem2reg.sh --run +``` + +结果为: + +```text +IR 生成: 22 / 22 +Pass 优化检查: 22 / 22 +运行结果: 22 / 22 +全部检查通过。 +``` + +这说明: + +- 所有测试都能完成当前 pass 流水线。 +- 生成的 IR 能被 LLVM 工具链接受。 +- 链接运行库后,程序输出和退出码均与 `.out` 匹配。 + +脚本中出现的标量 alloca warning 不影响当前语义正确性,它们只是提示后续还有进一步提升或更精细逃逸分析的空间。 + +## 8. 后续可改进方向 + +后续如果继续扩展 Lab4,可以考虑: + +- 为 `Mem2Reg` 增加更完整的可提升性分析。 +- 对未初始化 load 做更稳健的默认值处理。 +- 增加 CFG Simplify,删除常量条件分支和不可达块。 +- 增加 CSE,消除重复表达式。 +- 将常量折叠和传播做成迭代到不动点。 +- 增加 IR verifier,提前检查 phi incoming 数量、SSA 名字唯一性、基本块前驱匹配等问题。 diff --git a/include/ir/IR.h b/include/ir/IR.h index ea62b05..d654c58 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -333,6 +333,7 @@ enum class Opcode { FPToSI, // 浮点转整数 FPExt, // 浮点扩展 FPTrunc, // 浮点截断 + Phi, }; // ZExt 和 Trunc 是零扩展和截断指令,SysY 的 int (i32) vs LLVM IR 的比较结果 (i1)。 @@ -567,6 +568,56 @@ class BranchInst : public Instruction { BasicBlock* false_target_; // 假分支目标(条件跳转使用) }; +class PhiInst : public Instruction { + public: + PhiInst(std::shared_ptr ty, std::string name = "") + : Instruction(Opcode::Phi, std::move(ty), std::move(name)) {} + + void AddIncoming(Value* value, BasicBlock* block) { + if (!block) { + throw std::runtime_error("PhiInst incoming block cannot be null"); + } + incoming_.push_back({value, block}); + if (value) { + AddOperand(value); + } + } + + Value* GetIncomingValue(size_t index) const { + if (index >= incoming_.size()) { + throw std::out_of_range("PhiInst incoming value index out of range"); + } + return incoming_[index].first; + } + + BasicBlock* GetIncomingBlock(size_t index) const { + if (index >= incoming_.size()) { + throw std::out_of_range("PhiInst incoming block index out of range"); + } + return incoming_[index].second; + } + + size_t GetNumIncoming() const { return incoming_.size(); } + + void SetIncomingValue(size_t index, Value* value) { + if (index >= incoming_.size()) { + throw std::out_of_range("PhiInst incoming value index out of range"); + } + incoming_[index].first = value; + SetOperand(index, value); + } + + void SetIncomingBlock(size_t index, BasicBlock* block) { + if (index >= incoming_.size()) { + throw std::out_of_range("PhiInst incoming block index out of range"); + } + incoming_[index].second = block; + } + + private: + std::vector> incoming_; +}; + // 创建整数比较指令 class IcmpInst : public Instruction { public: @@ -730,6 +781,7 @@ class CallInst : public Instruction { const std::string& name); Function* GetCallee() const; const std::vector& GetArgs() const; + void SetArg(size_t index, Value* value); private: Function* callee_; @@ -774,6 +826,17 @@ class BasicBlock : public Value { return ptr; } + template + T* InsertAtBeginning(Args&&... args) { + auto inst = std::make_unique(std::forward(args)...); + auto* ptr = inst.get(); + ptr->SetParent(this); + instructions_.insert(instructions_.begin(), std::move(inst)); + return ptr; + } + + void RemoveInstruction(Instruction* inst); + private: Function* parent_ = nullptr; std::vector> instructions_; diff --git a/include/ir/analysis/DominatorTree.h b/include/ir/analysis/DominatorTree.h new file mode 100644 index 0000000..d27e0c0 --- /dev/null +++ b/include/ir/analysis/DominatorTree.h @@ -0,0 +1,40 @@ +#pragma once + +#include "ir/IR.h" + +#include +#include + +namespace ir { + +class DominatorTree { + public: + DominatorTree() = default; + ~DominatorTree() = default; + + void Recalculate(Function& function); + + BasicBlock* GetRoot() const; + BasicBlock* GetIDom(BasicBlock* block) const; + bool Dominates(BasicBlock* a, BasicBlock* b) const; + const std::vector& GetChildren(BasicBlock* block) const; + const std::vector& GetDominanceFrontier(BasicBlock* block) const; + const std::vector& GetPredecessors(BasicBlock* block) const; + const std::vector& GetSuccessors(BasicBlock* block) const; + + private: + void BuildCFG(Function& function); + void ComputeIDoms(); + void ComputeDominanceFrontiers(); + BasicBlock* Intersect(BasicBlock* first, BasicBlock* second) const; + + std::vector blocks_; + std::unordered_map> preds_; + std::unordered_map> succs_; + std::unordered_map idom_; + std::unordered_map> children_; + std::unordered_map> dominance_frontier_; + std::unordered_map dfs_number_; +}; + +} // namespace ir diff --git a/include/ir/passes/ConstFold.h b/include/ir/passes/ConstFold.h new file mode 100644 index 0000000..051cd1b --- /dev/null +++ b/include/ir/passes/ConstFold.h @@ -0,0 +1,15 @@ +#pragma once + +#include "ir/passes/PassManager.h" + +namespace ir { + +class ConstFoldPass : public Pass { + public: + ConstFoldPass() = default; + ~ConstFoldPass() override = default; + + bool RunOnFunction(Function& function) override; +}; + +} // namespace ir diff --git a/include/ir/passes/ConstProp.h b/include/ir/passes/ConstProp.h new file mode 100644 index 0000000..dfa3277 --- /dev/null +++ b/include/ir/passes/ConstProp.h @@ -0,0 +1,15 @@ +#pragma once + +#include "ir/passes/PassManager.h" + +namespace ir { + +class ConstPropPass : public Pass { + public: + ConstPropPass() = default; + ~ConstPropPass() override = default; + + bool RunOnFunction(Function& function) override; +}; + +} // namespace ir diff --git a/include/ir/passes/DCE.h b/include/ir/passes/DCE.h new file mode 100644 index 0000000..6c2981f --- /dev/null +++ b/include/ir/passes/DCE.h @@ -0,0 +1,15 @@ +#pragma once + +#include "ir/passes/PassManager.h" + +namespace ir { + +class DCEPass : public Pass { + public: + DCEPass() = default; + ~DCEPass() override = default; + + bool RunOnFunction(Function& function) override; +}; + +} // namespace ir diff --git a/include/ir/passes/Mem2Reg.h b/include/ir/passes/Mem2Reg.h new file mode 100644 index 0000000..e23a1e9 --- /dev/null +++ b/include/ir/passes/Mem2Reg.h @@ -0,0 +1,27 @@ +#pragma once + +#include "ir/IR.h" +#include "ir/analysis/DominatorTree.h" +#include "ir/passes/PassManager.h" + +namespace ir { + +class Mem2RegPass : public Pass { + public: + Mem2RegPass() = default; + ~Mem2RegPass() = default; + + // 将函数内的可提升内存变量提升到 SSA 形式。 + // 返回是否对函数做出了任何修改。 + bool RunOnFunction(Function& function) override; + + // 可选:在模块级别执行 mem2reg。 + bool RunOnModule(Module& module); + + private: + bool PromoteAllocas(Function& function, DominatorTree& domtree); + + bool changed_ = false; +}; + +} // namespace ir diff --git a/include/ir/passes/PassManager.h b/include/ir/passes/PassManager.h new file mode 100644 index 0000000..dbe7c07 --- /dev/null +++ b/include/ir/passes/PassManager.h @@ -0,0 +1,29 @@ +#pragma once + +#include "ir/IR.h" + +#include +#include + +namespace ir { + +class Pass { + public: + virtual ~Pass() = default; + virtual bool RunOnFunction(Function& function) = 0; +}; + +class PassManager { + public: + PassManager() = default; + ~PassManager() = default; + + void AddPass(std::unique_ptr pass); + bool Run(Function& function); + bool Run(Module& module); + + private: + std::vector> passes_; +}; + +} // namespace ir diff --git a/include/utils/CLI.h b/include/utils/CLI.h index 4b3a781..25045b1 100644 --- a/include/utils/CLI.h +++ b/include/utils/CLI.h @@ -8,6 +8,7 @@ struct CLIOptions { bool emit_parse_tree = false; bool emit_ir = true; bool emit_asm = false; + bool debug = false; bool show_help = false; }; diff --git a/include/utils/Log.h b/include/utils/Log.h index 303f1a1..0a85896 100644 --- a/include/utils/Log.h +++ b/include/utils/Log.h @@ -10,6 +10,11 @@ void LogInfo(std::string_view msg, std::ostream& os); void LogError(std::string_view msg, std::ostream& os); +extern bool g_debug_enabled; +void SetDebugEnabled(bool enabled); +bool IsDebugEnabled(); +std::ostream& DebugStream(); + std::string FormatError(std::string_view stage, std::string_view msg); std::string FormatErrorAt(std::string_view stage, std::size_t line, std::size_t column, std::string_view msg); diff --git a/out.ll b/out.ll new file mode 100644 index 0000000..f9ddd4a --- /dev/null +++ b/out.ll @@ -0,0 +1,40 @@ +declare i32 @getint() +declare i32 @getch() +declare i32 @getarray(i32*) +declare void @putint(i32) +declare void @putch(i32) +declare void @putarray(i32, i32*) +declare void @puts(i32*) +declare void @_sysy_starttime(i32) +declare void @_sysy_stoptime(i32) +declare void @starttime() +declare void @stoptime() +declare float @getfloat() +declare void @putfloat(float) +declare i32 @getfarray(float*) +declare void @putfarray(i32, float*) +declare i32* @memset(i32*, i32, i32) +declare i32* @sysy_alloc_i32(i32) +declare float* @sysy_alloc_f32(i32) +declare void @sysy_free_i32(i32*) +declare void @sysy_free_f32(float*) +declare void @sysy_zero_i32(i32*, i32) +declare void @sysy_zero_f32(float*, i32) +define i32 @main() { +entry: + %t0.retval = alloca i32 + %t1_x = alloca i32 + store i32 0, i32* %t1_x + %t5 = load i32, i32* %t1_x + %t6 = icmp eq i32 %t5, 0 + br i1 %t6, label %then.t2, label %else.t3 +then.t2: + store i32 5, i32* %t1_x + br label %merge.t4 +else.t3: + store i32 7, i32* %t1_x + br label %merge.t4 +merge.t4: + %t7 = load i32, i32* %t1_x + ret i32 %t7 +} diff --git a/out.s b/out.s new file mode 100644 index 0000000..e69de29 diff --git a/scripts/verify_mem2reg.sh b/scripts/verify_mem2reg.sh new file mode 100755 index 0000000..59c7b86 --- /dev/null +++ b/scripts/verify_mem2reg.sh @@ -0,0 +1,377 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +COMPILER="$ROOT_DIR/build/bin/compiler" +DEFAULT_TEST_ROOT="$ROOT_DIR/test" +TMP_DIR="$ROOT_DIR/build/test_passes" +CC_BIN="${CC:-cc}" +LLC_BIN="${LLC:-llc}" +CLANG_BIN="${CLANG:-clang}" +RUNTIME_SRC="$ROOT_DIR/sylib/sylib.c" +RUNTIME_OBJ="$TMP_DIR/sylib.o" + +debug=false +run_exec=false +test_root="$DEFAULT_TEST_ROOT" +stop_on_fail=false +strict_mem2reg=false + +usage() { + cat < 指定测试根目录,默认: $DEFAULT_TEST_ROOT + --stop-on-fail 遇到第一个失败立即退出 + --strict-mem2reg 将优化后残留标量 alloca 视为失败;默认只作为警告统计 + -h, --help 显示帮助 + +环境变量: + LLC= 指定 llc,默认 llc + CLANG= 指定 clang,默认 clang + CC= 指定 C 编译器,用于编译 sylib.c,默认 cc +EOF +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --run) + run_exec=true + shift + ;; + --debug) + debug=true + shift + ;; + --test-root) + if [[ $# -lt 2 ]]; then + echo "--test-root 需要目录参数" >&2 + exit 1 + fi + test_root="$2" + shift 2 + ;; + --stop-on-fail) + stop_on_fail=true + shift + ;; + --strict-mem2reg) + strict_mem2reg=true + shift + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "未知参数: $1" >&2 + usage >&2 + exit 1 + ;; + esac +done + +if [[ ! -x "$COMPILER" ]]; then + echo "未找到编译器: $COMPILER" >&2 + echo "请先构建编译器,例如: cmake -S . -B build && cmake --build build -j" >&2 + exit 1 +fi + +if [[ ! -d "$test_root" ]]; then + echo "测试目录不存在: $test_root" >&2 + exit 1 +fi + +mkdir -p "$TMP_DIR" + +runtime_ready=0 +if [[ "$run_exec" == true ]]; then + if ! command -v "$LLC_BIN" >/dev/null 2>&1; then + echo "未找到 llc: $LLC_BIN" >&2 + exit 1 + fi + if ! command -v "$CLANG_BIN" >/dev/null 2>&1; then + echo "未找到 clang: $CLANG_BIN" >&2 + exit 1 + fi + + if [[ -f "$RUNTIME_SRC" ]]; then + if "$CC_BIN" -c "$RUNTIME_SRC" -o "$RUNTIME_OBJ" >/dev/null 2>&1; then + runtime_ready=1 + else + echo "[WARN] 运行库编译失败,将只链接目标文件: $RUNTIME_SRC" >&2 + fi + else + echo "[WARN] 未找到运行库源码,将只链接目标文件: $RUNTIME_SRC" >&2 + fi +fi + +normalize_file() { + sed 's/\r$//' "$1" +} + +make_case_out_dir() { + local input=$1 + local rel + rel=$(realpath --relative-to="$test_root" "$(dirname "$input")") + echo "$TMP_DIR/$rel" +} + +extract_ir() { + local raw_file=$1 + local ll_file=$2 + + # 编译器在 debug 模式下可能把诊断也写到 stdout;这里保留 LLVM-like IR 行。 + grep -E '^(define |declare |@|[[:space:]]|})|^[A-Za-z_.$%][A-Za-z0-9_.$%]*:$' \ + "$raw_file" > "$ll_file" || true +} + +record_failure() { + local bucket=$1 + local message=$2 + case "$bucket" in + ir) ir_failures+=("$message") ;; + opt) opt_failures+=("$message") ;; + run) run_failures+=("$message") ;; + esac + if [[ "$stop_on_fail" == true ]]; then + echo "" + echo "遇到失败,按 --stop-on-fail 停止。失败文件保留在: $TMP_DIR" + exit 1 + fi +} + +record_warning() { + local bucket=$1 + local message=$2 + case "$bucket" in + opt) opt_warnings+=("$message") ;; + esac +} + +check_scalar_mem2reg() { + local ll_file=$1 + grep -nE '=[[:space:]]*alloca[[:space:]]+(i32|float|i1)\b' "$ll_file" || true +} + +compare_result() { + local input=$1 + local expected_file=$2 + local stdout_file=$3 + local status=$4 + + local actual_file="${stdout_file%.stdout}.actual.out" + { + cat "$stdout_file" + if [[ -s "$stdout_file" ]] && [[ "$(tail -c 1 "$stdout_file" | wc -l)" -eq 0 ]]; then + printf '\n' + fi + printf '%s\n' "$status" + } > "$actual_file" + + local expected_text + local actual_text + expected_text=$(normalize_file "$expected_file") + actual_text=$(normalize_file "$actual_file") + + if [[ "$expected_text" == "$actual_text" ]]; then + echo " [RUN] OK" + return 0 + fi + + echo " [RUN] FAIL: 输出或退出码不匹配" + echo " expected: $expected_file" + echo " actual: $actual_file" + if [[ "$debug" == true ]]; then + diff -u <(printf '%s\n' "$expected_text") <(printf '%s\n' "$actual_text") || true + fi + record_failure run "$input: output mismatch" + return 1 +} + +mapfile -t test_files < <(find "$test_root" -type f -name '*.sy' | sort) + +if [[ ${#test_files[@]} -eq 0 ]]; then + echo "未在目录中找到 .sy 测试: $test_root" >&2 + exit 1 +fi + +ir_total=0 +ir_pass=0 +opt_total=0 +opt_pass=0 +run_total=0 +run_pass=0 + +ir_failures=() +opt_failures=() +opt_warnings=() +run_failures=() + +echo "测试根目录: $test_root" +echo "输出目录: $TMP_DIR" +echo "测试数量: ${#test_files[@]}" +if [[ "$run_exec" == true ]]; then + echo "运行验证: 开启" +else + echo "运行验证: 关闭(加 --run 可开启语义对拍)" +fi +echo "" + +for input in "${test_files[@]}"; do + ir_total=$((ir_total + 1)) + opt_total=$((opt_total + 1)) + + out_dir=$(make_case_out_dir "$input") + mkdir -p "$out_dir" + + base=$(basename "$input") + stem=${base%.sy} + raw_ir="$out_dir/$stem.raw.ll" + ll_file="$out_dir/$stem.ll" + log_file="$out_dir/$stem.compiler.log" + stdout_file="$out_dir/$stem.stdout" + obj_file="$out_dir/$stem.o" + exe_file="$out_dir/$stem" + input_dir=$(dirname "$input") + stdin_file="$input_dir/$stem.in" + expected_file="$input_dir/$stem.out" + + echo "[TEST] ${input#$ROOT_DIR/}" + if [[ "$debug" == true ]]; then + echo " [CMD] $COMPILER --emit-ir $input" + fi + + compiler_status=0 + "$COMPILER" --emit-ir "$input" > "$raw_ir" 2> "$log_file" || compiler_status=$? + extract_ir "$raw_ir" "$ll_file" + + if [[ $compiler_status -ne 0 ]]; then + echo " [IR] FAIL: 编译器返回 $compiler_status" + record_failure ir "$input: compiler failed ($compiler_status)" + continue + fi + + if ! grep -qE '^define ' "$ll_file"; then + echo " [IR] FAIL: 未提取到有效函数定义" + record_failure ir "$input: invalid IR" + continue + fi + + ir_pass=$((ir_pass + 1)) + echo " [IR] OK" + + scalar_allocas=$(check_scalar_mem2reg "$ll_file") + if [[ -n "$scalar_allocas" ]]; then + if [[ "$strict_mem2reg" == true ]]; then + echo " [OPT] FAIL: 优化后仍有可提升标量 alloca" + else + echo " [OPT] WARN: 优化后仍有标量 alloca 残留" + fi + if [[ "$debug" == true ]]; then + echo "$scalar_allocas" | sed 's/^/ /' + fi + if [[ "$strict_mem2reg" == true ]]; then + record_failure opt "$input: scalar alloca remains" + else + opt_pass=$((opt_pass + 1)) + record_warning opt "$input: scalar alloca remains" + fi + else + opt_pass=$((opt_pass + 1)) + echo " [OPT] OK: 未发现标量 alloca 残留" + fi + + if [[ "$run_exec" != true ]]; then + continue + fi + + if [[ ! -f "$expected_file" ]]; then + echo " [RUN] SKIP: 未找到期望输出 $expected_file" + continue + fi + run_total=$((run_total + 1)) + + if ! "$LLC_BIN" -filetype=obj "$ll_file" -o "$obj_file" > "$stdout_file" 2>&1; then + echo " [RUN] FAIL: llc 生成对象文件失败" + record_failure run "$input: llc failed" + continue + fi + + if [[ $runtime_ready -eq 1 ]]; then + if ! "$CLANG_BIN" "$obj_file" "$RUNTIME_OBJ" -o "$exe_file" >> "$stdout_file" 2>&1; then + echo " [RUN] FAIL: clang 链接失败" + record_failure run "$input: clang link failed" + continue + fi + else + if ! "$CLANG_BIN" "$obj_file" -o "$exe_file" >> "$stdout_file" 2>&1; then + echo " [RUN] FAIL: clang 链接失败" + record_failure run "$input: clang link failed" + continue + fi + fi + + run_status=0 + if [[ -f "$stdin_file" ]]; then + "$exe_file" < "$stdin_file" > "$stdout_file" 2>&1 || run_status=$? + else + "$exe_file" > "$stdout_file" 2>&1 || run_status=$? + fi + + if compare_result "$input" "$expected_file" "$stdout_file" "$run_status"; then + run_pass=$((run_pass + 1)) + fi +done + +echo "" +echo "测试完成。" +echo "IR 生成: $ir_pass / $ir_total" +echo "Pass 优化检查: $opt_pass / $opt_total" +if [[ "$run_exec" == true ]]; then + echo "运行结果: $run_pass / $run_total" +fi + +if [[ ${#ir_failures[@]} -gt 0 ]]; then + echo "" + echo "IR 失败列表:" + for item in "${ir_failures[@]}"; do + echo " $item" + done +fi + +if [[ ${#opt_failures[@]} -gt 0 ]]; then + echo "" + echo "优化检查失败列表:" + for item in "${opt_failures[@]}"; do + echo " $item" + done +fi + +if [[ ${#opt_warnings[@]} -gt 0 ]]; then + echo "" + echo "优化警告列表(默认不算失败;加 --strict-mem2reg 可升级为失败):" + for item in "${opt_warnings[@]}"; do + echo " $item" + done +fi + +if [[ ${#run_failures[@]} -gt 0 ]]; then + echo "" + echo "运行失败列表:" + for item in "${run_failures[@]}"; do + echo " $item" + done +fi + +if [[ ${#ir_failures[@]} -gt 0 || ${#opt_failures[@]} -gt 0 || ${#run_failures[@]} -gt 0 ]]; then + echo "" + echo "失败产物已保留在: $TMP_DIR" + exit 1 +fi + +echo "" +echo "全部检查通过。" diff --git a/src/ir/BasicBlock.cpp b/src/ir/BasicBlock.cpp index b18502c..c6de1fc 100644 --- a/src/ir/BasicBlock.cpp +++ b/src/ir/BasicBlock.cpp @@ -9,6 +9,7 @@ #include "ir/IR.h" +#include #include namespace ir { @@ -32,6 +33,31 @@ const std::vector>& BasicBlock::GetInstructions() return instructions_; } +void BasicBlock::RemoveInstruction(Instruction* inst) { + if (!inst) { + return; + } + + auto it = std::find_if(instructions_.begin(), instructions_.end(), + [&](const std::unique_ptr& ptr) { + return ptr.get() == inst; + }); + if (it == instructions_.end()) { + return; + } + + // 清理该指令对操作数的 use 关系。 + for (size_t i = 0; i < inst->GetNumOperands(); ++i) { + Value* operand = inst->GetOperand(i); + if (operand) { + operand->RemoveUse(inst, i); + } + } + + inst->SetParent(nullptr); + instructions_.erase(it); +} + // 前驱/后继接口先保留给后续 CFG 扩展使用。 // 当前最小 IR 中还没有 branch 指令,因此这些列表通常为空。 const std::vector& BasicBlock::GetPredecessors() const { diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index a3a6278..9b3420a 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -196,6 +196,7 @@ static const char* OpcodeToString(Opcode op) { case Opcode::FPToSI: return "fptosi"; case Opcode::FPExt: return "fpext"; case Opcode::FPTrunc: return "fptrunc"; + case Opcode::Phi: return "phi"; } return "?"; } @@ -457,6 +458,19 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { break; } + case Opcode::Phi: { + auto* phi = static_cast(inst); + os << " " << phi->GetName() << " = phi " + << TypeToString(*phi->GetType()); + for (size_t i = 0; i < phi->GetNumIncoming(); ++i) { + os << (i == 0 ? " " : ", ") + << "[" << ValueToString(phi->GetIncomingValue(i)) + << ", %" << phi->GetIncomingBlock(i)->GetName() << "]"; + } + os << "\n"; + break; + } + case Opcode::ZExt: { auto* zext = static_cast(inst); os << " " << zext->GetName() << " = zext " diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index d0f280a..4b5e652 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -226,6 +226,14 @@ Function* CallInst::GetCallee() const { return callee_; } const std::vector& CallInst::GetArgs() const { return args_; } +void CallInst::SetArg(size_t index, Value* value) { + if (index >= args_.size()) { + throw std::out_of_range("CallInst argument index out of range"); + } + args_[index] = value; + SetOperand(index, value); +} + GEPInst::GEPInst(std::shared_ptr ptr_ty, Value* base, const std::vector& indices, @@ -278,4 +286,3 @@ CallInst::CallInst(std::shared_ptr ret_ty, Function* callee, } // namespace ir - diff --git a/src/ir/Value.cpp b/src/ir/Value.cpp index 56dd2e6..614ef9d 100644 --- a/src/ir/Value.cpp +++ b/src/ir/Value.cpp @@ -69,7 +69,19 @@ void Value::ReplaceAllUsesWith(Value* new_value) { if (!user) continue; size_t operand_index = use.GetOperandIndex(); if (user->GetOperand(operand_index) == this) { - user->SetOperand(operand_index, new_value); + if (auto* phi = dynamic_cast(user)) { + phi->SetIncomingValue(operand_index, new_value); + } else if (auto* br = dynamic_cast(user)) { + if (br->IsConditional() && operand_index == 0) { + br->SetCondition(new_value); + } else { + user->SetOperand(operand_index, new_value); + } + } else if (auto* call = dynamic_cast(user)) { + call->SetArg(operand_index, new_value); + } else { + user->SetOperand(operand_index, new_value); + } } } } diff --git a/src/ir/analysis/DominatorTree.cpp b/src/ir/analysis/DominatorTree.cpp index eaf7269..54f09a6 100644 --- a/src/ir/analysis/DominatorTree.cpp +++ b/src/ir/analysis/DominatorTree.cpp @@ -1,4 +1,254 @@ -// 支配树分析: -// - 构建/查询 Dominator Tree 及相关关系 -// - 为 mem2reg、CFG 优化与循环分析提供基础能力 +#include "ir/analysis/DominatorTree.h" +#include +#include +#include + +namespace ir { + +namespace { + +std::vector GetBlockSuccessors(BasicBlock* block) { + std::vector succs; + if (!block) { + return succs; + } + + const auto& instructions = block->GetInstructions(); + if (instructions.empty()) { + return succs; + } + + Instruction* term = instructions.back().get(); + if (!term->IsTerminator()) { + return succs; + } + + if (term->GetOpcode() == Opcode::Br) { + auto* br = static_cast(term); + succs.push_back(br->GetTarget()); + } else if (term->GetOpcode() == Opcode::CondBr) { + auto* br = static_cast(term); + succs.push_back(br->GetTrueTarget()); + succs.push_back(br->GetFalseTarget()); + } + return succs; +} + +} // namespace + +void DominatorTree::Recalculate(Function& function) { + BuildCFG(function); + if (blocks_.empty()) { + return; + } + + idom_.clear(); + ComputeIDoms(); + ComputeDominanceFrontiers(); +} + +BasicBlock* DominatorTree::GetRoot() const { + if (blocks_.empty()) { + return nullptr; + } + return blocks_.front(); +} + +BasicBlock* DominatorTree::GetIDom(BasicBlock* block) const { + auto it = idom_.find(block); + if (it == idom_.end()) { + return nullptr; + } + return it->second; +} + +bool DominatorTree::Dominates(BasicBlock* a, BasicBlock* b) const { + if (!a || !b) { + return false; + } + if (a == b) { + return true; + } + + auto it = idom_.find(b); + while (it != idom_.end() && it->second != b) { + if (it->second == a) { + return true; + } + b = it->second; + it = idom_.find(b); + } + return false; +} + +const std::vector& DominatorTree::GetChildren(BasicBlock* block) const { + auto it = children_.find(block); + if (it == children_.end()) { + static const std::vector empty; + return empty; + } + return it->second; +} + +const std::vector& DominatorTree::GetDominanceFrontier(BasicBlock* block) const { + auto it = dominance_frontier_.find(block); + if (it == dominance_frontier_.end()) { + static const std::vector empty; + return empty; + } + return it->second; +} + +const std::vector& DominatorTree::GetPredecessors(BasicBlock* block) const { + auto it = preds_.find(block); + if (it == preds_.end()) { + static const std::vector empty; + return empty; + } + return it->second; +} + +const std::vector& DominatorTree::GetSuccessors(BasicBlock* block) const { + auto it = succs_.find(block); + if (it == succs_.end()) { + static const std::vector empty; + return empty; + } + return it->second; +} + +void DominatorTree::BuildCFG(Function& function) { + blocks_.clear(); + preds_.clear(); + succs_.clear(); + idom_.clear(); + children_.clear(); + dominance_frontier_.clear(); + dfs_number_.clear(); + + BasicBlock* entry = function.GetEntry(); + if (!entry) { + return; + } + + std::unordered_set visited; + int next_number = 0; + + std::function dfs = [&](BasicBlock* block) { + if (!block || visited.count(block)) { + return; + } + visited.insert(block); + dfs_number_[block] = next_number++; + blocks_.push_back(block); + + auto successors = GetBlockSuccessors(block); + succs_[block] = successors; + for (BasicBlock* succ : successors) { + preds_[succ].push_back(block); + dfs(succ); + } + }; + + dfs(entry); +} + +void DominatorTree::ComputeIDoms() { + if (blocks_.empty()) { + return; + } + + BasicBlock* entry = blocks_.front(); + idom_[entry] = entry; + + bool changed = true; + while (changed) { + changed = false; + for (BasicBlock* block : blocks_) { + if (block == entry) { + continue; + } + + const auto& predecessors = preds_[block]; + BasicBlock* new_idom = nullptr; + for (BasicBlock* pred : predecessors) { + auto pred_it = idom_.find(pred); + if (pred_it == idom_.end()) { + continue; + } + if (!new_idom) { + new_idom = pred; + } else { + new_idom = Intersect(pred, new_idom); + } + } + + if (!new_idom) { + continue; + } + if (idom_.find(block) == idom_.end() || idom_[block] != new_idom) { + idom_[block] = new_idom; + changed = true; + } + } + } + + children_.clear(); + for (const auto& pair : idom_) { + BasicBlock* block = pair.first; + BasicBlock* parent = pair.second; + if (block != parent) { + children_[parent].push_back(block); + } + } +} + +void DominatorTree::ComputeDominanceFrontiers() { + dominance_frontier_.clear(); + + for (BasicBlock* block : blocks_) { + const auto& predecessors = preds_[block]; + if (predecessors.size() < 2) { + continue; + } + + for (BasicBlock* pred : predecessors) { + BasicBlock* runner = pred; + while (runner != idom_[block]) { + auto& frontier = dominance_frontier_[runner]; + if (std::find(frontier.begin(), frontier.end(), block) == frontier.end()) { + frontier.push_back(block); + } + runner = idom_[runner]; + } + } + } +} + +BasicBlock* DominatorTree::Intersect(BasicBlock* first, BasicBlock* second) const { + std::unordered_set first_ancestors; + for (BasicBlock* block = first; block;) { + first_ancestors.insert(block); + auto it = idom_.find(block); + if (it == idom_.end() || it->second == block) { + break; + } + block = it->second; + } + + for (BasicBlock* block = second; block;) { + if (first_ancestors.count(block)) { + return block; + } + auto it = idom_.find(block); + if (it == idom_.end() || it->second == block) { + break; + } + block = it->second; + } + + return GetRoot(); +} + +} // namespace ir diff --git a/src/ir/passes/ConstFold.cpp b/src/ir/passes/ConstFold.cpp index 19f2d43..f4d99d7 100644 --- a/src/ir/passes/ConstFold.cpp +++ b/src/ir/passes/ConstFold.cpp @@ -1,4 +1,299 @@ -// IR 常量折叠: -// - 折叠可判定的常量表达式 -// - 简化常量控制流分支(按实现范围裁剪) +#include "ir/passes/ConstFold.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ir { + +namespace { + +struct ConstKey { + Type::Kind kind; + int int_value; + uint32_t float_bits; + + bool operator==(const ConstKey& other) const { + return kind == other.kind && int_value == other.int_value && + float_bits == other.float_bits; + } +}; + +struct ConstKeyHash { + size_t operator()(const ConstKey& key) const { + size_t h = std::hash{}(static_cast(key.kind)); + h ^= std::hash{}(key.int_value) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(key.float_bits) + 0x9e3779b9 + (h << 6) + (h >> 2); + return h; + } +}; + +ConstantInt* GetIntConstant(std::shared_ptr ty, int value) { + static std::unordered_map, ConstKeyHash> cache; + ConstKey key{ty->GetKind(), value, 0}; + auto it = cache.find(key); + if (it != cache.end()) { + return it->second.get(); + } + auto constant = std::make_unique(ty, value); + auto* ptr = constant.get(); + cache.emplace(key, std::move(constant)); + return ptr; +} + +ConstantFloat* GetFloatConstant(float value) { + static std::unordered_map, ConstKeyHash> cache; + uint32_t bits = 0; + std::memcpy(&bits, &value, sizeof(bits)); + ConstKey key{Type::Kind::Float, 0, bits}; + auto it = cache.find(key); + if (it != cache.end()) { + return it->second.get(); + } + auto constant = std::make_unique(Type::GetFloatType(), value); + auto* ptr = constant.get(); + cache.emplace(key, std::move(constant)); + return ptr; +} + +void ReplaceUse(User* user, size_t index, Value* value) { + if (auto* phi = dynamic_cast(user)) { + phi->SetIncomingValue(index, value); + } else if (auto* br = dynamic_cast(user)) { + if (br->IsConditional() && index == 0) { + br->SetCondition(value); + } else { + user->SetOperand(index, value); + } + } else if (auto* call = dynamic_cast(user)) { + call->SetArg(index, value); + } else { + user->SetOperand(index, value); + } +} + +void ReplaceAllUses(Value* old_value, Value* new_value) { + auto uses = old_value->GetUses(); + for (const auto& use : uses) { + User* user = use.GetUser(); + if (!user) { + continue; + } + size_t index = use.GetOperandIndex(); + if (user->GetOperand(index) == old_value) { + ReplaceUse(user, index, new_value); + } + } +} + +bool IsFoldableBinary(Opcode op) { + switch (op) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::Div: + case Opcode::Mod: + case Opcode::And: + case Opcode::Or: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: + return true; + default: + return false; + } +} + +ConstantValue* FoldBinary(BinaryInst* inst) { + auto* lhs_i = dynamic_cast(inst->GetLhs()); + auto* rhs_i = dynamic_cast(inst->GetRhs()); + if (lhs_i && rhs_i) { + int lhs = lhs_i->GetValue(); + int rhs = rhs_i->GetValue(); + int result = 0; + switch (inst->GetOpcode()) { + case Opcode::Add: result = lhs + rhs; break; + case Opcode::Sub: result = lhs - rhs; break; + case Opcode::Mul: result = lhs * rhs; break; + case Opcode::Div: + if (rhs == 0) return nullptr; + result = lhs / rhs; + break; + case Opcode::Mod: + if (rhs == 0) return nullptr; + result = lhs % rhs; + break; + case Opcode::And: result = (lhs != 0 && rhs != 0) ? 1 : 0; break; + case Opcode::Or: result = (lhs != 0 || rhs != 0) ? 1 : 0; break; + default: return nullptr; + } + return GetIntConstant(inst->GetType(), result); + } + + auto* lhs_f = dynamic_cast(inst->GetLhs()); + auto* rhs_f = dynamic_cast(inst->GetRhs()); + if (lhs_f && rhs_f) { + float lhs = lhs_f->GetValue(); + float rhs = rhs_f->GetValue(); + float result = 0.0f; + switch (inst->GetOpcode()) { + case Opcode::FAdd: result = lhs + rhs; break; + case Opcode::FSub: result = lhs - rhs; break; + case Opcode::FMul: result = lhs * rhs; break; + case Opcode::FDiv: + if (rhs == 0.0f) return nullptr; + result = lhs / rhs; + break; + default: return nullptr; + } + return GetFloatConstant(result); + } + + return nullptr; +} + +ConstantValue* FoldICmp(IcmpInst* inst) { + auto* lhs_c = dynamic_cast(inst->GetLhs()); + auto* rhs_c = dynamic_cast(inst->GetRhs()); + if (!lhs_c || !rhs_c) { + return nullptr; + } + + int lhs = lhs_c->GetValue(); + int rhs = rhs_c->GetValue(); + bool result = false; + switch (inst->GetPredicate()) { + case IcmpInst::Predicate::EQ: result = lhs == rhs; break; + case IcmpInst::Predicate::NE: result = lhs != rhs; break; + case IcmpInst::Predicate::LT: result = lhs < rhs; break; + case IcmpInst::Predicate::LE: result = lhs <= rhs; break; + case IcmpInst::Predicate::GT: result = lhs > rhs; break; + case IcmpInst::Predicate::GE: result = lhs >= rhs; break; + } + return GetIntConstant(Type::GetInt1Type(), result ? 1 : 0); +} + +ConstantValue* FoldFCmp(FcmpInst* inst) { + auto* lhs_c = dynamic_cast(inst->GetLhs()); + auto* rhs_c = dynamic_cast(inst->GetRhs()); + if (!lhs_c || !rhs_c) { + return nullptr; + } + + float lhs = lhs_c->GetValue(); + float rhs = rhs_c->GetValue(); + bool ordered = !std::isnan(lhs) && !std::isnan(rhs); + bool result = false; + switch (inst->GetPredicate()) { + case FcmpInst::Predicate::FALSE: result = false; break; + case FcmpInst::Predicate::OEQ: result = ordered && lhs == rhs; break; + case FcmpInst::Predicate::OGT: result = ordered && lhs > rhs; break; + case FcmpInst::Predicate::OGE: result = ordered && lhs >= rhs; break; + case FcmpInst::Predicate::OLT: result = ordered && lhs < rhs; break; + case FcmpInst::Predicate::OLE: result = ordered && lhs <= rhs; break; + case FcmpInst::Predicate::ONE: result = ordered && lhs != rhs; break; + case FcmpInst::Predicate::ORD: result = ordered; break; + case FcmpInst::Predicate::UNO: result = !ordered; break; + case FcmpInst::Predicate::UEQ: result = !ordered || lhs == rhs; break; + case FcmpInst::Predicate::UGT: result = !ordered || lhs > rhs; break; + case FcmpInst::Predicate::UGE: result = !ordered || lhs >= rhs; break; + case FcmpInst::Predicate::ULT: result = !ordered || lhs < rhs; break; + case FcmpInst::Predicate::ULE: result = !ordered || lhs <= rhs; break; + case FcmpInst::Predicate::UNE: result = !ordered || lhs != rhs; break; + case FcmpInst::Predicate::TRUE: result = true; break; + } + return GetIntConstant(Type::GetInt1Type(), result ? 1 : 0); +} + +ConstantValue* FoldCast(Instruction* inst) { + if (auto* zext = dynamic_cast(inst)) { + if (auto* value = dynamic_cast(zext->GetValue())) { + return GetIntConstant(zext->GetType(), value->GetValue() != 0 ? 1 : 0); + } + } else if (auto* trunc = dynamic_cast(inst)) { + if (auto* value = dynamic_cast(trunc->GetValue())) { + int result = trunc->GetType()->IsInt1() ? (value->GetValue() != 0 ? 1 : 0) + : value->GetValue(); + return GetIntConstant(trunc->GetType(), result); + } + } else if (auto* sitofp = dynamic_cast(inst)) { + if (auto* value = dynamic_cast(sitofp->GetValue())) { + return GetFloatConstant(static_cast(value->GetValue())); + } + } else if (auto* fptosi = dynamic_cast(inst)) { + if (auto* value = dynamic_cast(fptosi->GetValue())) { + return GetIntConstant(fptosi->GetType(), static_cast(value->GetValue())); + } + } + return nullptr; +} + +ConstantValue* FoldPhi(PhiInst* phi) { + if (phi->GetNumIncoming() == 0) { + return nullptr; + } + + auto* first = dynamic_cast(phi->GetIncomingValue(0)); + if (!first) { + return nullptr; + } + for (size_t i = 1; i < phi->GetNumIncoming(); ++i) { + if (phi->GetIncomingValue(i) != first) { + return nullptr; + } + } + return first; +} + +ConstantValue* TryFold(Instruction* inst) { + if (auto* binary = dynamic_cast(inst)) { + if (IsFoldableBinary(binary->GetOpcode())) { + return FoldBinary(binary); + } + } + if (auto* icmp = dynamic_cast(inst)) { + return FoldICmp(icmp); + } + if (auto* fcmp = dynamic_cast(inst)) { + return FoldFCmp(fcmp); + } + if (auto* phi = dynamic_cast(inst)) { + return FoldPhi(phi); + } + return FoldCast(inst); +} + +} // namespace + +bool ConstFoldPass::RunOnFunction(Function& function) { + bool changed = false; + + for (const auto& block_ptr : function.GetBlocks()) { + BasicBlock* block = block_ptr.get(); + std::vector to_remove; + for (const auto& inst_ptr : block->GetInstructions()) { + Instruction* inst = inst_ptr.get(); + ConstantValue* folded = TryFold(inst); + if (!folded) { + continue; + } + ReplaceAllUses(inst, folded); + to_remove.push_back(inst); + changed = true; + } + + for (Instruction* inst : to_remove) { + block->RemoveInstruction(inst); + } + } + + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/ConstProp.cpp b/src/ir/passes/ConstProp.cpp index 1768b71..4d54751 100644 --- a/src/ir/passes/ConstProp.cpp +++ b/src/ir/passes/ConstProp.cpp @@ -1,5 +1,212 @@ -// 常量传播(Constant Propagation): -// - 沿 use-def 关系传播已知常量 -// - 将可替换的 SSA 值改写为常量,暴露更多折叠机会 -// - 常与 ConstFold、DCE、CFGSimplify 迭代配合使用 +#include "ir/passes/ConstProp.h" +#include +#include +#include +#include + +namespace ir { + +namespace { + +struct IntKey { + Type::Kind kind; + int value; + + bool operator==(const IntKey& other) const { + return kind == other.kind && value == other.value; + } +}; + +struct IntKeyHash { + size_t operator()(const IntKey& key) const { + size_t h = std::hash{}(static_cast(key.kind)); + h ^= std::hash{}(key.value) + 0x9e3779b9 + (h << 6) + (h >> 2); + return h; + } +}; + +ConstantInt* GetIntConstant(std::shared_ptr ty, int value) { + static std::unordered_map, IntKeyHash> cache; + IntKey key{ty->GetKind(), value}; + auto it = cache.find(key); + if (it != cache.end()) { + return it->second.get(); + } + auto constant = std::make_unique(ty, value); + auto* ptr = constant.get(); + cache.emplace(key, std::move(constant)); + return ptr; +} + +bool IsZero(Value* value) { + auto* constant = dynamic_cast(value); + return constant && constant->GetValue() == 0; +} + +bool IsOne(Value* value) { + auto* constant = dynamic_cast(value); + return constant && constant->GetValue() == 1; +} + +bool IsFloatZero(Value* value) { + auto* constant = dynamic_cast(value); + return constant && constant->GetValue() == 0.0f; +} + +bool IsFloatOne(Value* value) { + auto* constant = dynamic_cast(value); + return constant && constant->GetValue() == 1.0f; +} + +void ReplaceUse(User* user, size_t index, Value* value) { + if (auto* phi = dynamic_cast(user)) { + phi->SetIncomingValue(index, value); + } else if (auto* br = dynamic_cast(user)) { + if (br->IsConditional() && index == 0) { + br->SetCondition(value); + } else { + user->SetOperand(index, value); + } + } else if (auto* call = dynamic_cast(user)) { + call->SetArg(index, value); + } else { + user->SetOperand(index, value); + } +} + +void ReplaceAllUses(Value* old_value, Value* new_value) { + auto uses = old_value->GetUses(); + for (const auto& use : uses) { + User* user = use.GetUser(); + if (!user) { + continue; + } + size_t index = use.GetOperandIndex(); + if (user->GetOperand(index) == old_value) { + ReplaceUse(user, index, new_value); + } + } +} + +Value* SimplifyBinary(BinaryInst* inst) { + Value* lhs = inst->GetLhs(); + Value* rhs = inst->GetRhs(); + + switch (inst->GetOpcode()) { + case Opcode::Add: + case Opcode::Or: + if (IsZero(lhs)) return rhs; + if (IsZero(rhs)) return lhs; + break; + case Opcode::Sub: + if (IsZero(rhs)) return lhs; + break; + case Opcode::Mul: + if (IsZero(lhs) || IsZero(rhs)) return GetIntConstant(inst->GetType(), 0); + if (IsOne(lhs)) return rhs; + if (IsOne(rhs)) return lhs; + break; + case Opcode::Div: + if (IsZero(lhs)) return GetIntConstant(inst->GetType(), 0); + if (IsOne(rhs)) return lhs; + break; + case Opcode::Mod: + if (IsZero(lhs) || IsOne(rhs)) return GetIntConstant(inst->GetType(), 0); + break; + case Opcode::And: + if (IsZero(lhs) || IsZero(rhs)) return GetIntConstant(inst->GetType(), 0); + if (IsOne(lhs)) return rhs; + if (IsOne(rhs)) return lhs; + break; + case Opcode::FAdd: + if (IsFloatZero(lhs)) return rhs; + if (IsFloatZero(rhs)) return lhs; + break; + case Opcode::FSub: + if (IsFloatZero(rhs)) return lhs; + break; + case Opcode::FMul: + if (IsFloatOne(lhs)) return rhs; + if (IsFloatOne(rhs)) return lhs; + break; + case Opcode::FDiv: + if (IsFloatOne(rhs)) return lhs; + break; + default: + break; + } + + return nullptr; +} + +Value* SimplifyPhi(PhiInst* phi) { + if (phi->GetNumIncoming() == 0) { + return nullptr; + } + + Value* same = nullptr; + for (size_t i = 0; i < phi->GetNumIncoming(); ++i) { + Value* incoming = phi->GetIncomingValue(i); + if (incoming == phi) { + continue; + } + if (!same) { + same = incoming; + continue; + } + if (incoming != same) { + return nullptr; + } + } + return same; +} + +Value* TrySimplify(Instruction* inst) { + if (auto* binary = dynamic_cast(inst)) { + return SimplifyBinary(binary); + } + if (auto* phi = dynamic_cast(inst)) { + return SimplifyPhi(phi); + } + if (auto* zext = dynamic_cast(inst)) { + if (zext->GetValue()->GetType()->GetKind() == zext->GetType()->GetKind()) { + return zext->GetValue(); + } + } + if (auto* trunc = dynamic_cast(inst)) { + if (trunc->GetValue()->GetType()->GetKind() == trunc->GetType()->GetKind()) { + return trunc->GetValue(); + } + } + return nullptr; +} + +} // namespace + +bool ConstPropPass::RunOnFunction(Function& function) { + bool changed = false; + + for (const auto& block_ptr : function.GetBlocks()) { + BasicBlock* block = block_ptr.get(); + std::vector to_remove; + for (const auto& inst_ptr : block->GetInstructions()) { + Instruction* inst = inst_ptr.get(); + Value* replacement = TrySimplify(inst); + if (!replacement || replacement == inst) { + continue; + } + ReplaceAllUses(inst, replacement); + to_remove.push_back(inst); + changed = true; + } + + for (Instruction* inst : to_remove) { + block->RemoveInstruction(inst); + } + } + + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/DCE.cpp b/src/ir/passes/DCE.cpp index 5a0db91..e477b39 100644 --- a/src/ir/passes/DCE.cpp +++ b/src/ir/passes/DCE.cpp @@ -1,4 +1,54 @@ -// 死代码删除(DCE): -// - 删除无用指令与无用基本块 -// - 通常与 CFG 简化配合使用 +#include "ir/passes/DCE.h" +#include + +namespace ir { + +namespace { + +bool HasSideEffectOrControl(Instruction* inst) { + switch (inst->GetOpcode()) { + case Opcode::Store: + case Opcode::Ret: + case Opcode::Call: + case Opcode::Br: + case Opcode::CondBr: + return true; + default: + return false; + } +} + +bool IsRemovable(Instruction* inst) { + return !HasSideEffectOrControl(inst) && inst->GetUses().empty(); +} + +} // namespace + +bool DCEPass::RunOnFunction(Function& function) { + bool changed = false; + bool local_changed = true; + + while (local_changed) { + local_changed = false; + for (const auto& block_ptr : function.GetBlocks()) { + BasicBlock* block = block_ptr.get(); + std::vector to_remove; + for (const auto& inst_ptr : block->GetInstructions()) { + Instruction* inst = inst_ptr.get(); + if (IsRemovable(inst)) { + to_remove.push_back(inst); + } + } + for (Instruction* inst : to_remove) { + block->RemoveInstruction(inst); + local_changed = true; + changed = true; + } + } + } + + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/Mem2Reg.cpp b/src/ir/passes/Mem2Reg.cpp index 0b052ba..3786d86 100644 --- a/src/ir/passes/Mem2Reg.cpp +++ b/src/ir/passes/Mem2Reg.cpp @@ -1,4 +1,276 @@ -// Mem2Reg(SSA 构造): -// - 将局部变量的 alloca/load/store 提升为 SSA 形式 -// - 插入 PHI 并重写使用,依赖支配树等分析 +#include "ir/passes/Mem2Reg.h" +#include "ir/analysis/DominatorTree.h" +#include "utils/Log.h" +#include +#include +#include +#include +#include +#include +#include + +namespace ir { + +namespace { + +bool IsScalarAlloca(AllocaInst* alloca) { + if (!alloca) { + return false; + } + auto ty = alloca->GetType(); + return ty->IsPtrInt32() || ty->IsPtrFloat() || ty->IsPtrInt1(); +} + +std::shared_ptr GetAllocatedElementType(AllocaInst* alloca) { + if (!alloca) { + return nullptr; + } + auto ty = alloca->GetType(); + if (ty->IsPtrInt32()) { + return Type::GetInt32Type(); + } + if (ty->IsPtrFloat()) { + return Type::GetFloatType(); + } + if (ty->IsPtrInt1()) { + return Type::GetInt1Type(); + } + return nullptr; +} + +bool CollectAllocaUsers(AllocaInst* alloca, + std::vector& loads, + std::vector& stores) { + loads.clear(); + stores.clear(); + if (!alloca) { + return false; + } + + for (const auto& use : alloca->GetUses()) { + auto* user = use.GetUser(); + if (!user) { + return false; + } + + if (auto* load = dynamic_cast(user)) { + if (load->GetPtr() != alloca) { + return false; + } + loads.push_back(load); + } else if (auto* store = dynamic_cast(user)) { + if (store->GetPtr() != alloca) { + return false; + } + stores.push_back(store); + } else { + return false; + } + } + return true; +} + +bool RenameBlocks(BasicBlock* block, Value* incoming, AllocaInst* alloca, + const DominatorTree& domtree, + const std::unordered_map& phi_for_block, + std::unordered_map& block_out, + bool apply_changes) { + Value* current = incoming; + auto phi_it = phi_for_block.find(block); + if (phi_it != phi_for_block.end()) { + current = phi_it->second; + } + + std::vector to_remove; + for (const auto& inst_ptr : block->GetInstructions()) { + Instruction* inst = inst_ptr.get(); + if (phi_it != phi_for_block.end() && inst == phi_it->second) { + continue; + } + + if (inst->GetOpcode() == Opcode::Load) { + auto* load = static_cast(inst); + if (load->GetPtr() == alloca) { + if (!current) { + return false; + } + if (apply_changes) { + load->ReplaceAllUsesWith(current); + } + to_remove.push_back(inst); + continue; + } + } + + if (inst->GetOpcode() == Opcode::Store) { + auto* store = static_cast(inst); + if (store->GetPtr() == alloca) { + current = store->GetValue(); + if (apply_changes) { + to_remove.push_back(inst); + } + continue; + } + } + } + + block_out[block] = current; + + for (BasicBlock* succ : domtree.GetSuccessors(block)) { + auto succ_phi = phi_for_block.find(succ); + if (succ_phi != phi_for_block.end()) { + if (!current) { + return false; + } + if (!apply_changes) { + succ_phi->second->AddIncoming(current, block); + } + } + } + + if (apply_changes) { + for (Instruction* inst : to_remove) { + block->RemoveInstruction(inst); + } + } + + for (BasicBlock* child : domtree.GetChildren(block)) { + if (!RenameBlocks(child, current, alloca, domtree, phi_for_block, + block_out, apply_changes)) { + return false; + } + } + + return true; +} + +std::string MakePhiName(AllocaInst* alloca, BasicBlock* block, int id) { + std::string base = alloca && !alloca->GetName().empty() ? alloca->GetName() + : "mem2reg"; + std::string block_name = block && !block->GetName().empty() ? block->GetName() + : "block"; + return base + "." + block_name + ".phi" + std::to_string(id); +} + +} // namespace + +bool Mem2RegPass::RunOnFunction(Function& function) { + changed_ = false; + DebugStream() << "[DEBUG] Mem2RegPass: starting on function " << function.GetName() << std::endl; + DominatorTree domtree; + domtree.Recalculate(function); + DebugStream() << "[DEBUG] Mem2RegPass: dominator tree built for " << function.GetName() << std::endl; + changed_ = PromoteAllocas(function, domtree); + DebugStream() << "[DEBUG] Mem2RegPass: finished on function " << function.GetName() << " changed=" << changed_ << std::endl; + return changed_; +} + +bool Mem2RegPass::RunOnModule(Module& module) { + bool changed = false; + for (const auto& function : module.GetFunctions()) { + if (function) { + changed = RunOnFunction(*function) || changed; + } + } + return changed; +} + +bool Mem2RegPass::PromoteAllocas(Function& function, DominatorTree& domtree) { + BasicBlock* entry = function.GetEntry(); + if (!entry) { + return false; + } + + std::vector allocas; + for (const auto& inst_ptr : entry->GetInstructions()) { + if (auto* alloca = dynamic_cast(inst_ptr.get())) { + if (IsScalarAlloca(alloca)) { + allocas.push_back(alloca); + } + } + } + + bool changed = false; + int phi_id = 0; + for (AllocaInst* alloca : allocas) { + DebugStream() << "[DEBUG] Mem2RegPass: processing alloca " << alloca->GetName() << std::endl; + + std::vector loads; + std::vector stores; + if (!CollectAllocaUsers(alloca, loads, stores)) { + DebugStream() << "[DEBUG] Mem2RegPass: CollectAllocaUsers failed for " << alloca->GetName() << std::endl; + continue; + } + DebugStream() << "[DEBUG] Mem2RegPass: loads=" << loads.size() << " stores=" << stores.size() << std::endl; + if (stores.empty()) { + continue; + } + + std::unordered_set def_blocks; + for (StoreInst* store : stores) { + if (store->GetParent()) { + def_blocks.insert(store->GetParent()); + } + } + if (def_blocks.empty()) { + continue; + } + + auto element_type = GetAllocatedElementType(alloca); + if (!element_type) { + continue; + } + + std::unordered_map phi_for_block; + std::vector worklist(def_blocks.begin(), def_blocks.end()); + std::unordered_set has_phi; + while (!worklist.empty()) { + BasicBlock* block = worklist.back(); + worklist.pop_back(); + DebugStream() << "[DEBUG] Mem2RegPass: worklist block=" << block->GetName() << std::endl; + for (BasicBlock* frontier : domtree.GetDominanceFrontier(block)) { + if (has_phi.insert(frontier).second) { + PhiInst* phi = frontier->InsertAtBeginning(element_type, + MakePhiName(alloca, frontier, phi_id++)); + DebugStream() << "[DEBUG] Mem2RegPass: inserted phi in " << frontier->GetName() << std::endl; + phi_for_block[frontier] = phi; + if (!def_blocks.count(frontier)) { + worklist.push_back(frontier); + } + } + } + } + + std::unordered_map block_out; + DebugStream() << "[DEBUG] Mem2RegPass: before dry run RenameBlocks for " << alloca->GetName() << std::endl; + if (!RenameBlocks(function.GetEntry(), nullptr, alloca, domtree, + phi_for_block, block_out, false)) { + DebugStream() << "[DEBUG] Mem2RegPass: dry run failed for " << alloca->GetName() << std::endl; + for (auto& pair : phi_for_block) { + pair.first->RemoveInstruction(pair.second); + } + continue; + } + + DebugStream() << "[DEBUG] Mem2RegPass: before apply RenameBlocks for " << alloca->GetName() << std::endl; + if (!RenameBlocks(function.GetEntry(), nullptr, alloca, domtree, + phi_for_block, block_out, true)) { + DebugStream() << "[DEBUG] Mem2RegPass: apply run failed for " << alloca->GetName() << std::endl; + for (auto& pair : phi_for_block) { + pair.first->RemoveInstruction(pair.second); + } + continue; + } + + if (alloca->GetUses().empty()) { + entry->RemoveInstruction(alloca); + } + changed = true; + } + + changed_ = changed; + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/PassManager.cpp b/src/ir/passes/PassManager.cpp index 044328f..3038369 100644 --- a/src/ir/passes/PassManager.cpp +++ b/src/ir/passes/PassManager.cpp @@ -1 +1,38 @@ -// IR Pass 管理骨架。 +#include "ir/passes/PassManager.h" +#include "utils/Log.h" +#include + +namespace ir { + +void PassManager::AddPass(std::unique_ptr pass) { + if (pass) { + passes_.push_back(std::move(pass)); + } +} + +bool PassManager::Run(Function& function) { + bool changed = false; + DebugStream() << "[DEBUG] PassManager: running " << passes_.size() << " pass(es) on function " + << function.GetName() << std::endl; + for (const auto& pass : passes_) { + if (pass) { + DebugStream() << "[DEBUG] PassManager: before pass" << std::endl; + changed = pass->RunOnFunction(function) || changed; + DebugStream() << "[DEBUG] PassManager: after pass" << std::endl; + } + } + DebugStream() << "[DEBUG] PassManager: finished function " << function.GetName() << std::endl; + return changed; +} + +bool PassManager::Run(Module& module) { + bool changed = false; + for (const auto& function : module.GetFunctions()) { + if (function) { + changed = Run(*function) || changed; + } + } + return changed; +} + +} // namespace ir diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index e2143bc..f9eda4f 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -100,14 +100,14 @@ std::string MakeStaticArrayName(const ir::Function& func, // visitDecl: 处理声明 std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { - std::cerr << "[DEBUG] visitDecl: 开始处理声明" << std::endl; + DebugStream() << "[DEBUG] visitDecl: 开始处理声明" << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少变量声明")); } // 处理 varDecl if (auto* varDecl = ctx->varDecl()) { - std::cerr << "[DEBUG] visitDecl: 处理变量声明" << std::endl; + DebugStream() << "[DEBUG] visitDecl: 处理变量声明" << std::endl; for (auto* varDef : varDecl->varDef()) { varDef->accept(this); } @@ -115,20 +115,20 @@ std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { // 处理 constDecl if (ctx->constDecl()) { - std::cerr << "[DEBUG] visitDecl: 处理常量声明" << std::endl; + DebugStream() << "[DEBUG] visitDecl: 处理常量声明" << std::endl; auto* constDecl = ctx->constDecl(); for (auto* constDef : constDecl->constDef()) { constDef->accept(this); } } - std::cerr << "[DEBUG] visitDecl: 声明处理完成" << std::endl; + DebugStream() << "[DEBUG] visitDecl: 声明处理完成" << std::endl; return {}; } // visitConstDecl: 处理常量声明 std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) { - std::cerr << "[DEBUG] visitConstDecl: 开始处理常量声明" << std::endl; + DebugStream() << "[DEBUG] visitConstDecl: 开始处理常量声明" << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法常量声明")); } @@ -139,13 +139,13 @@ std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) { } } - std::cerr << "[DEBUG] visitConstDecl: 常量声明处理完成" << std::endl; + DebugStream() << "[DEBUG] visitConstDecl: 常量声明处理完成" << std::endl; return {}; } // visitConstDef: 处理常量定义 - 从符号表获取常量值 std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) { - std::cerr << "[DEBUG] visitConstDef: 开始处理常量定义" << std::endl; + DebugStream() << "[DEBUG] visitConstDef: 开始处理常量定义" << std::endl; if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("irgen", "非法常量定义")); } @@ -158,7 +158,7 @@ std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) { throw std::runtime_error(FormatError("irgen", "常量符号未找到: " + const_name)); } - std::cerr << "[DEBUG] visitConstDef: 从符号表获取常量 " << const_name + DebugStream() << "[DEBUG] visitConstDef: 从符号表获取常量 " << const_name << ", is_array_const: " << sym->IsArrayConstant() << std::endl; // 根据符号表中的常量值创建 IR 常量 @@ -270,11 +270,11 @@ std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) { ir::ConstantValue* const_value = nullptr; if (sym->type->IsInt32()) { const_value = builder_.CreateConstInt(sym->GetIntConstant()); - std::cerr << "[DEBUG] visitConstDef: 整型常量 " << const_name + DebugStream() << "[DEBUG] visitConstDef: 整型常量 " << const_name << " = " << sym->GetIntConstant() << std::endl; } else if (sym->type->IsFloat()) { const_value = builder_.CreateConstFloat(sym->GetFloatConstant()); - std::cerr << "[DEBUG] visitConstDef: 浮点常量 " << const_name + DebugStream() << "[DEBUG] visitConstDef: 浮点常量 " << const_name << " = " << sym->GetFloatConstant() << std::endl; } @@ -287,13 +287,13 @@ std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) { // visitVarDef: 处理变量定义 - 从符号表获取类型信息 std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { - std::cerr << "[DEBUG] visitVarDef: 开始处理变量定义" << std::endl; + DebugStream() << "[DEBUG] visitVarDef: 开始处理变量定义" << std::endl; if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("irgen", "非法变量定义")); } std::string varName = ctx->Ident()->getText(); - std::cerr << "[DEBUG] visitVarDef: 变量名称: " << varName << std::endl; + DebugStream() << "[DEBUG] visitVarDef: 变量名称: " << varName << std::endl; // 防止重复分配 if (storage_map_.find(ctx) != storage_map_.end()) { @@ -306,17 +306,17 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { throw std::runtime_error(FormatError("irgen", "变量符号未找到: " + varName)); } - std::cerr << "[DEBUG] visitVarDef: 变量类型: " + DebugStream() << "[DEBUG] visitVarDef: 变量类型: " << (sym->type->IsInt32() ? "int" : sym->type->IsFloat() ? "float" : sym->type->IsArray() ? "array" : "unknown") << std::endl; // 根据作用域处理 if (func_ == nullptr) { - std::cerr << "[DEBUG] visitVarDef: 处理全局变量" << std::endl; + DebugStream() << "[DEBUG] visitVarDef: 处理全局变量" << std::endl; return HandleGlobalVariable(ctx, varName, sym); } else { - std::cerr << "[DEBUG] visitVarDef: 处理局部变量" << std::endl; + DebugStream() << "[DEBUG] visitVarDef: 处理局部变量" << std::endl; return HandleLocalVariable(ctx, varName, sym); } } @@ -325,7 +325,7 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { std::any IRGenImpl::HandleGlobalVariable(SysYParser::VarDefContext* ctx, const std::string& varName, const Symbol* sym) { - std::cerr << "[DEBUG] HandleGlobalVariable: 开始处理全局变量 " << varName << std::endl; + DebugStream() << "[DEBUG] HandleGlobalVariable: 开始处理全局变量 " << varName << std::endl; if (!sym) { throw std::runtime_error(FormatError("irgen", "符号表信息缺失: " + varName)); @@ -349,7 +349,7 @@ std::any IRGenImpl::HandleGlobalVariable(SysYParser::VarDefContext* ctx, const auto& dimensions = array_ty->GetDimensions(); size_t total_size = array_ty->GetElementCount(); - std::cerr << "[DEBUG] HandleGlobalVariable: 全局数组 " << varName << " 维度: "; + DebugStream() << "[DEBUG] HandleGlobalVariable: 全局数组 " << varName << " 维度: "; for (int d : dimensions) std::cerr << d << " "; std::cerr << ", 总大小: " << total_size << std::endl; @@ -359,7 +359,7 @@ std::any IRGenImpl::HandleGlobalVariable(SysYParser::VarDefContext* ctx, // 处理初始化值(使用带维度感知的展平) std::vector init_consts; if (auto* initVal = ctx->initVal()) { - std::cerr << "[DEBUG] HandleGlobalVariable: 处理初始化值" << std::endl; + DebugStream() << "[DEBUG] HandleGlobalVariable: 处理初始化值" << std::endl; // 全局变量的初始化必须是常量表达式(语义检查已保证) std::vector flat_vals = FlattenInitVal( initVal, dimensions, is_float); @@ -439,7 +439,7 @@ std::any IRGenImpl::HandleGlobalVariable(SysYParser::VarDefContext* ctx, global_map_[varName] = global_var; } - std::cerr << "[DEBUG] HandleGlobalVariable: 全局变量处理完成" << std::endl; + DebugStream() << "[DEBUG] HandleGlobalVariable: 全局变量处理完成" << std::endl; return {}; } @@ -447,7 +447,7 @@ std::any IRGenImpl::HandleGlobalVariable(SysYParser::VarDefContext* ctx, std::any IRGenImpl::HandleLocalVariable(SysYParser::VarDefContext* ctx, const std::string& varName, const Symbol* sym) { - std::cerr << "[DEBUG] HandleLocalVariable: 开始处理局部变量 " << varName << std::endl; + DebugStream() << "[DEBUG] HandleLocalVariable: 开始处理局部变量 " << varName << std::endl; if (!sym) { throw std::runtime_error(FormatError("irgen", "符号表信息缺失: " + varName)); @@ -473,7 +473,7 @@ std::any IRGenImpl::HandleLocalVariable(SysYParser::VarDefContext* ctx, const bool use_heap_storage = current_function_is_recursive_ || total_bytes > kLocalArrayHeapThresholdBytes; - std::cerr << "[DEBUG] HandleLocalVariable: 局部数组 " << varName + DebugStream() << "[DEBUG] HandleLocalVariable: 局部数组 " << varName << " 总大小: " << total_size << std::endl; ir::Value* array_slot = nullptr; @@ -520,7 +520,7 @@ std::any IRGenImpl::HandleLocalVariable(SysYParser::VarDefContext* ctx, if (is_all_zero_init && !use_heap_storage) { builder_.CreateStore(module_.GetContext().GetAggregateZero(sym->type), array_slot); - std::cerr << "[DEBUG] HandleLocalVariable: aggregate zeroinitializer store for " + DebugStream() << "[DEBUG] HandleLocalVariable: aggregate zeroinitializer store for " << varName << std::endl; return {}; } @@ -617,35 +617,35 @@ std::any IRGenImpl::HandleLocalVariable(SysYParser::VarDefContext* ctx, builder_.CreateStore(init, slot); } - std::cerr << "[DEBUG] HandleLocalVariable: 局部变量处理完成" << std::endl; + DebugStream() << "[DEBUG] HandleLocalVariable: 局部变量处理完成" << std::endl; return {}; } // visitInitVal: 处理初始化值 std::any IRGenImpl::visitInitVal(SysYParser::InitValContext* ctx) { - std::cerr << "[DEBUG] visitInitVal: 开始处理初始化值" << std::endl; + DebugStream() << "[DEBUG] visitInitVal: 开始处理初始化值" << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法初始化值")); } // 如果是单个表达式 if (ctx->exp()) { - std::cerr << "[DEBUG] visitInitVal: 处理表达式初始化" << std::endl; + DebugStream() << "[DEBUG] visitInitVal: 处理表达式初始化" << std::endl; return EvalExpr(*ctx->exp()); } // 如果是聚合初始化(花括号列表) else if (!ctx->initVal().empty()) { - std::cerr << "[DEBUG] visitInitVal: 处理聚合初始化" << std::endl; + DebugStream() << "[DEBUG] visitInitVal: 处理聚合初始化" << std::endl; return ProcessNestedInitVals(ctx); } - std::cerr << "[DEBUG] visitInitVal: 空初始化列表" << std::endl; + DebugStream() << "[DEBUG] visitInitVal: 空初始化列表" << std::endl; return std::vector{}; } // ProcessNestedInitVals: 处理嵌套聚合初始化 std::vector IRGenImpl::ProcessNestedInitVals(SysYParser::InitValContext* ctx) { - std::cerr << "[DEBUG] ProcessNestedInitVals: 开始处理嵌套初始化值" << std::endl; + DebugStream() << "[DEBUG] ProcessNestedInitVals: 开始处理嵌套初始化值" << std::endl; std::vector all_values; for (auto* init_val : ctx->initVal()) { @@ -655,13 +655,13 @@ std::vector IRGenImpl::ProcessNestedInitVals(SysYParser::InitValCont // 尝试获取单个值 ir::Value* value = std::any_cast(result); all_values.push_back(value); - std::cerr << "[DEBUG] ProcessNestedInitVals: 获取到单个值" << std::endl; + DebugStream() << "[DEBUG] ProcessNestedInitVals: 获取到单个值" << std::endl; } catch (const std::bad_any_cast&) { try { // 尝试获取值列表(嵌套情况) std::vector nested_values = std::any_cast>(result); - std::cerr << "[DEBUG] ProcessNestedInitVals: 获取到嵌套值列表, 大小: " + DebugStream() << "[DEBUG] ProcessNestedInitVals: 获取到嵌套值列表, 大小: " << nested_values.size() << std::endl; all_values.insert(all_values.end(), nested_values.begin(), nested_values.end()); @@ -674,7 +674,7 @@ std::vector IRGenImpl::ProcessNestedInitVals(SysYParser::InitValCont } } - std::cerr << "[DEBUG] ProcessNestedInitVals: 共获取 " << all_values.size() + DebugStream() << "[DEBUG] ProcessNestedInitVals: 共获取 " << all_values.size() << " 个初始化值" << std::endl; return all_values; } diff --git a/src/irgen/IRGenDriver.cpp b/src/irgen/IRGenDriver.cpp index ac19a7b..49a8e38 100644 --- a/src/irgen/IRGenDriver.cpp +++ b/src/irgen/IRGenDriver.cpp @@ -4,6 +4,11 @@ #include "SysYParser.h" #include "ir/IR.h" +#include "ir/passes/ConstFold.h" +#include "ir/passes/ConstProp.h" +#include "ir/passes/DCE.h" +#include "ir/passes/Mem2Reg.h" +#include "ir/passes/PassManager.h" #include "utils/Log.h" // 修改 GenerateIR 函数 @@ -12,5 +17,16 @@ std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, auto module = std::make_unique(); IRGenImpl gen(*module, sema_result.context, sema_result.symbol_table); tree.accept(&gen); + + ir::PassManager pass_manager; + pass_manager.AddPass(std::make_unique()); + pass_manager.AddPass(std::make_unique()); + pass_manager.AddPass(std::make_unique()); + pass_manager.AddPass(std::make_unique()); + pass_manager.AddPass(std::make_unique()); + DebugStream() << "[DEBUG] IRGenDriver: before mem2reg" << std::endl; + pass_manager.Run(*module); + DebugStream() << "[DEBUG] IRGenDriver: after scalar opts" << std::endl; + return module; } diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index 5d8128b..a8346c6 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -23,7 +23,7 @@ // 表达式生成 ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { - std::cerr << "[DEBUG IRGEN] EvalExpr: 开始处理表达式 " << expr.getText() << std::endl; + DebugStream() << "[DEBUG IRGEN] EvalExpr: 开始处理表达式 " << expr.getText() << std::endl; try { auto result_any = expr.accept(this); @@ -34,7 +34,7 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { try { ir::Value* result = std::any_cast(result_any); - std::cerr << "[DEBUG] EvalExpr: success, result = " << (void*)result << std::endl; + DebugStream() << "[DEBUG] EvalExpr: success, result = " << (void*)result << std::endl; return result; } catch (const std::bad_any_cast& e) { std::cerr << "[ERROR] EvalExpr: bad any_cast - " << e.what() << std::endl; @@ -48,24 +48,24 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { } ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) { - std::cerr << "[DEBUG IRGEN] EvalCond: 开始处理条件表达式 " << cond.getText() << std::endl; + DebugStream() << "[DEBUG IRGEN] EvalCond: 开始处理条件表达式 " << cond.getText() << std::endl; return std::any_cast(cond.accept(this)); } // 基本表达式:数字、变量、括号表达式 std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitPrimaryExp: 开始处理基本表达式 " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitPrimaryExp: 开始处理基本表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少基本表达式")); } - std::cerr << "[DEBUG] visitPrimaryExp" << std::endl; + DebugStream() << "[DEBUG] visitPrimaryExp" << std::endl; // 处理数字字面量 if (ctx->DECIMAL_INT()) { int value = std::stoi(ctx->DECIMAL_INT()->getText()); ir::Value* const_int = builder_.CreateConstInt(value); - std::cerr << "[DEBUG] visitPrimaryExp: constant int " << value + DebugStream() << "[DEBUG] visitPrimaryExp: constant int " << value << " created as " << (void*)const_int << std::endl; return static_cast(const_int); } @@ -81,7 +81,7 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { value = 0.0f; } ir::Value* const_float = builder_.CreateConstFloat(value); - std::cerr << "[DEBUG] visitPrimaryExp: constant hex float " << value + DebugStream() << "[DEBUG] visitPrimaryExp: constant hex float " << value << " created as " << (void*)const_float << std::endl; return static_cast(const_float); } @@ -97,7 +97,7 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { value = 0.0f; } ir::Value* const_float = builder_.CreateConstFloat(value); - std::cerr << "[DEBUG] visitPrimaryExp: constant dec float " << value + DebugStream() << "[DEBUG] visitPrimaryExp: constant dec float " << value << " created as " << (void*)const_float << std::endl; return static_cast(const_float); } @@ -106,7 +106,7 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { std::string hex = ctx->HEX_INT()->getText(); int value = std::stoi(hex, nullptr, 16); ir::Value* const_int = builder_.CreateConstInt(value); - std::cerr << "[DEBUG] visitPrimaryExp: constant hex int " << value + DebugStream() << "[DEBUG] visitPrimaryExp: constant hex int " << value << " created as " << (void*)const_int << std::endl; return static_cast(const_int); } @@ -115,26 +115,26 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { std::string oct = ctx->OCTAL_INT()->getText(); int value = std::stoi(oct, nullptr, 8); ir::Value* const_int = builder_.CreateConstInt(value); - std::cerr << "[DEBUG] visitPrimaryExp: constant octal int " << value + DebugStream() << "[DEBUG] visitPrimaryExp: constant octal int " << value << " created as " << (void*)const_int << std::endl; return static_cast(const_int); } if (ctx->ZERO()) { ir::Value* const_int = builder_.CreateConstInt(0); - std::cerr << "[DEBUG] visitPrimaryExp: constant zero int created" << std::endl; + DebugStream() << "[DEBUG] visitPrimaryExp: constant zero int created" << std::endl; return static_cast(const_int); } // 处理变量 if (ctx->lVal()) { - std::cerr << "[DEBUG] visitPrimaryExp: visiting lVal" << std::endl; + DebugStream() << "[DEBUG] visitPrimaryExp: visiting lVal" << std::endl; return ctx->lVal()->accept(this); } // 处理括号表达式 if (ctx->L_PAREN() && ctx->exp()) { - std::cerr << "[DEBUG] visitPrimaryExp: visiting parenthesized expression" << std::endl; + DebugStream() << "[DEBUG] visitPrimaryExp: visiting parenthesized expression" << std::endl; return EvalExpr(*ctx->exp()); } @@ -144,13 +144,13 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { // 左值(变量)处理 std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitLVal: 开始处理左值 " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitLVal: 开始处理左值 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("irgen", "非法左值")); } std::string varName = ctx->Ident()->getText(); - std::cerr << "[DEBUG] visitLVal: " << varName << std::endl; + DebugStream() << "[DEBUG] visitLVal: " << varName << std::endl; // 先检查语义分析中常量绑定 const SysYParser::ConstDefContext* const_decl = sema_.ResolveConstUse(ctx); @@ -166,7 +166,7 @@ std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { // 如果是常量,直接返回常量值 if (sym && sym->kind == SymbolKind::Constant) { - std::cerr << "[DEBUG] visitLVal: 找到常量 " << varName << std::endl; + DebugStream() << "[DEBUG] visitLVal: 找到常量 " << varName << std::endl; if (sym->IsScalarConstant()) { if (sym->type->IsInt32()) { @@ -394,7 +394,7 @@ std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { } std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitAddExp: 开始处理加法表达式 " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitAddExp: 开始处理加法表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法加法表达式")); } @@ -418,7 +418,7 @@ std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { } ir::Value* right = std::any_cast(right_any); - std::cerr << "[DEBUG] visitAddExp: left=" << (void*)left + DebugStream() << "[DEBUG] visitAddExp: left=" << (void*)left << ", type=" << (left->GetType()->IsFloat() ? "float" : "int") << ", right=" << (void*)right << ", type=" << (right->GetType()->IsFloat() ? "float" : "int") << std::endl; @@ -458,7 +458,7 @@ std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitMulExp: 开始处理乘法表达式 " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitMulExp: 开始处理乘法表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); } @@ -482,7 +482,7 @@ std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { } ir::Value* right = std::any_cast(right_any); - std::cerr << "[DEBUG] visitMulExp: left=" << (void*)left + DebugStream() << "[DEBUG] visitMulExp: left=" << (void*)left << ", type=" << (left->GetType()->IsFloat() ? "float" : "int") << ", right=" << (void*)right << ", type=" << (right->GetType()->IsFloat() ? "float" : "int") << std::endl; @@ -532,7 +532,7 @@ std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { // 逻辑与 std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitLAndExp: 开始处理逻辑与表达式 " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitLAndExp: 开始处理逻辑与表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); if (!ctx->lAndExp()) { @@ -562,7 +562,7 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { // 逻辑或 std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitLOrExp: 开始处理逻辑或表达式 " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitLOrExp: 开始处理逻辑或表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); if (!ctx->lOrExp()) { @@ -591,32 +591,32 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { } std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitExp: 开始处理表达式 " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitExp: 开始处理表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) throw std::runtime_error(FormatError("irgen", "非法表达式")); return ctx->addExp()->accept(this); } std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitCond: 开始处理条件 " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitCond: 开始处理条件 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) throw std::runtime_error(FormatError("irgen", "非法条件表达式")); return ctx->lOrExp()->accept(this); } std::any IRGenImpl::visitCallExp(SysYParser::UnaryExpContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitCallExp: 开始处理函数调用 " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitCallExp: 开始处理函数调用 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("irgen", "非法函数调用")); } std::string funcName = ctx->Ident()->getText(); - std::cout << "[DEBUG IRGEN] visitCallExp: 调用函数 " << funcName << std::endl; + DebugStream() << "[DEBUG IRGEN] visitCallExp: 调用函数 " << funcName << std::endl; // 查找函数对象 ir::Function* callee = module_.FindFunction(funcName); // 如果没找到,可能是运行时函数还没声明,尝试动态声明 if (!callee) { - std::cout << "[DEBUG IRGEN] 函数 " << funcName << " 未找到,尝试动态声明" << std::endl; + DebugStream() << "[DEBUG IRGEN] 函数 " << funcName << " 未找到,尝试动态声明" << std::endl; // 根据函数名动态创建运行时函数声明 callee = CreateRuntimeFunctionDecl(funcName); @@ -631,7 +631,7 @@ std::any IRGenImpl::visitCallExp(SysYParser::UnaryExpContext* ctx) { auto argList = ctx->funcRParams()->accept(this); try { args = std::any_cast>(argList); - std::cout << "[DEBUG IRGEN] visitCallExp: 收集到 " << args.size() << " 个参数" << std::endl; + DebugStream() << "[DEBUG IRGEN] visitCallExp: 收集到 " << args.size() << " 个参数" << std::endl; } catch (const std::bad_any_cast& e) { std::cerr << "[ERROR] visitCallExp: 函数调用参数类型错误: " << e.what() << std::endl; } @@ -673,13 +673,13 @@ std::any IRGenImpl::visitCallExp(SysYParser::UnaryExpContext* ctx) { return static_cast(builder_.CreateConstInt(0)); } - std::cout << "[DEBUG IRGEN] visitCallExp: 函数调用完成,返回值 " << (void*)callResult << std::endl; + DebugStream() << "[DEBUG IRGEN] visitCallExp: 函数调用完成,返回值 " << (void*)callResult << std::endl; return static_cast(callResult); } // 动态创建运行时函数声明的辅助函数 ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName) { - std::cerr << "[DEBUG IRGEN] CreateRuntimeFunctionDecl: 开始创建运行时函数声明 " << funcName << std::endl; + DebugStream() << "[DEBUG IRGEN] CreateRuntimeFunctionDecl: 开始创建运行时函数声明 " << funcName << std::endl; // 根据常见运行时函数名创建对应的函数类型 if (funcName == "getint" || funcName == "getch") { @@ -792,7 +792,7 @@ ir::Function* IRGenImpl::CreateRuntimeFunctionDecl(const std::string& funcName) } std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitUnaryExp: 开始处理一元表达式 " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitUnaryExp: 开始处理一元表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法一元表达式")); } @@ -852,7 +852,7 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { // 实现函数调用 std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitFuncRParams: 开始处理函数参数 " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitFuncRParams: 开始处理函数参数 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) return std::vector{}; std::vector args; for (auto* exp : ctx->exp()) { @@ -863,7 +863,7 @@ std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) { // visitConstExp - 处理常量表达式 std::any IRGenImpl::visitConstExp(SysYParser::ConstExpContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitConstExp: 开始处理常量表达式 " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitConstExp: 开始处理常量表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx || !ctx->addExp()) { throw std::runtime_error(FormatError("irgen", "非法常量表达式")); } @@ -884,7 +884,7 @@ std::any IRGenImpl::visitConstExp(SysYParser::ConstExpContext* ctx) { // visitConstInitVal - 处理常量初始化值 std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitConstInitVal: 开始处理常量初始化值 " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitConstInitVal: 开始处理常量初始化值 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法常量初始化值")); } @@ -929,7 +929,7 @@ std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) { } std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitRelExp: 开始处理关系表达式 " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitRelExp: 开始处理关系表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法关系表达式")); } @@ -940,7 +940,7 @@ std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { auto* lhs = std::any_cast(left_any); auto* rhs = std::any_cast(right_any); - std::cerr << "[DEBUG] visitRelExp: left=" << (void*)lhs + DebugStream() << "[DEBUG] visitRelExp: left=" << (void*)lhs << ", type=" << (lhs->GetType()->IsFloat() ? "float" : "int") << ", right=" << (void*)rhs << ", type=" << (rhs->GetType()->IsFloat() ? "float" : "int") << std::endl; @@ -1004,7 +1004,7 @@ std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { } std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitEqExp: 开始处理相等表达式 " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitEqExp: 开始处理相等表达式 " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法相等表达式")); } @@ -1015,7 +1015,7 @@ std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { auto* lhs = std::any_cast(left_any); auto* rhs = std::any_cast(right_any); - std::cerr << "[DEBUG] visitEqExp: left=" << (void*)lhs + DebugStream() << "[DEBUG] visitEqExp: left=" << (void*)lhs << ", type=" << (lhs->GetType()->IsFloat() ? "float" : "int") << ", right=" << (void*)rhs << ", type=" << (rhs->GetType()->IsFloat() ? "float" : "int") << std::endl; @@ -1062,8 +1062,8 @@ std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { ir::Value* IRGenImpl::EvalAssign(SysYParser::StmtContext* ctx) { - std::cerr << "[DEBUG IRGEN] EvalAssign: 开始处理赋值语句 " << (ctx ? ctx->getText() : "") << std::endl; - std::cout << "[DEBUG IRGEN] visitCond: " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] EvalAssign: 开始处理赋值语句 " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitCond: " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx || !ctx->lVal() || !ctx->exp()) { throw std::runtime_error(FormatError("irgen", "非法赋值语句")); } @@ -1127,8 +1127,8 @@ ir::Value* IRGenImpl::EvalAssign(SysYParser::StmtContext* ctx) { } else { // 普通标量赋值 // 调试输出指针类型 - std::cerr << "[DEBUG] base_ptr type: " << base_ptr->GetType() << std::endl; - std::cerr << "[DEBUG] rhs type: " << rhs->GetType()<< std::endl; + DebugStream() << "[DEBUG] base_ptr type: " << base_ptr->GetType() << std::endl; + DebugStream() << "[DEBUG] rhs type: " << rhs->GetType()<< std::endl; // 如果 base_ptr 不是标量指针类型,可能需要特殊处理 if (!base_ptr->GetType()->IsPtrInt32() && !base_ptr->GetType()->IsPtrFloat()) { diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index d5334e0..5d108f5 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -54,7 +54,7 @@ IRGenImpl::IRGenImpl(ir::Module& module, } void IRGenImpl::AddRuntimeFunctions() { - std::cerr << "[DEBUG IRGEN] 添加运行时库函数声明" << std::endl; + DebugStream() << "[DEBUG IRGEN] 添加运行时库函数声明" << std::endl; // 输入函数(返回 int) module_.CreateFunction("getint", @@ -155,21 +155,21 @@ void IRGenImpl::AddRuntimeFunctions() { ir::Type::GetVoidType(), {ir::Type::GetPtrFloatType(), ir::Type::GetInt32Type()})); - std::cerr << "[DEBUG IRGEN] 运行时库函数声明完成" << std::endl; + DebugStream() << "[DEBUG IRGEN] 运行时库函数声明完成" << std::endl; } // 修正:没有 mainFuncDef,通过函数名找到 main std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitCompUnit" << std::endl; - std::cerr << "[DEBUG] IRGen: 符号表地址 = " << &symbol_table_ << std::endl; - std::cerr << "[DEBUG] IRGen: 开始生成 IR" << std::endl; + DebugStream() << "[DEBUG IRGEN] visitCompUnit" << std::endl; + DebugStream() << "[DEBUG] IRGen: 符号表地址 = " << &symbol_table_ << std::endl; + DebugStream() << "[DEBUG] IRGen: 开始生成 IR" << std::endl; // 尝试查找 main 函数 const Symbol* main_sym = symbol_table_.lookup("main"); if (main_sym) { - std::cerr << "[DEBUG] IRGen: 找到 main 函数符号" << std::endl; + DebugStream() << "[DEBUG] IRGen: 找到 main 函数符号" << std::endl; } else { - std::cerr << "[DEBUG] IRGen: 未找到 main 函数符号" << std::endl; + DebugStream() << "[DEBUG] IRGen: 未找到 main 函数符号" << std::endl; } if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少编译单元")); @@ -193,7 +193,7 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { } std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitFuncDef: " << (ctx && ctx->Ident() ? ctx->Ident()->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitFuncDef: " << (ctx && ctx->Ident() ? ctx->Ident()->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少函数定义")); } @@ -255,7 +255,7 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { auto func_type = ir::Type::GetFunctionType(ret_type, param_types); // 调试输出 - std::cerr << "[DEBUG] visitFuncDef: 创建函数 " << funcName + DebugStream() << "[DEBUG] visitFuncDef: 创建函数 " << funcName << ",返回类型: " << (ret_type->IsVoid() ? "void" : ret_type->IsFloat() ? "float" : "int") << ",参数数量: " << param_types.size() << std::endl; @@ -268,7 +268,7 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { throw std::runtime_error(FormatError("irgen", "创建函数失败: " + funcName)); } - std::cerr << "[DEBUG] visitFuncDef: 函数对象地址: " << (void*)func_ << std::endl; + DebugStream() << "[DEBUG] visitFuncDef: 函数对象地址: " << (void*)func_ << std::endl; // 设置插入点 auto* entry_block = func_->GetEntry(); @@ -328,7 +328,7 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { throw std::runtime_error(FormatError("irgen", "函数对象无效")); } - std::cerr << "[DEBUG] visitFuncDef: 为函数 " << funcName + DebugStream() << "[DEBUG] visitFuncDef: 为函数 " << funcName << " 添加参数 " << name << ",类型: " << (param_ty->IsInt32() ? "int32" : param_ty->IsFloat() ? "float" : param_ty->IsPtrInt32() ? "ptr_int32" : param_ty->IsPtrFloat() ? "ptr_float" : "other") @@ -371,17 +371,17 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { pointer_param_names_.erase(name); } - std::cerr << "[DEBUG] visitFuncDef: 参数 " << name << " 处理完成" << std::endl; + DebugStream() << "[DEBUG] visitFuncDef: 参数 " << name << " 处理完成" << std::endl; } } // 生成函数体 - std::cerr << "[DEBUG] visitFuncDef: 开始生成函数体" << std::endl; + DebugStream() << "[DEBUG] visitFuncDef: 开始生成函数体" << std::endl; ctx->block()->accept(this); // 如果当前插入块没有终止指令,添加默认返回 if (auto* cur = builder_.GetInsertBlock(); cur && !cur->HasTerminator()) { - std::cerr << "[DEBUG] visitFuncDef: 函数体没有终止指令,添加默认返回" << std::endl; + DebugStream() << "[DEBUG] visitFuncDef: 函数体没有终止指令,添加默认返回" << std::endl; if (function_cleanup_block_) { if (ret_type->IsFloat()) { builder_.CreateStore(builder_.CreateConstFloat(0.0f), function_return_slot_); @@ -420,7 +420,7 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { throw; } - std::cerr << "[DEBUG] visitFuncDef: 函数 " << funcName << " 生成完成" << std::endl; + DebugStream() << "[DEBUG] visitFuncDef: 函数 " << funcName << " 生成完成" << std::endl; func_ = nullptr; current_function_name_.clear(); current_function_is_recursive_ = false; @@ -467,7 +467,7 @@ ir::AllocaInst* IRGenImpl::CreateEntryAllocaFloat(const std::string& name) { std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitBlock: " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitBlock: " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句块")); } @@ -482,7 +482,7 @@ std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) { } auto* cur = builder_.GetInsertBlock(); - std::cerr << "[DEBUG] current insert block: " + DebugStream() << "[DEBUG] current insert block: " << (cur ? cur->GetName() : "") << std::endl; if (cur && cur->HasTerminator()) { break; @@ -500,7 +500,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult( } // 用于遍历块内项,返回是否继续访问后续项(如遇到 return/break/continue 则终止访问) std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitBlockItem: " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitBlockItem: " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少块内项")); } diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index 6b4661d..725ab76 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -16,7 +16,7 @@ // - 空语句、块语句嵌套分发之外的更多语句形态 std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { - std::cerr << "[DEBUG IRGEN] visitStmt: " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] visitStmt: " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句")); } @@ -65,7 +65,7 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { // 修改 HandleReturnStmt 函数 IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) { - std::cerr << "[DEBUG IRGEN] HandleReturnStmt: " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] HandleReturnStmt: " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少 return 语句")); } @@ -132,7 +132,7 @@ IRGenImpl::BlockFlow IRGenImpl::HandleReturnStmt(SysYParser::StmtContext* ctx) { // if语句(待实现) IRGenImpl::BlockFlow IRGenImpl::HandleIfStmt(SysYParser::StmtContext* ctx) { - std::cerr << "[DEBUG IRGEN] HandleIfStmt: " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] HandleIfStmt: " << (ctx ? ctx->getText() : "") << std::endl; auto* cond = ctx->cond(); auto* thenStmt = ctx->stmt(0); @@ -148,10 +148,10 @@ IRGenImpl::BlockFlow IRGenImpl::HandleIfStmt(SysYParser::StmtContext* ctx) { auto* elseBlock = (ctx->Else() && elseStmt) ? func_->CreateBlock(uniq("else")) : nullptr; auto* mergeBlock = func_->CreateBlock(uniq("merge")); - std::cerr << "[DEBUG IF] thenBlock: " << thenBlock->GetName() << std::endl; - if (elseBlock) std::cerr << "[DEBUG IF] elseBlock: " << elseBlock->GetName() << std::endl; - std::cerr << "[DEBUG IF] mergeBlock: " << mergeBlock->GetName() << std::endl; - std::cerr << "[DEBUG IF] current insert block before cond: " + DebugStream() << "[DEBUG IF] thenBlock: " << thenBlock->GetName() << std::endl; + if (elseBlock) DebugStream() << "[DEBUG IF] elseBlock: " << elseBlock->GetName() << std::endl; + DebugStream() << "[DEBUG IF] mergeBlock: " << mergeBlock->GetName() << std::endl; + DebugStream() << "[DEBUG IF] current insert block before cond: " << builder_.GetInsertBlock()->GetName() << std::endl; // 生成条件 @@ -168,59 +168,59 @@ IRGenImpl::BlockFlow IRGenImpl::HandleIfStmt(SysYParser::StmtContext* ctx) { // 创建条件跳转 if (elseBlock) { - std::cerr << "[DEBUG IF] Creating condbr: " << condValue->GetName() + DebugStream() << "[DEBUG IF] Creating condbr: " << condValue->GetName() << " -> " << thenBlock->GetName() << ", " << elseBlock->GetName() << std::endl; builder_.CreateCondBr(condValue, thenBlock, elseBlock); } else { - std::cerr << "[DEBUG IF] Creating condbr: " << condValue->GetName() + DebugStream() << "[DEBUG IF] Creating condbr: " << condValue->GetName() << " -> " << thenBlock->GetName() << ", " << mergeBlock->GetName() << std::endl; builder_.CreateCondBr(condValue, thenBlock, mergeBlock); } // 生成 then 分支 - std::cerr << "[DEBUG IF] Generating then branch in block: " << thenBlock->GetName() << std::endl; + DebugStream() << "[DEBUG IF] Generating then branch in block: " << thenBlock->GetName() << std::endl; builder_.SetInsertPoint(thenBlock); auto thenResult = thenStmt->accept(this); bool thenTerminated = (std::any_cast(thenResult) == BlockFlow::Terminated); - std::cerr << "[DEBUG IF] then branch terminated: " << thenTerminated << std::endl; + DebugStream() << "[DEBUG IF] then branch terminated: " << thenTerminated << std::endl; if (!thenTerminated) { - std::cerr << "[DEBUG IF] Adding br to merge block from then" << std::endl; + DebugStream() << "[DEBUG IF] Adding br to merge block from then" << std::endl; builder_.CreateBr(mergeBlock); } - std::cerr << "[DEBUG IF] then block has terminator: " << thenBlock->HasTerminator() << std::endl; + DebugStream() << "[DEBUG IF] then block has terminator: " << thenBlock->HasTerminator() << std::endl; // 生成 else 分支 bool elseTerminated = false; if (elseBlock) { - std::cout << "[DEBUG IF] Generating else branch in block: " << elseBlock->GetName() << std::endl; + DebugStream() << "[DEBUG IF] Generating else branch in block: " << elseBlock->GetName() << std::endl; builder_.SetInsertPoint(elseBlock); auto elseResult = elseStmt->accept(this); elseTerminated = (std::any_cast(elseResult) == BlockFlow::Terminated); - std::cout << "[DEBUG IF] else branch terminated: " << elseTerminated << std::endl; + DebugStream() << "[DEBUG IF] else branch terminated: " << elseTerminated << std::endl; if (!elseTerminated) { - std::cout << "[DEBUG IF] Adding br to merge block from else" << std::endl; + DebugStream() << "[DEBUG IF] Adding br to merge block from else" << std::endl; builder_.CreateBr(mergeBlock); } - std::cout << "[DEBUG IF] else block has terminator: " << elseBlock->HasTerminator() << std::endl; + DebugStream() << "[DEBUG IF] else block has terminator: " << elseBlock->HasTerminator() << std::endl; } // 决定后续插入点 - std::cout << "[DEBUG IF] thenTerminated=" << thenTerminated + DebugStream() << "[DEBUG IF] thenTerminated=" << thenTerminated << ", elseTerminated=" << elseTerminated << std::endl; if (elseBlock) { - std::cout << "[DEBUG IF] Setting insert point to merge block: " + DebugStream() << "[DEBUG IF] Setting insert point to merge block: " << mergeBlock->GetName() << std::endl; builder_.SetInsertPoint(mergeBlock); } else { - std::cout << "[DEBUG IF] No else, setting insert point to merge block: " + DebugStream() << "[DEBUG IF] No else, setting insert point to merge block: " << mergeBlock->GetName() << std::endl; builder_.SetInsertPoint(mergeBlock); } - std::cout << "[DEBUG IF] Final insert block: " + DebugStream() << "[DEBUG IF] Final insert block: " << builder_.GetInsertBlock()->GetName() << std::endl; return BlockFlow::Continue; @@ -228,13 +228,13 @@ IRGenImpl::BlockFlow IRGenImpl::HandleIfStmt(SysYParser::StmtContext* ctx) { // while语句(待实现)IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) { IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) { - std::cout << "[DEBUG IRGEN] HandleWhileStmt: " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] HandleWhileStmt: " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx || !ctx->cond() || !ctx->stmt(0)) { throw std::runtime_error(FormatError("irgen", "非法 while 语句")); } - std::cout << "[DEBUG WHILE] Current insert block before while: " + DebugStream() << "[DEBUG WHILE] Current insert block before while: " << builder_.GetInsertBlock()->GetName() << std::endl; auto uniq = [&](const std::string& prefix) { @@ -246,18 +246,18 @@ IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) { auto* bodyBlock = func_->CreateBlock(uniq("while.body")); auto* exitBlock = func_->CreateBlock(uniq("while.exit")); - std::cout << "[DEBUG WHILE] condBlock: " << condBlock->GetName() << std::endl; - std::cout << "[DEBUG WHILE] bodyBlock: " << bodyBlock->GetName() << std::endl; - std::cout << "[DEBUG WHILE] exitBlock: " << exitBlock->GetName() << std::endl; + DebugStream() << "[DEBUG WHILE] condBlock: " << condBlock->GetName() << std::endl; + DebugStream() << "[DEBUG WHILE] bodyBlock: " << bodyBlock->GetName() << std::endl; + DebugStream() << "[DEBUG WHILE] exitBlock: " << exitBlock->GetName() << std::endl; - std::cout << "[DEBUG WHILE] Adding br to condBlock from current block" << std::endl; + DebugStream() << "[DEBUG WHILE] Adding br to condBlock from current block" << std::endl; builder_.CreateBr(condBlock); loopStack_.push_back({condBlock, bodyBlock, exitBlock}); - std::cout << "[DEBUG WHILE] loopStack size: " << loopStack_.size() << std::endl; + DebugStream() << "[DEBUG WHILE] loopStack size: " << loopStack_.size() << std::endl; // 条件块 - std::cout << "[DEBUG WHILE] Generating condition in block: " << condBlock->GetName() << std::endl; + DebugStream() << "[DEBUG WHILE] Generating condition in block: " << condBlock->GetName() << std::endl; builder_.SetInsertPoint(condBlock); auto* condValue = EvalCond(*ctx->cond()); if (!condValue->GetType()->IsInt1()) { @@ -270,28 +270,28 @@ IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) { } } builder_.CreateCondBr(condValue, bodyBlock, exitBlock); - std::cout << "[DEBUG WHILE] condBlock has terminator: " << condBlock->HasTerminator() << std::endl; + DebugStream() << "[DEBUG WHILE] condBlock has terminator: " << condBlock->HasTerminator() << std::endl; // 循环体 - std::cout << "[DEBUG WHILE] Generating body in block: " << bodyBlock->GetName() << std::endl; + DebugStream() << "[DEBUG WHILE] Generating body in block: " << bodyBlock->GetName() << std::endl; builder_.SetInsertPoint(bodyBlock); auto bodyResult = ctx->stmt(0)->accept(this); bool bodyTerminated = (std::any_cast(bodyResult) == BlockFlow::Terminated); - std::cout << "[DEBUG WHILE] body terminated: " << bodyTerminated << std::endl; + DebugStream() << "[DEBUG WHILE] body terminated: " << bodyTerminated << std::endl; if (!bodyTerminated) { - std::cout << "[DEBUG WHILE] Adding br to condBlock from body" << std::endl; + DebugStream() << "[DEBUG WHILE] Adding br to condBlock from body" << std::endl; builder_.CreateBr(condBlock); } - std::cout << "[DEBUG WHILE] bodyBlock has terminator: " << bodyBlock->HasTerminator() << std::endl; + DebugStream() << "[DEBUG WHILE] bodyBlock has terminator: " << bodyBlock->HasTerminator() << std::endl; loopStack_.pop_back(); - std::cout << "[DEBUG WHILE] loopStack size after pop: " << loopStack_.size() << std::endl; + DebugStream() << "[DEBUG WHILE] loopStack size after pop: " << loopStack_.size() << std::endl; // 设置插入点为 exitBlock - std::cout << "[DEBUG WHILE] Setting insert point to exitBlock: " << exitBlock->GetName() << std::endl; + DebugStream() << "[DEBUG WHILE] Setting insert point to exitBlock: " << exitBlock->GetName() << std::endl; builder_.SetInsertPoint(exitBlock); - std::cout << "[DEBUG WHILE] exitBlock has terminator before return: " + DebugStream() << "[DEBUG WHILE] exitBlock has terminator before return: " << exitBlock->HasTerminator() << std::endl; return BlockFlow::Continue; @@ -299,15 +299,15 @@ IRGenImpl::BlockFlow IRGenImpl::HandleWhileStmt(SysYParser::StmtContext* ctx) { // break语句(待实现) IRGenImpl::BlockFlow IRGenImpl::HandleBreakStmt(SysYParser::StmtContext* ctx) { - std::cout << "[DEBUG IRGEN] HandleBreakStmt: " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] HandleBreakStmt: " << (ctx ? ctx->getText() : "") << std::endl; if (loopStack_.empty()) { throw std::runtime_error(FormatError("irgen", "break 语句不在循环中")); } - std::cout << "[DEBUG BREAK] Current insert block before break: " + DebugStream() << "[DEBUG BREAK] Current insert block before break: " << builder_.GetInsertBlock()->GetName() << std::endl; - std::cout << "[DEBUG BREAK] Breaking to exitBlock: " + DebugStream() << "[DEBUG BREAK] Breaking to exitBlock: " << loopStack_.back().exitBlock->GetName() << std::endl; // 跳转到循环退出块 @@ -318,15 +318,15 @@ IRGenImpl::BlockFlow IRGenImpl::HandleBreakStmt(SysYParser::StmtContext* ctx) { } IRGenImpl::BlockFlow IRGenImpl::HandleContinueStmt(SysYParser::StmtContext* ctx) { - std::cout << "[DEBUG IRGEN] HandleContinueStmt: " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] HandleContinueStmt: " << (ctx ? ctx->getText() : "") << std::endl; if (loopStack_.empty()) { throw std::runtime_error(FormatError("irgen", "continue 语句不在循环中")); } - std::cout << "[DEBUG CONTINUE] Current insert block before continue: " + DebugStream() << "[DEBUG CONTINUE] Current insert block before continue: " << builder_.GetInsertBlock()->GetName() << std::endl; - std::cout << "[DEBUG CONTINUE] Continuing to condBlock: " + DebugStream() << "[DEBUG CONTINUE] Continuing to condBlock: " << loopStack_.back().condBlock->GetName() << std::endl; // 跳转到循环条件块 @@ -340,7 +340,7 @@ IRGenImpl::BlockFlow IRGenImpl::HandleContinueStmt(SysYParser::StmtContext* ctx) // 赋值语句 // 赋值语句 IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) { - std::cout << "[DEBUG IRGEN] HandleAssignStmt: " << (ctx ? ctx->getText() : "") << std::endl; + DebugStream() << "[DEBUG IRGEN] HandleAssignStmt: " << (ctx ? ctx->getText() : "") << std::endl; if (!ctx || !ctx->lVal() || !ctx->exp()) { throw std::runtime_error(FormatError("irgen", "非法赋值语句")); @@ -354,7 +354,7 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) { auto* lval = ctx->lVal(); std::string varName = lval->Ident()->getText(); - std::cerr << "[DEBUG] HandleAssignStmt: assigning to " << varName << std::endl; + DebugStream() << "[DEBUG] HandleAssignStmt: assigning to " << varName << std::endl; // 1. 检查是否为常量(不能给常量赋值) auto* const_decl = sema_.ResolveConstUse(lval); @@ -372,7 +372,7 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) { auto it = storage_map_.find(var_decl); if (it != storage_map_.end()) { base_ptr = it->second; - std::cerr << "[DEBUG] HandleAssignStmt: found in storage_map_ for " << varName + DebugStream() << "[DEBUG] HandleAssignStmt: found in storage_map_ for " << varName << ", ptr = " << (void*)base_ptr << std::endl; } } @@ -382,7 +382,7 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) { auto it2 = param_map_.find(varName); if (it2 != param_map_.end()) { base_ptr = it2->second; - std::cerr << "[DEBUG] HandleAssignStmt: found in param_map_ for " << varName + DebugStream() << "[DEBUG] HandleAssignStmt: found in param_map_ for " << varName << ", ptr = " << (void*)base_ptr << std::endl; } } @@ -392,7 +392,7 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) { auto it3 = global_map_.find(varName); if (it3 != global_map_.end()) { base_ptr = it3->second; - std::cerr << "[DEBUG] HandleAssignStmt: found in global_map_ for " << varName + DebugStream() << "[DEBUG] HandleAssignStmt: found in global_map_ for " << varName << ", ptr = " << (void*)base_ptr << std::endl; } } @@ -402,7 +402,7 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) { auto it4 = local_var_map_.find(varName); if (it4 != local_var_map_.end()) { base_ptr = it4->second; - std::cerr << "[DEBUG] HandleAssignStmt: found in local_var_map_ for " << varName + DebugStream() << "[DEBUG] HandleAssignStmt: found in local_var_map_ for " << varName << ", ptr = " << (void*)base_ptr << std::endl; } } @@ -497,21 +497,21 @@ IRGenImpl::BlockFlow IRGenImpl::HandleAssignStmt(SysYParser::StmtContext* ctx) { builder_.CreateStore(rhs, elem_ptr); } else { // 普通标量赋值 - std::cerr << "[DEBUG] HandleAssignStmt: scalar assignment to " << varName + DebugStream() << "[DEBUG] HandleAssignStmt: scalar assignment to " << varName << ", ptr = " << (void*)base_ptr << ", rhs = " << (void*)rhs << std::endl; // 在 HandleAssignStmt 中,存储前添加类型调试 if (base_ptr && base_ptr->GetType()) { - std::cerr << "[DEBUG] Is int32: " << base_ptr->GetType()->IsInt32() << std::endl; - std::cerr << "[DEBUG] Is float: " << base_ptr->GetType()->IsFloat() << std::endl; - std::cerr << "[DEBUG] Is ptr int32: " << base_ptr->GetType()->IsPtrInt32() << std::endl; - std::cerr << "[DEBUG] Is ptr float: " << base_ptr->GetType()->IsPtrFloat() << std::endl; - std::cerr << "[DEBUG] Is array: " << base_ptr->GetType()->IsArray() << std::endl; + DebugStream() << "[DEBUG] Is int32: " << base_ptr->GetType()->IsInt32() << std::endl; + DebugStream() << "[DEBUG] Is float: " << base_ptr->GetType()->IsFloat() << std::endl; + DebugStream() << "[DEBUG] Is ptr int32: " << base_ptr->GetType()->IsPtrInt32() << std::endl; + DebugStream() << "[DEBUG] Is ptr float: " << base_ptr->GetType()->IsPtrFloat() << std::endl; + DebugStream() << "[DEBUG] Is array: " << base_ptr->GetType()->IsArray() << std::endl; } if (rhs && rhs->GetType()) { - std::cerr << "[DEBUG] Value is int32: " << rhs->GetType()->IsInt32() << std::endl; + DebugStream() << "[DEBUG] Value is int32: " << rhs->GetType()->IsInt32() << std::endl; } if (base_ptr->GetType()->IsPtrFloat() && rhs->GetType()->IsInt32()) { rhs = builder_.CreateSIToFP(rhs, ir::Type::GetFloatType(), diff --git a/src/main.cpp b/src/main.cpp index 88ed747..820313b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -35,6 +35,8 @@ int main(int argc, char** argv) { } auto sema = RunSema(*comp_unit); + SetDebugEnabled(opts.debug); + auto module = GenerateIR(*comp_unit, sema); if (opts.emit_ir) { ir::IRPrinter printer; diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index 8d51715..424dbbe 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -67,7 +67,7 @@ public: } else { return_type = ir::Type::GetInt32Type(); } - std::cout << "[DEBUG] 进入函数: " << name + DebugStream() << "[DEBUG] 进入函数: " << name << " 返回类型: " << (return_type->IsInt32() ? "int" : return_type->IsFloat() ? "float" : "void") << std::endl; @@ -83,7 +83,7 @@ public: if (ctx->block()) { // 处理函数体 ctx->block()->accept(this); } - std::cout << "[DEBUG] 函数 " << name + DebugStream() << "[DEBUG] 函数 " << name << " has_return: " << current_func_has_return_ << " return_type_is_void: " << return_type->IsVoid() << std::endl; @@ -170,7 +170,7 @@ public: std::vector dims; bool is_array = !ctx->constExp().empty(); // 调试输出 - std::cout << "[DEBUG] CheckVarDef: " << name + DebugStream() << "[DEBUG] CheckVarDef: " << name << " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown") << " is_array: " << is_array << " dim_count: " << ctx->constExp().size() << std::endl; @@ -185,23 +185,23 @@ public: throw std::runtime_error(FormatError("sema", "数组维度必须为正整数")); } dims.push_back(dim); - std::cout << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl; + DebugStream() << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl; } // 创建数组类型 type = ir::Type::GetArrayType(base_type, dims); - std::cout << "[DEBUG] 创建数组类型完成" << std::endl; - std::cout << "[DEBUG] type->IsArray(): " << type->IsArray() << std::endl; - std::cout << "[DEBUG] type->GetKind(): " << (int)type->GetKind() << std::endl; + DebugStream() << "[DEBUG] 创建数组类型完成" << std::endl; + DebugStream() << "[DEBUG] type->IsArray(): " << type->IsArray() << std::endl; + DebugStream() << "[DEBUG] type->GetKind(): " << (int)type->GetKind() << std::endl; // 验证数组类型 if (type->IsArray()) { auto* arr_type = dynamic_cast(type.get()); if (arr_type) { - std::cout << "[DEBUG] ArrayType dimensions: "; + DebugStream() << "[DEBUG] ArrayType dimensions: "; for (int d : arr_type->GetDimensions()) { - std::cout << d << " "; + DebugStream() << d << " "; } - std::cout << std::endl; - std::cout << "[DEBUG] Element type: " + DebugStream() << std::endl; + DebugStream() << "[DEBUG] Element type: " << (arr_type->GetElementType()->IsInt32() ? "int" : arr_type->GetElementType()->IsFloat() ? "float" : "unknown") << std::endl; @@ -230,7 +230,7 @@ public: sym.param_types.clear(); // 确保不混淆 } table_.addSymbol(sym); // 添加到符号表 - std::cout << "[DEBUG] 符号添加完成: " << name + DebugStream() << "[DEBUG] 符号添加完成: " << name << " type_kind: " << (int)sym.type->GetKind() << " is_array: " << sym.type->IsArray() << std::endl; @@ -250,7 +250,7 @@ public: std::shared_ptr type = base_type; std::vector dims; bool is_array = !ctx->constExp().empty(); - std::cout << "[DEBUG] CheckConstDef: " << name + DebugStream() << "[DEBUG] CheckConstDef: " << name << " base_type: " << (base_type->IsInt32() ? "int" : base_type->IsFloat() ? "float" : "unknown") << " is_array: " << is_array << " dim_count: " << ctx->constExp().size() << std::endl; @@ -262,10 +262,10 @@ public: throw std::runtime_error(FormatError("sema", "数组维度必须为正整数")); } dims.push_back(dim); - std::cout << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl; + DebugStream() << "[DEBUG] dim[" << dims.size() - 1 << "] = " << dim << std::endl; } type = ir::Type::GetArrayType(base_type, dims); - std::cout << "[DEBUG] 创建数组类型完成,IsArray: " << type->IsArray() << std::endl; + DebugStream() << "[DEBUG] 创建数组类型完成,IsArray: " << type->IsArray() << std::endl; } // ========== 绑定维度表达式 ========== @@ -280,7 +280,7 @@ public: BindConstInitVal(ctx->constInitVal()); init_values = table_.EvaluateConstInitVal(ctx->constInitVal(), dims, base_type); - std::cout << "[DEBUG] 初始化值数量: " << init_values.size() << std::endl; + DebugStream() << "[DEBUG] 初始化值数量: " << init_values.size() << std::endl; } // 计算期望的元素数量 @@ -288,12 +288,12 @@ public: if (is_array) { expected_count = 1; for (int d : dims) expected_count *= d; - std::cout << "[DEBUG] 期望元素数量: " << expected_count << std::endl; + DebugStream() << "[DEBUG] 期望元素数量: " << expected_count << std::endl; } // 如果初始化值不足,补零 if (is_array && init_values.size() < expected_count) { - std::cout << "[DEBUG] 初始化值不足,补零" << std::endl; + DebugStream() << "[DEBUG] 初始化值不足,补零" << std::endl; SymbolTable::ConstValue zero; if (base_type->IsInt32()) { zero.kind = SymbolTable::ConstValue::INT; @@ -314,13 +314,13 @@ public: Symbol sym; sym.name = name; sym.kind = SymbolKind::Constant; - std::cout << "CheckConstDef: before addSymbol, sym.kind = " << (int)sym.kind << std::endl; + DebugStream() << "CheckConstDef: before addSymbol, sym.kind = " << (int)sym.kind << std::endl; sym.type = type; sym.scope_level = table_.currentScopeLevel(); sym.is_initialized = true; sym.var_def_ctx = nullptr; sym.const_def_ctx = ctx; - std::cout << "保存常量定义上下文: " << name << ", ctx: " << ctx << std::endl; + DebugStream() << "保存常量定义上下文: " << name << ", ctx: " << ctx << std::endl; // ========== 存储常量值 ========== if (is_array) { @@ -338,7 +338,7 @@ public: sym.array_const_values.push_back(cv); } - std::cout << "[DEBUG] 存储数组常量,共 " << sym.array_const_values.size() + DebugStream() << "[DEBUG] 存储数组常量,共 " << sym.array_const_values.size() << " 个元素" << std::endl; } else if (!init_values.empty()) { @@ -346,11 +346,11 @@ public: if (base_type->IsInt32() && init_values[0].kind == SymbolTable::ConstValue::INT) { sym.is_int_const = true; sym.const_value.i32 = init_values[0].int_val; - std::cout << "[DEBUG] 存储整型常量: " << init_values[0].int_val << std::endl; + DebugStream() << "[DEBUG] 存储整型常量: " << init_values[0].int_val << std::endl; } else if (base_type->IsFloat() && init_values[0].kind == SymbolTable::ConstValue::FLOAT) { sym.is_int_const = false; sym.const_value.f32 = init_values[0].float_val; - std::cout << "[DEBUG] 存储浮点常量: " << init_values[0].float_val << std::endl; + DebugStream() << "[DEBUG] 存储浮点常量: " << init_values[0].float_val << std::endl; } else if (base_type->IsInt32() && init_values[0].kind == SymbolTable::ConstValue::FLOAT) { // 整型常量用浮点数初始化(需要检查是否为整数) float f = init_values[0].float_val; @@ -361,12 +361,12 @@ public: } sym.is_int_const = true; sym.const_value.i32 = i; - std::cout << "[DEBUG] 浮点转整型常量: " << f << " -> " << i << std::endl; + DebugStream() << "[DEBUG] 浮点转整型常量: " << f << " -> " << i << std::endl; } else if (base_type->IsFloat() && init_values[0].kind == SymbolTable::ConstValue::INT) { // 浮点常量用整型初始化,隐式转换 sym.is_int_const = false; sym.const_value.f32 = static_cast(init_values[0].int_val); - std::cout << "[DEBUG] 整型转浮点常量: " << init_values[0].int_val + DebugStream() << "[DEBUG] 整型转浮点常量: " << init_values[0].int_val << " -> " << static_cast(init_values[0].int_val) << std::endl; } } else { @@ -374,15 +374,15 @@ public: if (!is_array) { throw std::runtime_error(FormatError("sema", "常量必须有初始化值: " + name)); } - std::cout << "[DEBUG] 数组常量无初始化器,将全部补零" << std::endl; + DebugStream() << "[DEBUG] 数组常量无初始化器,将全部补零" << std::endl; } table_.addSymbol(sym); - std::cout << "CheckConstDef: after addSymbol, sym.kind = " << (int)sym.kind << std::endl; + DebugStream() << "CheckConstDef: after addSymbol, sym.kind = " << (int)sym.kind << std::endl; auto* stored = table_.lookup(name); - std::cout << "CheckConstDef: after addSymbol, stored const_def_ctx = " << stored->const_def_ctx << std::endl; + DebugStream() << "CheckConstDef: after addSymbol, stored const_def_ctx = " << stored->const_def_ctx << std::endl; - std::cout << "[DEBUG] 常量符号添加完成: " << name + DebugStream() << "[DEBUG] 常量符号添加完成: " << name << " is_array_const: " << sym.is_array_const << " element_count: " << sym.array_const_values.size() << std::endl; } @@ -407,20 +407,20 @@ public: std::any visitStmt(SysYParser::StmtContext* ctx) override { if (!ctx) return {}; // 调试输出 - std::cout << "[DEBUG] visitStmt: "; - if (ctx->Return()) std::cout << "Return "; - if (ctx->If()) std::cout << "If "; - if (ctx->While()) std::cout << "While "; - if (ctx->Break()) std::cout << "Break "; - if (ctx->Continue()) std::cout << "Continue "; - if (ctx->lVal() && ctx->Assign()) std::cout << "Assign "; - if (ctx->exp() && ctx->Semi()) std::cout << "ExpStmt "; - if (ctx->block()) std::cout << "Block "; - std::cout << std::endl; + DebugStream() << "[DEBUG] visitStmt: "; + if (ctx->Return()) DebugStream() << "Return "; + if (ctx->If()) DebugStream() << "If "; + if (ctx->While()) DebugStream() << "While "; + if (ctx->Break()) DebugStream() << "Break "; + if (ctx->Continue()) DebugStream() << "Continue "; + if (ctx->lVal() && ctx->Assign()) DebugStream() << "Assign "; + if (ctx->exp() && ctx->Semi()) DebugStream() << "ExpStmt "; + if (ctx->block()) DebugStream() << "Block "; + DebugStream() << std::endl; // 判断语句类型 - 注意:Return() 返回的是 TerminalNode* if (ctx->Return() != nullptr) { // return 语句 - std::cout << "[DEBUG] 检测到 return 语句" << std::endl; + DebugStream() << "[DEBUG] 检测到 return 语句" << std::endl; return visitReturnStmtInternal(ctx); } else if (ctx->lVal() != nullptr && ctx->Assign() != nullptr) { // 赋值语句 @@ -449,14 +449,14 @@ public: // return 语句内部实现 std::any visitReturnStmtInternal(SysYParser::StmtContext* ctx) { - std::cout << "[DEBUG] visitReturnStmtInternal 被调用" << std::endl; + DebugStream() << "[DEBUG] visitReturnStmtInternal 被调用" << std::endl; std::shared_ptr expected = current_func_return_type_; if (!expected) { throw std::runtime_error(FormatError("sema", "return 语句不在函数体内")); } if (ctx->exp() != nullptr) { // 有返回值的 return - std::cout << "[DEBUG] 有返回值的 return" << std::endl; + DebugStream() << "[DEBUG] 有返回值的 return" << std::endl; ExprInfo ret_val = CheckExp(ctx->exp()); if (expected->IsVoid()) { throw std::runtime_error(FormatError("sema", "void 函数不能返回值")); @@ -469,23 +469,23 @@ public: } // 设置 has_return 标志 current_func_has_return_ = true; - std::cout << "[DEBUG] 设置 current_func_has_return_ = true" << std::endl; + DebugStream() << "[DEBUG] 设置 current_func_has_return_ = true" << std::endl; } else { // 无返回值的 return - std::cout << "[DEBUG] 无返回值的 return" << std::endl; + DebugStream() << "[DEBUG] 无返回值的 return" << std::endl; if (!expected->IsVoid()) { throw std::runtime_error(FormatError("sema", "非 void 函数必须返回值")); } // 设置 has_return 标志 current_func_has_return_ = true; - std::cout << "[DEBUG] 设置 current_func_has_return_ = true" << std::endl; + DebugStream() << "[DEBUG] 设置 current_func_has_return_ = true" << std::endl; } return {}; } // 左值表达式(变量引用) std::any visitLVal(SysYParser::LValContext* ctx) override { - std::cout << "[DEBUG] visitLVal: " << ctx->getText() << std::endl; + DebugStream() << "[DEBUG] visitLVal: " << ctx->getText() << std::endl; if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("sema", "非法变量引用")); } @@ -496,7 +496,7 @@ public: } // 检查数组访问 bool is_array_access = !ctx->exp().empty(); - std::cout << "[DEBUG] name: " << name + DebugStream() << "[DEBUG] name: " << name << ", is_array_access: " << is_array_access << ", subscript_count: " << ctx->exp().size() << std::endl; ExprInfo result; @@ -504,7 +504,7 @@ public: bool is_array_or_ptr = false; if (sym->type) { is_array_or_ptr = sym->type->IsArray() || sym->type->IsPtrInt32() || sym->type->IsPtrFloat(); - std::cout << "[DEBUG] type_kind: " << (int)sym->type->GetKind() + DebugStream() << "[DEBUG] type_kind: " << (int)sym->type->GetKind() << ", is_array: " << sym->type->IsArray() << ", is_ptr: " << (sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) << std::endl; } @@ -517,7 +517,7 @@ public: if (auto* arr_type = dynamic_cast(sym->type.get())) { dim_count = arr_type->GetDimensions().size(); elem_type = arr_type->GetElementType(); - std::cout << "[DEBUG] 数组维度: " << dim_count << std::endl; + DebugStream() << "[DEBUG] 数组维度: " << dim_count << std::endl; } } else if (sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) { dim_count = 1; @@ -526,11 +526,11 @@ public: } else if (sym->type->IsPtrFloat()) { elem_type = ir::Type::GetFloatType(); } - std::cout << "[DEBUG] 指针类型, dim_count: 1" << std::endl; + DebugStream() << "[DEBUG] 指针类型, dim_count: 1" << std::endl; } if (is_array_access) { - std::cout << "[DEBUG] 有下标访问,期望维度: " << dim_count + DebugStream() << "[DEBUG] 有下标访问,期望维度: " << dim_count << ", 实际下标数: " << ctx->exp().size() << std::endl; if (ctx->exp().size() != dim_count) { throw std::runtime_error(FormatError("sema", "数组下标个数不匹配")); @@ -545,9 +545,9 @@ public: result.is_lvalue = true; result.is_const = false; } else { - std::cout << "[DEBUG] 无下标访问" << std::endl; + DebugStream() << "[DEBUG] 无下标访问" << std::endl; if (sym->type->IsArray()) { - std::cout << "[DEBUG] 数组名作为地址,转换为指针" << std::endl; + DebugStream() << "[DEBUG] 数组名作为地址,转换为指针" << std::endl; if (auto* arr_type = dynamic_cast(sym->type.get())) { if (arr_type->GetElementType()->IsInt32()) { result.type = ir::Type::GetPtrInt32Type(); @@ -669,7 +669,7 @@ public: // 主表达式 std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override { - std::cout << "[DEBUG] visitPrimaryExp: " << ctx->getText() << std::endl; + DebugStream() << "[DEBUG] visitPrimaryExp: " << ctx->getText() << std::endl; ExprInfo result; if (ctx->lVal()) { // 左值表达式 result = CheckLValue(ctx->lVal()); @@ -701,14 +701,14 @@ public: // 一元表达式 std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override { - std::cout << "[DEBUG] visitUnaryExp: " << ctx->getText() << std::endl; + DebugStream() << "[DEBUG] visitUnaryExp: " << ctx->getText() << std::endl; ExprInfo result; if (ctx->primaryExp()) { ctx->primaryExp()->accept(this); auto* info = sema_.GetExprType(ctx->primaryExp()); if (info) result = *info; } else if (ctx->Ident() && ctx->L_PAREN()) { // 函数调用 - std::cout << "[DEBUG] 函数调用: " << ctx->Ident()->getText() << std::endl; + DebugStream() << "[DEBUG] 函数调用: " << ctx->Ident()->getText() << std::endl; result = CheckFuncCall(ctx); } else if (ctx->unaryOp()) { // 一元运算 ctx->unaryExp()->accept(this); @@ -1074,7 +1074,7 @@ public: // 新增:同时返回两者 SemaResult TakeResult() { - std::cerr << "[DEBUG] TakeResult 前: 符号表作用域数量 = " + DebugStream() << "[DEBUG] TakeResult 前: 符号表作用域数量 = " << table_.getScopeCount() << std::endl; // 可选:打印符号表内容 @@ -1084,7 +1084,7 @@ public: result.context = std::move(sema_); result.symbol_table = std::move(table_); - std::cerr << "[DEBUG] TakeResult 后: 符号表作用域数量 = " + DebugStream() << "[DEBUG] TakeResult 后: 符号表作用域数量 = " << result.symbol_table.getScopeCount() << std::endl; return result; } @@ -1106,7 +1106,7 @@ private: if (!ctx || !ctx->addExp()) { throw std::runtime_error(FormatError("sema", "无效表达式")); } - std::cout << "[DEBUG] CheckExp: " << ctx->getText() << std::endl; + DebugStream() << "[DEBUG] CheckExp: " << ctx->getText() << std::endl; ctx->addExp()->accept(this); auto* info = sema_.GetExprType(ctx->addExp()); if (!info) { @@ -1157,18 +1157,18 @@ private: if (!sym) { throw std::runtime_error(FormatError("sema", "未定义的变量: " + name)); } - std::cout << "CheckLValue: found sym->name = " << sym->name + DebugStream() << "CheckLValue: found sym->name = " << sym->name << ", sym->kind = " << (int)sym->kind << std::endl; if (sym->kind == SymbolKind::Variable && sym->var_def_ctx) { sema_.BindVarUse(ctx, sym->var_def_ctx); - std::cout << "绑定变量: " << name << " -> VarDefContext" << std::endl; + DebugStream() << "绑定变量: " << name << " -> VarDefContext" << std::endl; } else if (sym->kind == SymbolKind::Constant && sym->const_def_ctx) { sema_.BindConstUse(ctx, sym->const_def_ctx); - std::cout << "绑定常量: " << name << " -> ConstDefContext" << std::endl; + DebugStream() << "绑定常量: " << name << " -> ConstDefContext" << std::endl; } - std::cout << "CheckLValue 绑定变量: " << name + DebugStream() << "CheckLValue 绑定变量: " << name << ", sym->kind: " << (int)sym->kind << ", sym->var_def_ctx: " << sym->var_def_ctx << ", sym->const_def_ctx: " << sym->const_def_ctx << std::endl; @@ -1203,9 +1203,9 @@ private: } else if (sym->type->IsPtrFloat()) { elem_type = ir::Type::GetFloatType(); } - std::cout << "数组参数维度: " << dim_count << " 维, dims: "; - for (int d : dims) std::cout << d << " "; - std::cout << std::endl; + DebugStream() << "数组参数维度: " << dim_count << " 维, dims: "; + for (int d : dims) DebugStream() << d << " "; + DebugStream() << std::endl; } else if (sym->type && (sym->type->IsPtrInt32() || sym->type->IsPtrFloat())) { // 普通指针,只能有一个下标 dim_count = 1; @@ -1218,7 +1218,7 @@ private: size_t subscript_count = ctx->exp().size(); - std::cout << "dim_count: " << dim_count << ", subscript_count: " << subscript_count << std::endl; + DebugStream() << "dim_count: " << dim_count << ", subscript_count: " << subscript_count << std::endl; if (dim_count > 0 || sym->is_array_param || sym->type->IsArray() || sym->type->IsPtrInt32() || sym->type->IsPtrFloat()) { @@ -1239,11 +1239,11 @@ private: if (subscript_count == dim_count) { // 完全索引,返回元素类型 - std::cout << "完全索引,返回元素类型" << std::endl; + DebugStream() << "完全索引,返回元素类型" << std::endl; return {elem_type, true, false}; } else { // 部分索引,返回子数组的指针类型 - std::cout << "部分索引,返回指针类型" << std::endl; + DebugStream() << "部分索引,返回指针类型" << std::endl; // 计算剩余维度的指针类型 if (elem_type->IsInt32()) { return {ir::Type::GetPtrInt32Type(), false, false}; @@ -1257,7 +1257,7 @@ private: // 没有下标访问 if (sym->type && sym->type->IsArray()) { // 数组名作为地址 - std::cout << "数组名作为地址" << std::endl; + DebugStream() << "数组名作为地址" << std::endl; if (auto* arr_type = dynamic_cast(sym->type.get())) { if (arr_type->GetElementType()->IsInt32()) { return {ir::Type::GetPtrInt32Type(), false, true}; @@ -1268,7 +1268,7 @@ private: return {ir::Type::GetPtrInt32Type(), false, true}; } else if (sym->is_array_param) { // 数组参数名作为地址 - std::cout << "数组参数名作为地址" << std::endl; + DebugStream() << "数组参数名作为地址" << std::endl; if (sym->type->IsPtrInt32()) { return {ir::Type::GetPtrInt32Type(), false, true}; } else { @@ -1292,14 +1292,14 @@ private: throw std::runtime_error(FormatError("sema", "非法函数调用")); } std::string func_name = ctx->Ident()->getText(); - std::cout << "[DEBUG] CheckFuncCall: " << func_name << std::endl; + DebugStream() << "[DEBUG] CheckFuncCall: " << func_name << std::endl; auto* func_sym = table_.lookup(func_name); if (!func_sym || func_sym->kind != SymbolKind::Function) { throw std::runtime_error(FormatError("sema", "未定义的函数: " + func_name)); } std::vector args; if (ctx->funcRParams()) { - std::cout << "[DEBUG] 处理函数调用参数:" << std::endl; + DebugStream() << "[DEBUG] 处理函数调用参数:" << std::endl; for (auto* exp : ctx->funcRParams()->exp()) { if (exp) { args.push_back(CheckExp(exp)); @@ -1310,7 +1310,7 @@ private: throw std::runtime_error(FormatError("sema", "参数个数不匹配")); } for (size_t i = 0; i < std::min(args.size(), func_sym->param_types.size()); ++i) { - std::cout << "[DEBUG] 检查参数 " << i << ": 实参类型 " << (int)args[i].type->GetKind() + DebugStream() << "[DEBUG] 检查参数 " << i << ": 实参类型 " << (int)args[i].type->GetKind() << " 形参类型 " << (int)func_sym->param_types[i]->GetKind() << std::endl; if (!IsTypeCompatible(args[i].type, func_sym->param_types[i])) { throw std::runtime_error(FormatError("sema", "参数类型不匹配")); @@ -1511,10 +1511,10 @@ private: sym.array_dims = dims; table_.addSymbol(sym); - std::cout << "[DEBUG] 添加参数: " << name << " type_kind: " << (int)param_type->GetKind() + DebugStream() << "[DEBUG] 添加参数: " << name << " type_kind: " << (int)param_type->GetKind() << " is_array: " << is_array << " dims: "; - for (int d : dims) std::cout << d << " "; - std::cout << std::endl; + for (int d : dims) DebugStream() << d << " "; + DebugStream() << std::endl; } } diff --git a/src/sem/SymbolTable.cpp b/src/sem/SymbolTable.cpp index 4253825..82e948d 100644 --- a/src/sem/SymbolTable.cpp +++ b/src/sem/SymbolTable.cpp @@ -1,4 +1,5 @@ #include "sem/SymbolTable.h" +#include "utils/Log.h" #include // 用于访问父节点 #include #include @@ -10,7 +11,7 @@ #ifdef DEBUG_SYMBOL_TABLE #include -#define DEBUG_MSG(msg) std::cerr << "[SymbolTable Debug] " << msg << std::endl +#define DEBUG_MSG(msg) DebugStream() << "[SymbolTable Debug] " << msg << std::endl #else #define DEBUG_MSG(msg) #endif @@ -48,7 +49,7 @@ bool SymbolTable::addSymbol(const Symbol& sym) { // 立即验证存储的符号 const auto& stored = current_scope[sym.name]; - std::cout << "SymbolTable::addSymbol: stored " << sym.name + DebugStream() << "SymbolTable::addSymbol: stored " << sym.name << " with kind=" << (int)stored.kind << ", const_def_ctx=" << stored.const_def_ctx << std::endl; @@ -69,7 +70,7 @@ const Symbol* SymbolTable::lookup(const std::string& name) const { const auto& scope = scopes_[*it]; auto found = scope.find(name); if (found != scope.end()) { - std::cout << "SymbolTable::lookup: found " << name + DebugStream() << "SymbolTable::lookup: found " << name << " in active scope index " << *it << ", kind=" << (int)found->second.kind << ", const_def_ctx=" << found->second.const_def_ctx diff --git a/src/utils/CLI.cpp b/src/utils/CLI.cpp index 21b6d20..61212ba 100644 --- a/src/utils/CLI.cpp +++ b/src/utils/CLI.cpp @@ -58,6 +58,11 @@ CLIOptions ParseCLI(int argc, char** argv) { continue; } + if (std::strcmp(arg, "--debug") == 0) { + opt.debug = true; + continue; + } + if (arg[0] == '-') { throw std::runtime_error( FormatError("cli", std::string("未知参数: ") + arg + diff --git a/src/utils/Log.cpp b/src/utils/Log.cpp index e540ba8..5f9191c 100644 --- a/src/utils/Log.cpp +++ b/src/utils/Log.cpp @@ -2,17 +2,44 @@ #include "utils/Log.h" +#include #include +#include #include +bool g_debug_enabled = false; + namespace { +class NullBuffer : public std::streambuf { + protected: + int overflow(int c) override { return c; } +}; + +std::ostream& NullStream() { + static NullBuffer null_buffer; + static std::ostream null_stream(&null_buffer); + return null_stream; +} + bool IsCLIError(const std::string_view msg) { return HasErrorPrefix(msg, "cli"); } } // namespace +void SetDebugEnabled(bool enabled) { + g_debug_enabled = enabled; +} + +bool IsDebugEnabled() { + return g_debug_enabled; +} + +std::ostream& DebugStream() { + return g_debug_enabled ? std::cerr : NullStream(); +} + void LogInfo(const std::string_view msg, std::ostream& os) { os << "[info] " << msg << "\n"; } @@ -57,6 +84,7 @@ void PrintHelp(std::ostream& os) { << " --emit-parse-tree 仅在显式模式下启用语法树输出\n" << " --emit-ir 仅在显式模式下启用 IR 输出\n" << " --emit-asm 仅在显式模式下启用 AArch64 汇编输出\n" + << " --debug 启用调试日志输出\n" << "\n" << "说明:\n" << " - 默认输出 IR\n" diff --git a/test/test_case/mem2reg/01_phi.out b/test/test_case/mem2reg/01_phi.out new file mode 100644 index 0000000..7ed6ff8 --- /dev/null +++ b/test/test_case/mem2reg/01_phi.out @@ -0,0 +1 @@ +5 diff --git a/test/test_case/mem2reg/01_phi.sy b/test/test_case/mem2reg/01_phi.sy new file mode 100644 index 0000000..aa9db2b --- /dev/null +++ b/test/test_case/mem2reg/01_phi.sy @@ -0,0 +1,9 @@ +int main() { + int x = 0; + if (x == 0) { + x = 5; + } else { + x = 7; + } + return x; +}