diff --git a/.gitignore b/.gitignore index 1ee33a1..51f5f27 100644 --- a/.gitignore +++ b/.gitignore @@ -68,3 +68,4 @@ Thumbs.db # Project outputs # ========================= test/test_result/ +sema_check \ No newline at end of file diff --git a/README.md b/README.md index c24a2fa..d1a537a 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,10 @@ 如果希望进一步参考编译相关项目和往届优秀实现,可以查看编译比赛官网的技术支持栏目:。其中的“备赛推荐”整理了一些编译相关项目,也能看到往届优秀作品的开源实现,这些内容都很值得参考。 +此外,仓库中还提供了一份当前实现状态与测试入口的总览文档,便于组内同步进度: + +- `doc/实验进度与测试方法.md` + ## 3. 头歌平台协作流程 头歌平台的代码托管方式与 GitHub/Gitee 类似。如果你希望基于当前仓库快速开始协作,可以参考下面这套流程。 diff --git a/doc/lab2剩余任务分工.md b/doc/lab2剩余任务分工.md new file mode 100644 index 0000000..8969726 --- /dev/null +++ b/doc/lab2剩余任务分工.md @@ -0,0 +1,102 @@ +### 人员 1:基础表达式与赋值支持(lc,已完成) + +- 任务 1.1:支持更多二元运算符(Sub, Mul, Div, Mod) +- 任务 1.2:支持一元运算符(正负号) +- 任务 1.3:支持赋值表达式 +- 任务 1.4:支持逗号分隔的多个变量声明 + +### 人员 2:控制流支持(lyy,已完成) + +- 任务 2.1:支持 if-else 条件语句 +- 任务 2.2:支持 while 循环语句 +- 任务 2.3:支持 break/continue 语句 +- 任务 2.4:支持比较和逻辑表达式 + +### 人员 3:函数与全局变量支持 + +- 任务 3.1:支持全局变量声明与初始化 +- 任务 3.2:支持函数参数处理 +- 任务 3.3:支持函数调用生成 +- 任务 3.4:支持 const 常量声明 + +## 人员 1 完成情况详细说明(更新于 2026-03-30) + +### ✅ 已完成任务 + +人员 1 已完整实现 Lab2 IR 生成的基础功能模块,包括: + +1. **二元运算符**(任务 1.1) + - 实现 `Sub`, `Mul`, `Div`, `Mod` 四种运算符 + - 修改文件:`include/ir/IR.h`, `src/ir/IRBuilder.cpp`, `src/ir/IRPrinter.cpp`, `src/irgen/IRGenExp.cpp` +2. **一元运算符**(任务 1.2) + - 实现正负号运算符(`+`, `-`) + - 新增 `UnaryInst` 类支持一元指令 + - 负号生成 `sub 0, x` 指令(LLVM IR 标准形式) +3. **赋值表达式**(任务 1.3) + - 实现变量赋值语句的 IR 生成 + - 修改文件:`src/irgen/IRGenStmt.cpp` +4. **多变量声明**(任务 1.4) + - 支持逗号分隔的变量声明(如 `int a, b, c;`) + - 支持带初始化的多变量声明(如 `int a = 1, b = 2;`) + +### 🧪 测试验证 + +- **Lab1 语法分析**:✅ 通过(10/11 functional 测试,1 个数组测试超出范围) +- **Lab2 语义分析**:✅ 通过(6 正例 + 4 反例) +- **IR 生成测试**:✅ 通过(7/7 自定义测试用例) + - 测试脚本:`./scripts/test_lab2_ir1.sh` + - 测试用例目录:`test/test_case/irgen_lab1_4/` + +### 📝 代码质量 + +- 所有修改已通过编译测试 +- 未影响原有 Lab1 和 Lab2 Sema 功能 +- 代码风格与项目保持一致 +- 关键函数添加了注释说明 + +### 🔄 协作接口 + +人员 1 的实现为后续任务提供了以下接口: + +- **表达式生成**:`visitAddExp`, `visitMulExp`, `visitUnaryExp` +- **语句生成**:`visitStmt`(支持赋值和 return) +- **变量管理**:`storage_map_` 维护变量名到栈槽位的映射 +- **IR 构建**:`IRBuilder::CreateBinary`, `IRBuilder::CreateNeg`, `IRBuilder::CreateStore` + +后续人员可以在此基础上扩展更复杂的功能(控制流、函数调用等)。 + +## 人员 2 完成情况详细说明(更新于 2026-03-31) + +### ✅ 已完成任务 + +人员 2 已完整实现 Lab2 IR 生成中涉及的控制流支持,包括: + +1. **IR 结构与底层辅助拓展** + - 补充 `Int1` 基础类型以及 `Value::IsInt1()`。 + - 新增 `CmpInst`, `ZextInst`, `BranchInst` 以及 `CondBranchInst` 以支持关系计算和跳转逻辑。 + - 在 `IRBuilder` 中补齐创建此类指令的便捷接口与 `IRPrinter` 适配,并修复了 `IRPrinter` 存在的块命名 `%%` 重复问题。 + - 优化 `Context::NextTemp` 分配命名使用 `%t` 前缀,解决非线性顺序下纯数字临时变量引发 `llc` 后端词法顺序验证失败问题。 +2. **比较和逻辑表达式**(任务 2.4) + - 新增实现 `visitRelExp`、`visitEqExp`。 + - 实现条件二元表达式全链路短路求值 (`visitLAndExp`、`visitLOrExp`)。短路时通过控制流跳转+利用局部栈变量分配并多次赋值记录实现栈传递,规避了 `phi` 的麻烦。 + - 利用 `visitCondUnaryExp` 增加逻辑非 `!` 判定。 +3. **控制流框架支持**(任务 2.1 - 2.3) + - 在 `visitStmt` 中完美实现了 `if-else` 条件语句(自动插入无条件跳合块)、`while` 循环语句。 + - 在 `IRGen` 实例中通过 `current_loop_cond_bb_` 等维护循环栈,实现了 `break` 与 `continue`。 + - 修复了此前框架在 `IRGenDecl.cpp` 的 `visitBlock` 中缺少终结向上传递导致的 `break` 生成不匹配死块 BUG 及重复 `Branch` 问题。 +4. **关键前序 Bug 修复** + - 发现了在原框架里 `src/sem/Sema.cpp` 进行 AST 解析时 `RelExp` 和 `EqExp` 对于非原生底层变量追踪由于左偏漏调规则导致 `null_ptr` (`变量使用缺少语义绑定:a`) 报错的问题,并做出了精修复。 + +### 🧪 测试验证 + +- **Lab2 语义分析**:修复后所有已有的语义正例验证正常。 +- **IR 生成与后端执行**:✅ 自建嵌套含复合逻辑循环脚本测试通过。 +- **验证命令**(运行含 break 和 while 的范例文件): + ```bash + cd build && make -j$(nproc) && cd .. && ./scripts/verify_ir.sh test/test_case/functional/29_break.sy --run + ``` + + **完整测试脚本** + ```bash + for f in test/test_case/functional/*.sy; do echo "Testing $f..."; ./scripts/verify_ir.sh "$f" --run > /dev/null || echo "FAILED $f"; done + ``` \ No newline at end of file diff --git a/doc/实验进度与测试方法.md b/doc/实验进度与测试方法.md new file mode 100644 index 0000000..3c982dd --- /dev/null +++ b/doc/实验进度与测试方法.md @@ -0,0 +1,436 @@ +# 实验进度与测试方法 + +## 1. 当前实验进度 + +本文档用于记录当前仓库在各个 Lab 上的实现状态,以及对应的测试与验证方式。 +需要注意:本仓库当前仍处于“课程示例框架 + 逐步补全”的阶段,并不是一个已经完整实现全部 SysY 语义的编译器。 + +### 1.1 Lab1 当前进度 + +Lab1 对应前端语法分析与语法树构建。 + +当前状态: + +- 已提供 `SysY.g4`、ANTLR 驱动与语法树打印能力。 +- 已支持通过 `--emit-parse-tree` 输出语法树。 +- 可使用 `parse-only` 模式单独构建前端,不依赖 `sem` / `irgen` / `mir`。 + +### 1.2 Lab2 当前进度 + +Lab2 对应“语法树 -> 语义检查 -> IR”。 + +当前状态可以拆成两部分来看: + +1. `Sema` + - 已完成一版基于当前 SysY grammar 的语义检查基础实现。 + - 已支持多层作用域、变量/常量重定义检查、先声明后使用。 + - 已支持函数符号收集、函数调用检查、`main` 入口检查。 + - 已支持 `break` / `continue` 使用位置检查。 + - 已支持 `return` 与函数返回类型匹配检查。 + - 已支持 `const` 常量表达式求值、数组维度检查、全局初始化常量性检查。 + - 已支持 `int/float` 标量表达式、比较、逻辑表达式的基础类型检查。 + - 已内建 `getint`、`putch`、`getfloat`、`getarray`、`putarray` 等常见运行库函数声明。 + +2. `IRGen` + - 当前仓库原有 `IRGen` 仍是最小示例版本。 + - 当前只适合支持“局部 `int` 变量 + 常量 + 简单表达式 + `return`”这类极小子集。 + - 由于 grammar 已扩展,而 `IRGen` 尚未完全同步,所以 Lab2 目前**只完成了前半部分:Sema 基础扩展**。 + - Lab2 的 IR 生成部分仍需继续补全。 + +### 1.3 Lab3 当前进度 + +Lab3 对应“IR -> MIR -> 汇编”。 + +当前状态: + +- 仓库中保留了最小后端链路。 +- 仅适合消费当前最小 IR 子集。 +- 尚不具备对完整 SysY 程序稳定生成汇编的能力。 + +### 1.4 Lab4-Lab6 当前进度 + +当前仓库已经预留: + +- IR 分析与 Pass 目录结构 +- `Mem2Reg`、`ConstFold`、`ConstProp`、`DCE`、`CSE`、`CFGSimplify` 等文件框架 +- 循环分析、支配树、后端优化等实验入口 + +但这些阶段是否“完成”,取决于你们后续自行补全,不应默认认为仓库当前已经完全实现。 + +## 2. 推荐测试思路 + +建议把测试分成三层: + +1. `单阶段验证` + - 只验证某个阶段是否工作,例如只看 parse、只看 sema、只看 IR 输出。 + +2. `链路验证` + - 从源码一路走到 IR 或汇编,再运行程序,比对 `.out`。 + +3. `批量回归` + - 对 `test/test_case` 下多个测试统一执行,避免只靠 `simple_add.sy` 判断功能是否完成。 + +## 3. 别人拉取当前实现后的推荐编译方式 + +如果其他同学拉取了当前仓库,建议按下面顺序准备环境并编译。 + +### 3.1 先生成 ANTLR 输出 + +当前仓库的 CMake 会收集构建目录中的 ANTLR 生成文件,但不会自动调用 ANTLR,所以第一次构建前应先执行: + +```bash +mkdir -p build/generated/antlr4 +java -jar third_party/antlr-4.13.2-complete.jar \ + -Dlanguage=Cpp \ + -visitor -no-listener \ + -Xexact-output-dir \ + -o build/generated/antlr4 \ + src/antlr4/SysY.g4 +``` + +### 3.2 如果只想验证 Lab1 + +只构建 parse-only 前端: + +```bash +cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=ON +cmake --build build -j "$(nproc)" +``` + +构建后可直接运行: + +```bash +./scripts/test_lab1.sh test/test_case/functional +``` + +### 3.3 如果想验证当前 Lab2 的 Sema 部分 + +由于当前仓库中的 `IRGen` 还没有完全跟上新 grammar,而我们这次主要完成的是 `Sema`,所以推荐单独准备一个 `build-sema/` 目录来验证语义检查。 + +推荐命令如下: + +```bash +cmake -S . -B build-sema -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF +mkdir -p build-sema/generated +cp -r build/generated/antlr4 build-sema/generated/ +cmake --build build-sema --target frontend utils sem -j "$(nproc)" +``` + +然后编译 `sema_check`: + +```bash +g++ -std=c++17 \ + -Iinclude \ + -Isrc \ + -Ibuild-sema/generated/antlr4 \ + -Ithird_party/antlr4-runtime-4.13.2/runtime/src \ + tools/sema_check.cpp \ + build-sema/src/sem/libsem.a \ + build-sema/src/frontend/libfrontend.a \ + build-sema/src/utils/libutils.a \ + build-sema/libantlr4_runtime.a \ + -pthread \ + -o build-sema/sema_check +``` + +完成后即可运行: + +```bash +./scripts/test_lab2_sema.sh positive +./scripts/test_lab2_sema.sh negative +``` + +说明: + +- `build/` 主要用于 Lab1 parse-only 或后续全量构建 +- `build-sema/` 主要用于当前阶段单独验证 `Sema` +- `scripts/test_lab2_sema.sh` 依赖 `./build-sema/sema_check` + +### 3.4 如果后续要做全量构建 + +等 `IRGen` 与 grammar 完全同步后,可直接做全量构建: + +```bash +cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF +cmake --build build -j "$(nproc)" +``` + +但在当前阶段,不建议把“全量 build 成功”作为验证 `Sema` 的唯一标准,因为 Lab2 目前完成的是语义分析前半部分,不是整套 IR 生成。 + +## 4. Lab1 测试方法 + +### 3.1 构建命令 + +先生成 ANTLR 输出: + +```bash +mkdir -p build/generated/antlr4 +java -jar third_party/antlr-4.13.2-complete.jar \ + -Dlanguage=Cpp \ + -visitor -no-listener \ + -Xexact-output-dir \ + -o build/generated/antlr4 \ + src/antlr4/SysY.g4 +``` + +然后使用 `parse-only` 构建: + +```bash +cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=ON +cmake --build build -j "$(nproc)" +``` + +### 3.2 单个样例测试 + +```bash +./build/bin/compiler --emit-parse-tree test/test_case/functional/simple_add.sy +``` + +### 3.3 批量测试 + +仓库已提供 parse 批量测试脚本。为避免终端直接打印大量语法树导致输出过长,脚本会把每个用例的语法树输出写入单独日志文件。 + +```bash +./scripts/test_lab1.sh test/test_case/functional +``` + +如果希望指定日志目录,可以使用: + +```bash +./scripts/test_lab1.sh test/test_case/functional test/test_result/lab1_parse_logs +``` + +终端中会看到形如: + +```text +TEST test/test_case/functional/simple_add.sy -> test/test_result/lab1_parse_logs/simple_add.parse.log +... +ALL_PARSE_OK (...) logs: test/test_result/lab1_parse_logs +``` + +说明当前测试目录中的 `.sy` 文件都能通过语法分析;具体语法树内容可直接查看对应 `.parse.log` 文件。 + +## 5. Lab2 测试方法 + +Lab2 建议分成两部分测试:`Sema` 和 `IRGen`。 + +### 4.1 Lab2 当前推荐先测 Sema + +因为当前仓库中 `IRGen` 还未完全同步到新 grammar,所以当前阶段更适合先用“语义检查”来证明 Lab2 前半部分已经实现。 + +#### 4.1.1 当前已验证通过的正例 + +下面这些测试用例已经可以作为当前 `Sema` 的正向样例: + +```bash +./scripts/test_lab2_sema.sh positive +``` + +如果希望指定日志目录,可以使用: + +```bash +./scripts/test_lab2_sema.sh positive test/test_result/lab2_sema_positive_logs +``` + +预期现象: + +- 终端按用例打印 `TEST ... -> ...` +- 全部通过后输出 `ALL_SEMA_POSITIVE_OK (...)` +- 详细输出写入 `*.sema.log` + +#### 4.1.2 当前可用于演示的反例 + +当前已经准备好的反例位于: + +- `test/test_case/sema_negative/undef.sy` +- `test/test_case/sema_negative/break.sy` +- `test/test_case/sema_negative/ret.sy` +- `test/test_case/sema_negative/call.sy` + +执行命令: + +```bash +./scripts/test_lab2_sema.sh negative +``` + +如果希望指定日志目录,可以使用: + +```bash +./scripts/test_lab2_sema.sh negative test/test_result/lab2_sema_negative_logs +``` + +预期现象: + +- 终端按用例打印 `TEST ... -> ...` +- 全部符合预期后输出 `ALL_SEMA_NEGATIVE_OK (...)` +- 每个反例的详细错误信息写入对应 `.sema.log` + +例如: + +- 使用未声明变量 +- 循环外 `break` +- `void` 函数返回值 +- 函数参数个数不匹配 + +#### 4.1.3 语义错误定位信息说明 + +语义错误信息中的 `@行:列` 用于标明错误位置。 + +例如: + +```text +[error] [sema] @1:19 - 使用了未声明的标识符: a +``` + +表示: + +- `1` 是第 1 行 +- `19` 是第 19 列 + +也就是提示错误出现在源代码第 1 行第 19 列附近,便于快速定位。 + +#### 4.1.4 当前 Sema 已覆盖的主要错误类型 + +当前已实现的典型错误检测包括: + +- 未声明标识符使用 +- 同作用域重定义 +- 函数重定义 +- 缺少合法 `main` +- 函数参数数量或类型不匹配 +- `break/continue` 不在循环中 +- `return` 与函数返回类型不匹配 +- 给 `const` 对象赋值 +- 数组维度非法 +- 全局初始化不满足编译期常量要求 + +### 4.2 Lab2 后续 IR 测试方式 + +当 `IRGen` 与当前 grammar 对齐后,可使用如下命令输出 IR: + +```bash +./build/bin/compiler --emit-ir test/test_case/functional/simple_add.sy +``` + +若需要进一步验证 “IR -> 可执行程序” 链路,可使用: + +```bash +./scripts/verify_ir.sh test/test_case/functional/simple_add.sy test/test_result/ir --run +``` + +但需要强调: +在当前仓库状态下,这条命令只适合用于未来 IRGen 完成后的测试;不能拿它来证明当前已完成的 `Sema` 部分。 + +## 6. Lab3 测试方法 + +Lab3 对应汇编输出与后端链路。 + +### 5.1 构建 + +需要全量构建: + +```bash +cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF +cmake --build build -j "$(nproc)" +``` + +### 5.2 单个样例输出汇编 + +```bash +./build/bin/compiler --emit-asm test/test_case/functional/simple_add.sy +``` + +### 5.3 汇编链路验证 + +```bash +./scripts/verify_asm.sh test/test_case/functional/simple_add.sy test/test_result/asm --run +``` + +`--run` 模式下会: + +1. 生成汇编 +2. 交叉编译为 AArch64 可执行文件 +3. 用 `qemu-aarch64` 运行 +4. 将输出与同名 `.out` 比对 + +## 7. Lab4 测试方法 + +Lab4 是优化实验,测试重点不只是“能不能运行”,还包括“优化前后语义一致”。 + +建议按下面顺序验证: + +1. 先确保未优化版本功能正确 +2. 接入优化后再次跑 `verify_ir.sh` 或 `verify_asm.sh` +3. 比较优化前后的 IR 或汇编输出 +4. 在多个测试上回归,避免某个优化只在 `simple_add` 上看起来没问题 + +推荐命令: + +```bash +./scripts/verify_ir.sh test/test_case/functional/simple_add.sy test/test_result/ir --run +./scripts/verify_asm.sh test/test_case/functional/simple_add.sy test/test_result/asm --run +``` + +如果你们为优化实现了单独开关,也应额外对比: + +```bash +./build/bin/compiler --emit-ir test/test_case/functional/simple_add.sy +./build/bin/compiler --emit-asm test/test_case/functional/simple_add.sy +``` + +## 8. Lab5 测试方法 + +Lab5 的测试重点是: + +- 寄存器分配后代码仍然正确 +- spill/reload 逻辑没有破坏语义 +- 汇编仍能完整运行 + +推荐直接走后端完整链路: + +```bash +./scripts/verify_asm.sh test/test_case/functional/simple_add.sy test/test_result/asm --run +``` + +完成寄存器分配后,不应只测单个样例,建议至少覆盖: + +- `functional/` +- `performance/` 中若干较大样例 + +## 9. Lab6 测试方法 + +Lab6 重点是循环和并行相关优化,测试要分成功能正确性和优化收益两部分。 + +### 8.1 功能正确性 + +```bash +./scripts/verify_ir.sh test/test_case/functional/simple_add.sy test/test_result/ir --run +./scripts/verify_asm.sh test/test_case/functional/simple_add.sy test/test_result/asm --run +``` + +### 8.2 优化效果观察 + +你们可以对比优化前后的: + +- IR 输出 +- 汇编输出 +- 执行时间 +- 代码规模 + +例如: + +```bash +./build/bin/compiler --emit-ir test/test_case/functional/simple_add.sy +./build/bin/compiler --emit-asm test/test_case/functional/simple_add.sy +``` + +真正评估循环优化时,建议使用包含明显循环结构的功能或性能测试,而不是只看 `simple_add.sy`。 + +## 10. 当前阶段的建议结论 + +如果你要汇报当前仓库状态,可以概括为: + +1. Lab1 的语法树构建链路已经具备独立测试方式。 +2. Lab2 当前已经完成 `Sema` 基础扩展,并可通过正反例直接演示。 +3. Lab2 的 `IRGen` 还需要继续补全,当前不能把整份 Lab2 视为全部完成。 +4. Lab3 及后续实验目前主要还是框架和最小样例能力,完整覆盖仍需后续实现。 diff --git a/include/ir/IR.h b/include/ir/IR.h index b961192..06f837f 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -93,16 +93,18 @@ class Context { class Type { public: - enum class Kind { Void, Int32, PtrInt32 }; + enum class Kind { Void, Int1, Int32, PtrInt32 }; explicit Type(Kind k); // 使用静态共享对象获取类型。 // 同一类型可直接比较返回值是否相等,例如: // Type::GetInt32Type() == Type::GetInt32Type() static const std::shared_ptr& GetVoidType(); + static const std::shared_ptr& GetInt1Type(); static const std::shared_ptr& GetInt32Type(); static const std::shared_ptr& GetPtrInt32Type(); Kind GetKind() const; bool IsVoid() const; + bool IsInt1() const; bool IsInt32() const; bool IsPtrInt32() const; @@ -118,6 +120,7 @@ class Value { const std::string& GetName() const; void SetName(std::string n); bool IsVoid() const; + bool IsInt1() const; bool IsInt32() const; bool IsPtrInt32() const; bool IsConstant() const; @@ -152,7 +155,10 @@ class ConstantInt : public ConstantValue { }; // 后续还需要扩展更多指令类型。 -enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret }; +// enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret }; +enum class Opcode { Add, Sub, Mul, Div, Mod, Neg, Alloca, Load, Store, Ret, Cmp, Zext, Br, CondBr }; + +enum class CmpOp { Eq, Ne, Lt, Gt, Le, Ge }; // User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。 // 当前实现中只有 Instruction 继承自 User。 @@ -196,7 +202,14 @@ class BinaryInst : public Instruction { BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, Value* rhs, std::string name); Value* GetLhs() const; - Value* GetRhs() const; + Value* GetRhs() const; +}; + +class UnaryInst : public Instruction { + public: + UnaryInst(Opcode op, std::shared_ptr ty, Value* operand, + std::string name); + Value* GetUnaryOperand() const; }; class ReturnInst : public Instruction { @@ -223,6 +236,37 @@ class StoreInst : public Instruction { Value* GetPtr() const; }; +class CmpInst : public Instruction { + public: + CmpInst(CmpOp cmp_op, Value* lhs, Value* rhs, std::string name); + CmpOp GetCmpOp() const; + Value* GetLhs() const; + Value* GetRhs() const; + + private: + CmpOp cmp_op_; +}; + +class ZextInst : public Instruction { + public: + ZextInst(std::shared_ptr dest_ty, Value* val, std::string name); + Value* GetValue() const; +}; + +class BranchInst : public Instruction { + public: + BranchInst(BasicBlock* dest); + BasicBlock* GetDest() const; +}; + +class CondBranchInst : public Instruction { + public: + CondBranchInst(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb); + Value* GetCond() const; + BasicBlock* GetTrueBlock() const; + BasicBlock* GetFalseBlock() const; +}; + // BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。 // 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。 class BasicBlock : public Value { @@ -300,10 +344,17 @@ class IRBuilder { BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs, const std::string& name); BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name); + UnaryInst* CreateNeg(Value* operand, const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name); StoreInst* CreateStore(Value* val, Value* ptr); ReturnInst* CreateRet(Value* v); + CmpInst* CreateCmp(CmpOp op, Value* lhs, Value* rhs, const std::string& name); + ZextInst* CreateZext(Value* val, const std::string& name); + BranchInst* CreateBr(BasicBlock* dest); + CondBranchInst* CreateCondBr(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb); private: Context& ctx_; diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index 231ba90..fec791a 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -26,16 +26,23 @@ class IRGenImpl final : public SysYBaseVisitor { std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; + std::any visitBlock(SysYParser::BlockContext* ctx) override; std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; std::any visitDecl(SysYParser::DeclContext* ctx) override; std::any visitStmt(SysYParser::StmtContext* ctx) override; std::any visitVarDef(SysYParser::VarDefContext* ctx) override; - std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; - std::any visitParenExp(SysYParser::ParenExpContext* ctx) override; - std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override; - std::any visitVarExp(SysYParser::VarExpContext* ctx) override; - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override; + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override; + std::any visitNumber(SysYParser::NumberContext* ctx) override; + std::any visitLVal(SysYParser::LValContext* ctx) override; + std::any visitAddExp(SysYParser::AddExpContext* ctx) override; + std::any visitMulExp(SysYParser::MulExpContext* ctx) override; + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override; + std::any visitRelExp(SysYParser::RelExpContext* ctx) override; + std::any visitEqExp(SysYParser::EqExpContext* ctx) override; + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override; + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override; + std::any visitCondUnaryExp(SysYParser::CondUnaryExpContext* ctx) override; + std::any visitCond(SysYParser::CondContext* ctx) override; private: enum class BlockFlow { @@ -50,8 +57,17 @@ class IRGenImpl final : public SysYBaseVisitor { const SemanticContext& sema_; ir::Function* func_; ir::IRBuilder builder_; - // 名称绑定由 Sema 负责;IRGen 只维护“声明 -> 存储槽位”的代码生成状态。 - std::unordered_map storage_map_; + // 名称绑定由 Sema 负责;IRGen 只维护"变量名 -> 存储槽位"的代码生成状态。 + std::unordered_map storage_map_; + + // 用于 break 和 continue 跳转的目标位置 + ir::BasicBlock* current_loop_cond_bb_ = nullptr; + ir::BasicBlock* current_loop_exit_bb_ = nullptr; + + int bb_cnt_ = 0; + std::string NextBlockName(const std::string& prefix = "bb") { + return prefix + "_" + std::to_string(++bb_cnt_); + } }; std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, diff --git a/include/sem/Sema.h b/include/sem/Sema.h index 9ac057b..64e2595 100644 --- a/include/sem/Sema.h +++ b/include/sem/Sema.h @@ -1,30 +1,69 @@ // 基于语法树的语义检查与名称绑定。 #pragma once +#include #include +#include #include "SysYParser.h" +enum class SemanticType { + Void, + Int, + Float, +}; + +struct ScalarConstant { + SemanticType type = SemanticType::Int; + double number = 0.0; +}; + +struct ObjectBinding { + enum class DeclKind { + Var, + Const, + Param, + }; + + std::string name; + SemanticType type = SemanticType::Int; + DeclKind decl_kind = DeclKind::Var; + bool is_array_param = false; + std::vector dimensions; + const SysYParser::VarDefContext* var_def = nullptr; + const SysYParser::ConstDefContext* const_def = nullptr; + const SysYParser::FuncFParamContext* func_param = nullptr; + bool has_const_value = false; + ScalarConstant const_value; +}; + +struct FunctionBinding { + std::string name; + SemanticType return_type = SemanticType::Int; + std::vector params; + const SysYParser::FuncDefContext* func_def = nullptr; + bool is_builtin = false; +}; + class SemanticContext { public: - void BindVarUse(SysYParser::VarContext* use, - SysYParser::VarDefContext* decl) { - var_uses_[use] = decl; - } + void BindObjectUse(const SysYParser::LValContext* use, ObjectBinding binding); + const ObjectBinding* ResolveObjectUse( + const SysYParser::LValContext* use) const; + + void BindFunctionCall(const SysYParser::UnaryExpContext* call, + FunctionBinding binding); + const FunctionBinding* ResolveFunctionCall( + const SysYParser::UnaryExpContext* call) const; - SysYParser::VarDefContext* ResolveVarUse( - const SysYParser::VarContext* use) const { - auto it = var_uses_.find(use); - return it == var_uses_.end() ? nullptr : it->second; - } + void RegisterFunction(FunctionBinding binding); + const FunctionBinding* ResolveFunction(const std::string& name) const; private: - std::unordered_map - var_uses_; + std::unordered_map object_uses_; + std::unordered_map + function_calls_; + std::unordered_map functions_; }; -// 目前仅检查: -// - 变量先声明后使用 -// - 局部变量不允许重复定义 SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); diff --git a/include/sem/SymbolTable.h b/include/sem/SymbolTable.h index c9396dd..201112c 100644 --- a/include/sem/SymbolTable.h +++ b/include/sem/SymbolTable.h @@ -1,17 +1,25 @@ -// 极简符号表:记录局部变量定义点。 +// 维护对象符号的多层作用域。 #pragma once #include +#include #include +#include -#include "SysYParser.h" +#include "sem/Sema.h" class SymbolTable { public: - void Add(const std::string& name, SysYParser::VarDefContext* decl); - bool Contains(const std::string& name) const; - SysYParser::VarDefContext* Lookup(const std::string& name) const; + SymbolTable(); + + void EnterScope(); + void ExitScope(); + + bool Add(const ObjectBinding& symbol); + bool ContainsInCurrentScope(std::string_view name) const; + const ObjectBinding* Lookup(std::string_view name) const; + size_t Depth() const; private: - std::unordered_map table_; + std::vector> scopes_; }; diff --git a/patch_IR_h.patch b/patch_IR_h.patch new file mode 100644 index 0000000..c1a32ce --- /dev/null +++ b/patch_IR_h.patch @@ -0,0 +1,83 @@ +--- include/ir/IR.h ++++ include/ir/IR.h +@@ -93,6 +93,7 @@ + class Type { + public: +- enum class Kind { Void, Int32, PtrInt32 }; ++ enum class Kind { Void, Int1, Int32, PtrInt32 }; + explicit Type(Kind k); + // 使用静态共享对象获取类型。 + // 同一类型可直接比较返回值是否相等,例如: + // Type::GetInt32Type() == Type::GetInt32Type() + static const std::shared_ptr& GetVoidType(); ++ static const std::shared_ptr& GetInt1Type(); + static const std::shared_ptr& GetInt32Type(); + static const std::shared_ptr& GetPtrInt32Type(); + Kind GetKind() const; + bool IsVoid() const; ++ bool IsInt1() const; + bool IsInt32() const; + bool IsPtrInt32() const; +@@ -118,6 +119,7 @@ + const std::string& GetName() const; + void SetName(std::string n); + bool IsVoid() const; ++ bool IsInt1() const; + bool IsInt32() const; + bool IsPtrInt32() const; + bool IsConstant() const; +@@ -153,7 +155,9 @@ + + // 后续还需要扩展更多指令类型。 +-// enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret }; +-enum class Opcode { Add, Sub, Mul, Div, Mod, Neg, Alloca, Load, Store, Ret }; ++enum class Opcode { Add, Sub, Mul, Div, Mod, Neg, Alloca, Load, Store, Ret, Cmp, Zext, Br, CondBr }; ++ ++enum class CmpOp { Eq, Ne, Lt, Gt, Le, Ge }; + + // User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。 +@@ -231,6 +235,33 @@ + Value* GetPtr() const; + }; + ++class CmpInst : public Instruction { ++ public: ++ CmpInst(CmpOp cmp_op, Value* lhs, Value* rhs, std::string name); ++ CmpOp GetCmpOp() const; ++ Value* GetLhs() const; ++ Value* GetRhs() const; ++ private: ++ CmpOp cmp_op_; ++}; ++ ++class ZextInst : public Instruction { ++ public: ++ ZextInst(std::shared_ptr dest_ty, Value* val, std::string name); ++ Value* GetValue() const; ++}; ++ ++class BranchInst : public Instruction { ++ public: ++ BranchInst(BasicBlock* dest); ++ BasicBlock* GetDest() const; ++}; ++ ++class CondBranchInst : public Instruction { ++ public: ++ CondBranchInst(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb); ++ Value* GetCond() const; ++ BasicBlock* GetTrueBlock() const; ++ BasicBlock* GetFalseBlock() const; ++}; ++ + // BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。 +@@ -315,6 +346,10 @@ + LoadInst* CreateLoad(Value* ptr, const std::string& name); + StoreInst* CreateStore(Value* val, Value* ptr); + ReturnInst* CreateRet(Value* v); ++ CmpInst* CreateCmp(CmpOp op, Value* lhs, Value* rhs, const std::string& name); ++ ZextInst* CreateZext(Value* val, const std::string& name); ++ BranchInst* CreateBr(BasicBlock* dest); ++ CondBranchInst* CreateCondBr(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb); + + private: diff --git a/scripts/test_lab1.sh b/scripts/test_lab1.sh new file mode 100755 index 0000000..538d9cc --- /dev/null +++ b/scripts/test_lab1.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +set -euo pipefail + +case_dir="${1:-test/test_case}" +log_dir="${2:-test/test_result/lab1_parse_logs}" + +if [[ ! -d "$case_dir" ]]; then + echo "测试目录不存在: $case_dir" >&2 + exit 1 +fi + +compiler="./build/bin/compiler" +if [[ ! -x "$compiler" ]]; then + echo "未找到编译器: $compiler ,请先构建 parse-only 版本。" >&2 + exit 1 +fi + +mkdir -p "$log_dir" + +mapfile -t cases < <(find "$case_dir" -name '*.sy' | sort) +if [[ ${#cases[@]} -eq 0 ]]; then + echo "未找到任何 .sy 测试文件: $case_dir" >&2 + exit 1 +fi + +for f in "${cases[@]}"; do + rel="${f#$case_dir/}" + safe_name="${rel//\//__}" + log_file="$log_dir/${safe_name%.sy}.parse.log" + echo "TEST $f -> $log_file" + if ! "$compiler" --emit-parse-tree "$f" >"$log_file" 2>&1; then + echo "FAIL $f (see $log_file)" >&2 + exit 1 + fi +done + +echo "ALL_PARSE_OK (${#cases[@]} cases) logs: $log_dir" diff --git a/scripts/test_lab2_ir1.sh b/scripts/test_lab2_ir1.sh new file mode 100755 index 0000000..65f9c3f --- /dev/null +++ b/scripts/test_lab2_ir1.sh @@ -0,0 +1,157 @@ +#!/usr/bin/env bash +# 测试 Lab2 IR 生成 - 人员 1 的任务 +# 测试内容: +# - 任务 1.1: 支持更多二元运算符(Sub, Mul, Div, Mod) +# - 任务 1.2: 支持一元运算符(正负号) +# - 任务 1.3: 支持赋值表达式 +# - 任务 1.4: 支持逗号分隔的多个变量声明 + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" +COMPILER="$PROJECT_ROOT/build/bin/compiler" +TEST_DIR="$PROJECT_ROOT/test/test_case/irgen_lab1_4" +RESULT_DIR="$PROJECT_ROOT/test/test_result/lab2_ir1" + +# 颜色输出 +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +echo "=========================================" +echo "Lab2 IR 生成测试 - 部分任务验证" +echo "=========================================" +echo "" + +# 检查编译器是否存在 +if [[ ! -x "$COMPILER" ]]; then + echo -e "${RED}错误:编译器不存在或不可执行:$COMPILER${NC}" + echo "请先运行:cmake --build build" + exit 1 +fi + +# 检查测试目录是否存在 +if [[ ! -d "$TEST_DIR" ]]; then + echo -e "${RED}错误:测试目录不存在:$TEST_DIR${NC}" + exit 1 +fi + +# 创建结果目录 +mkdir -p "$RESULT_DIR" + +# 统计 +total=0 +passed=0 +failed=0 + +# 测试函数 +run_test() { + local input=$1 + local basename=$(basename "$input" .sy) + local expected_out="$TEST_DIR/$basename.out" + local actual_out="$RESULT_DIR/$basename.actual.out" + local ll_file="$RESULT_DIR/$basename.ll" + + ((total++)) || true + + echo -n "测试 $basename ... " + + # 生成 IR + if ! "$COMPILER" --emit-ir "$input" > "$ll_file" 2>&1; then + echo -e "${RED}IR 生成失败${NC}" + ((failed++)) || true + return 1 + fi + + # 如果需要运行并比对输出 + if [[ -f "$expected_out" ]]; then + # 编译并运行 + local exe_file="$RESULT_DIR/$basename" + if ! llc -O0 -filetype=obj "$ll_file" -o "$RESULT_DIR/$basename.o" 2>/dev/null; then + echo -e "${YELLOW}LLVM 编译失败 (llc)${NC}" + cat "$ll_file" + ((failed++)) || true + return 1 + fi + + if ! clang "$RESULT_DIR/$basename.o" -o "$exe_file" 2>/dev/null; then + echo -e "${YELLOW}链接失败 (clang)${NC}" + ((failed++)) || true + return 1 + fi + + # 运行程序,捕获返回值(低 8 位) + local exit_code=0 + "$exe_file" > "$actual_out" 2>&1 || exit_code=$? + + # 处理返回值(LLVM/AArch64 返回的是 8 位无符号整数) + if [[ $exit_code -gt 127 ]]; then + # 转换为有符号整数 + exit_code=$((exit_code - 256)) + fi + echo "$exit_code" > "$actual_out" + + # 比对输出 + if diff -q "$expected_out" "$actual_out" > /dev/null 2>&1; then + echo -e "${GREEN}✓ 通过${NC}" + ((passed++)) || true + return 0 + else + echo -e "${RED}✗ 输出不匹配${NC}" + echo " 期望:$(cat "$expected_out")" + echo " 实际:$(cat "$actual_out")" + ((failed++)) || true + return 1 + fi + else + # 没有期望输出,只检查 IR 生成 + echo -e "${GREEN}✓ IR 生成成功${NC}" + ((passed++)) || true + return 0 + fi +} + +# 查找所有测试用例 +test_files=() +while IFS= read -r -d '' file; do + test_files+=("$file") +done < <(find "$TEST_DIR" -name "*.sy" -type f -print0 | sort -z) + +if [[ ${#test_files[@]} -eq 0 ]]; then + echo -e "${RED}未找到测试用例:$TEST_DIR${NC}" + exit 1 +fi + +echo "找到 ${#test_files[@]} 个测试用例" +echo "" + +# 运行所有测试 +for test_file in "${test_files[@]}"; do + run_test "$test_file" || true +done + +# 输出统计 +echo "" +echo "=========================================" +echo "测试结果统计" +echo "=========================================" +echo -e "总数:$total" +echo -e "通过:${GREEN}$passed${NC}" +echo -e "失败:${RED}$failed${NC}" +echo "" + +if [[ $failed -eq 0 ]]; then + echo -e "${GREEN}✓ 所有测试通过!${NC}" + echo "" + echo "测试覆盖:" + echo " ✓ 任务 1.1: 二元运算符(Sub, Mul, Div, Mod)" + echo " ✓ 任务 1.2: 一元运算符(正负号)" + echo " ✓ 任务 1.3: 赋值表达式" + echo " ✓ 任务 1.4: 逗号分隔的多变量声明" + exit 0 +else + echo -e "${RED}✗ 有 $failed 个测试失败${NC}" + exit 1 +fi diff --git a/scripts/test_lab2_sema.sh b/scripts/test_lab2_sema.sh new file mode 100755 index 0000000..ffb41d8 --- /dev/null +++ b/scripts/test_lab2_sema.sh @@ -0,0 +1,92 @@ +#!/usr/bin/env bash + +set -euo pipefail + +mode="${1:-positive}" +log_dir="${2:-test/test_result/lab2_sema_logs}" + +checker="./build-sema/sema_check" +if [[ ! -x "$checker" ]]; then + echo "未找到语义测试驱动: $checker" >&2 + echo "请先准备 build-sema/sema_check。" >&2 + exit 1 +fi + +mkdir -p "$log_dir" + +case_files=() +expected_prefix="" + +case "$mode" in + positive) + expected_prefix="OK" + case_files=( + "test/test_case/functional/simple_add.sy" + "test/test_case/functional/09_func_defn.sy" + "test/test_case/functional/25_scope3.sy" + "test/test_case/functional/29_break.sy" + "test/test_case/functional/05_arr_defn4.sy" + "test/test_case/functional/95_float.sy" + ) + ;; + negative) + expected_prefix="ERR" + case_files=( + "test/test_case/sema_negative/undef.sy" + "test/test_case/sema_negative/break.sy" + "test/test_case/sema_negative/ret.sy" + "test/test_case/sema_negative/call.sy" + ) + ;; + *) + echo "用法: $0 [positive|negative] [log_dir]" >&2 + exit 1 + ;; +esac + +if [[ ${#case_files[@]} -eq 0 ]]; then + echo "没有可执行的测试用例" >&2 + exit 1 +fi + +for f in "${case_files[@]}"; do + if [[ ! -f "$f" ]]; then + echo "测试文件不存在: $f" >&2 + exit 1 + fi +done + +all_ok=true +for f in "${case_files[@]}"; do + base="$(basename "${f%.sy}")" + log_file="$log_dir/${base}.sema.log" + echo "TEST $f -> $log_file" + set +e + "$checker" "$f" >"$log_file" 2>&1 + status=$? + set -e + + if ! grep -q "^${expected_prefix} $f$" "$log_file"; then + echo "FAIL $f (see $log_file)" >&2 + all_ok=false + continue + fi + + if [[ "$mode" == "positive" && $status -ne 0 ]]; then + echo "FAIL $f (expected success, see $log_file)" >&2 + all_ok=false + continue + fi + + if [[ "$mode" == "negative" && $status -eq 0 ]]; then + echo "FAIL $f (expected semantic error, see $log_file)" >&2 + all_ok=false + continue + fi +done + +if [[ "$all_ok" != true ]]; then + exit 1 +fi + +echo "ALL_SEMA_${mode^^}_OK (${#case_files[@]} cases) logs: $log_dir" diff --git a/src/antlr4/SysY.g4 b/src/antlr4/SysY.g4 index 0907727..9acd02a 100644 --- a/src/antlr4/SysY.g4 +++ b/src/antlr4/SysY.g4 @@ -15,26 +15,27 @@ BREAK: 'break'; CONTINUE: 'continue'; RETURN: 'return'; +LE: '<='; +GE: '>='; +EQ: '=='; +NE: '!='; +AND: '&&'; +OR: '||'; + ASSIGN: '='; +LT: '<'; +GT: '>'; ADD: '+'; SUB: '-'; MUL: '*'; DIV: '/'; MOD: '%'; -EQ: '=='; -NEQ: '!='; -LT: '<'; -GT: '>'; -LE: '<='; -GE: '>='; NOT: '!'; -AND: '&&'; -OR: '||'; LPAREN: '('; RPAREN: ')'; -LBRACKET: '['; -RBRACKET: ']'; +LBRACK: '['; +RBRACK: ']'; LBRACE: '{'; RBRACE: '}'; COMMA: ','; @@ -42,39 +43,52 @@ SEMICOLON: ';'; ID: [a-zA-Z_][a-zA-Z_0-9]*; -ILITERAL: DEC_LIT | OCT_LIT | HEX_LIT; -fragment DEC_LIT: [1-9][0-9]* | '0'; -fragment OCT_LIT: '0'[0-7]+; -fragment HEX_LIT: ('0x' | '0X') [0-9a-fA-F]+; +HEX_FLOAT_LITERAL + : ('0x' | '0X') HEX_DIGIT* '.' HEX_DIGIT+ BINARY_EXPONENT + | ('0x' | '0X') HEX_DIGIT+ '.' HEX_DIGIT* BINARY_EXPONENT + | ('0x' | '0X') HEX_DIGIT+ BINARY_EXPONENT + ; + +DEC_FLOAT_LITERAL + : DEC_DIGIT+ '.' DEC_DIGIT* DEC_EXPONENT? + | '.' DEC_DIGIT+ DEC_EXPONENT? + | DEC_DIGIT+ DEC_EXPONENT + ; -FLITERAL: DEC_FLOAT_LIT | HEX_FLOAT_LIT; -fragment DEC_FLOAT_LIT - : [0-9]+ '.' [0-9]* EXPONENT? - | '.' [0-9]+ EXPONENT? - | [0-9]+ EXPONENT +HEX_INT_LITERAL + : ('0x' | '0X') HEX_DIGIT+ ; -fragment EXPONENT: ('e'|'E') ('+'|'-')? [0-9]+; -fragment HEX_FLOAT_LIT - : ('0x'|'0X') HEX_MANTISSA HEX_EXPONENT +OCT_INT_LITERAL + : '0' OCT_DIGIT+ ; -fragment HEX_MANTISSA - : [0-9a-fA-F]+ '.' [0-9a-fA-F]* - | '.' [0-9a-fA-F]+ - | [0-9a-fA-F]+ + +DEC_INT_LITERAL + : '0' + | [1-9] DEC_DIGIT* ; -fragment HEX_EXPONENT: ('p'|'P') ('+'|'-')? [0-9]+; WS: [ \t\r\n]+ -> skip; LINECOMMENT: '//' ~[\r\n]* -> skip; BLOCKCOMMENT: '/*' .*? '*/' -> skip; +fragment DEC_DIGIT: [0-9]; +fragment OCT_DIGIT: [0-7]; +fragment HEX_DIGIT: [0-9a-fA-F]; +fragment DEC_EXPONENT: [eE] [+-]? DEC_DIGIT+; +fragment BINARY_EXPONENT: [pP] [+-]? DEC_DIGIT+; + /*===-------------------------------------------===*/ /* Syntax rules */ /*===-------------------------------------------===*/ compUnit - : (decl | funcDef)+ EOF + : topLevelItem (topLevelItem)* EOF + ; + +topLevelItem + : decl + | funcDef ; decl @@ -83,30 +97,33 @@ decl ; constDecl - : CONST btype constDef (COMMA constDef)* SEMICOLON + : CONST bType constDef (COMMA constDef)* SEMICOLON ; -btype +varDecl + : bType varDef (COMMA varDef)* SEMICOLON + ; + +bType : INT | FLOAT ; constDef - : ID (LBRACKET constExp RBRACKET)* ASSIGN constInitVal + : ID constIndex* ASSIGN constInitVal ; -constInitVal - : constExp - | LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE +varDef + : ID constIndex* (ASSIGN initVal)? ; -varDecl - : btype varDef (COMMA varDef)* SEMICOLON +constIndex + : LBRACK constExp RBRACK ; -varDef - : ID (LBRACKET constExp RBRACKET)* - | ID (LBRACKET constExp RBRACKET)* ASSIGN initVal +constInitVal + : constExp + | LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE ; initVal @@ -115,7 +132,7 @@ initVal ; funcDef - : funcType ID LPAREN funcFParams? RPAREN blockStmt + : funcType ID LPAREN funcFParams? RPAREN block ; funcType @@ -129,10 +146,10 @@ funcFParams ; funcFParam - : btype ID (LBRACKET RBRACKET (LBRACKET exp RBRACKET)*)? + : bType ID (LBRACK RBRACK (LBRACK exp RBRACK)*)? ; -blockStmt +block : LBRACE blockItem* RBRACE ; @@ -142,53 +159,107 @@ blockItem ; stmt - : lValue ASSIGN exp SEMICOLON # assignStmt - | exp? SEMICOLON # exprStmt - | blockStmt # blockStmtNode - | IF LPAREN cond RPAREN stmt (ELSE stmt)? # ifStmt - | WHILE LPAREN cond RPAREN stmt # whileStmt - | BREAK SEMICOLON # breakStmt - | CONTINUE SEMICOLON # continueStmt - | RETURN exp? SEMICOLON # returnStmt + : lVal ASSIGN exp SEMICOLON + | exp? SEMICOLON + | block + | IF LPAREN cond RPAREN stmt (ELSE stmt)? + | WHILE LPAREN cond RPAREN stmt + | BREAK SEMICOLON + | CONTINUE SEMICOLON + | RETURN exp? SEMICOLON ; exp - : LPAREN exp RPAREN # parenExp - | lValue # lvalExp - | number # numberExp - | ID LPAREN funcRParams? RPAREN # funcCallExp - | unaryOp exp # unaryOpExp - | exp (MUL | DIV | MOD) exp # mulExp - | exp (ADD | SUB) exp # addExp - | exp (LT | GT | LE | GE) exp # relExp - | exp (EQ | NEQ) exp # eqExp - | exp AND exp # lAndExp - | exp OR exp # lOrExp + : addExp ; cond - : exp + : lOrExp + ; + +lVal + : ID (LBRACK exp RBRACK)* ; -lValue - : ID (LBRACKET exp RBRACKET)* +primaryExp + : LPAREN exp RPAREN + | lVal + | number ; number - : ILITERAL - | FLITERAL + : intConst + | floatConst ; -unaryOp +intConst + : DEC_INT_LITERAL + | OCT_INT_LITERAL + | HEX_INT_LITERAL + ; + +floatConst + : DEC_FLOAT_LITERAL + | HEX_FLOAT_LITERAL + ; + +unaryExp + : primaryExp + | ID LPAREN funcRParams? RPAREN + | addUnaryOp unaryExp + ; + +addUnaryOp : ADD | SUB - | NOT ; funcRParams : exp (COMMA exp)* ; +mulExp + : unaryExp + | mulExp MUL unaryExp + | mulExp DIV unaryExp + | mulExp MOD unaryExp + ; + +addExp + : mulExp + | addExp ADD mulExp + | addExp SUB mulExp + ; + +relExp + : addExp + | relExp LT addExp + | relExp GT addExp + | relExp LE addExp + | relExp GE addExp + ; + +eqExp + : relExp + | eqExp EQ relExp + | eqExp NE relExp + ; + +lAndExp + : condUnaryExp + | lAndExp AND condUnaryExp + ; + +lOrExp + : lAndExp + | lOrExp OR lAndExp + ; + +condUnaryExp + : eqExp + | NOT condUnaryExp + ; + constExp - : exp + : addExp ; diff --git a/src/ir/Context.cpp b/src/ir/Context.cpp index 16c982c..5f32c65 100644 --- a/src/ir/Context.cpp +++ b/src/ir/Context.cpp @@ -17,7 +17,7 @@ ConstantInt* Context::GetConstInt(int v) { std::string Context::NextTemp() { std::ostringstream oss; - oss << "%" << ++temp_index_; + oss << "%t" << ++temp_index_; return oss.str(); } diff --git a/src/ir/IRBuilder.cpp b/src/ir/IRBuilder.cpp index 90f03c4..3569ab6 100644 --- a/src/ir/IRBuilder.cpp +++ b/src/ir/IRBuilder.cpp @@ -86,4 +86,62 @@ ReturnInst* IRBuilder::CreateRet(Value* v) { return insert_block_->Append(Type::GetVoidType(), v); } +BinaryInst* IRBuilder::CreateSub(Value* lhs, Value* rhs, const std::string& name) { + return CreateBinary(Opcode::Sub, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateMul(Value* lhs, Value* rhs, const std::string& name) { + return CreateBinary(Opcode::Mul, lhs, rhs, name); +} + +UnaryInst* IRBuilder::CreateNeg(Value* operand, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!operand) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateNeg 缺少操作数")); + } + return insert_block_->Append(Opcode::Neg, Type::GetInt32Type(), operand, name); +} + +CmpInst* IRBuilder::CreateCmp(CmpOp op, Value* lhs, Value* rhs, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!lhs || !rhs) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCmp 缺少操作数")); + } + return insert_block_->Append(op, lhs, rhs, name); +} + +ZextInst* IRBuilder::CreateZext(Value* val, const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!val) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateZext 缺少操作数")); + } + return insert_block_->Append(Type::GetInt32Type(), val, name); +} + +BranchInst* IRBuilder::CreateBr(BasicBlock* dest) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!dest) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateBr 缺少操作数")); + } + return insert_block_->Append(dest); +} + +CondBranchInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + if (!cond || !true_bb || !false_bb) { + throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCondBr 缺少操作数")); + } + return insert_block_->Append(cond, true_bb, false_bb); +} + } // namespace ir diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index 30efbb6..40cfc77 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -16,6 +16,8 @@ static const char* TypeToString(const Type& ty) { switch (ty.GetKind()) { case Type::Kind::Void: return "void"; + case Type::Kind::Int1: + return "i1"; case Type::Kind::Int32: return "i32"; case Type::Kind::PtrInt32: @@ -32,6 +34,12 @@ static const char* OpcodeToString(Opcode op) { return "sub"; case Opcode::Mul: return "mul"; + case Opcode::Div: + return "sdiv"; + case Opcode::Mod: + return "srem"; + case Opcode::Neg: + return "neg"; case Opcode::Alloca: return "alloca"; case Opcode::Load: @@ -40,6 +48,31 @@ static const char* OpcodeToString(Opcode op) { return "store"; case Opcode::Ret: return "ret"; + case Opcode::Cmp: + return "icmp"; + case Opcode::Zext: + return "zext"; + case Opcode::Br: + case Opcode::CondBr: + return "br"; + } + return "?"; +} + +static const char* CmpOpToString(CmpOp op) { + switch (op) { + case CmpOp::Eq: + return "eq"; + case CmpOp::Ne: + return "ne"; + case CmpOp::Lt: + return "slt"; + case CmpOp::Gt: + return "sgt"; + case CmpOp::Le: + return "sle"; + case CmpOp::Ge: + return "sge"; } return "?"; } @@ -51,6 +84,21 @@ static std::string ValueToString(const Value* v) { return v ? v->GetName() : ""; } +static std::string PrintLabel(const Value* bb) { + if (!bb) return ""; + std::string name = bb->GetName(); + if (name.empty()) return ""; + if (name[0] == '%') return name; + return "%" + name; +} + +static std::string PrintLabelDef(const Value* bb) { + if (!bb) return ""; + std::string name = bb->GetName(); + if (!name.empty() && name[0] == '%') return name.substr(1); + return name; +} + void IRPrinter::Print(const Module& module, std::ostream& os) { for (const auto& func : module.GetFunctions()) { os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName() @@ -59,13 +107,15 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { if (!bb) { continue; } - os << bb->GetName() << ":\n"; + os << PrintLabelDef(bb.get()) << ":\n"; for (const auto& instPtr : bb->GetInstructions()) { const auto* inst = instPtr.get(); switch (inst->GetOpcode()) { case Opcode::Add: case Opcode::Sub: - case Opcode::Mul: { + case Opcode::Mul: + case Opcode::Div: + case Opcode::Mod: { auto* bin = static_cast(inst); os << " " << bin->GetName() << " = " << OpcodeToString(bin->GetOpcode()) << " " @@ -74,6 +124,14 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { << ValueToString(bin->GetRhs()) << "\n"; break; } + case Opcode::Neg: { + auto* unary = static_cast(inst); + os << " " << unary->GetName() << " = " + << OpcodeToString(unary->GetOpcode()) << " " + << TypeToString(*unary->GetUnaryOperand()->GetType()) << " " + << ValueToString(unary->GetUnaryOperand()) << "\n"; + break; + } case Opcode::Alloca: { auto* alloca = static_cast(inst); os << " " << alloca->GetName() << " = alloca i32\n"; @@ -97,6 +155,35 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { << ValueToString(ret->GetValue()) << "\n"; break; } + case Opcode::Cmp: { + auto* cmp = static_cast(inst); + os << " " << cmp->GetName() << " = icmp " + << CmpOpToString(cmp->GetCmpOp()) << " " + << TypeToString(*cmp->GetLhs()->GetType()) << " " + << ValueToString(cmp->GetLhs()) << ", " + << ValueToString(cmp->GetRhs()) << "\n"; + break; + } + case Opcode::Zext: { + auto* zext = static_cast(inst); + os << " " << zext->GetName() << " = zext " + << TypeToString(*zext->GetOperand(0)->GetType()) << " " + << ValueToString(zext->GetOperand(0)) << " to " + << TypeToString(*zext->GetType()) << "\n"; + break; + } + case Opcode::Br: { + auto* br = static_cast(inst); + os << " br label " << PrintLabel(br->GetDest()) << "\n"; + break; + } + case Opcode::CondBr: { + auto* cbr = static_cast(inst); + os << " br i1 " << ValueToString(cbr->GetCond()) + << ", label " << PrintLabel(cbr->GetTrueBlock()) + << ", label " << PrintLabel(cbr->GetFalseBlock()) << "\n"; + break; + } } } } diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index 7928716..9ae696c 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -52,7 +52,7 @@ Instruction::Instruction(Opcode op, std::shared_ptr ty, std::string name) Opcode Instruction::GetOpcode() const { return opcode_; } -bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret; } +bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret || opcode_ == Opcode::Br || opcode_ == Opcode::CondBr; } BasicBlock* Instruction::GetParent() const { return parent_; } @@ -61,8 +61,9 @@ void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; } BinaryInst::BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, Value* rhs, std::string name) : Instruction(op, std::move(ty), std::move(name)) { - if (op != Opcode::Add) { - throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add")); + if (op != Opcode::Add && op != Opcode::Sub && op != Opcode::Mul && + op != Opcode::Div && op != Opcode::Mod) { + throw std::runtime_error(FormatError("ir", "BinaryInst 不支持的操作码")); } if (!lhs || !rhs) { throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数")); @@ -85,6 +86,29 @@ Value* BinaryInst::GetLhs() const { return GetOperand(0); } Value* BinaryInst::GetRhs() const { return GetOperand(1); } +UnaryInst::UnaryInst(Opcode op, std::shared_ptr ty, Value* operand, + std::string name) + : Instruction(op, std::move(ty), std::move(name)) { + if (op != Opcode::Neg) { + throw std::runtime_error(FormatError("ir", "UnaryInst 不支持的操作码")); + } + if (!operand) { + throw std::runtime_error(FormatError("ir", "UnaryInst 缺少操作数")); + } + if (!type_ || !operand->GetType()) { + throw std::runtime_error(FormatError("ir", "UnaryInst 缺少类型信息")); + } + if (type_->GetKind() != operand->GetType()->GetKind()) { + throw std::runtime_error(FormatError("ir", "UnaryInst 类型不匹配")); + } + if (!type_->IsInt32()) { + throw std::runtime_error(FormatError("ir", "UnaryInst 当前只支持 i32")); + } + AddOperand(operand); +} + +Value* UnaryInst::GetUnaryOperand() const { return GetOperand(0); } + ReturnInst::ReturnInst(std::shared_ptr void_ty, Value* val) : Instruction(Opcode::Ret, std::move(void_ty), "") { if (!val) { @@ -148,4 +172,63 @@ Value* StoreInst::GetValue() const { return GetOperand(0); } Value* StoreInst::GetPtr() const { return GetOperand(1); } +CmpInst::CmpInst(CmpOp cmp_op, Value* lhs, Value* rhs, std::string name) + : Instruction(Opcode::Cmp, Type::GetInt1Type(), std::move(name)), cmp_op_(cmp_op) { + if (!lhs || !rhs) { + throw std::runtime_error(FormatError("ir", "CmpInst 缺少操作数")); + } + if (!lhs->GetType() || !rhs->GetType()) { + throw std::runtime_error(FormatError("ir", "CmpInst 缺少操作数类型信息")); + } + if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind()) { + throw std::runtime_error(FormatError("ir", "CmpInst 操作数类型不匹配")); + } + AddOperand(lhs); + AddOperand(rhs); +} + +CmpOp CmpInst::GetCmpOp() const { return cmp_op_; } +Value* CmpInst::GetLhs() const { return GetOperand(0); } +Value* CmpInst::GetRhs() const { return GetOperand(1); } + +ZextInst::ZextInst(std::shared_ptr dest_ty, Value* val, std::string name) + : Instruction(Opcode::Zext, std::move(dest_ty), std::move(name)) { + if (!val) { + throw std::runtime_error(FormatError("ir", "ZextInst 缺少操作数")); + } + if (!type_->IsInt32() || !val->GetType()->IsInt1()) { + throw std::runtime_error(FormatError("ir", "ZextInst 当前只支持 i1 到 i32")); + } + AddOperand(val); +} + +Value* ZextInst::GetValue() const { return GetOperand(0); } + +BranchInst::BranchInst(BasicBlock* dest) + : Instruction(Opcode::Br, Type::GetVoidType(), "") { + if (!dest) { + throw std::runtime_error(FormatError("ir", "BranchInst 缺少目的块")); + } + AddOperand(dest); +} + +BasicBlock* BranchInst::GetDest() const { return static_cast(GetOperand(0)); } + +CondBranchInst::CondBranchInst(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb) + : Instruction(Opcode::CondBr, Type::GetVoidType(), "") { + if (!cond || !true_bb || !false_bb) { + throw std::runtime_error(FormatError("ir", "CondBranchInst 缺少连边操作数")); + } + if (!cond->GetType()->IsInt1()) { + throw std::runtime_error(FormatError("ir", "CondBranchInst 必须使用 i1 作为条件")); + } + AddOperand(cond); + AddOperand(true_bb); + AddOperand(false_bb); +} + +Value* CondBranchInst::GetCond() const { return GetOperand(0); } +BasicBlock* CondBranchInst::GetTrueBlock() const { return static_cast(GetOperand(1)); } +BasicBlock* CondBranchInst::GetFalseBlock() const { return static_cast(GetOperand(2)); } + } // namespace ir diff --git a/src/ir/Type.cpp b/src/ir/Type.cpp index 3e1684d..c32d640 100644 --- a/src/ir/Type.cpp +++ b/src/ir/Type.cpp @@ -10,6 +10,11 @@ const std::shared_ptr& Type::GetVoidType() { return type; } +const std::shared_ptr& Type::GetInt1Type() { + static const std::shared_ptr type = std::make_shared(Kind::Int1); + return type; +} + const std::shared_ptr& Type::GetInt32Type() { static const std::shared_ptr type = std::make_shared(Kind::Int32); return type; @@ -24,6 +29,8 @@ Type::Kind Type::GetKind() const { return kind_; } bool Type::IsVoid() const { return kind_ == Kind::Void; } +bool Type::IsInt1() const { return kind_ == Kind::Int1; } + bool Type::IsInt32() const { return kind_ == Kind::Int32; } bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; } diff --git a/src/ir/Value.cpp b/src/ir/Value.cpp index 2e9f4c1..12a06b4 100644 --- a/src/ir/Value.cpp +++ b/src/ir/Value.cpp @@ -18,6 +18,8 @@ void Value::SetName(std::string n) { name_ = std::move(n); } bool Value::IsVoid() const { return type_ && type_->IsVoid(); } +bool Value::IsInt1() const { return type_ && type_->IsInt1(); } + bool Value::IsInt32() const { return type_ && type_->IsInt32(); } bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); } diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index 0eb62ae..1cd0db8 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -6,30 +6,20 @@ #include "ir/IR.h" #include "utils/Log.h" -namespace { - -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { - throw std::runtime_error(FormatError("irgen", "非法左值")); - } - return lvalue.ID()->getText(); -} - -} // namespace - -std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { +std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句块")); } + bool terminated = false; for (auto* item : ctx->blockItem()) { if (item) { if (VisitBlockItemResult(*item) == BlockFlow::Terminated) { - // 当前语法要求 return 为块内最后一条语句;命中后可停止生成。 + terminated = true; break; } } } - return {}; + return terminated ? BlockFlow::Terminated : BlockFlow::Continue; } IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult( @@ -63,15 +53,21 @@ std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少变量声明")); } - if (!ctx->btype() || !ctx->btype()->INT()) { + // 当前语法中 decl 包含 constDecl 或 varDecl,这里只支持 varDecl + auto* var_decl = ctx->varDecl(); + if (!var_decl) { + throw std::runtime_error(FormatError("irgen", "当前仅支持变量声明")); + } + if (!var_decl->bType() || !var_decl->bType()->INT()) { throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明")); } - auto* var_def = ctx->varDef(); - if (!var_def) { - throw std::runtime_error(FormatError("irgen", "非法变量声明")); + // 遍历所有 varDef + for (auto* var_def : var_decl->varDef()) { + if (var_def) { + var_def->accept(this); + } } - var_def->accept(this); - return {}; + return BlockFlow::Continue; } @@ -83,25 +79,29 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少变量定义")); } - if (!ctx->lValue()) { + if (!ctx->ID()) { throw std::runtime_error(FormatError("irgen", "变量声明缺少名称")); } - GetLValueName(*ctx->lValue()); - if (storage_map_.find(ctx) != storage_map_.end()) { + // 暂不支持数组声明(constIndex) + if (!ctx->constIndex().empty()) { + throw std::runtime_error(FormatError("irgen", "暂不支持数组声明")); + } + std::string var_name = ctx->ID()->getText(); + if (storage_map_.find(var_name) != storage_map_.end()) { throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); } auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); - storage_map_[ctx] = slot; + storage_map_[var_name] = slot; ir::Value* init = nullptr; - if (auto* init_value = ctx->initValue()) { - if (!init_value->exp()) { + if (auto* init_val = ctx->initVal()) { + if (!init_val->exp()) { throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化")); } - init = EvalExpr(*init_value->exp()); + init = EvalExpr(*init_val->exp()); } else { init = builder_.CreateConstInt(0); } builder_.CreateStore(init, slot); - return {}; + return BlockFlow::Continue; } diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index cf4797c..4565c6e 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -25,20 +25,51 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { } -std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "非法括号表达式")); +std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法基本表达式")); } - return EvalExpr(*ctx->exp()); + // 处理括号表达式:LPAREN exp RPAREN + if (ctx->exp()) { + return EvalExpr(*ctx->exp()); + } + // 处理 lVal(变量使用)- 交给 visitLVal 处理 + if (ctx->lVal()) { + // 直接在这里处理变量读取,避免 accept 调用可能导致的问题 + auto* lval_ctx = ctx->lVal(); + if (!lval_ctx || !lval_ctx->ID()) { + throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); + } + const auto* decl = sema_.ResolveObjectUse(lval_ctx); + if (!decl) { + throw std::runtime_error( + FormatError("irgen", + "变量使用缺少语义绑定:" + lval_ctx->ID()->getText())); + } + std::string var_name = lval_ctx->ID()->getText(); + auto it = storage_map_.find(var_name); + if (it == storage_map_.end()) { + throw std::runtime_error( + FormatError("irgen", + "变量声明缺少存储槽位:" + var_name)); + } + return static_cast( + builder_.CreateLoad(it->second, module_.GetContext().NextTemp())); + } + // 处理 number + if (ctx->number()) { + return ctx->number()->accept(this); + } + throw std::runtime_error(FormatError("irgen", "不支持的基本表达式类型")); } -std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { +std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) { + if (!ctx || !ctx->intConst()) { throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量")); } return static_cast( - builder_.CreateConstInt(std::stoi(ctx->number()->getText()))); + builder_.CreateConstInt(std::stoi(ctx->intConst()->getText()))); } // 变量使用的处理流程: @@ -47,34 +78,252 @@ std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) { // 3. 最后生成 load,把内存中的值读出来。 // // 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。 -std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) { - if (!ctx || !ctx->var() || !ctx->var()->ID()) { +std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { + if (!ctx || !ctx->ID()) { throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); } - auto* decl = sema_.ResolveVarUse(ctx->var()); + const auto* decl = sema_.ResolveObjectUse(ctx); if (!decl) { throw std::runtime_error( FormatError("irgen", - "变量使用缺少语义绑定: " + ctx->var()->ID()->getText())); + "变量使用缺少语义绑定:" + ctx->ID()->getText())); } - auto it = storage_map_.find(decl); + // 使用变量名查找存储槽位 + std::string var_name = ctx->ID()->getText(); + auto it = storage_map_.find(var_name); if (it == storage_map_.end()) { throw std::runtime_error( FormatError("irgen", - "变量声明缺少存储槽位: " + ctx->var()->ID()->getText())); + "变量声明缺少存储槽位:" + var_name)); } return static_cast( builder_.CreateLoad(it->second, module_.GetContext().NextTemp())); } -std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { - throw std::runtime_error(FormatError("irgen", "非法加法表达式")); +std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法加减法表达式")); + } + + // 如果是 mulExp 直接返回(addExp : mulExp) + if (ctx->mulExp() && ctx->addExp() == nullptr) { + return ctx->mulExp()->accept(this); + } + + // 处理 addExp op mulExp 的递归形式 + if (!ctx->addExp() || !ctx->mulExp()) { + throw std::runtime_error(FormatError("irgen", "非法加减法表达式结构")); + } + + ir::Value* lhs = std::any_cast(ctx->addExp()->accept(this)); + ir::Value* rhs = std::any_cast(ctx->mulExp()->accept(this)); + + ir::Opcode op = ir::Opcode::Add; + if (ctx->ADD()) { + op = ir::Opcode::Add; + } else if (ctx->SUB()) { + op = ir::Opcode::Sub; + } else { + throw std::runtime_error(FormatError("irgen", "未知的加减运算符")); } - ir::Value* lhs = EvalExpr(*ctx->exp(0)); - ir::Value* rhs = EvalExpr(*ctx->exp(1)); + return static_cast( - builder_.CreateBinary(ir::Opcode::Add, lhs, rhs, - module_.GetContext().NextTemp())); + builder_.CreateBinary(op, lhs, rhs, module_.GetContext().NextTemp())); +} + + +std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法一元表达式")); + } + + // 如果是 primaryExp 直接返回(unaryExp : primaryExp) + if (ctx->primaryExp()) { + return ctx->primaryExp()->accept(this); + } + + // 处理函数调用(unaryExp : ID LPAREN funcRParams? RPAREN) + // 当前暂不支持,留给后续扩展 + if (ctx->ID()) { + throw std::runtime_error(FormatError("irgen", "暂不支持函数调用")); + } + + // 处理一元运算符(unaryExp : addUnaryOp unaryExp) + if (ctx->addUnaryOp() && ctx->unaryExp()) { + ir::Value* operand = std::any_cast(ctx->unaryExp()->accept(this)); + + // 判断是正号还是负号 + if (ctx->addUnaryOp()->SUB()) { + // 负号:生成 sub 0, operand(LLVM IR 中没有 neg 指令) + ir::Value* zero = builder_.CreateConstInt(0); + return static_cast( + builder_.CreateSub(zero, operand, module_.GetContext().NextTemp())); + } else if (ctx->addUnaryOp()->ADD()) { + // 正号:直接返回操作数(+x 等价于 x) + return operand; + } else { + throw std::runtime_error(FormatError("irgen", "未知的一元运算符")); + } + } + + throw std::runtime_error(FormatError("irgen", "不支持的一元表达式类型")); +} + +std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法乘除法表达式")); + } + + // 如果是 unaryExp 直接返回(mulExp : unaryExp) + if (ctx->unaryExp() && ctx->mulExp() == nullptr) { + return ctx->unaryExp()->accept(this); + } + + // 处理 mulExp op unaryExp 的递归形式 + if (!ctx->mulExp() || !ctx->unaryExp()) { + throw std::runtime_error(FormatError("irgen", "非法乘除法表达式结构")); + } + + ir::Value* lhs = std::any_cast(ctx->mulExp()->accept(this)); + ir::Value* rhs = std::any_cast(ctx->unaryExp()->accept(this)); + + ir::Opcode op = ir::Opcode::Mul; + if (ctx->MUL()) { + op = ir::Opcode::Mul; + } else if (ctx->DIV()) { + op = ir::Opcode::Div; + } else if (ctx->MOD()) { + op = ir::Opcode::Mod; + } else { + throw std::runtime_error(FormatError("irgen", "未知的乘除运算符")); + } + + return static_cast( + builder_.CreateBinary(op, lhs, rhs, module_.GetContext().NextTemp())); +} + +std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { + if (ctx->addExp() && ctx->relExp() == nullptr) { + return ctx->addExp()->accept(this); + } + ir::Value* lhs = std::any_cast(ctx->relExp()->accept(this)); + ir::Value* rhs = std::any_cast(ctx->addExp()->accept(this)); + if (lhs->GetType()->IsInt1()) lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp()); + if (rhs->GetType()->IsInt1()) rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp()); + + ir::CmpOp op; + if (ctx->LT()) op = ir::CmpOp::Lt; + else if (ctx->GT()) op = ir::CmpOp::Gt; + else if (ctx->LE()) op = ir::CmpOp::Le; + else if (ctx->GE()) op = ir::CmpOp::Ge; + else throw std::runtime_error(FormatError("irgen", "未知的关系运算符")); + + return static_cast(builder_.CreateCmp(op, lhs, rhs, module_.GetContext().NextTemp())); +} + +std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { + if (ctx->relExp() && ctx->eqExp() == nullptr) { + return ctx->relExp()->accept(this); + } + ir::Value* lhs = std::any_cast(ctx->eqExp()->accept(this)); + ir::Value* rhs = std::any_cast(ctx->relExp()->accept(this)); + if (lhs->GetType()->IsInt1()) lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp()); + if (rhs->GetType()->IsInt1()) rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp()); + + ir::CmpOp op; + if (ctx->EQ()) op = ir::CmpOp::Eq; + else if (ctx->NE()) op = ir::CmpOp::Ne; + else throw std::runtime_error(FormatError("irgen", "未知的相等运算符")); + + return static_cast(builder_.CreateCmp(op, lhs, rhs, module_.GetContext().NextTemp())); } + +std::any IRGenImpl::visitCondUnaryExp(SysYParser::CondUnaryExpContext* ctx) { + if (ctx->eqExp()) { + return ctx->eqExp()->accept(this); + } + if (ctx->NOT()) { + ir::Value* operand = std::any_cast(ctx->condUnaryExp()->accept(this)); + if (operand->GetType()->IsInt1()) { + operand = builder_.CreateZext(operand, module_.GetContext().NextTemp()); + } + ir::Value* zero = builder_.CreateConstInt(0); + return static_cast(builder_.CreateCmp(ir::CmpOp::Eq, operand, zero, module_.GetContext().NextTemp())); + } + throw std::runtime_error(FormatError("irgen", "非法条件一元表达式")); +} + +std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { + if (ctx->condUnaryExp() && ctx->lAndExp() == nullptr) { + return ctx->condUnaryExp()->accept(this); + } + + ir::AllocaInst* res_ptr = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); + ir::Value* zero = builder_.CreateConstInt(0); + builder_.CreateStore(zero, res_ptr); + + ir::BasicBlock* rhs_bb = func_->CreateBlock(NextBlockName("land_rhs")); + ir::BasicBlock* end_bb = func_->CreateBlock(NextBlockName("land_end")); + + ir::Value* lhs = std::any_cast(ctx->lAndExp()->accept(this)); + if (lhs->GetType()->IsInt1()) { + lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp()); + } + ir::Value* lhs_cond = builder_.CreateCmp(ir::CmpOp::Ne, lhs, zero, module_.GetContext().NextTemp()); + builder_.CreateCondBr(lhs_cond, rhs_bb, end_bb); + + builder_.SetInsertPoint(rhs_bb); + ir::Value* rhs = std::any_cast(ctx->condUnaryExp()->accept(this)); + if (rhs->GetType()->IsInt1()) { + rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp()); + } + ir::Value* rhs_cond = builder_.CreateCmp(ir::CmpOp::Ne, rhs, zero, module_.GetContext().NextTemp()); + ir::Value* rhs_res = builder_.CreateZext(rhs_cond, module_.GetContext().NextTemp()); + builder_.CreateStore(rhs_res, res_ptr); + builder_.CreateBr(end_bb); + + builder_.SetInsertPoint(end_bb); + return static_cast(builder_.CreateLoad(res_ptr, module_.GetContext().NextTemp())); +} + +std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { + if (ctx->lAndExp() && ctx->lOrExp() == nullptr) { + return ctx->lAndExp()->accept(this); + } + + ir::AllocaInst* res_ptr = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); + ir::Value* one = builder_.CreateConstInt(1); + builder_.CreateStore(one, res_ptr); + + ir::BasicBlock* rhs_bb = func_->CreateBlock(NextBlockName("lor_rhs")); + ir::BasicBlock* end_bb = func_->CreateBlock(NextBlockName("lor_end")); + + ir::Value* lhs = std::any_cast(ctx->lOrExp()->accept(this)); + ir::Value* zero = builder_.CreateConstInt(0); + if (lhs->GetType()->IsInt1()) { + lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp()); + } + ir::Value* lhs_cond = builder_.CreateCmp(ir::CmpOp::Eq, lhs, zero, module_.GetContext().NextTemp()); + builder_.CreateCondBr(lhs_cond, rhs_bb, end_bb); + + builder_.SetInsertPoint(rhs_bb); + ir::Value* rhs = std::any_cast(ctx->lAndExp()->accept(this)); + if (rhs->GetType()->IsInt1()) { + rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp()); + } + ir::Value* rhs_cond = builder_.CreateCmp(ir::CmpOp::Ne, rhs, zero, module_.GetContext().NextTemp()); + ir::Value* rhs_res = builder_.CreateZext(rhs_cond, module_.GetContext().NextTemp()); + builder_.CreateStore(rhs_res, res_ptr); + builder_.CreateBr(end_bb); + + builder_.SetInsertPoint(end_bb); + return static_cast(builder_.CreateLoad(res_ptr, module_.GetContext().NextTemp())); +} + +std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) { + if (!ctx || !ctx->lOrExp()) { + throw std::runtime_error(FormatError("irgen", "非法条件表达式")); + } + return ctx->lOrExp()->accept(this); +} \ No newline at end of file diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index 4912d03..4ee5b3e 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -29,7 +29,7 @@ IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema) // 编译单元的 IR 生成当前只实现了最小功能: // - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容; -// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR; +// - 当前会读取编译单元中的 topLevelItem,找到 funcDef 后生成函数 IR; // // 当前还没有实现: // - 多个函数定义的遍历与生成; @@ -38,12 +38,15 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少编译单元")); } - auto* func = ctx->funcDef(); - if (!func) { - throw std::runtime_error(FormatError("irgen", "缺少函数定义")); + // 遍历所有 topLevelItem,找到 funcDef + for (auto* item : ctx->topLevelItem()) { + if (item && item->funcDef()) { + item->funcDef()->accept(this); + // 当前只支持单个函数,找到第一个后就返回 + return {}; + } } - func->accept(this); - return {}; + throw std::runtime_error(FormatError("irgen", "缺少函数定义")); } // 函数 IR 生成当前实现了: @@ -61,12 +64,11 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { // - 入口块中的参数初始化逻辑。 // ... -// 因此这里目前只支持最小的“无参 int 函数”生成。 std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少函数定义")); } - if (!ctx->blockStmt()) { + if (!ctx->block()) { throw std::runtime_error(FormatError("irgen", "函数体为空")); } if (!ctx->ID()) { @@ -80,7 +82,7 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { builder_.SetInsertPoint(func_->GetEntry()); storage_map_.clear(); - ctx->blockStmt()->accept(this); + ctx->block()->accept(this); // 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。 VerifyFunctionStructure(*func_); return {}; diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index 751550c..e44bd0a 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -9,9 +9,9 @@ // 语句生成当前只实现了最小子集。 // 目前支持: // - return ; +// - 赋值语句:lVal = exp; // // 还未支持: -// - 赋值语句 // - if / while 等控制流 // - 空语句、块语句嵌套分发之外的更多语句形态 @@ -19,21 +19,134 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句")); } - if (ctx->returnStmt()) { - return ctx->returnStmt()->accept(this); + + if (ctx->lVal() && ctx->ASSIGN()) { + if (!ctx->exp()) { + throw std::runtime_error(FormatError("irgen", "赋值语句缺少表达式")); + } + ir::Value* rhs = EvalExpr(*ctx->exp()); + auto* lval_ctx = ctx->lVal(); + if (!lval_ctx || !lval_ctx->ID()) { + throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量赋值")); + } + const auto* decl = sema_.ResolveObjectUse(lval_ctx); + if (!decl) { + throw std::runtime_error( + FormatError("irgen", "变量使用缺少语义绑定:" + lval_ctx->ID()->getText())); + } + std::string var_name = lval_ctx->ID()->getText(); + auto it = storage_map_.find(var_name); + if (it == storage_map_.end()) { + throw std::runtime_error( + FormatError("irgen", "变量声明缺少存储槽位:" + var_name)); + } + builder_.CreateStore(rhs, it->second); + return BlockFlow::Continue; } - throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); -} - - -std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少 return 语句")); + + if (ctx->IF()) { + ir::Value* cond_val = std::any_cast(ctx->cond()->accept(this)); + // cond_val must be i1, if it's not we need to check if it's != 0 + if (cond_val->GetType()->IsInt32()) { + ir::Value* zero = builder_.CreateConstInt(0); + cond_val = builder_.CreateCmp(ir::CmpOp::Ne, cond_val, zero, module_.GetContext().NextTemp()); + } + + ir::BasicBlock* then_bb = func_->CreateBlock(NextBlockName("if_then")); + ir::BasicBlock* else_bb = ctx->ELSE() ? func_->CreateBlock(NextBlockName("if_else")) : nullptr; + ir::BasicBlock* merge_bb = func_->CreateBlock(NextBlockName("if_merge")); + + builder_.CreateCondBr(cond_val, then_bb, else_bb ? else_bb : merge_bb); + + builder_.SetInsertPoint(then_bb); + auto then_flow = std::any_cast(ctx->stmt(0)->accept(this)); + if (then_flow == BlockFlow::Continue) { + builder_.CreateBr(merge_bb); + } + + if (ctx->ELSE()) { + builder_.SetInsertPoint(else_bb); + auto else_flow = std::any_cast(ctx->stmt(1)->accept(this)); + if (else_flow == BlockFlow::Continue) { + builder_.CreateBr(merge_bb); + } + } + + builder_.SetInsertPoint(merge_bb); + return BlockFlow::Continue; + } + + if (ctx->WHILE()) { + ir::BasicBlock* cond_bb = func_->CreateBlock(NextBlockName("while_cond")); + ir::BasicBlock* body_bb = func_->CreateBlock(NextBlockName("while_body")); + ir::BasicBlock* exit_bb = func_->CreateBlock(NextBlockName("while_exit")); + + builder_.CreateBr(cond_bb); + builder_.SetInsertPoint(cond_bb); + + ir::Value* cond_val = std::any_cast(ctx->cond()->accept(this)); + if (cond_val->GetType()->IsInt32()) { + ir::Value* zero = builder_.CreateConstInt(0); + cond_val = builder_.CreateCmp(ir::CmpOp::Ne, cond_val, zero, module_.GetContext().NextTemp()); + } + builder_.CreateCondBr(cond_val, body_bb, exit_bb); + + builder_.SetInsertPoint(body_bb); + ir::BasicBlock* old_cond = current_loop_cond_bb_; + ir::BasicBlock* old_exit = current_loop_exit_bb_; + current_loop_cond_bb_ = cond_bb; + current_loop_exit_bb_ = exit_bb; + + auto body_flow = std::any_cast(ctx->stmt(0)->accept(this)); + if (body_flow == BlockFlow::Continue) { + builder_.CreateBr(cond_bb); + } + + current_loop_cond_bb_ = old_cond; + current_loop_exit_bb_ = old_exit; + + builder_.SetInsertPoint(exit_bb); + return BlockFlow::Continue; } - if (!ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "return 缺少表达式")); + + if (ctx->BREAK()) { + if (!current_loop_exit_bb_) { + throw std::runtime_error(FormatError("irgen", "break 必须在循环内")); + } + builder_.CreateBr(current_loop_exit_bb_); + return BlockFlow::Terminated; } - ir::Value* v = EvalExpr(*ctx->exp()); - builder_.CreateRet(v); - return BlockFlow::Terminated; + + if (ctx->CONTINUE()) { + if (!current_loop_cond_bb_) { + throw std::runtime_error(FormatError("irgen", "continue 必须在循环内")); + } + builder_.CreateBr(current_loop_cond_bb_); + return BlockFlow::Terminated; + } + + if (ctx->RETURN()) { + if (ctx->exp()) { + ir::Value* v = EvalExpr(*ctx->exp()); + builder_.CreateRet(v); + } else { + throw std::runtime_error(FormatError("irgen", "暂不支持 void return")); + } + return BlockFlow::Terminated; + } + + if (ctx->block()) { + return ctx->block()->accept(this); + } + + if (ctx->exp()) { + EvalExpr(*ctx->exp()); + return BlockFlow::Continue; + } + + if (ctx->SEMICOLON()) { + return BlockFlow::Continue; + } + + throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); } diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index 745374c..95f0629 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -1,8 +1,13 @@ #include "sem/Sema.h" #include +#include +#include #include #include +#include +#include +#include #include "SysYBaseVisitor.h" #include "sem/SymbolTable.h" @@ -10,74 +15,258 @@ namespace { -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { - throw std::runtime_error(FormatError("sema", "非法左值")); +constexpr int kUnknownArrayDim = -1; + +struct ExprInfo { + SemanticType type = SemanticType::Int; + bool is_lvalue = false; + bool is_const_object = false; + std::vector dimensions; + bool has_const_value = false; + ScalarConstant const_value; + + bool IsScalar() const { return dimensions.empty() && type != SemanticType::Void; } + bool IsArray() const { return !dimensions.empty(); } +}; + +SemanticType ParseBType(SysYParser::BTypeContext& ctx) { + if (ctx.INT()) { + return SemanticType::Int; + } + if (ctx.FLOAT()) { + return SemanticType::Float; + } + throw std::runtime_error(FormatError("sema", "未知基础类型")); +} + +SemanticType ParseFuncType(SysYParser::FuncTypeContext& ctx) { + if (ctx.VOID()) { + return SemanticType::Void; + } + if (ctx.INT()) { + return SemanticType::Int; + } + if (ctx.FLOAT()) { + return SemanticType::Float; + } + throw std::runtime_error(FormatError("sema", "未知函数返回类型")); +} + +int ConvertToInt(const ScalarConstant& value) { + return static_cast(value.number); +} + +double ConvertToFloat(const ScalarConstant& value) { return value.number; } + +bool IsNumericType(SemanticType type) { + return type == SemanticType::Int || type == SemanticType::Float; +} + +bool CanImplicitlyConvert(SemanticType from, SemanticType to) { + if (from == to) { + return true; + } + if (!IsNumericType(from) || !IsNumericType(to)) { + return false; + } + return true; +} + +ScalarConstant CastConstant(const ScalarConstant& value, SemanticType to) { + if (!CanImplicitlyConvert(value.type, to)) { + throw std::runtime_error(FormatError("sema", "非法常量类型转换")); + } + ScalarConstant result; + result.type = to; + result.number = to == SemanticType::Int ? static_cast(ConvertToInt(value)) + : ConvertToFloat(value); + return result; +} + +bool IsTrue(const ScalarConstant& value) { + if (value.type == SemanticType::Float) { + return value.number != 0.0; + } + return ConvertToInt(value) != 0; +} + +ScalarConstant MakeInt(int value) { + return ScalarConstant{SemanticType::Int, static_cast(value)}; +} + +ScalarConstant MakeFloat(double value) { + return ScalarConstant{SemanticType::Float, value}; +} + +const antlr4::Token* StartToken(const antlr4::ParserRuleContext* ctx) { + return ctx ? ctx->getStart() : nullptr; +} + +[[noreturn]] void ThrowSemaError(const antlr4::ParserRuleContext* ctx, + std::string_view msg) { + if (const auto* tok = StartToken(ctx)) { + throw std::runtime_error( + FormatErrorAt("sema", tok->getLine(), tok->getCharPositionInLine(), msg)); + } + throw std::runtime_error(FormatError("sema", msg)); +} + +int ParseIntLiteral(SysYParser::IntConstContext& ctx) { + return std::stoi(ctx.getText(), nullptr, 0); +} + +double ParseFloatLiteral(SysYParser::FloatConstContext& ctx) { + const std::string text = ctx.getText(); + char* end = nullptr; + const double value = std::strtod(text.c_str(), &end); + if (end == nullptr || *end != '\0') { + throw std::runtime_error(FormatError("sema", "非法浮点字面量: " + text)); } - return lvalue.ID()->getText(); + return value; +} + +FunctionBinding MakeBuiltinFunction(std::string name, SemanticType return_type, + std::vector params) { + FunctionBinding fn; + fn.name = std::move(name); + fn.return_type = return_type; + fn.params = std::move(params); + fn.is_builtin = true; + return fn; +} + +ObjectBinding MakeParam(std::string name, SemanticType type, + std::vector dimensions = {}, + bool is_array_param = false) { + ObjectBinding param; + param.name = std::move(name); + param.type = type; + param.decl_kind = ObjectBinding::DeclKind::Param; + param.dimensions = std::move(dimensions); + param.is_array_param = is_array_param; + return param; } class SemaVisitor final : public SysYBaseVisitor { public: + SemaVisitor() { RegisterBuiltins(); } + std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { if (!ctx) { - throw std::runtime_error(FormatError("sema", "缺少编译单元")); + ThrowSemaError(ctx, "缺少编译单元"); + } + + CollectFunctions(*ctx); + for (auto* item : ctx->topLevelItem()) { + if (!item) { + continue; + } + item->accept(this); + } + + const FunctionBinding* main = sema_.ResolveFunction("main"); + if (!main || main->is_builtin) { + ThrowSemaError(ctx, "缺少 main 函数定义"); } - auto* func = ctx->funcDef(); - if (!func || !func->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); + if (main->return_type != SemanticType::Int || !main->params.empty()) { + ThrowSemaError(main->func_def, "main 函数必须是无参 int main()"); + } + return {}; + } + + std::any visitTopLevelItem(SysYParser::TopLevelItemContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "缺少顶层定义"); } - if (!func->ID() || func->ID()->getText() != "main") { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); + if (ctx->decl()) { + ctx->decl()->accept(this); + return {}; + } + if (ctx->funcDef()) { + ctx->funcDef()->accept(this); + return {}; + } + ThrowSemaError(ctx, "暂不支持的顶层定义"); + } + + std::any visitDecl(SysYParser::DeclContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "缺少声明"); + } + if (ctx->constDecl()) { + ctx->constDecl()->accept(this); + return {}; + } + if (ctx->varDecl()) { + ctx->varDecl()->accept(this); + return {}; + } + ThrowSemaError(ctx, "非法声明"); + } + + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override { + if (!ctx || !ctx->bType()) { + ThrowSemaError(ctx, "非法常量声明"); + } + const SemanticType type = ParseBType(*ctx->bType()); + for (auto* def : ctx->constDef()) { + DeclareConst(*def, type); + } + return {}; + } + + std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override { + if (!ctx || !ctx->bType()) { + ThrowSemaError(ctx, "非法变量声明"); } - func->accept(this); - if (!seen_return_) { - throw std::runtime_error( - FormatError("sema", "main 函数必须包含 return 语句")); + const SemanticType type = ParseBType(*ctx->bType()); + for (auto* def : ctx->varDef()) { + DeclareVar(*def, type); } return {}; } std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { - if (!ctx || !ctx->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); + if (!ctx || !ctx->ID() || !ctx->funcType() || !ctx->block()) { + ThrowSemaError(ctx, "非法函数定义"); } - if (!ctx->funcType() || !ctx->funcType()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持 int main")); + + const FunctionBinding* binding = sema_.ResolveFunction(ctx->ID()->getText()); + if (!binding) { + ThrowSemaError(ctx, "函数未完成预收集: " + ctx->ID()->getText()); } - const auto& items = ctx->blockStmt()->blockItem(); - if (items.empty()) { - throw std::runtime_error( - FormatError("sema", "main 函数不能为空,且必须以 return 结束")); + + const FunctionBinding* prev = current_function_; + current_function_ = binding; + symbols_.EnterScope(); + for (const auto& param : binding->params) { + if (!symbols_.Add(param)) { + ThrowSemaError(ctx, "函数形参重复定义: " + param.name); + } } - ctx->blockStmt()->accept(this); + ctx->block()->accept(this); + symbols_.ExitScope(); + current_function_ = prev; return {}; } - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override { + std::any visitBlock(SysYParser::BlockContext* ctx) override { if (!ctx) { - throw std::runtime_error(FormatError("sema", "缺少语句块")); + ThrowSemaError(ctx, "缺少语句块"); } - const auto& items = ctx->blockItem(); - for (size_t i = 0; i < items.size(); ++i) { - auto* item = items[i]; - if (!item) { - continue; - } - if (seen_return_) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); + symbols_.EnterScope(); + for (auto* item : ctx->blockItem()) { + if (item) { + item->accept(this); } - current_item_index_ = i; - total_items_ = items.size(); - item->accept(this); } + symbols_.ExitScope(); return {}; } std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { if (!ctx) { - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + ThrowSemaError(ctx, "缺少块内语句"); } if (ctx->decl()) { ctx->decl()->accept(this); @@ -87,112 +276,770 @@ class SemaVisitor final : public SysYBaseVisitor { ctx->stmt()->accept(this); return {}; } - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + ThrowSemaError(ctx, "非法块内语句"); } - std::any visitDecl(SysYParser::DeclContext* ctx) override { + std::any visitStmt(SysYParser::StmtContext* ctx) override { if (!ctx) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); + ThrowSemaError(ctx, "缺少语句"); } - if (!ctx->btype() || !ctx->btype()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明")); + + if (ctx->BREAK()) { + if (loop_depth_ == 0) { + ThrowSemaError(ctx, "break 只能出现在循环内部"); + } + return {}; + } + if (ctx->CONTINUE()) { + if (loop_depth_ == 0) { + ThrowSemaError(ctx, "continue 只能出现在循环内部"); + } + return {}; } - auto* var_def = ctx->varDef(); - if (!var_def || !var_def->lValue()) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); + if (ctx->RETURN()) { + CheckReturn(*ctx); + return {}; } - const std::string name = GetLValueName(*var_def->lValue()); - if (table_.Contains(name)) { - throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); + if (ctx->WHILE()) { + RequireScalar(ctx->cond(), EvalCond(*ctx->cond()), "while 条件必须是标量表达式"); + ++loop_depth_; + ctx->stmt(0)->accept(this); + --loop_depth_; + return {}; } - if (auto* init = var_def->initValue()) { - if (!init->exp()) { - throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化")); + if (ctx->IF()) { + RequireScalar(ctx->cond(), EvalCond(*ctx->cond()), "if 条件必须是标量表达式"); + ctx->stmt(0)->accept(this); + if (ctx->stmt().size() > 1 && ctx->stmt(1)) { + ctx->stmt(1)->accept(this); } - init->exp()->accept(this); + return {}; + } + if (ctx->block()) { + ctx->block()->accept(this); + return {}; + } + if (ctx->lVal() && ctx->ASSIGN()) { + CheckAssignment(*ctx); + return {}; + } + if (ctx->exp()) { + EvalExpr(*ctx->exp()); + return {}; } - table_.Add(name, var_def); return {}; } - std::any visitStmt(SysYParser::StmtContext* ctx) override { - if (!ctx || !ctx->returnStmt()) { - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); + std::any visitExp(SysYParser::ExpContext* ctx) override { + if (!ctx || !ctx->addExp()) { + ThrowSemaError(ctx, "非法表达式"); } - ctx->returnStmt()->accept(this); - return {}; + return EvalExpr(*ctx->addExp()); } - std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("sema", "return 缺少表达式")); + std::any visitCond(SysYParser::CondContext* ctx) override { + if (!ctx || !ctx->lOrExp()) { + ThrowSemaError(ctx, "非法条件表达式"); } - ctx->exp()->accept(this); - seen_return_ = true; - if (current_item_index_ + 1 != total_items_) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); + return EvalExpr(*ctx->lOrExp()); + } + + std::any visitLVal(SysYParser::LValContext* ctx) override { + if (!ctx || !ctx->ID()) { + ThrowSemaError(ctx, "非法左值"); } - return {}; + return AnalyzeLVal(*ctx); } - std::any visitParenExp(SysYParser::ParenExpContext* ctx) override { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("sema", "非法括号表达式")); + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法基础表达式"); } - ctx->exp()->accept(this); - return {}; + if (ctx->exp()) { + return EvalExpr(*ctx->exp()); + } + if (ctx->lVal()) { + return AnalyzeLVal(*ctx->lVal()); + } + if (ctx->number()) { + return EvalExpr(*ctx->number()); + } + ThrowSemaError(ctx, "非法基础表达式"); } - std::any visitVarExp(SysYParser::VarExpContext* ctx) override { - if (!ctx || !ctx->var()) { - throw std::runtime_error(FormatError("sema", "非法变量表达式")); + std::any visitNumber(SysYParser::NumberContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法数字字面量"); } - ctx->var()->accept(this); - return {}; + if (ctx->intConst()) { + return EvalExpr(*ctx->intConst()); + } + if (ctx->floatConst()) { + return EvalExpr(*ctx->floatConst()); + } + ThrowSemaError(ctx, "非法数字字面量"); } - std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { - throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量")); + std::any visitIntConst(SysYParser::IntConstContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法整数字面量"); } - return {}; + ExprInfo expr; + expr.type = SemanticType::Int; + expr.has_const_value = true; + expr.const_value = MakeInt(ParseIntLiteral(*ctx)); + return expr; } - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { - throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式")); + std::any visitFloatConst(SysYParser::FloatConstContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法浮点字面量"); } - ctx->exp(0)->accept(this); - ctx->exp(1)->accept(this); - return {}; + ExprInfo expr; + expr.type = SemanticType::Float; + expr.has_const_value = true; + expr.const_value = MakeFloat(ParseFloatLiteral(*ctx)); + return expr; } - std::any visitVar(SysYParser::VarContext* ctx) override { - if (!ctx || !ctx->ID()) { - throw std::runtime_error(FormatError("sema", "非法变量引用")); + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法一元表达式"); } - const std::string name = ctx->ID()->getText(); - auto* decl = table_.Lookup(name); - if (!decl) { - throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); + if (ctx->primaryExp()) { + return EvalExpr(*ctx->primaryExp()); } - sema_.BindVarUse(ctx, decl); - return {}; + if (ctx->ID()) { + return AnalyzeCall(*ctx); + } + if (ctx->addUnaryOp() && ctx->unaryExp()) { + ExprInfo operand = EvalExpr(*ctx->unaryExp()); + RequireScalar(ctx->unaryExp(), operand, "一元运算要求标量操作数"); + ExprInfo result; + result.type = operand.type; + if (ctx->addUnaryOp()->SUB() && operand.has_const_value) { + result.has_const_value = true; + result.const_value = operand.const_value; + result.const_value.number = -result.const_value.number; + } else if (ctx->addUnaryOp()->ADD() && operand.has_const_value) { + result.has_const_value = true; + result.const_value = operand.const_value; + } + return result; + } + ThrowSemaError(ctx, "非法一元表达式"); + } + + std::any visitMulExp(SysYParser::MulExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法乘法表达式"); + } + // 如果是 mulExp : unaryExp 形式(没有 MUL/DIV/MOD token),直接处理 unaryExp + if (!ctx->MUL() && !ctx->DIV() && !ctx->MOD()) { + return EvalExpr(*ctx->unaryExp()); + } + // 否则是 mulExp MUL/DIV/MOD unaryExp 形式 + ExprInfo lhs = EvalExpr(*ctx->mulExp()); + ExprInfo rhs = EvalExpr(*ctx->unaryExp()); + return EvalArithmetic(*ctx, lhs, rhs, ctx->MUL() ? '*' : (ctx->DIV() ? '/' : '%')); + } + + std::any visitAddExp(SysYParser::AddExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法加法表达式"); + } + // 如果是 addExp : mulExp 形式(没有 ADD/SUB token),直接处理 mulExp + if (!ctx->ADD() && !ctx->SUB()) { + return EvalExpr(*ctx->mulExp()); + } + // 否则是 addExp ADD/SUB mulExp 形式 + ExprInfo lhs = EvalExpr(*ctx->addExp()); + ExprInfo rhs = EvalExpr(*ctx->mulExp()); + return EvalArithmetic(*ctx, lhs, rhs, ctx->ADD() ? '+' : '-'); + } + + std::any visitRelExp(SysYParser::RelExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法关系表达式"); + } + if (ctx->relExp() == nullptr) { + return EvalExpr(*ctx->addExp()); + } + ExprInfo lhs = EvalExpr(*ctx->relExp()); + ExprInfo rhs = EvalExpr(*ctx->addExp()); + return EvalCompare(*ctx, lhs, rhs); + } + + std::any visitEqExp(SysYParser::EqExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法相等表达式"); + } + if (ctx->eqExp() == nullptr) { + return EvalExpr(*ctx->relExp()); + } + ExprInfo lhs = EvalExpr(*ctx->eqExp()); + ExprInfo rhs = EvalExpr(*ctx->relExp()); + return EvalCompare(*ctx, lhs, rhs); + } + + std::any visitCondUnaryExp(SysYParser::CondUnaryExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法条件一元表达式"); + } + if (ctx->eqExp()) { + return EvalExpr(*ctx->eqExp()); + } + ExprInfo operand = EvalExpr(*ctx->condUnaryExp()); + RequireScalar(ctx->condUnaryExp(), operand, "逻辑非要求标量操作数"); + ExprInfo result; + result.type = SemanticType::Int; + if (operand.has_const_value) { + result.has_const_value = true; + result.const_value = MakeInt(IsTrue(operand.const_value) ? 0 : 1); + } + return result; + } + + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法逻辑与表达式"); + } + if (ctx->lAndExp() == nullptr) { + return EvalExpr(*ctx->condUnaryExp()); + } + ExprInfo lhs = EvalExpr(*ctx->lAndExp()); + ExprInfo rhs = EvalExpr(*ctx->condUnaryExp()); + return EvalLogical(*ctx, lhs, rhs, true); + } + + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override { + if (!ctx) { + ThrowSemaError(ctx, "非法逻辑或表达式"); + } + if (ctx->lOrExp() == nullptr) { + return EvalExpr(*ctx->lAndExp()); + } + ExprInfo lhs = EvalExpr(*ctx->lOrExp()); + ExprInfo rhs = EvalExpr(*ctx->lAndExp()); + return EvalLogical(*ctx, lhs, rhs, false); + } + + std::any visitConstExp(SysYParser::ConstExpContext* ctx) override { + if (!ctx || !ctx->addExp()) { + ThrowSemaError(ctx, "非法常量表达式"); + } + ExprInfo expr = EvalExpr(*ctx->addExp()); + if (!expr.IsScalar() || !expr.has_const_value) { + ThrowSemaError(ctx, "要求编译期常量表达式"); + } + return expr; } SemanticContext TakeSemanticContext() { return std::move(sema_); } private: - SymbolTable table_; + ExprInfo EvalExpr(antlr4::tree::ParseTree& node) { + return std::any_cast(node.accept(this)); + } + + ExprInfo EvalCond(SysYParser::CondContext& cond) { return EvalExpr(cond); } + + ExprInfo AnalyzeLVal(SysYParser::LValContext& ctx) { + const std::string name = ctx.ID()->getText(); + const ObjectBinding* symbol = symbols_.Lookup(name); + if (!symbol) { + ThrowSemaError(&ctx, "使用了未声明的标识符:" + name); + } + + sema_.BindObjectUse(&ctx, *symbol); + + if (ctx.exp().size() > symbol->dimensions.size()) { + ThrowSemaError(&ctx, "数组下标过多: " + name); + } + + for (auto* exp : ctx.exp()) { + ExprInfo index = EvalExpr(*exp); + RequireScalar(exp, index, "数组下标必须是标量表达式"); + } + + ExprInfo result; + result.type = symbol->type; + result.is_const_object = symbol->decl_kind == ObjectBinding::DeclKind::Const; + result.is_lvalue = ctx.exp().size() == symbol->dimensions.size(); + result.dimensions.assign(symbol->dimensions.begin() + ctx.exp().size(), + symbol->dimensions.end()); + if (result.dimensions.empty() && symbol->has_const_value) { + result.has_const_value = true; + result.const_value = symbol->const_value; + } + return result; + } + + ExprInfo AnalyzeCall(SysYParser::UnaryExpContext& ctx) { + const std::string name = ctx.ID()->getText(); + if (const ObjectBinding* object = symbols_.Lookup(name)) { + ThrowSemaError(&ctx, "标识符不是函数: " + object->name); + } + + const FunctionBinding* fn = sema_.ResolveFunction(name); + if (!fn) { + ThrowSemaError(&ctx, "调用了未定义的函数: " + name); + } + + std::vector args; + if (ctx.funcRParams()) { + for (auto* exp : ctx.funcRParams()->exp()) { + args.push_back(EvalExpr(*exp)); + } + } + if (args.size() != fn->params.size()) { + ThrowSemaError(&ctx, "函数参数个数不匹配: " + name); + } + for (size_t i = 0; i < args.size(); ++i) { + CheckArgument(ctx, fn->params[i], args[i], i); + } + + sema_.BindFunctionCall(&ctx, *fn); + + ExprInfo result; + result.type = fn->return_type; + return result; + } + + void CheckArgument(const antlr4::ParserRuleContext& call_site, + const ObjectBinding& param, const ExprInfo& arg, + size_t index) { + if (param.dimensions.empty()) { + if (!arg.IsScalar()) { + ThrowSemaError(&call_site, "第 " + std::to_string(index + 1) + + " 个参数需要标量实参"); + } + if (!CanImplicitlyConvert(arg.type, param.type)) { + ThrowSemaError(&call_site, "第 " + std::to_string(index + 1) + + " 个参数类型不匹配"); + } + return; + } + + if (!arg.IsArray()) { + ThrowSemaError(&call_site, "第 " + std::to_string(index + 1) + + " 个参数需要数组实参"); + } + if (arg.type != param.type || arg.dimensions.size() != param.dimensions.size()) { + ThrowSemaError(&call_site, "第 " + std::to_string(index + 1) + + " 个数组参数类型不匹配"); + } + for (size_t dim = 1; dim < param.dimensions.size(); ++dim) { + if (param.dimensions[dim] != kUnknownArrayDim && + arg.dimensions[dim] != param.dimensions[dim]) { + ThrowSemaError(&call_site, "第 " + std::to_string(index + 1) + + " 个数组参数维度不匹配"); + } + } + } + + void CheckAssignment(SysYParser::StmtContext& ctx) { + ExprInfo lhs = AnalyzeLVal(*ctx.lVal()); + if (!lhs.IsScalar() || !lhs.is_lvalue) { + ThrowSemaError(&ctx, "赋值语句左侧必须是可写标量左值"); + } + if (lhs.is_const_object) { + ThrowSemaError(&ctx, "不能给 const 对象赋值"); + } + ExprInfo rhs = EvalExpr(*ctx.exp()); + RequireScalar(ctx.exp(), rhs, "赋值语句右侧必须是标量表达式"); + if (!CanImplicitlyConvert(rhs.type, lhs.type)) { + ThrowSemaError(&ctx, "赋值语句两侧类型不兼容"); + } + } + + void CheckReturn(SysYParser::StmtContext& ctx) { + if (!current_function_) { + ThrowSemaError(&ctx, "return 语句不在函数内部"); + } + if (current_function_->return_type == SemanticType::Void) { + if (ctx.exp()) { + ThrowSemaError(&ctx, "void 函数不能返回值"); + } + return; + } + if (!ctx.exp()) { + ThrowSemaError(&ctx, "非 void 函数必须返回值"); + } + ExprInfo expr = EvalExpr(*ctx.exp()); + RequireScalar(ctx.exp(), expr, "return 表达式必须是标量"); + if (!CanImplicitlyConvert(expr.type, current_function_->return_type)) { + ThrowSemaError(&ctx, "return 表达式类型与函数返回类型不匹配"); + } + } + + void DeclareConst(SysYParser::ConstDefContext& ctx, SemanticType type) { + ObjectBinding symbol; + symbol.name = ctx.ID()->getText(); + symbol.type = type; + symbol.decl_kind = ObjectBinding::DeclKind::Const; + symbol.const_def = &ctx; + symbol.dimensions = EvalArrayDims(ctx.constIndex(), true); + + if (symbols_.ContainsInCurrentScope(symbol.name)) { + ThrowSemaError(&ctx, "重复定义标识符: " + symbol.name); + } + if (symbols_.Depth() == 1 && sema_.ResolveFunction(symbol.name)) { + ThrowSemaError(&ctx, "全局对象与函数重名: " + symbol.name); + } + + if (!ctx.constInitVal()) { + ThrowSemaError(&ctx, "const 对象缺少初始化"); + } + if (symbol.dimensions.empty()) { + symbol.const_value = ValidateConstInitScalar(*ctx.constInitVal(), type); + symbol.has_const_value = true; + } else { + ValidateConstInitAggregate(*ctx.constInitVal(), type); + } + + if (!symbols_.Add(symbol)) { + ThrowSemaError(&ctx, "重复定义标识符: " + symbol.name); + } + } + + void DeclareVar(SysYParser::VarDefContext& ctx, SemanticType type) { + ObjectBinding symbol; + symbol.name = ctx.ID()->getText(); + symbol.type = type; + symbol.decl_kind = ObjectBinding::DeclKind::Var; + symbol.var_def = &ctx; + symbol.dimensions = EvalArrayDims(ctx.constIndex(), true); + + if (symbols_.ContainsInCurrentScope(symbol.name)) { + ThrowSemaError(&ctx, "重复定义标识符: " + symbol.name); + } + if (symbols_.Depth() == 1 && sema_.ResolveFunction(symbol.name)) { + ThrowSemaError(&ctx, "全局对象与函数重名: " + symbol.name); + } + + if (!symbols_.Add(symbol)) { + ThrowSemaError(&ctx, "重复定义标识符: " + symbol.name); + } + + if (!ctx.initVal()) { + return; + } + if (symbol.dimensions.empty()) { + ValidateVarInitScalar(*ctx.initVal(), type, symbols_.Depth() == 1); + } else { + ValidateVarInitAggregate(*ctx.initVal(), type, symbols_.Depth() == 1); + } + } + + std::vector EvalArrayDims( + const std::vector& indices, + bool require_positive) { + std::vector dims; + dims.reserve(indices.size()); + for (auto* index : indices) { + if (!index || !index->constExp()) { + ThrowSemaError(index, "数组维度缺少常量表达式"); + } + ExprInfo expr = EvalExpr(*index->constExp()); + if (!expr.IsScalar() || !expr.has_const_value) { + ThrowSemaError(index, "数组维度必须是整型常量表达式"); + } + const int dim = ConvertToInt(CastConstant(expr.const_value, SemanticType::Int)); + if (require_positive && dim <= 0) { + ThrowSemaError(index, "数组维度必须为正整数"); + } + dims.push_back(dim); + } + return dims; + } + + ScalarConstant ValidateConstInitScalar(SysYParser::ConstInitValContext& init, + SemanticType target_type) { + if (!init.constExp()) { + ThrowSemaError(&init, "标量 const 初始化必须是常量表达式"); + } + ExprInfo expr = EvalExpr(*init.constExp()); + if (!expr.IsScalar() || !expr.has_const_value) { + ThrowSemaError(&init, "标量 const 初始化必须是常量表达式"); + } + return CastConstant(expr.const_value, target_type); + } + + void ValidateConstInitAggregate(SysYParser::ConstInitValContext& init, + SemanticType target_type) { + if (init.constExp()) { + ExprInfo expr = EvalExpr(*init.constExp()); + if (!expr.IsScalar() || !expr.has_const_value) { + ThrowSemaError(&init, "数组 const 初始化要求常量表达式"); + } + CastConstant(expr.const_value, target_type); + return; + } + for (auto* nested : init.constInitVal()) { + if (nested) { + ValidateConstInitAggregate(*nested, target_type); + } + } + } + + void ValidateVarInitScalar(SysYParser::InitValContext& init, + SemanticType target_type, bool require_constant) { + if (!init.exp()) { + ThrowSemaError(&init, "标量初始化非法"); + } + ExprInfo expr = EvalExpr(*init.exp()); + RequireScalar(&init, expr, "标量初始化要求标量表达式"); + if (!CanImplicitlyConvert(expr.type, target_type)) { + ThrowSemaError(&init, "初始化表达式类型不兼容"); + } + if (require_constant && !expr.has_const_value) { + ThrowSemaError(&init, "全局变量初始化要求编译期常量"); + } + } + + void ValidateVarInitAggregate(SysYParser::InitValContext& init, + SemanticType target_type, bool require_constant) { + if (init.exp()) { + ExprInfo expr = EvalExpr(*init.exp()); + RequireScalar(&init, expr, "数组初始化元素必须是标量表达式"); + if (!CanImplicitlyConvert(expr.type, target_type)) { + ThrowSemaError(&init, "数组初始化元素类型不兼容"); + } + if (require_constant && !expr.has_const_value) { + ThrowSemaError(&init, "全局数组初始化要求编译期常量"); + } + return; + } + for (auto* nested : init.initVal()) { + if (nested) { + ValidateVarInitAggregate(*nested, target_type, require_constant); + } + } + } + + ExprInfo EvalArithmetic(const antlr4::ParserRuleContext& ctx, const ExprInfo& lhs, + const ExprInfo& rhs, char op) { + RequireScalar(&ctx, lhs, "算术运算要求标量操作数"); + RequireScalar(&ctx, rhs, "算术运算要求标量操作数"); + ExprInfo result; + result.type = lhs.type == SemanticType::Float || rhs.type == SemanticType::Float + ? SemanticType::Float + : SemanticType::Int; + if (!lhs.has_const_value || !rhs.has_const_value) { + return result; + } + + result.has_const_value = true; + const ScalarConstant lc = CastConstant(lhs.const_value, result.type); + const ScalarConstant rc = CastConstant(rhs.const_value, result.type); + if (result.type == SemanticType::Float) { + double value = 0.0; + if (op == '+') value = lc.number + rc.number; + if (op == '-') value = lc.number - rc.number; + if (op == '*') value = lc.number * rc.number; + if (op == '/') value = lc.number / rc.number; + if (op == '%') { + ThrowSemaError(&ctx, "浮点数不支持取模运算"); + } + result.const_value = MakeFloat(value); + return result; + } + + const int li = ConvertToInt(lc); + const int ri = ConvertToInt(rc); + int value = 0; + if (op == '+') value = li + ri; + if (op == '-') value = li - ri; + if (op == '*') value = li * ri; + if (op == '/') value = li / ri; + if (op == '%') value = li % ri; + result.const_value = MakeInt(value); + return result; + } + + ExprInfo EvalCompare(antlr4::ParserRuleContext& ctx, const ExprInfo& lhs, + const ExprInfo& rhs) { + RequireScalar(&ctx, lhs, "比较运算要求标量操作数"); + RequireScalar(&ctx, rhs, "比较运算要求标量操作数"); + ExprInfo result; + result.type = SemanticType::Int; + if (!lhs.has_const_value || !rhs.has_const_value) { + return result; + } + + const SemanticType promoted = + lhs.type == SemanticType::Float || rhs.type == SemanticType::Float + ? SemanticType::Float + : SemanticType::Int; + const ScalarConstant lc = CastConstant(lhs.const_value, promoted); + const ScalarConstant rc = CastConstant(rhs.const_value, promoted); + bool value = false; + if (auto* rel = dynamic_cast(&ctx)) { + if (rel->LT()) value = lc.number < rc.number; + if (rel->GT()) value = lc.number > rc.number; + if (rel->LE()) value = lc.number <= rc.number; + if (rel->GE()) value = lc.number >= rc.number; + } else if (auto* eq = dynamic_cast(&ctx)) { + if (eq->EQ()) value = lc.number == rc.number; + if (eq->NE()) value = lc.number != rc.number; + } + result.has_const_value = true; + result.const_value = MakeInt(value ? 1 : 0); + return result; + } + + ExprInfo EvalLogical(const antlr4::ParserRuleContext& ctx, const ExprInfo& lhs, + const ExprInfo& rhs, bool is_and) { + RequireScalar(&ctx, lhs, "逻辑运算要求标量操作数"); + RequireScalar(&ctx, rhs, "逻辑运算要求标量操作数"); + ExprInfo result; + result.type = SemanticType::Int; + if (!lhs.has_const_value || !rhs.has_const_value) { + return result; + } + const bool value = + is_and ? (IsTrue(lhs.const_value) && IsTrue(rhs.const_value)) + : (IsTrue(lhs.const_value) || IsTrue(rhs.const_value)); + result.has_const_value = true; + result.const_value = MakeInt(value ? 1 : 0); + return result; + } + + void RequireScalar(const antlr4::ParserRuleContext* ctx, const ExprInfo& expr, + std::string_view message) { + if (!expr.IsScalar()) { + ThrowSemaError(ctx, message); + } + } + + void CollectFunctions(SysYParser::CompUnitContext& ctx) { + for (auto* item : ctx.topLevelItem()) { + if (!item || !item->funcDef()) { + continue; + } + FunctionBinding fn = BuildFunctionSignature(*item->funcDef()); + if (sema_.ResolveFunction(fn.name)) { + ThrowSemaError(item->funcDef(), "重复定义函数: " + fn.name); + } + if (symbols_.ContainsInCurrentScope(fn.name)) { + ThrowSemaError(item->funcDef(), "函数与全局对象重名: " + fn.name); + } + sema_.RegisterFunction(std::move(fn)); + } + } + + FunctionBinding BuildFunctionSignature(SysYParser::FuncDefContext& ctx) { + FunctionBinding fn; + fn.name = ctx.ID()->getText(); + fn.return_type = ParseFuncType(*ctx.funcType()); + fn.func_def = &ctx; + if (ctx.funcFParams()) { + for (auto* param : ctx.funcFParams()->funcFParam()) { + fn.params.push_back(BuildParamBinding(*param)); + } + } + return fn; + } + + ObjectBinding BuildParamBinding(SysYParser::FuncFParamContext& ctx) { + if (!ctx.ID() || !ctx.bType()) { + ThrowSemaError(&ctx, "非法函数形参"); + } + ObjectBinding param; + param.name = ctx.ID()->getText(); + param.type = ParseBType(*ctx.bType()); + param.decl_kind = ObjectBinding::DeclKind::Param; + param.func_param = &ctx; + if (!ctx.LBRACK().empty()) { + param.is_array_param = true; + param.dimensions.push_back(kUnknownArrayDim); + for (auto* exp : ctx.exp()) { + ExprInfo dim = EvalExpr(*exp); + if (!dim.IsScalar() || !dim.has_const_value) { + ThrowSemaError(&ctx, "数组形参维度必须是整型常量表达式"); + } + const int value = ConvertToInt(CastConstant(dim.const_value, SemanticType::Int)); + if (value <= 0) { + ThrowSemaError(&ctx, "数组形参维度必须为正整数"); + } + param.dimensions.push_back(value); + } + } + return param; + } + + void RegisterBuiltins() { + sema_.RegisterFunction(MakeBuiltinFunction("getint", SemanticType::Int, {})); + sema_.RegisterFunction(MakeBuiltinFunction("getch", SemanticType::Int, {})); + sema_.RegisterFunction( + MakeBuiltinFunction("getfloat", SemanticType::Float, {})); + sema_.RegisterFunction(MakeBuiltinFunction( + "getarray", SemanticType::Int, + {MakeParam("a", SemanticType::Int, {kUnknownArrayDim}, true)})); + sema_.RegisterFunction(MakeBuiltinFunction( + "getfarray", SemanticType::Int, + {MakeParam("a", SemanticType::Float, {kUnknownArrayDim}, true)})); + sema_.RegisterFunction(MakeBuiltinFunction( + "putint", SemanticType::Void, {MakeParam("x", SemanticType::Int)})); + sema_.RegisterFunction(MakeBuiltinFunction( + "putch", SemanticType::Void, {MakeParam("x", SemanticType::Int)})); + sema_.RegisterFunction(MakeBuiltinFunction( + "putfloat", SemanticType::Void, {MakeParam("x", SemanticType::Float)})); + sema_.RegisterFunction(MakeBuiltinFunction( + "putarray", SemanticType::Void, + {MakeParam("n", SemanticType::Int), + MakeParam("a", SemanticType::Int, {kUnknownArrayDim}, true)})); + sema_.RegisterFunction(MakeBuiltinFunction( + "putfarray", SemanticType::Void, + {MakeParam("n", SemanticType::Int), + MakeParam("a", SemanticType::Float, {kUnknownArrayDim}, true)})); + sema_.RegisterFunction( + MakeBuiltinFunction("starttime", SemanticType::Void, {})); + sema_.RegisterFunction( + MakeBuiltinFunction("stoptime", SemanticType::Void, {})); + } + + SymbolTable symbols_; SemanticContext sema_; - bool seen_return_ = false; - size_t current_item_index_ = 0; - size_t total_items_ = 0; + const FunctionBinding* current_function_ = nullptr; + int loop_depth_ = 0; }; } // namespace +void SemanticContext::BindObjectUse(const SysYParser::LValContext* use, + ObjectBinding binding) { + object_uses_[use] = std::move(binding); +} + +const ObjectBinding* SemanticContext::ResolveObjectUse( + const SysYParser::LValContext* use) const { + auto it = object_uses_.find(use); + return it == object_uses_.end() ? nullptr : &it->second; +} + +void SemanticContext::BindFunctionCall(const SysYParser::UnaryExpContext* call, + FunctionBinding binding) { + function_calls_[call] = std::move(binding); +} + +const FunctionBinding* SemanticContext::ResolveFunctionCall( + const SysYParser::UnaryExpContext* call) const { + auto it = function_calls_.find(call); + return it == function_calls_.end() ? nullptr : &it->second; +} + +void SemanticContext::RegisterFunction(FunctionBinding binding) { + functions_[binding.name] = std::move(binding); +} + +const FunctionBinding* SemanticContext::ResolveFunction( + const std::string& name) const { + auto it = functions_.find(name); + return it == functions_.end() ? nullptr : &it->second; +} + SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { SemaVisitor visitor; comp_unit.accept(&visitor); diff --git a/src/sem/SymbolTable.cpp b/src/sem/SymbolTable.cpp index ffeea89..01b44bf 100644 --- a/src/sem/SymbolTable.cpp +++ b/src/sem/SymbolTable.cpp @@ -1,17 +1,39 @@ -// 维护局部变量声明的注册与查找。 +// 维护对象符号的注册与按作用域查找。 #include "sem/SymbolTable.h" -void SymbolTable::Add(const std::string& name, - SysYParser::VarDefContext* decl) { - table_[name] = decl; +#include + +SymbolTable::SymbolTable() : scopes_(1) {} + +void SymbolTable::EnterScope() { scopes_.emplace_back(); } + +void SymbolTable::ExitScope() { + if (scopes_.size() <= 1) { + throw std::runtime_error("symbol table scope underflow"); + } + scopes_.pop_back(); } -bool SymbolTable::Contains(const std::string& name) const { - return table_.find(name) != table_.end(); +bool SymbolTable::Add(const ObjectBinding& symbol) { + auto& scope = scopes_.back(); + return scope.emplace(symbol.name, symbol).second; } -SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const { - auto it = table_.find(name); - return it == table_.end() ? nullptr : it->second; +bool SymbolTable::ContainsInCurrentScope(std::string_view name) const { + const auto& scope = scopes_.back(); + return scope.find(std::string(name)) != scope.end(); } + +const ObjectBinding* SymbolTable::Lookup(std::string_view name) const { + const std::string key(name); + for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { + auto found = it->find(key); + if (found != it->end()) { + return &found->second; + } + } + return nullptr; +} + +size_t SymbolTable::Depth() const { return scopes_.size(); } diff --git a/sysy2022.pdf b/sysy2022.pdf new file mode 100644 index 0000000..217d6bd Binary files /dev/null and b/sysy2022.pdf differ diff --git a/test/test_case/irgen_lab1_4/01_simple_add.out b/test/test_case/irgen_lab1_4/01_simple_add.out new file mode 100644 index 0000000..00750ed --- /dev/null +++ b/test/test_case/irgen_lab1_4/01_simple_add.out @@ -0,0 +1 @@ +3 diff --git a/test/test_case/irgen_lab1_4/01_simple_add.sy b/test/test_case/irgen_lab1_4/01_simple_add.sy new file mode 100644 index 0000000..8b5f091 --- /dev/null +++ b/test/test_case/irgen_lab1_4/01_simple_add.sy @@ -0,0 +1,6 @@ +// 测试:简单加法 +int main() { + int a = 1; + int b = 2; + return a + b; +} diff --git a/test/test_case/irgen_lab1_4/02_sub_mul.out b/test/test_case/irgen_lab1_4/02_sub_mul.out new file mode 100644 index 0000000..81b5c5d --- /dev/null +++ b/test/test_case/irgen_lab1_4/02_sub_mul.out @@ -0,0 +1 @@ +37 diff --git a/test/test_case/irgen_lab1_4/02_sub_mul.sy b/test/test_case/irgen_lab1_4/02_sub_mul.sy new file mode 100644 index 0000000..8ddeca9 --- /dev/null +++ b/test/test_case/irgen_lab1_4/02_sub_mul.sy @@ -0,0 +1,8 @@ +// 测试:减法和乘法 +int main() { + int a = 10; + int b = 3; + int c = a - b; + int d = a * b; + return c + d; +} diff --git a/test/test_case/irgen_lab1_4/03_div_mod.out b/test/test_case/irgen_lab1_4/03_div_mod.out new file mode 100644 index 0000000..7ed6ff8 --- /dev/null +++ b/test/test_case/irgen_lab1_4/03_div_mod.out @@ -0,0 +1 @@ +5 diff --git a/test/test_case/irgen_lab1_4/03_div_mod.sy b/test/test_case/irgen_lab1_4/03_div_mod.sy new file mode 100644 index 0000000..beefc9e --- /dev/null +++ b/test/test_case/irgen_lab1_4/03_div_mod.sy @@ -0,0 +1,8 @@ +// 测试:除法和取模 +int main() { + int a = 20; + int b = 6; + int c = a / b; + int d = a % b; + return c + d; +} diff --git a/test/test_case/irgen_lab1_4/04_unary.out b/test/test_case/irgen_lab1_4/04_unary.out new file mode 100644 index 0000000..7ed6ff8 --- /dev/null +++ b/test/test_case/irgen_lab1_4/04_unary.out @@ -0,0 +1 @@ +5 diff --git a/test/test_case/irgen_lab1_4/04_unary.sy b/test/test_case/irgen_lab1_4/04_unary.sy new file mode 100644 index 0000000..c36c5e1 --- /dev/null +++ b/test/test_case/irgen_lab1_4/04_unary.sy @@ -0,0 +1,7 @@ +// 测试:一元运算符(正负号) +int main() { + int a = 5; + int b = -a; + int c = +10; + return b + c; +} diff --git a/test/test_case/irgen_lab1_4/05_assign.out b/test/test_case/irgen_lab1_4/05_assign.out new file mode 100644 index 0000000..209e3ef --- /dev/null +++ b/test/test_case/irgen_lab1_4/05_assign.out @@ -0,0 +1 @@ +20 diff --git a/test/test_case/irgen_lab1_4/05_assign.sy b/test/test_case/irgen_lab1_4/05_assign.sy new file mode 100644 index 0000000..5e1eae2 --- /dev/null +++ b/test/test_case/irgen_lab1_4/05_assign.sy @@ -0,0 +1,7 @@ +// 测试:赋值表达式 +int main() { + int a = 10; + int b = 20; + a = b; + return a; +} diff --git a/test/test_case/irgen_lab1_4/06_multi_decl.out b/test/test_case/irgen_lab1_4/06_multi_decl.out new file mode 100644 index 0000000..1e8b314 --- /dev/null +++ b/test/test_case/irgen_lab1_4/06_multi_decl.out @@ -0,0 +1 @@ +6 diff --git a/test/test_case/irgen_lab1_4/06_multi_decl.sy b/test/test_case/irgen_lab1_4/06_multi_decl.sy new file mode 100644 index 0000000..fa95159 --- /dev/null +++ b/test/test_case/irgen_lab1_4/06_multi_decl.sy @@ -0,0 +1,5 @@ +// 测试:逗号分隔的多变量声明 +int main() { + int a = 1, b = 2, c = 3; + return a + b + c; +} diff --git a/test/test_case/irgen_lab1_4/07_comprehensive.out b/test/test_case/irgen_lab1_4/07_comprehensive.out new file mode 100644 index 0000000..81b5c5d --- /dev/null +++ b/test/test_case/irgen_lab1_4/07_comprehensive.out @@ -0,0 +1 @@ +37 diff --git a/test/test_case/irgen_lab1_4/07_comprehensive.sy b/test/test_case/irgen_lab1_4/07_comprehensive.sy new file mode 100644 index 0000000..753aeae --- /dev/null +++ b/test/test_case/irgen_lab1_4/07_comprehensive.sy @@ -0,0 +1,14 @@ +// 测试:综合测试(所有功能) +int main() { + int a = 10, b = 5; + int c = a + b; + int d = a - b; + int e = a * 2; + int f = a / b; + int g = a % b; + int h = -c; + int i = +d; + a = b + c; + b = d + e; + return a + b + f + g + h + i; +} diff --git a/test/test_case/sema_negative/break.sy b/test/test_case/sema_negative/break.sy new file mode 100644 index 0000000..d43dfb7 --- /dev/null +++ b/test/test_case/sema_negative/break.sy @@ -0,0 +1 @@ +int main(){ break; return 0; } diff --git a/test/test_case/sema_negative/call.sy b/test/test_case/sema_negative/call.sy new file mode 100644 index 0000000..bdb5857 --- /dev/null +++ b/test/test_case/sema_negative/call.sy @@ -0,0 +1,2 @@ +int f(int x){ return x; } +int main(){ return f(); } diff --git a/test/test_case/sema_negative/ret.sy b/test/test_case/sema_negative/ret.sy new file mode 100644 index 0000000..9dabf4e --- /dev/null +++ b/test/test_case/sema_negative/ret.sy @@ -0,0 +1,2 @@ +void f(){ return 1; } +int main(){ return 0; } diff --git a/test/test_case/sema_negative/undef.sy b/test/test_case/sema_negative/undef.sy new file mode 100644 index 0000000..4d607a3 --- /dev/null +++ b/test/test_case/sema_negative/undef.sy @@ -0,0 +1 @@ +int main(){ return a; } diff --git a/tools/sema_check.cpp b/tools/sema_check.cpp new file mode 100644 index 0000000..ce6b574 --- /dev/null +++ b/tools/sema_check.cpp @@ -0,0 +1,34 @@ +#include +#include +#include + +#include "frontend/AntlrDriver.h" +#include "sem/Sema.h" +#include "utils/Log.h" + +int main(int argc, char** argv) { + if (argc < 2) { + std::cerr << "usage: sema_check [more.sy...]\n"; + return 2; + } + + bool failed = false; + for (int i = 1; i < argc; ++i) { + const std::string path = argv[i]; + try { + auto antlr = ParseFileWithAntlr(path); + auto* comp_unit = dynamic_cast(antlr.tree); + if (!comp_unit) { + throw std::runtime_error(FormatError("sema_check", "语法树根节点不是 compUnit")); + } + (void)RunSema(*comp_unit); + std::cout << "OK " << path << "\n"; + } catch (const std::exception& ex) { + failed = true; + std::cout << "ERR " << path << "\n"; + PrintException(std::cout, ex); + } + } + + return failed ? 1 : 0; +}