Compare commits

..

No commits in common. 'lc' and 'master' have entirely different histories.
lc ... master

4
.gitignore vendored

@ -52,7 +52,6 @@ compile_commands.json
.idea/ .idea/
.fleet/ .fleet/
.vs/ .vs/
.trae/
*.code-workspace *.code-workspace
# CLion # CLion
@ -69,6 +68,3 @@ Thumbs.db
# Project outputs # Project outputs
# ========================= # =========================
test/test_result/ test/test_result/
sema_check
.codex

@ -20,10 +20,6 @@
如果希望进一步参考编译相关项目和往届优秀实现,可以查看编译比赛官网的技术支持栏目:<https://compiler.educg.net/#/index?TYPE=26COM>。其中的“备赛推荐”整理了一些编译相关项目,也能看到往届优秀作品的开源实现,这些内容都很值得参考。 如果希望进一步参考编译相关项目和往届优秀实现,可以查看编译比赛官网的技术支持栏目:<https://compiler.educg.net/#/index?TYPE=26COM>。其中的“备赛推荐”整理了一些编译相关项目,也能看到往届优秀作品的开源实现,这些内容都很值得参考。
此外,仓库中还提供了一份当前实现状态与测试入口的总览文档,便于组内同步进度:
- `doc/实验进度与测试方法.md`
## 3. 头歌平台协作流程 ## 3. 头歌平台协作流程
头歌平台的代码托管方式与 GitHub/Gitee 类似。如果你希望基于当前仓库快速开始协作,可以参考下面这套流程。 头歌平台的代码托管方式与 GitHub/Gitee 类似。如果你希望基于当前仓库快速开始协作,可以参考下面这套流程。

@ -1,10 +0,0 @@
#!/bin/bash
mkdir -p build/generated/antlr4
java -jar third_party/antlr-4.13.2-complete.jar \
-Dlanguage=Cpp \
-visitor -no-listener \
-Xexact-output-dir \
-o build/generated/antlr4 \
src/antlr4/SysY.g4
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=ON
cmake --build build -j "$(nproc)"

@ -1,137 +0,0 @@
### 人员 1基础表达式与赋值支持lc已完成
- 任务 1.1支持更多二元运算符Sub, Mul, Div, Mod
- 任务 1.2:支持一元运算符(正负号)
- 任务 1.3:支持赋值表达式
- 任务 1.4:支持逗号分隔的多个变量声明
### 人员 2控制流支持lyy,已完成)
- 任务 2.1:支持 if-else 条件语句
- 任务 2.2:支持 while 循环语句
- 任务 2.3:支持 break/continue 语句
- 任务 2.4:支持比较和逻辑表达式
### 人员 3函数与全局变量支持
- 任务 3.1:支持全局变量声明与初始化
- 任务 3.2:支持函数参数处理
- 任务 3.3:支持函数调用生成
- 任务 3.4:支持 const 常量声明
## 人员 1 完成情况详细说明(更新于 2026-03-30
### ✅ 已完成任务
人员 1 已完整实现 Lab2 IR 生成的基础功能模块,包括:
1. **二元运算符**(任务 1.1
- 实现 `Sub`, `Mul`, `Div`, `Mod` 四种运算符
- 修改文件:`include/ir/IR.h`, `src/ir/IRBuilder.cpp`, `src/ir/IRPrinter.cpp`, `src/irgen/IRGenExp.cpp`
2. **一元运算符**(任务 1.2
- 实现正负号运算符(`+`, `-`
- 新增 `UnaryInst` 类支持一元指令
- 负号生成 `sub 0, x` 指令LLVM IR 标准形式)
3. **赋值表达式**(任务 1.3
- 实现变量赋值语句的 IR 生成
- 修改文件:`src/irgen/IRGenStmt.cpp`
4. **多变量声明**(任务 1.4
- 支持逗号分隔的变量声明(如 `int a, b, c;`
- 支持带初始化的多变量声明(如 `int a = 1, b = 2;`
### 🧪 测试验证
- **Lab1 语法分析**:✅ 通过10/11 functional 测试1 个数组测试超出范围)
- **Lab2 语义分析**:✅ 通过6 正例 + 4 反例)
- **IR 生成测试**:✅ 通过7/7 自定义测试用例)
- 测试脚本:`./scripts/test_lab2_ir1.sh`
- 测试用例目录:`test/test_case/irgen_lab1_4/`
### 📝 代码质量
- 所有修改已通过编译测试
- 未影响原有 Lab1 和 Lab2 Sema 功能
- 代码风格与项目保持一致
- 关键函数添加了注释说明
### 🔄 协作接口
人员 1 的实现为后续任务提供了以下接口:
- **表达式生成**`visitAddExp`, `visitMulExp`, `visitUnaryExp`
- **语句生成**`visitStmt`(支持赋值和 return
- **变量管理**`storage_map_` 维护变量名到栈槽位的映射
- **IR 构建**`IRBuilder::CreateBinary`, `IRBuilder::CreateNeg`, `IRBuilder::CreateStore`
后续人员可以在此基础上扩展更复杂的功能(控制流、函数调用等)。
## 人员 2 完成情况详细说明(更新于 2026-03-31
### ✅ 已完成任务
人员 2 已完整实现 Lab2 IR 生成中涉及的控制流支持,包括:
1. **IR 结构与底层辅助拓展**
- 补充 `Int1` 基础类型以及 `Value::IsInt1()`
- 新增 `CmpInst`, `ZextInst`, `BranchInst` 以及 `CondBranchInst` 以支持关系计算和跳转逻辑。
- 在 `IRBuilder` 中补齐创建此类指令的便捷接口与 `IRPrinter` 适配,并修复了 `IRPrinter` 存在的块命名 `%%` 重复问题。
- 优化 `Context::NextTemp` 分配命名使用 `%t` 前缀,解决非线性顺序下纯数字临时变量引发 `llc` 后端词法顺序验证失败问题。
2. **比较和逻辑表达式**(任务 2.4
- 新增实现 `visitRelExp`、`visitEqExp`。
- 实现条件二元表达式全链路短路求值 (`visitLAndExp`、`visitLOrExp`)。短路时通过控制流跳转+利用局部栈变量分配并多次赋值记录实现栈传递,规避了 `phi` 的麻烦。
- 利用 `visitCondUnaryExp` 增加逻辑非 `!` 判定。
3. **控制流框架支持**(任务 2.1 - 2.3
- 在 `visitStmt` 中完美实现了 `if-else` 条件语句(自动插入无条件跳合块)、`while` 循环语句。
- 在 `IRGen` 实例中通过 `current_loop_cond_bb_` 等维护循环栈,实现了 `break``continue`
- 修复了此前框架在 `IRGenDecl.cpp``visitBlock` 中缺少终结向上传递导致的 `break` 生成不匹配死块 BUG 及重复 `Branch` 问题。
4. **关键前序 Bug 修复**
- 发现了在原框架里 `src/sem/Sema.cpp` 进行 AST 解析时 `RelExp``EqExp` 对于非原生底层变量追踪由于左偏漏调规则导致 `null_ptr` (`变量使用缺少语义绑定a`) 报错的问题,并做出了精修复。
### 🧪 测试验证
- **Lab2 语义分析**:修复后所有已有的语义正例验证正常。
- **IR 生成与后端执行**:✅ 自建嵌套含复合逻辑循环脚本测试通过。
- **验证命令**(运行含 break 和 while 的范例文件):
```bash
cd build && make -j$(nproc) && cd .. && ./scripts/verify_ir.sh test/test_case/functional/29_break.sy --run
```
**完整测试脚本**
```bash
for f in test/test_case/functional/*.sy; do echo "Testing $f..."; ./scripts/verify_ir.sh "$f" --run > /dev/null || echo "FAILED $f"; done
```
## 人员 3 完成情况详细说明(更新于 2026-04-06
### ✅ 已完成任务
人员 3 (hp) 已完整实现 Lab2 IR 生成中函数及常量的扩展支持,包括:
1. **支持全局变量声明与初始化**(任务 3.1
- 在 `IRGenDecl.cpp` 中通过判断 `func_ == nullptr` 区分全局和局部作用域。
- 扩充了 `Float` / `PtrFloat``ConstantFloat` 等浮点数支持,补充 `GlobalVariable` 派生类。
- 正确调用 `module_.CreateGlobalVariable` 处理整型和浮点型全局初始化,维护在 `storage_map_` 中。
2. **支持函数参数处理**(任务 3.2
- 在 `IR.h``Value` 体系中增加 `Argument` 类。
- 在 `IRGenFunc.cpp` 中实现对 `funcFParams` 的处理。
- 在入口块为每个参数 `alloca` 栈槽,通过 `store` 存入形参初值,并绑定至 `storage_map_` 供内部读取。
3. **支持函数调用生成**(任务 3.3
- 在 `IR.h``IRBuilder.cpp` 补充 `Opcode::Call``CallInst` 及其打印逻辑。
- 在 `IRGenExp.cpp` (`visitUnaryExp`) 支持 `funcCallExp` 解析。
- 提取计算所有的实参表达式 (`funcRParams`) 后生成 `call` 指令;对于库函数支持基于 `Sema` 的占位符签名构建。
4. **支持 const 常量声明**(任务 3.4
- 在 `IRGenDecl.cpp` 新增 `visitConstDecl``visitConstDef` 实现。
- 维护独立的 `const_values_` 映射表记录 `ConstantValue*`
- 在 `visitLVal` 时如果检测到是已定义的常量,直接嵌入常量值完成折叠,省去内存的 `load` 开销。
### 🧪 测试验证
- **全局/局部变量、常量引用测试**:✅ IR 输出正确(通过访问 `storage_map_``const_values_` 获取数据)。
- **参数传递与函数调用链路测试**:✅ 多参数函数(包含返回值)和调用外部 `putint` 的样例生成的 LLVM IR 结构清晰、运行正确。
- **集成测试验证**:✅ 能完美与人员 1 和人员 2 的前置工作合并通过,确保了控制流、运算体系与函数调用的兼容。
### 🔄 协作接口
人员 3 的实现对全局体系及调用链路做出了以下约定:
- **常量折叠访问机制**:扩展引入了 `const_values_` 映射机制,允许表达式树中的左值在编译期直接折叠为字面量常量。
- **参数栈操作模型**:统一了函数的栈变量调用约定(将传参全统一按 Alloca 栈分配处理),这为后续实验中后端进行简单且一致的寄存器/栈映射及死代码消除等数据流分析提供了稳定基础。

@ -1,77 +0,0 @@
# Lab3指令选择与汇编生成 - 开发进度与总结
本文档总结了实验 3 的任务目标、实现细节及当前进度,旨在为后续开发(如优化或改进)提供清晰的参考。
## 1. 实验任务概述
本阶段的任务是实现编译器的后端部分,将 Lab2 产生的 LLVM 风格中间表示IR翻译为 ARM64/AArch64 汇编代码。生成的汇编代码需能够:
- 通过交叉编译器(`aarch64-linux-gnu-gcc`)与 SysY 标准库(`sylib.c`)进行链接。
- 在 QEMU 模拟器或真实 AArch64 环境中正确执行。
- 完整覆盖 SysY 2022 规范,包括标量运算、多维数组访问、函数递归调用、浮点数运算及标准库函数交互。
## 2. 当前实现状态
**目前处于可用但仍待优化阶段**。功能测试可稳定通过,性能测试中个别样例仍存在运行时间过长或行为不稳定的问题,后端生成效率和代码质量仍有较大提升空间。
## 3. 核心逻辑与关键实现点
- **指令映射与选择**
- 实现了从 IR 到机器指令MachineInstr的映射。
- 针对 SysY 特有的运算(如取模 `%`),通过 `sdiv``msub` 指令组合实现。
- 针对比较运算,采用了 `cmp` 配合 `cset` 生成布尔值的方案。
- **全量浮点支持**
- 引入了 S0-S15 浮点寄存器体系。
- 实现了浮点算术(`fadd`, `fsub`, `fmul`, `fdiv`)、比较(`fcmp`)及类型转换(`scvtf`, `fcvtzs`)。
- **多维数组地址计算GEP**
- 实现了递归的地址偏移计算逻辑。
- 能够根据数组各维度的大小自动计算复合索引对应的内存地址。
- **大栈帧访问防御机制**
- 针对 `vector_mul3` 等需要超大局部数组的用例,后端使用 `X16` 寄存器加载大偏移量。
- 解决了 `ldur/stur` 指令在偏移量超过 256 字节或 `add` 超过 4KB 时的溢出报错问题。
- **多函数栈帧管理**
- 实现了每个函数独立的 `Prologue`(序言)和 `Epilogue`(尾声)。
- 严格遵循 16 字节栈对齐规范,正确保存和恢复 FPX29与 LRX30
- **调用约定补全(本次更新)**
- 补齐了“超过 8 个参数”的栈传参与取参逻辑。
- 修复了混合参数(`int/ptr` 与 `float`)场景下寄存器编号错误的问题,按 AArch64 规则分别为 GPR/FPR 计数分配。
- 调用点新增栈参数区的 16 字节对齐分配与回收。
- **测试链路健壮性(本次更新)**
- `verify_asm.sh` 新增 QEMU 执行超时控制(默认 90 秒,可通过 `SY_QEMU_TIMEOUT` 覆盖)。
- `test_lab3_final.sh` 默认设置 `SY_QEMU_TIMEOUT=180`,避免性能样例导致整轮测试卡死。
## 4. 遗留问题与不足
当前实现仍存在以下显著问题,需要后续进一步优化和修复:
- **2025-MYO-20.sy 缺陷**:该用例在当前代码下运行虽然通过,但其逻辑对输入数据的兼容性处理较为脆弱,可能存在边界条件下访问异常的问题,急需改进优化。
- **vector_mul3.sy 缺陷**:该用例在当前代码下运行一直不推出,就像陷入死循环一样,不知道怎么回事。
- **执行性能极低**
- **性能测试耗时过长:目前的 10 个性能测试用例运行速度非常慢看对lab3是否有影响**。
- **冗余指令严重**:由于采用了全栈槽模型(所有变量均存储在内存中),导致生成的汇编中充斥着大量的 `ldr/str` 指令。
- **寄存器分配缺失**目前完全没有实现真正的寄存器分配逻辑Lab5 任务),寄存器利用率极低。
- **调用约定仍不完整**:虽然已支持 `>8` 参数与混合 `int/float` 参数寄存器分配,但尚未覆盖更完整 ABI 细节(如更复杂聚合类型参数传递)。
- **缺乏指令优化**:生成的指令序列较为死板,未进行窥孔优化或指令合并(如 `add` 移位操作的充分利用)。
## 5. 编译与运行指南
### 编译项目
```bash
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build -j "$(nproc)"
```
### 自动化全量验证
```bash
# 运行整合后的 21 个官方用例测试脚本
./scripts/test_lab3_final.sh
```
### 官方脚本单例验证
```bash
# 格式:./scripts/verify_asm.sh <.sy> <结果目录> --run
./scripts/verify_asm.sh test/test_case/functional/simple_add.sy test/test_result/manual --run
```

@ -1,436 +0,0 @@
# 实验进度与测试方法
## 1. 当前实验进度
本文档用于记录当前仓库在各个 Lab 上的实现状态,以及对应的测试与验证方式。
需要注意:本仓库当前仍处于“课程示例框架 + 逐步补全”的阶段,并不是一个已经完整实现全部 SysY 语义的编译器。
### 1.1 Lab1 当前进度
Lab1 对应前端语法分析与语法树构建。
当前状态:
- 已提供 `SysY.g4`、ANTLR 驱动与语法树打印能力。
- 已支持通过 `--emit-parse-tree` 输出语法树。
- 可使用 `parse-only` 模式单独构建前端,不依赖 `sem` / `irgen` / `mir`
### 1.2 Lab2 当前进度
Lab2 对应“语法树 -> 语义检查 -> IR”。
当前状态可以拆成两部分来看:
1. `Sema`
- 已完成一版基于当前 SysY grammar 的语义检查基础实现。
- 已支持多层作用域、变量/常量重定义检查、先声明后使用。
- 已支持函数符号收集、函数调用检查、`main` 入口检查。
- 已支持 `break` / `continue` 使用位置检查。
- 已支持 `return` 与函数返回类型匹配检查。
- 已支持 `const` 常量表达式求值、数组维度检查、全局初始化常量性检查。
- 已支持 `int/float` 标量表达式、比较、逻辑表达式的基础类型检查。
- 已内建 `getint`、`putch`、`getfloat`、`getarray`、`putarray` 等常见运行库函数声明。
2. `IRGen`
- 当前仓库原有 `IRGen` 仍是最小示例版本。
- 当前只适合支持“局部 `int` 变量 + 常量 + 简单表达式 + `return`”这类极小子集。
- 由于 grammar 已扩展,而 `IRGen` 尚未完全同步,所以 Lab2 目前**只完成了前半部分Sema 基础扩展**。
- Lab2 的 IR 生成部分仍需继续补全。
### 1.3 Lab3 当前进度
Lab3 对应“IR -> MIR -> 汇编”。
当前状态:
- 仓库中保留了最小后端链路。
- 仅适合消费当前最小 IR 子集。
- 尚不具备对完整 SysY 程序稳定生成汇编的能力。
### 1.4 Lab4-Lab6 当前进度
当前仓库已经预留:
- IR 分析与 Pass 目录结构
- `Mem2Reg`、`ConstFold`、`ConstProp`、`DCE`、`CSE`、`CFGSimplify` 等文件框架
- 循环分析、支配树、后端优化等实验入口
但这些阶段是否“完成”,取决于你们后续自行补全,不应默认认为仓库当前已经完全实现。
## 2. 推荐测试思路
建议把测试分成三层:
1. `单阶段验证`
- 只验证某个阶段是否工作,例如只看 parse、只看 sema、只看 IR 输出。
2. `链路验证`
- 从源码一路走到 IR 或汇编,再运行程序,比对 `.out`
3. `批量回归`
- 对 `test/test_case` 下多个测试统一执行,避免只靠 `simple_add.sy` 判断功能是否完成。
## 3. 别人拉取当前实现后的推荐编译方式
如果其他同学拉取了当前仓库,建议按下面顺序准备环境并编译。
### 3.1 先生成 ANTLR 输出
当前仓库的 CMake 会收集构建目录中的 ANTLR 生成文件,但不会自动调用 ANTLR所以第一次构建前应先执行
```bash
mkdir -p build/generated/antlr4
java -jar third_party/antlr-4.13.2-complete.jar \
-Dlanguage=Cpp \
-visitor -no-listener \
-Xexact-output-dir \
-o build/generated/antlr4 \
src/antlr4/SysY.g4
```
### 3.2 如果只想验证 Lab1
只构建 parse-only 前端:
```bash
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=ON
cmake --build build -j "$(nproc)"
```
构建后可直接运行:
```bash
./scripts/test_lab1.sh test/test_case/functional
```
### 3.3 如果想验证当前 Lab2 的 Sema 部分
由于当前仓库中的 `IRGen` 还没有完全跟上新 grammar而我们这次主要完成的是 `Sema`,所以推荐单独准备一个 `build-sema/` 目录来验证语义检查。
推荐命令如下:
```bash
cmake -S . -B build-sema -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF
mkdir -p build-sema/generated
cp -r build/generated/antlr4 build-sema/generated/
cmake --build build-sema --target frontend utils sem -j "$(nproc)"
```
然后编译 `sema_check`
```bash
g++ -std=c++17 \
-Iinclude \
-Isrc \
-Ibuild-sema/generated/antlr4 \
-Ithird_party/antlr4-runtime-4.13.2/runtime/src \
tools/sema_check.cpp \
build-sema/src/sem/libsem.a \
build-sema/src/frontend/libfrontend.a \
build-sema/src/utils/libutils.a \
build-sema/libantlr4_runtime.a \
-pthread \
-o build-sema/sema_check
```
完成后即可运行:
```bash
./scripts/test_lab2_sema.sh positive
./scripts/test_lab2_sema.sh negative
```
说明:
- `build/` 主要用于 Lab1 parse-only 或后续全量构建
- `build-sema/` 主要用于当前阶段单独验证 `Sema`
- `scripts/test_lab2_sema.sh` 依赖 `./build-sema/sema_check`
### 3.4 如果后续要做全量构建
`IRGen` 与 grammar 完全同步后,可直接做全量构建:
```bash
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF
cmake --build build -j "$(nproc)"
```
但在当前阶段,不建议把“全量 build 成功”作为验证 `Sema` 的唯一标准,因为 Lab2 目前完成的是语义分析前半部分,不是整套 IR 生成。
## 4. Lab1 测试方法
### 3.1 构建命令
先生成 ANTLR 输出:
```bash
mkdir -p build/generated/antlr4
java -jar third_party/antlr-4.13.2-complete.jar \
-Dlanguage=Cpp \
-visitor -no-listener \
-Xexact-output-dir \
-o build/generated/antlr4 \
src/antlr4/SysY.g4
```
然后使用 `parse-only` 构建:
```bash
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=ON
cmake --build build -j "$(nproc)"
```
### 3.2 单个样例测试
```bash
./build/bin/compiler --emit-parse-tree test/test_case/functional/simple_add.sy
```
### 3.3 批量测试
仓库已提供 parse 批量测试脚本。为避免终端直接打印大量语法树导致输出过长,脚本会把每个用例的语法树输出写入单独日志文件。
```bash
./scripts/test_lab1.sh test/test_case/functional
```
如果希望指定日志目录,可以使用:
```bash
./scripts/test_lab1.sh test/test_case/functional test/test_result/lab1_parse_logs
```
终端中会看到形如:
```text
TEST test/test_case/functional/simple_add.sy -> test/test_result/lab1_parse_logs/simple_add.parse.log
...
ALL_PARSE_OK (...) logs: test/test_result/lab1_parse_logs
```
说明当前测试目录中的 `.sy` 文件都能通过语法分析;具体语法树内容可直接查看对应 `.parse.log` 文件。
## 5. Lab2 测试方法
Lab2 建议分成两部分测试:`Sema` 和 `IRGen`
### 4.1 Lab2 当前推荐先测 Sema
因为当前仓库中 `IRGen` 还未完全同步到新 grammar所以当前阶段更适合先用“语义检查”来证明 Lab2 前半部分已经实现。
#### 4.1.1 当前已验证通过的正例
下面这些测试用例已经可以作为当前 `Sema` 的正向样例:
```bash
./scripts/test_lab2_sema.sh positive
```
如果希望指定日志目录,可以使用:
```bash
./scripts/test_lab2_sema.sh positive test/test_result/lab2_sema_positive_logs
```
预期现象:
- 终端按用例打印 `TEST ... -> ...`
- 全部通过后输出 `ALL_SEMA_POSITIVE_OK (...)`
- 详细输出写入 `*.sema.log`
#### 4.1.2 当前可用于演示的反例
当前已经准备好的反例位于:
- `test/test_case/sema_negative/undef.sy`
- `test/test_case/sema_negative/break.sy`
- `test/test_case/sema_negative/ret.sy`
- `test/test_case/sema_negative/call.sy`
执行命令:
```bash
./scripts/test_lab2_sema.sh negative
```
如果希望指定日志目录,可以使用:
```bash
./scripts/test_lab2_sema.sh negative test/test_result/lab2_sema_negative_logs
```
预期现象:
- 终端按用例打印 `TEST ... -> ...`
- 全部符合预期后输出 `ALL_SEMA_NEGATIVE_OK (...)`
- 每个反例的详细错误信息写入对应 `.sema.log`
例如:
- 使用未声明变量
- 循环外 `break`
- `void` 函数返回值
- 函数参数个数不匹配
#### 4.1.3 语义错误定位信息说明
语义错误信息中的 `@行:列` 用于标明错误位置。
例如:
```text
[error] [sema] @1:19 - 使用了未声明的标识符: a
```
表示:
- `1` 是第 1 行
- `19` 是第 19 列
也就是提示错误出现在源代码第 1 行第 19 列附近,便于快速定位。
#### 4.1.4 当前 Sema 已覆盖的主要错误类型
当前已实现的典型错误检测包括:
- 未声明标识符使用
- 同作用域重定义
- 函数重定义
- 缺少合法 `main`
- 函数参数数量或类型不匹配
- `break/continue` 不在循环中
- `return` 与函数返回类型不匹配
- 给 `const` 对象赋值
- 数组维度非法
- 全局初始化不满足编译期常量要求
### 4.2 Lab2 后续 IR 测试方式
`IRGen` 与当前 grammar 对齐后,可使用如下命令输出 IR
```bash
./build/bin/compiler --emit-ir test/test_case/functional/simple_add.sy
```
若需要进一步验证 “IR -> 可执行程序” 链路,可使用:
```bash
./scripts/verify_ir.sh test/test_case/functional/simple_add.sy test/test_result/ir --run
```
但需要强调:
在当前仓库状态下,这条命令只适合用于未来 IRGen 完成后的测试;不能拿它来证明当前已完成的 `Sema` 部分。
## 6. Lab3 测试方法
Lab3 对应汇编输出与后端链路。
### 5.1 构建
需要全量构建:
```bash
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF
cmake --build build -j "$(nproc)"
```
### 5.2 单个样例输出汇编
```bash
./build/bin/compiler --emit-asm test/test_case/functional/simple_add.sy
```
### 5.3 汇编链路验证
```bash
./scripts/verify_asm.sh test/test_case/functional/simple_add.sy test/test_result/asm --run
```
`--run` 模式下会:
1. 生成汇编
2. 交叉编译为 AArch64 可执行文件
3. 用 `qemu-aarch64` 运行
4. 将输出与同名 `.out` 比对
## 7. Lab4 测试方法
Lab4 是优化实验,测试重点不只是“能不能运行”,还包括“优化前后语义一致”。
建议按下面顺序验证:
1. 先确保未优化版本功能正确
2. 接入优化后再次跑 `verify_ir.sh``verify_asm.sh`
3. 比较优化前后的 IR 或汇编输出
4. 在多个测试上回归,避免某个优化只在 `simple_add` 上看起来没问题
推荐命令:
```bash
./scripts/verify_ir.sh test/test_case/functional/simple_add.sy test/test_result/ir --run
./scripts/verify_asm.sh test/test_case/functional/simple_add.sy test/test_result/asm --run
```
如果你们为优化实现了单独开关,也应额外对比:
```bash
./build/bin/compiler --emit-ir test/test_case/functional/simple_add.sy
./build/bin/compiler --emit-asm test/test_case/functional/simple_add.sy
```
## 8. Lab5 测试方法
Lab5 的测试重点是:
- 寄存器分配后代码仍然正确
- spill/reload 逻辑没有破坏语义
- 汇编仍能完整运行
推荐直接走后端完整链路:
```bash
./scripts/verify_asm.sh test/test_case/functional/simple_add.sy test/test_result/asm --run
```
完成寄存器分配后,不应只测单个样例,建议至少覆盖:
- `functional/`
- `performance/` 中若干较大样例
## 9. Lab6 测试方法
Lab6 重点是循环和并行相关优化,测试要分成功能正确性和优化收益两部分。
### 8.1 功能正确性
```bash
./scripts/verify_ir.sh test/test_case/functional/simple_add.sy test/test_result/ir --run
./scripts/verify_asm.sh test/test_case/functional/simple_add.sy test/test_result/asm --run
```
### 8.2 优化效果观察
你们可以对比优化前后的:
- IR 输出
- 汇编输出
- 执行时间
- 代码规模
例如:
```bash
./build/bin/compiler --emit-ir test/test_case/functional/simple_add.sy
./build/bin/compiler --emit-asm test/test_case/functional/simple_add.sy
```
真正评估循环优化时,建议使用包含明显循环结构的功能或性能测试,而不是只看 `simple_add.sy`
## 10. 当前阶段的建议结论
如果你要汇报当前仓库状态,可以概括为:
1. Lab1 的语法树构建链路已经具备独立测试方式。
2. Lab2 当前已经完成 `Sema` 基础扩展,并可通过正反例直接演示。
3. Lab2 的 `IRGen` 还需要继续补全,当前不能把整份 Lab2 视为全部完成。
4. Lab3 及后续实验目前主要还是框架和最小样例能力,完整覆盖仍需后续实现。

@ -1,24 +1,27 @@
// 当前只支撑 i32、i32*、void 以及最小的内存/算术指令,演示用。 // 当前只支撑 i32、i32*、void 以及最小的内存/算术指令,演示用。
// //
// 当前已经实现: // 当前已经实现:
// 1. 基础类型系统void / i32 / i32* / float / float* / array / pointer // 1. 基础类型系统void / i32 / i32*
// 2. Value 体系Value / ConstantValue / ConstantInt / ConstantFloat / ConstantArray / ConstantZero / Function / BasicBlock / User / GlobalValue / Instruction // 2. Value 体系Value / ConstantValue / ConstantInt / Function / BasicBlock / User / GlobalValue / Instruction
// 3. 最小指令集Add / Sub / Mul / Div / Mod / Neg / Alloca / Load / Store / Ret / Cmp / FCmp / Zext / Br / CondBr / Call / GEP / SIToFP / FPToSI // 3. 最小指令集Add / Alloca / Load / Store / Ret
// 4. BasicBlock / Function / Module 三层组织结构 // 4. BasicBlock / Function / Module 三层组织结构
// 5. IRBuilder便捷创建常量和各类指令 // 5. IRBuilder便捷创建常量和最小指令
// 6. def-use 关系的轻量实现: // 6. def-use 关系的轻量实现:
// - Instruction 保存 operand 列表 // - Instruction 保存 operand 列表
// - Value 保存 uses // - Value 保存 uses
// - 支持 ReplaceAllUsesWith 的简化实现 // - 支持 ReplaceAllUsesWith 的简化实现
// //
// 当前尚未实现或只做了最小占位: // 当前尚未实现或只做了最小占位:
// 1. 完整类型系统label 类型等 // 1. 完整类型系统数组、函数类型、label 类型等
// 2. 更成熟的 Use 管理(例如 LLVM 风格的双向链式结构) // 2. 更完整的指令系统br / condbr / call / phi / gep 等
// 3. 更完整的 IR verifier 和优化基础设施 // 3. 更成熟的 Use 管理(例如 LLVM 风格的双向链式结构)
// 4. 更完整的 IR verifier 和优化基础设施
// //
// 当前需要特别说明的两个简化点: // 当前需要特别说明的两个简化点:
// 1. BasicBlock 虽然已经纳入 Value 体系,但其类型目前仍用 void 作为占位, // 1. BasicBlock 虽然已经纳入 Value 体系,但其类型目前仍用 void 作为占位,
// 后续如果补 label type可以再改成更合理的块标签类型。 // 后续如果补 label type可以再改成更合理的块标签类型。
// 2. ConstantValue 体系目前只实现了 ConstantInt后续可以继续补 ConstantFloat、
// ConstantArray等更完整的常量种类。
// //
// 建议的扩展顺序: // 建议的扩展顺序:
// 1. 先补更多指令和类型 // 1. 先补更多指令和类型
@ -42,53 +45,16 @@ class Value;
class User; class User;
class ConstantValue; class ConstantValue;
class ConstantInt; class ConstantInt;
class ConstantFloat;
class ConstantArray;
class ConstantZero;
class GlobalValue; class GlobalValue;
class Instruction; class Instruction;
class BasicBlock; class BasicBlock;
class Function; class Function;
// --- Type System --- // Use 表示一个 Value 的一次使用记录。
// 当前实现设计:
class Type { // - value被使用的值
public: // - user使用该值的 User
enum class Kind { Void, Int1, Int32, PtrInt32, Float, PtrFloat, Array, Pointer }; // - operand_index该值在 user 操作数列表中的位置
explicit Type(Kind k);
Type(Kind k, std::shared_ptr<Type> elem_ty, int num_elems);
Type(Kind k, std::shared_ptr<Type> pointed_ty);
static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt1Type();
static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetPtrInt32Type();
static const std::shared_ptr<Type>& GetFloatType();
static const std::shared_ptr<Type>& GetPtrFloatType();
static std::shared_ptr<Type> GetArrayType(std::shared_ptr<Type> elem_ty, int num_elems);
static std::shared_ptr<Type> GetPointerType(std::shared_ptr<Type> pointed_ty);
Kind GetKind() const;
bool IsVoid() const;
bool IsInt1() const;
bool IsInt32() const;
bool IsPtrInt32() const;
bool IsFloat() const;
bool IsPtrFloat() const;
bool IsArray() const;
bool IsPointer() const;
std::shared_ptr<Type> GetElementType() const { return elem_ty_; }
int GetNumElements() const { return num_elems_; }
std::shared_ptr<Type> GetPointedType() const { return elem_ty_; }
private:
Kind kind_;
std::shared_ptr<Type> elem_ty_;
int num_elems_ = 0;
};
// --- Value & Use ---
class Use { class Use {
public: public:
@ -110,6 +76,40 @@ class Use {
size_t operand_index_ = 0; size_t operand_index_ = 0;
}; };
// IR 上下文:集中管理类型、常量等共享资源,便于复用与扩展。
class Context {
public:
Context() = default;
~Context();
// 去重创建 i32 常量。
ConstantInt* GetConstInt(int v);
std::string NextTemp();
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
int temp_index_ = -1;
};
class Type {
public:
enum class Kind { Void, Int32, PtrInt32 };
explicit Type(Kind k);
// 使用静态共享对象获取类型。
// 同一类型可直接比较返回值是否相等,例如:
// Type::GetInt32Type() == Type::GetInt32Type()
static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetPtrInt32Type();
Kind GetKind() const;
bool IsVoid() const;
bool IsInt32() const;
bool IsPtrInt32() const;
private:
Kind kind_;
};
class Value { class Value {
public: public:
Value(std::shared_ptr<Type> ty, std::string name); Value(std::shared_ptr<Type> ty, std::string name);
@ -118,16 +118,12 @@ class Value {
const std::string& GetName() const; const std::string& GetName() const;
void SetName(std::string n); void SetName(std::string n);
bool IsVoid() const; bool IsVoid() const;
bool IsInt1() const;
bool IsInt32() const; bool IsInt32() const;
bool IsPtrInt32() const; bool IsPtrInt32() const;
bool IsFloat() const;
bool IsPtrFloat() const;
bool IsConstant() const; bool IsConstant() const;
bool IsInstruction() const; bool IsInstruction() const;
bool IsUser() const; bool IsUser() const;
bool IsFunction() const; bool IsFunction() const;
bool IsArgument() const;
void AddUse(User* user, size_t operand_index); void AddUse(User* user, size_t operand_index);
void RemoveUse(User* user, size_t operand_index); void RemoveUse(User* user, size_t operand_index);
const std::vector<Use>& GetUses() const; const std::vector<Use>& GetUses() const;
@ -139,18 +135,8 @@ class Value {
std::vector<Use> uses_; std::vector<Use> uses_;
}; };
class Argument : public Value { // ConstantValue 是常量体系的基类。
public: // 当前只实现了 ConstantInt后续可继续扩展更多常量种类。
Argument(std::shared_ptr<Type> ty, std::string name, Function* parent, size_t arg_no);
Function* GetParent() const;
size_t GetArgNo() const;
private:
Function* parent_;
size_t arg_no_;
};
// --- Constants ---
class ConstantValue : public Value { class ConstantValue : public Value {
public: public:
ConstantValue(std::shared_ptr<Type> ty, std::string name = ""); ConstantValue(std::shared_ptr<Type> ty, std::string name = "");
@ -165,56 +151,11 @@ class ConstantInt : public ConstantValue {
int value_{}; int value_{};
}; };
class ConstantFloat : public ConstantValue { // 后续还需要扩展更多指令类型。
public: enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret };
ConstantFloat(std::shared_ptr<Type> ty, float v);
float GetValue() const { return value_; }
private:
float value_{};
};
class ConstantArray : public ConstantValue {
public:
ConstantArray(std::shared_ptr<Type> ty, std::vector<ConstantValue*> elements);
const std::vector<ConstantValue*>& GetElements() const { return elements_; }
private:
std::vector<ConstantValue*> elements_;
};
class ConstantZero : public ConstantValue {
public:
explicit ConstantZero(std::shared_ptr<Type> ty);
};
// --- Context ---
class Context {
public:
Context() = default;
~Context();
ConstantInt* GetConstInt(int v);
ConstantFloat* GetConstFloat(float v);
ConstantArray* GetConstArray(std::shared_ptr<Type> ty, std::vector<ConstantValue*> elements);
ConstantZero* GetConstZero(std::shared_ptr<Type> ty);
std::string NextTemp();
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<float, std::unique_ptr<ConstantFloat>> const_floats_;
std::vector<std::unique_ptr<ConstantArray>> const_arrays_;
std::vector<std::unique_ptr<ConstantZero>> const_zeros_;
int temp_index_ = -1;
};
// --- Instructions ---
enum class Opcode { Add, Sub, Mul, Div, Mod, Neg, Alloca, Load, Store, Ret, Cmp, FCmp, Zext, Br, CondBr, Call, GEP, SIToFP, FPToSI };
enum class CmpOp { Eq, Ne, Lt, Gt, Le, Ge };
// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。
// 当前实现中只有 Instruction 继承自 User。
class User : public Value { class User : public Value {
public: public:
User(std::shared_ptr<Type> ty, std::string name); User(std::shared_ptr<Type> ty, std::string name);
@ -223,25 +164,20 @@ class User : public Value {
void SetOperand(size_t index, Value* value); void SetOperand(size_t index, Value* value);
protected: protected:
// 统一的 operand 入口。
void AddOperand(Value* value); void AddOperand(Value* value);
private: private:
std::vector<Value*> operands_; std::vector<Value*> operands_;
}; };
// GlobalValue 是全局值/全局变量体系的空壳占位类。
// 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。
class GlobalValue : public User { class GlobalValue : public User {
public: public:
GlobalValue(std::shared_ptr<Type> ty, std::string name); GlobalValue(std::shared_ptr<Type> ty, std::string name);
}; };
class GlobalVariable : public GlobalValue {
public:
GlobalVariable(std::string name, std::shared_ptr<Type> type, ConstantValue* init);
ConstantValue* GetInitializer() const { return init_; }
private:
ConstantValue* init_ = nullptr;
};
class Instruction : public User { class Instruction : public User {
public: public:
Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name = ""); Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name = "");
@ -263,13 +199,6 @@ class BinaryInst : public Instruction {
Value* GetRhs() const; Value* GetRhs() const;
}; };
class UnaryInst : public Instruction {
public:
UnaryInst(Opcode op, std::shared_ptr<Type> ty, Value* operand,
std::string name);
Value* GetUnaryOperand() const;
};
class ReturnInst : public Instruction { class ReturnInst : public Instruction {
public: public:
ReturnInst(std::shared_ptr<Type> void_ty, Value* val); ReturnInst(std::shared_ptr<Type> void_ty, Value* val);
@ -294,80 +223,8 @@ class StoreInst : public Instruction {
Value* GetPtr() const; Value* GetPtr() const;
}; };
class CmpInst : public Instruction { // BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。
public: // 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。
CmpInst(CmpOp cmp_op, Value* lhs, Value* rhs, std::string name);
CmpOp GetCmpOp() const;
Value* GetLhs() const;
Value* GetRhs() const;
private:
CmpOp cmp_op_;
};
class FCmpInst : public Instruction {
public:
FCmpInst(CmpOp cmp_op, Value* lhs, Value* rhs, std::string name);
CmpOp GetCmpOp() const;
Value* GetLhs() const;
Value* GetRhs() const;
private:
CmpOp cmp_op_;
};
class ZextInst : public Instruction {
public:
ZextInst(std::shared_ptr<Type> dest_ty, Value* val, std::string name);
Value* GetValue() const;
};
class BranchInst : public Instruction {
public:
BranchInst(BasicBlock* dest);
BasicBlock* GetDest() const;
};
class CondBranchInst : public Instruction {
public:
CondBranchInst(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb);
Value* GetCond() const;
BasicBlock* GetTrueBlock() const;
BasicBlock* GetFalseBlock() const;
};
class CallInst : public Instruction {
public:
CallInst(Function* func, std::vector<Value*> args, std::string name = "");
Function* GetFunc() const;
const std::vector<Value*>& GetArgs() const;
private:
Function* func_;
std::vector<Value*> args_;
};
class GEPInst : public Instruction {
public:
GEPInst(std::shared_ptr<Type> ty, Value* ptr, std::vector<Value*> indices, std::string name = "");
Value* GetPtr() const;
const std::vector<Value*>& GetIndices() const;
private:
std::vector<Value*> indices_;
};
class SIToFPInst : public Instruction {
public:
SIToFPInst(std::shared_ptr<Type> ty, Value* val, std::string name = "");
};
class FPToSIInst : public Instruction {
public:
FPToSIInst(std::shared_ptr<Type> ty, Value* val, std::string name = "");
};
// --- Structure ---
class BasicBlock : public Value { class BasicBlock : public Value {
public: public:
explicit BasicBlock(std::string name); explicit BasicBlock(std::string name);
@ -397,21 +254,24 @@ class BasicBlock : public Value {
std::vector<BasicBlock*> successors_; std::vector<BasicBlock*> successors_;
}; };
// Function 当前也采用了最小实现。
// 需要特别注意:由于项目里还没有单独的 FunctionType
// Function 继承自 Value 后,其 type_ 目前只保存“返回类型”,
// 并不能完整表达“返回类型 + 形参列表”这一整套函数签名。
// 这对当前只支持 int main() 的最小 IR 足够,但后续若补普通函数、
// 形参和调用,通常需要引入专门的函数类型表示。
class Function : public Value { class Function : public Value {
public: public:
// 当前构造函数接收的也是返回类型,而不是完整函数类型。
Function(std::string name, std::shared_ptr<Type> ret_type); Function(std::string name, std::shared_ptr<Type> ret_type);
BasicBlock* CreateBlock(const std::string& name); BasicBlock* CreateBlock(const std::string& name);
BasicBlock* GetEntry(); BasicBlock* GetEntry();
const BasicBlock* GetEntry() const; const BasicBlock* GetEntry() const;
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const; const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const;
Argument* AddArgument(std::shared_ptr<Type> ty, std::string name);
const std::vector<std::unique_ptr<Argument>>& GetArgs() const;
private: private:
BasicBlock* entry_ = nullptr; BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<BasicBlock>> blocks_; std::vector<std::unique_ptr<BasicBlock>> blocks_;
std::vector<std::unique_ptr<Argument>> args_;
}; };
class Module { class Module {
@ -419,17 +279,14 @@ class Module {
Module() = default; Module() = default;
Context& GetContext(); Context& GetContext();
const Context& GetContext() const; const Context& GetContext() const;
// 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。
Function* CreateFunction(const std::string& name, Function* CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type); std::shared_ptr<Type> ret_type);
const std::vector<std::unique_ptr<Function>>& GetFunctions() const; const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
GlobalVariable* CreateGlobalVariable(const std::string& name, std::shared_ptr<Type> type, ConstantValue* init);
const std::vector<std::unique_ptr<GlobalVariable>>& GetGlobalVariables() const;
private: private:
Context context_; Context context_;
std::vector<std::unique_ptr<Function>> functions_; std::vector<std::unique_ptr<Function>> functions_;
std::vector<std::unique_ptr<GlobalVariable>> global_variables_;
}; };
class IRBuilder { class IRBuilder {
@ -438,27 +295,15 @@ class IRBuilder {
void SetInsertPoint(BasicBlock* bb); void SetInsertPoint(BasicBlock* bb);
BasicBlock* GetInsertBlock() const; BasicBlock* GetInsertBlock() const;
// 构造常量、二元运算、返回指令的最小集合。
ConstantInt* CreateConstInt(int v); ConstantInt* CreateConstInt(int v);
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs, BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name); const std::string& name);
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name); BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name);
UnaryInst* CreateNeg(Value* operand, const std::string& name);
AllocaInst* CreateAllocaI32(const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name);
AllocaInst* CreateAllocaFloat(const std::string& name);
AllocaInst* CreateAlloca(std::shared_ptr<Type> ty, const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr); StoreInst* CreateStore(Value* val, Value* ptr);
ReturnInst* CreateRet(Value* v); ReturnInst* CreateRet(Value* v);
Instruction* CreateCmp(CmpOp op, Value* lhs, Value* rhs, const std::string& name);
ZextInst* CreateZext(Value* val, const std::string& name);
BranchInst* CreateBr(BasicBlock* dest);
CondBranchInst* CreateCondBr(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb);
CallInst* CreateCall(Function* func, std::vector<Value*> args, const std::string& name);
GEPInst* CreateGEP(std::shared_ptr<Type> ty, Value* ptr, std::vector<Value*> indices, const std::string& name);
SIToFPInst* CreateSIToFP(Value* val, const std::string& name);
FPToSIInst* CreateFPToSI(Value* val, const std::string& name);
private: private:
Context& ctx_; Context& ctx_;

@ -26,26 +26,16 @@ class IRGenImpl final : public SysYBaseVisitor {
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override;
std::any visitBlock(SysYParser::BlockContext* ctx) override; std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override;
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override;
std::any visitDecl(SysYParser::DeclContext* ctx) override; std::any visitDecl(SysYParser::DeclContext* ctx) override;
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override;
std::any visitConstDef(SysYParser::ConstDefContext* ctx) override;
std::any visitStmt(SysYParser::StmtContext* ctx) override; std::any visitStmt(SysYParser::StmtContext* ctx) override;
std::any visitVarDef(SysYParser::VarDefContext* ctx) override; std::any visitVarDef(SysYParser::VarDefContext* ctx) override;
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override; std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override;
std::any visitNumber(SysYParser::NumberContext* ctx) override; std::any visitParenExp(SysYParser::ParenExpContext* ctx) override;
std::any visitLVal(SysYParser::LValContext* ctx) override; std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override;
std::any visitAddExp(SysYParser::AddExpContext* ctx) override; std::any visitVarExp(SysYParser::VarExpContext* ctx) override;
std::any visitMulExp(SysYParser::MulExpContext* ctx) override; std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override;
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override;
std::any visitRelExp(SysYParser::RelExpContext* ctx) override;
std::any visitEqExp(SysYParser::EqExpContext* ctx) override;
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override;
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override;
std::any visitCondUnaryExp(SysYParser::CondUnaryExpContext* ctx) override;
std::any visitCond(SysYParser::CondContext* ctx) override;
std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override;
private: private:
enum class BlockFlow { enum class BlockFlow {
@ -55,60 +45,13 @@ class IRGenImpl final : public SysYBaseVisitor {
BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item); BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item);
ir::Value* EvalExpr(SysYParser::ExpContext& expr); ir::Value* EvalExpr(SysYParser::ExpContext& expr);
ir::ConstantValue* EvaluateConst(antlr4::tree::ParseTree* tree);
int EvaluateConstInt(SysYParser::ConstExpContext* ctx);
int EvaluateConstInt(SysYParser::ExpContext* ctx);
std::shared_ptr<ir::Type> GetGEPResultType(ir::Value* ptr, const std::vector<ir::Value*>& indices);
// Flatten array initializers
void FlattenInitVal(SysYParser::InitValContext* ctx,
const std::vector<int>& dims,
const std::vector<int>& sub_sizes,
int dim_idx,
size_t& current_pos,
std::vector<ir::Value*>& results,
bool is_float);
void FlattenConstInitVal(SysYParser::ConstInitValContext* ctx,
const std::vector<int>& dims,
const std::vector<int>& sub_sizes,
int dim_idx,
size_t& current_pos,
std::vector<ir::ConstantValue*>& results,
bool is_float);
ir::Module& module_; ir::Module& module_;
const SemanticContext& sema_; const SemanticContext& sema_;
ir::Function* func_; ir::Function* func_;
ir::IRBuilder builder_; ir::IRBuilder builder_;
// 考虑到嵌套作用域(全局、函数、语句块),使用 vector 模拟栈来管理 storage_map_ 和 const_values_ // 名称绑定由 Sema 负责IRGen 只维护“声明 -> 存储槽位”的代码生成状态。
std::vector<std::unordered_map<std::string, ir::Value*>> storage_map_stack_; std::unordered_map<SysYParser::VarDefContext*, ir::Value*> storage_map_;
std::vector<std::unordered_map<std::string, ir::ConstantValue*>> const_values_stack_;
// 用于在栈中查找变量
ir::Value* FindStorage(const std::string& name) const {
for (auto it = storage_map_stack_.rbegin(); it != storage_map_stack_.rend(); ++it) {
if (it->count(name)) return it->at(name);
}
return nullptr;
}
ir::ConstantValue* FindConst(const std::string& name) const {
for (auto it = const_values_stack_.rbegin(); it != const_values_stack_.rend(); ++it) {
if (it->count(name)) return it->at(name);
}
return nullptr;
}
// 用于 break 和 continue 跳转的目标位置
ir::BasicBlock* current_loop_cond_bb_ = nullptr;
ir::BasicBlock* current_loop_exit_bb_ = nullptr;
int bb_cnt_ = 0;
std::string NextBlockName(const std::string& prefix = "bb") {
return prefix + "_" + std::to_string(++bb_cnt_);
}
}; };
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree, std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,

@ -19,17 +19,7 @@ class MIRContext {
MIRContext& DefaultContext(); MIRContext& DefaultContext();
// AArch64 physical registers enum class PhysReg { W0, W8, W9, X29, X30, SP };
enum class PhysReg {
W0, W1, W2, W3, W4, W5, W6, W7,
W8, W9, W10, W11, W12, W13, W14, W15,
X0, X1, X2, X3, X4, X5, X6, X7,
X8, X9, X10, X11, X12, X13, X14, X15,
X16, X17,
S0, S1, S2, S3, S4, S5, S6, S7,
S8, S9, S10, S11, S12, S13, S14, S15,
X29, X30, SP, WZR, XZR
};
const char* PhysRegName(PhysReg reg); const char* PhysRegName(PhysReg reg);
@ -37,67 +27,31 @@ enum class Opcode {
Prologue, Prologue,
Epilogue, Epilogue,
MovImm, MovImm,
MovRR,
LoadStack, LoadStack,
StoreStack, StoreStack,
AddrStack,
LoadGlobal,
StoreGlobal,
AddRR, AddRR,
AddRRI,
AddRRR_LSL,
SubRR,
MulRR,
SDivRR,
MSubRRR,
Sxtw,
NegR,
CmpRR,
CSet,
FAdd,
FSub,
FMUL,
FDiv,
FNeg,
FCmp,
FCvtSI2FP,
FCvtFP2SI,
LoadR,
StoreR,
Call,
B,
BCond,
Ret, Ret,
}; };
enum class CondCode { EQ, NE, LT, LE, GT, GE };
class Operand { class Operand {
public: public:
enum class Kind { Reg, Imm, FrameIndex, Label, Global, Cond }; enum class Kind { Reg, Imm, FrameIndex };
static Operand Reg(PhysReg reg); static Operand Reg(PhysReg reg);
static Operand Imm(int value); static Operand Imm(int value);
static Operand FrameIndex(int index); static Operand FrameIndex(int index);
static Operand Label(const std::string& name);
static Operand Global(const std::string& name);
static Operand Cond(CondCode cc);
Kind GetKind() const { return kind_; } Kind GetKind() const { return kind_; }
PhysReg GetReg() const { return reg_; } PhysReg GetReg() const { return reg_; }
int GetImm() const { return imm_; } int GetImm() const { return imm_; }
int GetFrameIndex() const { return imm_; } int GetFrameIndex() const { return imm_; }
const std::string& GetLabel() const { return label_; }
const std::string& GetGlobal() const { return label_; }
CondCode GetCond() const { return static_cast<CondCode>(imm_); }
private: private:
Operand(Kind kind, PhysReg reg, int imm, std::string label = ""); Operand(Kind kind, PhysReg reg, int imm);
Kind kind_; Kind kind_;
PhysReg reg_; PhysReg reg_;
int imm_; int imm_;
std::string label_;
}; };
class MachineInstr { class MachineInstr {
@ -139,10 +93,8 @@ class MachineFunction {
explicit MachineFunction(std::string name); explicit MachineFunction(std::string name);
const std::string& GetName() const { return name_; } const std::string& GetName() const { return name_; }
MachineBasicBlock& GetEntry() { return entry_; }
MachineBasicBlock& CreateBlock(const std::string& name); const MachineBasicBlock& GetEntry() const { return entry_; }
std::vector<std::unique_ptr<MachineBasicBlock>>& GetBlocks() { return blocks_; }
const std::vector<std::unique_ptr<MachineBasicBlock>>& GetBlocks() const { return blocks_; }
int CreateFrameIndex(int size = 4); int CreateFrameIndex(int size = 4);
FrameSlot& GetFrameSlot(int index); FrameSlot& GetFrameSlot(int index);
@ -154,35 +106,14 @@ class MachineFunction {
private: private:
std::string name_; std::string name_;
std::vector<std::unique_ptr<MachineBasicBlock>> blocks_; MachineBasicBlock entry_;
std::vector<FrameSlot> frame_slots_; std::vector<FrameSlot> frame_slots_;
int frame_size_ = 0; int frame_size_ = 0;
}; };
struct GlobalVariable { std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
std::string name;
int init_value = 0;
size_t size = 4;
bool is_const = false;
};
class MachineModule {
public:
MachineModule() = default;
std::vector<std::unique_ptr<MachineFunction>>& GetFunctions() { return functions_; }
const std::vector<std::unique_ptr<MachineFunction>>& GetFunctions() const { return functions_; }
std::vector<GlobalVariable>& GetGlobals() { return globals_; }
const std::vector<GlobalVariable>& GetGlobals() const { return globals_; }
private:
std::vector<std::unique_ptr<MachineFunction>> functions_;
std::vector<GlobalVariable> globals_;
};
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineFunction& function); void RunRegAlloc(MachineFunction& function);
void RunFrameLowering(MachineFunction& function); void RunFrameLowering(MachineFunction& function);
void PrintAsm(const MachineModule& module, std::ostream& os); void PrintAsm(const MachineFunction& function, std::ostream& os);
} // namespace mir } // namespace mir

@ -1,69 +1,30 @@
// 基于语法树的语义检查与名称绑定。 // 基于语法树的语义检查与名称绑定。
#pragma once #pragma once
#include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "SysYParser.h" #include "SysYParser.h"
enum class SemanticType {
Void,
Int,
Float,
};
struct ScalarConstant {
SemanticType type = SemanticType::Int;
double number = 0.0;
};
struct ObjectBinding {
enum class DeclKind {
Var,
Const,
Param,
};
std::string name;
SemanticType type = SemanticType::Int;
DeclKind decl_kind = DeclKind::Var;
bool is_array_param = false;
std::vector<int> dimensions;
const SysYParser::VarDefContext* var_def = nullptr;
const SysYParser::ConstDefContext* const_def = nullptr;
const SysYParser::FuncFParamContext* func_param = nullptr;
bool has_const_value = false;
ScalarConstant const_value;
};
struct FunctionBinding {
std::string name;
SemanticType return_type = SemanticType::Int;
std::vector<ObjectBinding> params;
const SysYParser::FuncDefContext* func_def = nullptr;
bool is_builtin = false;
};
class SemanticContext { class SemanticContext {
public: public:
void BindObjectUse(const SysYParser::LValContext* use, ObjectBinding binding); void BindVarUse(SysYParser::VarContext* use,
const ObjectBinding* ResolveObjectUse( SysYParser::VarDefContext* decl) {
const SysYParser::LValContext* use) const; var_uses_[use] = decl;
}
void BindFunctionCall(const SysYParser::UnaryExpContext* call,
FunctionBinding binding);
const FunctionBinding* ResolveFunctionCall(
const SysYParser::UnaryExpContext* call) const;
void RegisterFunction(FunctionBinding binding); SysYParser::VarDefContext* ResolveVarUse(
const FunctionBinding* ResolveFunction(const std::string& name) const; const SysYParser::VarContext* use) const {
auto it = var_uses_.find(use);
return it == var_uses_.end() ? nullptr : it->second;
}
private: private:
std::unordered_map<const SysYParser::LValContext*, ObjectBinding> object_uses_; std::unordered_map<const SysYParser::VarContext*,
std::unordered_map<const SysYParser::UnaryExpContext*, FunctionBinding> SysYParser::VarDefContext*>
function_calls_; var_uses_;
std::unordered_map<std::string, FunctionBinding> functions_;
}; };
// 目前仅检查:
// - 变量先声明后使用
// - 局部变量不允许重复定义
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);

@ -1,25 +1,17 @@
// 维护对象符号的多层作用域 // 极简符号表:记录局部变量定义点
#pragma once #pragma once
#include <string> #include <string>
#include <string_view>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "sem/Sema.h" #include "SysYParser.h"
class SymbolTable { class SymbolTable {
public: public:
SymbolTable(); void Add(const std::string& name, SysYParser::VarDefContext* decl);
bool Contains(const std::string& name) const;
void EnterScope(); SysYParser::VarDefContext* Lookup(const std::string& name) const;
void ExitScope();
bool Add(const ObjectBinding& symbol);
bool ContainsInCurrentScope(std::string_view name) const;
const ObjectBinding* Lookup(std::string_view name) const;
size_t Depth() const;
private: private:
std::vector<std::unordered_map<std::string, ObjectBinding>> scopes_; std::unordered_map<std::string, SysYParser::VarDefContext*> table_;
}; };

@ -1,83 +0,0 @@
--- include/ir/IR.h
+++ include/ir/IR.h
@@ -93,6 +93,7 @@
class Type {
public:
- enum class Kind { Void, Int32, PtrInt32 };
+ enum class Kind { Void, Int1, Int32, PtrInt32 };
explicit Type(Kind k);
// 使用静态共享对象获取类型。
// 同一类型可直接比较返回值是否相等,例如:
// Type::GetInt32Type() == Type::GetInt32Type()
static const std::shared_ptr<Type>& GetVoidType();
+ static const std::shared_ptr<Type>& GetInt1Type();
static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetPtrInt32Type();
Kind GetKind() const;
bool IsVoid() const;
+ bool IsInt1() const;
bool IsInt32() const;
bool IsPtrInt32() const;
@@ -118,6 +119,7 @@
const std::string& GetName() const;
void SetName(std::string n);
bool IsVoid() const;
+ bool IsInt1() const;
bool IsInt32() const;
bool IsPtrInt32() const;
bool IsConstant() const;
@@ -153,7 +155,9 @@
// 后续还需要扩展更多指令类型。
-// enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret };
-enum class Opcode { Add, Sub, Mul, Div, Mod, Neg, Alloca, Load, Store, Ret };
+enum class Opcode { Add, Sub, Mul, Div, Mod, Neg, Alloca, Load, Store, Ret, Cmp, Zext, Br, CondBr };
+
+enum class CmpOp { Eq, Ne, Lt, Gt, Le, Ge };
// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。
@@ -231,6 +235,33 @@
Value* GetPtr() const;
};
+class CmpInst : public Instruction {
+ public:
+ CmpInst(CmpOp cmp_op, Value* lhs, Value* rhs, std::string name);
+ CmpOp GetCmpOp() const;
+ Value* GetLhs() const;
+ Value* GetRhs() const;
+ private:
+ CmpOp cmp_op_;
+};
+
+class ZextInst : public Instruction {
+ public:
+ ZextInst(std::shared_ptr<Type> dest_ty, Value* val, std::string name);
+ Value* GetValue() const;
+};
+
+class BranchInst : public Instruction {
+ public:
+ BranchInst(BasicBlock* dest);
+ BasicBlock* GetDest() const;
+};
+
+class CondBranchInst : public Instruction {
+ public:
+ CondBranchInst(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb);
+ Value* GetCond() const;
+ BasicBlock* GetTrueBlock() const;
+ BasicBlock* GetFalseBlock() const;
+};
+
// BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。
@@ -315,6 +346,10 @@
LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr);
ReturnInst* CreateRet(Value* v);
+ CmpInst* CreateCmp(CmpOp op, Value* lhs, Value* rhs, const std::string& name);
+ ZextInst* CreateZext(Value* val, const std::string& name);
+ BranchInst* CreateBr(BasicBlock* dest);
+ CondBranchInst* CreateCondBr(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb);
private:

@ -1,38 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
case_dir="${1:-test/test_case}"
log_dir="${2:-test/test_result/lab1_parse_logs}"
if [[ ! -d "$case_dir" ]]; then
echo "测试目录不存在: $case_dir" >&2
exit 1
fi
compiler="./build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建 parse-only 版本。" >&2
exit 1
fi
mkdir -p "$log_dir"
mapfile -t cases < <(find "$case_dir" -name '*.sy' | sort)
if [[ ${#cases[@]} -eq 0 ]]; then
echo "未找到任何 .sy 测试文件: $case_dir" >&2
exit 1
fi
for f in "${cases[@]}"; do
rel="${f#$case_dir/}"
safe_name="${rel//\//__}"
log_file="$log_dir/${safe_name%.sy}.parse.log"
echo "TEST $f -> $log_file"
if ! "$compiler" --emit-parse-tree "$f" >"$log_file" 2>&1; then
echo "FAIL $f (see $log_file)" >&2
exit 1
fi
done
echo "ALL_PARSE_OK (${#cases[@]} cases) logs: $log_dir"

@ -1,142 +0,0 @@
#!/usr/bin/env bash
# 实验 2 全量测试脚本 (改进版)
# 逻辑参考 verify_ir.sh 与 verify_asm.sh
# 增加了批量测试与统计功能,并确保链接 SysY 运行库 (sylib.c)
set -uo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
COMPILER="$PROJECT_ROOT/build/bin/compiler"
SYLIB="$PROJECT_ROOT/sylib/sylib.c"
RESULT_DIR="$PROJECT_ROOT/test/test_result/lab2_full"
# 检查依赖
if [[ ! -x "$COMPILER" ]]; then
echo "错误:编译器不存在,请先构建项目。"
exit 1
fi
if [[ ! -f "$SYLIB" ]]; then
echo "错误:未找到运行库 $SYLIB"
exit 1
fi
# 颜色输出
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
mkdir -p "$RESULT_DIR"
total=0
passed=0
failed=0
run_test() {
local input=$1
local base=$(basename "$input")
local stem=${base%.sy}
local input_dir=$(dirname "$input")
local out_file="$RESULT_DIR/$stem.ll"
local obj_file="$RESULT_DIR/$stem.o"
local exe_file="$RESULT_DIR/$stem"
local stdin_file="$input_dir/$stem.in"
local expected_file="$input_dir/$stem.out"
local actual_file="$RESULT_DIR/$stem.actual.out"
local stdout_file="$RESULT_DIR/$stem.stdout"
((total++)) || true
echo -n "[$total] 测试 $base ... "
# 1. 生成 IR
if ! "$COMPILER" --emit-ir "$input" > "$out_file" 2>&1; then
echo -e "${RED}IR 生成失败${NC}"
((failed++)) || true
return 1
fi
# 2. 编译 IR 到对象文件 (llc)
if ! llc -filetype=obj "$out_file" -o "$obj_file" > /dev/null 2>&1; then
echo -e "${RED}LLVM 编译失败 (llc)${NC}"
((failed++)) || true
return 1
fi
# 3. 链接运行库 (借鉴 verify_asm.sh 逻辑,但明确包含 sylib.c)
if ! clang "$obj_file" "$SYLIB" -o "$exe_file" > /dev/null 2>&1; then
echo -e "${RED}链接失败 (clang)${NC}"
((failed++)) || true
return 1
fi
# 4. 运行程序并捕获输出与退出码 (增加栈空间限制)
local status=0
ulimit -s unlimited 2>/dev/null || true
if [[ -f "$stdin_file" ]]; then
"$exe_file" < "$stdin_file" > "$stdout_file" 2>/dev/null || status=$?
else
"$exe_file" > "$stdout_file" 2>/dev/null || status=$?
fi
# 格式化实际输出 (借鉴 verify_ir.sh 格式)
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && [[ "$(tail -c 1 "$stdout_file" | wc -l)" -eq 0 ]]; then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
# 5. 比对结果
if [[ -f "$expected_file" ]]; then
# 忽略空格差异 (-b -w)
if diff -q -b -w "$expected_file" "$actual_file" > /dev/null 2>&1; then
echo -e "${GREEN} 通过${NC}"
((passed++)) || true
else
echo -e "${RED} 输出不匹配${NC}"
((failed++)) || true
fi
else
echo -e "${YELLOW}! 缺少预期输出文件${NC}"
((passed++)) || true
fi
}
# 批量运行
echo "========================================="
echo "实验 2 全量测试开始 (IR 语义验证)"
echo "========================================="
echo ""
run_batch() {
local dir=$1
if [[ ! -d "$dir" ]]; then return; fi
echo "正在测试目录: $dir"
for sy_file in $(ls "$dir"/*.sy | sort); do
run_test "$sy_file"
done
echo ""
}
run_batch "$PROJECT_ROOT/test/test_case/functional"
run_batch "$PROJECT_ROOT/test/test_case/performance"
echo "========================================="
echo "测试结果统计"
echo "========================================="
echo -e "总数:$total"
echo -e "通过:${GREEN}$passed${NC}"
echo -e "失败:${RED}$failed${NC}"
echo ""
if [[ $failed -eq 0 ]]; then
echo -e "${GREEN} 所有测试通过!实验 2 任务完成。${NC}"
exit 0
else
echo -e "${RED}$failed 个测试失败,请检查逻辑。${NC}"
exit 1
fi

@ -1,157 +0,0 @@
#!/usr/bin/env bash
# 测试 Lab2 IR 生成 - 人员 1 的任务
# 测试内容:
# - 任务 1.1: 支持更多二元运算符Sub, Mul, Div, Mod
# - 任务 1.2: 支持一元运算符(正负号)
# - 任务 1.3: 支持赋值表达式
# - 任务 1.4: 支持逗号分隔的多个变量声明
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_DIR="$PROJECT_ROOT/test/test_case/irgen_lab1_4"
RESULT_DIR="$PROJECT_ROOT/test/test_result/lab2_ir1"
# 颜色输出
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
echo "========================================="
echo "Lab2 IR 生成测试 - 部分任务验证"
echo "========================================="
echo ""
# 检查编译器是否存在
if [[ ! -x "$COMPILER" ]]; then
echo -e "${RED}错误:编译器不存在或不可执行:$COMPILER${NC}"
echo "请先运行cmake --build build"
exit 1
fi
# 检查测试目录是否存在
if [[ ! -d "$TEST_DIR" ]]; then
echo -e "${RED}错误:测试目录不存在:$TEST_DIR${NC}"
exit 1
fi
# 创建结果目录
mkdir -p "$RESULT_DIR"
# 统计
total=0
passed=0
failed=0
# 测试函数
run_test() {
local input=$1
local basename=$(basename "$input" .sy)
local expected_out="$TEST_DIR/$basename.out"
local actual_out="$RESULT_DIR/$basename.actual.out"
local ll_file="$RESULT_DIR/$basename.ll"
((total++)) || true
echo -n "测试 $basename ... "
# 生成 IR
if ! "$COMPILER" --emit-ir "$input" > "$ll_file" 2>&1; then
echo -e "${RED}IR 生成失败${NC}"
((failed++)) || true
return 1
fi
# 如果需要运行并比对输出
if [[ -f "$expected_out" ]]; then
# 编译并运行
local exe_file="$RESULT_DIR/$basename"
if ! llc -O0 -filetype=obj "$ll_file" -o "$RESULT_DIR/$basename.o" 2>/dev/null; then
echo -e "${YELLOW}LLVM 编译失败 (llc)${NC}"
cat "$ll_file"
((failed++)) || true
return 1
fi
if ! clang "$RESULT_DIR/$basename.o" -o "$exe_file" 2>/dev/null; then
echo -e "${YELLOW}链接失败 (clang)${NC}"
((failed++)) || true
return 1
fi
# 运行程序,捕获返回值(低 8 位)
local exit_code=0
"$exe_file" > "$actual_out" 2>&1 || exit_code=$?
# 处理返回值LLVM/AArch64 返回的是 8 位无符号整数)
if [[ $exit_code -gt 127 ]]; then
# 转换为有符号整数
exit_code=$((exit_code - 256))
fi
echo "$exit_code" > "$actual_out"
# 比对输出
if diff -q "$expected_out" "$actual_out" > /dev/null 2>&1; then
echo -e "${GREEN}✓ 通过${NC}"
((passed++)) || true
return 0
else
echo -e "${RED}✗ 输出不匹配${NC}"
echo " 期望:$(cat "$expected_out")"
echo " 实际:$(cat "$actual_out")"
((failed++)) || true
return 1
fi
else
# 没有期望输出,只检查 IR 生成
echo -e "${GREEN}✓ IR 生成成功${NC}"
((passed++)) || true
return 0
fi
}
# 查找所有测试用例
test_files=()
while IFS= read -r -d '' file; do
test_files+=("$file")
done < <(find "$TEST_DIR" -name "*.sy" -type f -print0 | sort -z)
if [[ ${#test_files[@]} -eq 0 ]]; then
echo -e "${RED}未找到测试用例:$TEST_DIR${NC}"
exit 1
fi
echo "找到 ${#test_files[@]} 个测试用例"
echo ""
# 运行所有测试
for test_file in "${test_files[@]}"; do
run_test "$test_file" || true
done
# 输出统计
echo ""
echo "========================================="
echo "测试结果统计"
echo "========================================="
echo -e "总数:$total"
echo -e "通过:${GREEN}$passed${NC}"
echo -e "失败:${RED}$failed${NC}"
echo ""
if [[ $failed -eq 0 ]]; then
echo -e "${GREEN}✓ 所有测试通过!${NC}"
echo ""
echo "测试覆盖:"
echo " ✓ 任务 1.1: 二元运算符Sub, Mul, Div, Mod"
echo " ✓ 任务 1.2: 一元运算符(正负号)"
echo " ✓ 任务 1.3: 赋值表达式"
echo " ✓ 任务 1.4: 逗号分隔的多变量声明"
exit 0
else
echo -e "${RED}✗ 有 $failed 个测试失败${NC}"
exit 1
fi

@ -1,92 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
mode="${1:-positive}"
log_dir="${2:-test/test_result/lab2_sema_logs}"
checker="./build-sema/sema_check"
if [[ ! -x "$checker" ]]; then
echo "未找到语义测试驱动: $checker" >&2
echo "请先准备 build-sema/sema_check。" >&2
exit 1
fi
mkdir -p "$log_dir"
case_files=()
expected_prefix=""
case "$mode" in
positive)
expected_prefix="OK"
case_files=(
"test/test_case/functional/simple_add.sy"
"test/test_case/functional/09_func_defn.sy"
"test/test_case/functional/25_scope3.sy"
"test/test_case/functional/29_break.sy"
"test/test_case/functional/05_arr_defn4.sy"
"test/test_case/functional/95_float.sy"
)
;;
negative)
expected_prefix="ERR"
case_files=(
"test/test_case/sema_negative/undef.sy"
"test/test_case/sema_negative/break.sy"
"test/test_case/sema_negative/ret.sy"
"test/test_case/sema_negative/call.sy"
)
;;
*)
echo "用法: $0 [positive|negative] [log_dir]" >&2
exit 1
;;
esac
if [[ ${#case_files[@]} -eq 0 ]]; then
echo "没有可执行的测试用例" >&2
exit 1
fi
for f in "${case_files[@]}"; do
if [[ ! -f "$f" ]]; then
echo "测试文件不存在: $f" >&2
exit 1
fi
done
all_ok=true
for f in "${case_files[@]}"; do
base="$(basename "${f%.sy}")"
log_file="$log_dir/${base}.sema.log"
echo "TEST $f -> $log_file"
set +e
"$checker" "$f" >"$log_file" 2>&1
status=$?
set -e
if ! grep -q "^${expected_prefix} $f$" "$log_file"; then
echo "FAIL $f (see $log_file)" >&2
all_ok=false
continue
fi
if [[ "$mode" == "positive" && $status -ne 0 ]]; then
echo "FAIL $f (expected success, see $log_file)" >&2
all_ok=false
continue
fi
if [[ "$mode" == "negative" && $status -eq 0 ]]; then
echo "FAIL $f (expected semantic error, see $log_file)" >&2
all_ok=false
continue
fi
done
if [[ "$all_ok" != true ]]; then
exit 1
fi
echo "ALL_SEMA_${mode^^}_OK (${#case_files[@]} cases) logs: $log_dir"

@ -1,124 +0,0 @@
#!/usr/bin/env bash
# Lab3 指令选择与汇编生成 - 最终全量测试脚本
# 整合了所有阶段的测试,参考 verify_asm.sh 官方逻辑
set -uo pipefail
# 路径配置
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
COMPILER="$PROJECT_ROOT/build/bin/compiler"
VERIFY_ASM="$SCRIPT_DIR/verify_asm.sh"
RESULT_DIR="$PROJECT_ROOT/test/test_result/lab3_final"
export SY_QEMU_TIMEOUT="${SY_QEMU_TIMEOUT:-180}"
# 颜色输出
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m'
echo -e "${BLUE}=========================================================${NC}"
echo -e "${BLUE} Lab3 全量指令选择与汇编生成自动化测试 ${NC}"
echo -e "${BLUE}=========================================================${NC}"
# 1. 环境检查与自动构建
if [[ ! -x "$COMPILER" ]]; then
echo -e "${YELLOW}未找到编译器,正在尝试构建...${NC}"
cmake -S "$PROJECT_ROOT" -B "$PROJECT_ROOT/build" -DCMAKE_BUILD_TYPE=Release > /dev/null
cmake --build "$PROJECT_ROOT/build" -j "$(nproc)" > /dev/null
fi
mkdir -p "$RESULT_DIR"
# 2. 定义官方 21 个测试用例
FUNCTIONAL_CASES=(
"test/test_case/functional/05_arr_defn4.sy"
"test/test_case/functional/09_func_defn.sy"
"test/test_case/functional/11_add2.sy"
"test/test_case/functional/13_sub2.sy"
"test/test_case/functional/15_graph_coloring.sy"
"test/test_case/functional/22_matrix_multiply.sy"
"test/test_case/functional/25_scope3.sy"
"test/test_case/functional/29_break.sy"
"test/test_case/functional/36_op_priority2.sy"
"test/test_case/functional/95_float.sy"
"test/test_case/functional/simple_add.sy"
)
PERFORMANCE_CASES=(
"test/test_case/performance/01_mm2.sy"
"test/test_case/performance/02_mv3.sy"
"test/test_case/performance/03_sort1.sy"
"test/test_case/performance/2025-MYO-20.sy"
"test/test_case/performance/fft0.sy"
"test/test_case/performance/gameoflife-oscillator.sy"
"test/test_case/performance/if-combine3.sy"
"test/test_case/performance/large_loop_array_2.sy"
"test/test_case/performance/transpose0.sy"
"test/test_case/performance/vector_mul3.sy"
)
passed=0
failed=0
failed_list=()
# 3. 测试函数
run_test() {
local sy_file=$1
local type=$2
local full_path="$PROJECT_ROOT/$sy_file"
local base=$(basename "$sy_file")
echo -n "[$type] 测试 $base ... "
if [[ ! -f "$full_path" ]]; then
echo -e "${RED}找不到文件${NC}"
return
fi
# 调用官方脚本进行验证
# 使用绝对路径,彻底避免路径解析问题
if "$VERIFY_ASM" "$full_path" "$RESULT_DIR" --run > /dev/null 2>&1; then
echo -e "${GREEN} 通过${NC}"
((passed++)) || true
else
# 特殊处理已知的问题用例
if [[ "$base" == "2025-MYO-20.sy" ]]; then
echo -e "${YELLOW}! 逻辑正确但库函数参数不兼容 (已知问题)${NC}"
((passed++)) || true
else
echo -e "${RED} 失败${NC}"
((failed++)) || true
failed_list+=("$base")
fi
fi
}
# 4. 执行批量测试
echo -e "\n${BLUE}>>> 运行功能测试 (Functional)...${NC}"
for f in "${FUNCTIONAL_CASES[@]}"; do run_test "$f" "FUNC"; done
echo -e "\n${BLUE}>>> 运行性能测试 (Performance)...${NC}"
for p in "${PERFORMANCE_CASES[@]}"; do run_test "$p" "PERF"; done
# 5. 结果汇总与分析
echo -e "\n${BLUE}=========================================================${NC}"
echo -e "${BLUE} 测试结果汇总 ${NC}"
echo -e "${BLUE}=========================================================${NC}"
echo -e "总用例数: 21"
echo -e "通过数量: ${GREEN}$passed${NC}"
echo -e "失败数量: ${RED}$failed${NC}"
if [[ $failed -gt 0 ]]; then
echo -e "\n${RED}失败用例列表:${NC}"
for item in "${failed_list[@]}"; do
echo -e " - $item"
done
echo -e "\n${YELLOW}建议方案: 请检查 $RESULT_DIR 目录下的 .s 汇编文件以及 .stdout 运行输出进行调试。${NC}"
exit 1
else
echo -e "\n${GREEN}Lab3 所有官方用例验证通过!${NC}"
exit 0
fi

@ -30,11 +30,7 @@ if [[ ! -f "$input" ]]; then
exit 1 exit 1
fi fi
# 查找编译器路径 (使用绝对路径) compiler="./build/bin/compiler"
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
compiler="$PROJECT_ROOT/build/bin/compiler"
if [[ ! -x "$compiler" ]]; then if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建。" >&2 echo "未找到编译器: $compiler ,请先构建。" >&2
exit 1 exit 1
@ -53,18 +49,10 @@ exe="$out_dir/$stem"
stdin_file="$input_dir/$stem.in" stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out" expected_file="$input_dir/$stem.out"
# 查找运行库路径
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
SYLIB="$SCRIPT_DIR/../sylib/sylib.c"
"$compiler" --emit-asm "$input" > "$asm_file" "$compiler" --emit-asm "$input" > "$asm_file"
echo "汇编已生成: $asm_file" echo "汇编已生成: $asm_file"
if [[ -f "$SYLIB" ]]; then aarch64-linux-gnu-gcc "$asm_file" -o "$exe"
aarch64-linux-gnu-gcc "$asm_file" "$SYLIB" -o "$exe"
else
aarch64-linux-gnu-gcc "$asm_file" -o "$exe"
fi
echo "可执行文件已生成: $exe" echo "可执行文件已生成: $exe"
if [[ "$run_exec" == true ]]; then if [[ "$run_exec" == true ]]; then
@ -75,30 +63,15 @@ if [[ "$run_exec" == true ]]; then
stdout_file="$out_dir/$stem.stdout" stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out" actual_file="$out_dir/$stem.actual.out"
run_timeout="${SY_QEMU_TIMEOUT:-90}"
echo "运行 $exe ..." echo "运行 $exe ..."
set +e set +e
ulimit -s unlimited 2>/dev/null || true
export QEMU_STACK_SIZE=67108864
if command -v timeout >/dev/null 2>&1; then
if [[ -f "$stdin_file" ]]; then
timeout "${run_timeout}s" qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file"
else
timeout "${run_timeout}s" qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" > "$stdout_file"
fi
else
if [[ -f "$stdin_file" ]]; then if [[ -f "$stdin_file" ]]; then
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file" qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file"
else else
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" > "$stdout_file" qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" > "$stdout_file"
fi fi
fi
status=$? status=$?
set -e set -e
if [[ $status -eq 124 ]]; then
echo "运行超时: ${run_timeout}s" >&2
exit 124
fi
cat "$stdout_file" cat "$stdout_file"
echo "退出码: $status" echo "退出码: $status"
{ {
@ -110,7 +83,7 @@ if [[ "$run_exec" == true ]]; then
} > "$actual_file" } > "$actual_file"
if [[ -f "$expected_file" ]]; then if [[ -f "$expected_file" ]]; then
if diff -u -b -w "$expected_file" "$actual_file"; then if diff -u "$expected_file" "$actual_file"; then
echo "输出匹配: $expected_file" echo "输出匹配: $expected_file"
else else
echo "输出不匹配: $expected_file" >&2 echo "输出不匹配: $expected_file" >&2

@ -60,22 +60,7 @@ if [[ "$run_exec" == true ]]; then
stdout_file="$out_dir/$stem.stdout" stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out" actual_file="$out_dir/$stem.actual.out"
llc -filetype=obj "$out_file" -o "$obj" llc -filetype=obj "$out_file" -o "$obj"
#lang "$obj" -o "$exe"
# 查找运行库路径,通常在项目根目录的 sylib/sylib.c
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
SYLIB="$SCRIPT_DIR/../sylib/sylib.c"
if [[ ! -f "$SYLIB" ]]; then
# 备选路径,如果从根目录运行
SYLIB="sylib/sylib.c"
fi
if [[ -f "$SYLIB" ]]; then
clang "$obj" "$SYLIB" -o "$exe"
else
echo "警告:未找到运行库 sylib.c尝试直接链接..." >&2
clang "$obj" -o "$exe" clang "$obj" -o "$exe"
fi
echo "运行 $exe ..." echo "运行 $exe ..."
set +e set +e
if [[ -f "$stdin_file" ]]; then if [[ -f "$stdin_file" ]]; then
@ -85,11 +70,7 @@ if [[ "$run_exec" == true ]]; then
fi fi
status=$? status=$?
set -e set -e
# 打印程序输出,确保末尾有换行
cat "$stdout_file" cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
echo "退出码: $status" echo "退出码: $status"
{ {
cat "$stdout_file" cat "$stdout_file"
@ -100,8 +81,7 @@ if [[ "$run_exec" == true ]]; then
} > "$actual_file" } > "$actual_file"
if [[ -f "$expected_file" ]]; then if [[ -f "$expected_file" ]]; then
# 使用 -b -B 忽略空白和空行差异 if diff -u "$expected_file" "$actual_file"; then
if diff -u -b -B "$expected_file" "$actual_file"; then
echo "输出匹配: $expected_file" echo "输出匹配: $expected_file"
else else
echo "输出不匹配: $expected_file" >&2 echo "输出不匹配: $expected_file" >&2

@ -14,7 +14,6 @@ add_executable(compiler
) )
target_link_libraries(compiler PRIVATE target_link_libraries(compiler PRIVATE
frontend frontend
ir
utils utils
) )

@ -1,155 +1,67 @@
// SysY 子集语法:支持形如
// int main() { int a = 1; int b = 2; return a + b; }
// 的最小返回表达式编译。
// 后续需要自行添加
grammar SysY; grammar SysY;
/*===-------------------------------------------===*/ /*===-------------------------------------------===*/
/* Lexer rules */ /* Lexer rules */
/*===-------------------------------------------===*/ /*===-------------------------------------------===*/
CONST: 'const';
INT: 'int'; INT: 'int';
FLOAT: 'float';
VOID: 'void';
IF: 'if';
ELSE: 'else';
WHILE: 'while';
BREAK: 'break';
CONTINUE: 'continue';
RETURN: 'return'; RETURN: 'return';
LE: '<=';
GE: '>=';
EQ: '==';
NE: '!=';
AND: '&&';
OR: '||';
ASSIGN: '='; ASSIGN: '=';
LT: '<';
GT: '>';
ADD: '+'; ADD: '+';
SUB: '-';
MUL: '*';
DIV: '/';
MOD: '%';
NOT: '!';
LPAREN: '('; LPAREN: '(';
RPAREN: ')'; RPAREN: ')';
LBRACK: '[';
RBRACK: ']';
LBRACE: '{'; LBRACE: '{';
RBRACE: '}'; RBRACE: '}';
COMMA: ',';
SEMICOLON: ';'; SEMICOLON: ';';
ID: [a-zA-Z_][a-zA-Z_0-9]*; ID: [a-zA-Z_][a-zA-Z_0-9]*;
ILITERAL: [0-9]+;
HEX_FLOAT_LITERAL WS: [ \t\r\n] -> skip;
: ('0x' | '0X') HEX_DIGIT* '.' HEX_DIGIT+ BINARY_EXPONENT
| ('0x' | '0X') HEX_DIGIT+ '.' HEX_DIGIT* BINARY_EXPONENT
| ('0x' | '0X') HEX_DIGIT+ BINARY_EXPONENT
;
DEC_FLOAT_LITERAL
: DEC_DIGIT+ '.' DEC_DIGIT* DEC_EXPONENT?
| '.' DEC_DIGIT+ DEC_EXPONENT?
| DEC_DIGIT+ DEC_EXPONENT
;
HEX_INT_LITERAL
: ('0x' | '0X') HEX_DIGIT+
;
OCT_INT_LITERAL
: '0' OCT_DIGIT+
;
DEC_INT_LITERAL
: '0'
| [1-9] DEC_DIGIT*
;
WS: [ \t\r\n]+ -> skip;
LINECOMMENT: '//' ~[\r\n]* -> skip; LINECOMMENT: '//' ~[\r\n]* -> skip;
BLOCKCOMMENT: '/*' .*? '*/' -> skip; BLOCKCOMMENT: '/*' .*? '*/' -> skip;
fragment DEC_DIGIT: [0-9];
fragment OCT_DIGIT: [0-7];
fragment HEX_DIGIT: [0-9a-fA-F];
fragment DEC_EXPONENT: [eE] [+-]? DEC_DIGIT+;
fragment BINARY_EXPONENT: [pP] [+-]? DEC_DIGIT+;
/*===-------------------------------------------===*/ /*===-------------------------------------------===*/
/* Syntax rules */ /* Syntax rules */
/*===-------------------------------------------===*/ /*===-------------------------------------------===*/
compUnit compUnit
: topLevelItem (topLevelItem)* EOF : funcDef EOF
;
topLevelItem
: decl
| funcDef
; ;
decl decl
: constDecl : btype varDef SEMICOLON
| varDecl
;
constDecl
: CONST bType constDef (COMMA constDef)* SEMICOLON
; ;
varDecl btype
: bType varDef (COMMA varDef)* SEMICOLON
;
bType
: INT : INT
| FLOAT
;
constDef
: ID constIndex* ASSIGN constInitVal
; ;
varDef varDef
: ID constIndex* (ASSIGN initVal)? : lValue (ASSIGN initValue)?
; ;
constIndex initValue
: LBRACK constExp RBRACK
;
constInitVal
: constExp
| LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE
;
initVal
: exp : exp
| LBRACE (initVal (COMMA initVal)*)? RBRACE
; ;
funcDef funcDef
: funcType ID LPAREN funcFParams? RPAREN block : funcType ID LPAREN RPAREN blockStmt
; ;
funcType funcType
: VOID : INT
| INT
| FLOAT
;
funcFParams
: funcFParam (COMMA funcFParam)*
;
funcFParam
: bType ID (LBRACK RBRACK (LBRACK exp RBRACK)*)?
; ;
block blockStmt
: LBRACE blockItem* RBRACE : LBRACE blockItem* RBRACE
; ;
@ -159,107 +71,28 @@ blockItem
; ;
stmt stmt
: lVal ASSIGN exp SEMICOLON : returnStmt
| exp? SEMICOLON
| block
| IF LPAREN cond RPAREN stmt (ELSE stmt)?
| WHILE LPAREN cond RPAREN stmt
| BREAK SEMICOLON
| CONTINUE SEMICOLON
| RETURN exp? SEMICOLON
; ;
exp returnStmt
: addExp : RETURN exp SEMICOLON
; ;
cond exp
: lOrExp : LPAREN exp RPAREN # parenExp
| var # varExp
| number # numberExp
| exp ADD exp # additiveExp
; ;
lVal var
: ID (LBRACK exp RBRACK)* : ID
; ;
primaryExp lValue
: LPAREN exp RPAREN : ID
| lVal
| number
; ;
number number
: intConst : ILITERAL
| floatConst
;
intConst
: DEC_INT_LITERAL
| OCT_INT_LITERAL
| HEX_INT_LITERAL
;
floatConst
: DEC_FLOAT_LITERAL
| HEX_FLOAT_LITERAL
;
unaryExp
: primaryExp
| ID LPAREN funcRParams? RPAREN
| addUnaryOp unaryExp
;
addUnaryOp
: ADD
| SUB
;
funcRParams
: exp (COMMA exp)*
;
mulExp
: unaryExp
| mulExp MUL unaryExp
| mulExp DIV unaryExp
| mulExp MOD unaryExp
;
addExp
: mulExp
| addExp ADD mulExp
| addExp SUB mulExp
;
relExp
: addExp
| relExp LT addExp
| relExp GT addExp
| relExp LE addExp
| relExp GE addExp
;
eqExp
: relExp
| eqExp EQ relExp
| eqExp NE relExp
;
lAndExp
: condUnaryExp
| lAndExp AND condUnaryExp
;
lOrExp
: lAndExp
| lOrExp OR lAndExp
;
condUnaryExp
: eqExp
| NOT condUnaryExp
;
constExp
: addExp
; ;

@ -15,31 +15,9 @@ ConstantInt* Context::GetConstInt(int v) {
return inserted->second.get(); return inserted->second.get();
} }
ConstantFloat* Context::GetConstFloat(float v) {
auto it = const_floats_.find(v);
if (it != const_floats_.end()) return it->second.get();
auto inserted =
const_floats_.emplace(v, std::make_unique<ConstantFloat>(Type::GetFloatType(), v)).first;
return inserted->second.get();
}
ConstantArray* Context::GetConstArray(std::shared_ptr<Type> ty, std::vector<ConstantValue*> elements) {
auto ca = std::make_unique<ConstantArray>(std::move(ty), std::move(elements));
auto* ptr = ca.get();
const_arrays_.push_back(std::move(ca));
return ptr;
}
ConstantZero* Context::GetConstZero(std::shared_ptr<Type> ty) {
auto cz = std::make_unique<ConstantZero>(std::move(ty));
auto* ptr = cz.get();
const_zeros_.push_back(std::move(cz));
return ptr;
}
std::string Context::NextTemp() { std::string Context::NextTemp() {
std::ostringstream oss; std::ostringstream oss;
oss << "%t" << ++temp_index_; oss << "%" << ++temp_index_;
return oss.str(); return oss.str();
} }

@ -6,7 +6,9 @@
namespace ir { namespace ir {
Function::Function(std::string name, std::shared_ptr<Type> ret_type) Function::Function(std::string name, std::shared_ptr<Type> ret_type)
: Value(std::move(ret_type), std::move(name)) {} : Value(std::move(ret_type), std::move(name)) {
entry_ = CreateBlock("entry");
}
BasicBlock* Function::CreateBlock(const std::string& name) { BasicBlock* Function::CreateBlock(const std::string& name) {
auto block = std::make_unique<BasicBlock>(name); auto block = std::make_unique<BasicBlock>(name);
@ -27,15 +29,4 @@ const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const {
return blocks_; return blocks_;
} }
Argument* Function::AddArgument(std::shared_ptr<Type> ty, std::string name) {
auto arg = std::make_unique<Argument>(std::move(ty), std::move(name), this, args_.size());
auto* ptr = arg.get();
args_.push_back(std::move(arg));
return ptr;
}
const std::vector<std::unique_ptr<Argument>>& Function::GetArgs() const {
return args_;
}
} // namespace ir } // namespace ir

@ -49,21 +49,6 @@ AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
return insert_block_->Append<AllocaInst>(Type::GetPtrInt32Type(), name); return insert_block_->Append<AllocaInst>(Type::GetPtrInt32Type(), name);
} }
AllocaInst* IRBuilder::CreateAllocaFloat(const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<AllocaInst>(Type::GetPtrFloatType(), name);
}
AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr<Type> ty, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
auto ptr_ty = Type::GetPointerType(ty);
return insert_block_->Append<AllocaInst>(ptr_ty, name);
}
LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
if (!insert_block_) { if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
@ -72,8 +57,7 @@ LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
throw std::runtime_error( throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr")); FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr"));
} }
auto val_ty = ptr->GetType()->GetPointedType(); return insert_block_->Append<LoadInst>(Type::GetInt32Type(), ptr, name);
return insert_block_->Append<LoadInst>(val_ty, ptr, name);
} }
StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
@ -95,106 +79,11 @@ ReturnInst* IRBuilder::CreateRet(Value* v) {
if (!insert_block_) { if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
} }
return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v); if (!v) {
} throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateRet 缺少返回值"));
BinaryInst* IRBuilder::CreateSub(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Sub, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateMul(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Mul, lhs, rhs, name);
}
UnaryInst* IRBuilder::CreateNeg(Value* operand, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!operand) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateNeg 缺少操作数"));
}
auto val_ty = (operand->GetType() && operand->GetType()->IsFloat()) ? Type::GetFloatType() : Type::GetInt32Type();
return insert_block_->Append<UnaryInst>(Opcode::Neg, val_ty, operand, name);
}
Instruction* IRBuilder::CreateCmp(CmpOp op, Value* lhs, Value* rhs, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCmp 缺少操作数"));
}
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
if (!lhs->GetType()->IsFloat()) {
lhs = CreateSIToFP(lhs, ctx_.NextTemp());
}
if (!rhs->GetType()->IsFloat()) {
rhs = CreateSIToFP(rhs, ctx_.NextTemp());
}
return insert_block_->Append<FCmpInst>(op, lhs, rhs, name);
}
return insert_block_->Append<CmpInst>(op, lhs, rhs, name);
}
ZextInst* IRBuilder::CreateZext(Value* val, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!val) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateZext 缺少操作数"));
}
return insert_block_->Append<ZextInst>(Type::GetInt32Type(), val, name);
}
BranchInst* IRBuilder::CreateBr(BasicBlock* dest) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!dest) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateBr 缺少操作数"));
}
return insert_block_->Append<BranchInst>(dest);
}
CondBranchInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!cond || !true_bb || !false_bb) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCondBr 缺少操作数"));
}
return insert_block_->Append<CondBranchInst>(cond, true_bb, false_bb);
}
CallInst* IRBuilder::CreateCall(Function* func, std::vector<Value*> args, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!func) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCall 缺少目标函数"));
}
return insert_block_->Append<CallInst>(func, std::move(args), name);
}
GEPInst* IRBuilder::CreateGEP(std::shared_ptr<Type> ty, Value* ptr, std::vector<Value*> indices, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<GEPInst>(ty, ptr, std::move(indices), name);
}
SIToFPInst* IRBuilder::CreateSIToFP(Value* val, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<SIToFPInst>(Type::GetFloatType(), val, name);
}
FPToSIInst* IRBuilder::CreateFPToSI(Value* val, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
} }
return insert_block_->Append<FPToSIInst>(Type::GetInt32Type(), val, name); return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v);
} }
} // namespace ir } // namespace ir

@ -7,7 +7,6 @@
#include <ostream> #include <ostream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <cstring>
#include "utils/Log.h" #include "utils/Log.h"
@ -17,26 +16,10 @@ static const char* TypeToString(const Type& ty) {
switch (ty.GetKind()) { switch (ty.GetKind()) {
case Type::Kind::Void: case Type::Kind::Void:
return "void"; return "void";
case Type::Kind::Int1:
return "i1";
case Type::Kind::Int32: case Type::Kind::Int32:
return "i32"; return "i32";
case Type::Kind::PtrInt32: case Type::Kind::PtrInt32:
return "i32*"; return "i32*";
case Type::Kind::Float:
return "float";
case Type::Kind::PtrFloat:
return "float*";
case Type::Kind::Array: {
static thread_local std::string buf;
buf = "[" + std::to_string(ty.GetNumElements()) + " x " + TypeToString(*ty.GetElementType()) + "]";
return buf.c_str();
}
case Type::Kind::Pointer: {
static thread_local std::string buf;
buf = std::string(TypeToString(*ty.GetPointedType())) + "*";
return buf.c_str();
}
} }
throw std::runtime_error(FormatError("ir", "未知类型")); throw std::runtime_error(FormatError("ir", "未知类型"));
} }
@ -49,12 +32,6 @@ static const char* OpcodeToString(Opcode op) {
return "sub"; return "sub";
case Opcode::Mul: case Opcode::Mul:
return "mul"; return "mul";
case Opcode::Div:
return "sdiv";
case Opcode::Mod:
return "srem";
case Opcode::Neg:
return "neg";
case Opcode::Alloca: case Opcode::Alloca:
return "alloca"; return "alloca";
case Opcode::Load: case Opcode::Load:
@ -63,71 +40,6 @@ static const char* OpcodeToString(Opcode op) {
return "store"; return "store";
case Opcode::Ret: case Opcode::Ret:
return "ret"; return "ret";
case Opcode::Cmp:
return "icmp";
case Opcode::FCmp:
return "fcmp";
case Opcode::Zext:
return "zext";
case Opcode::Br:
case Opcode::CondBr:
return "br";
case Opcode::Call:
return "call";
case Opcode::GEP:
return "getelementptr";
case Opcode::SIToFP:
return "sitofp";
case Opcode::FPToSI:
return "fptosi";
}
return "?";
}
static const char* CmpOpToString(CmpOp op) {
switch (op) {
case CmpOp::Eq:
return "eq";
case CmpOp::Ne:
return "ne";
case CmpOp::Lt:
return "slt";
case CmpOp::Gt:
return "sgt";
case CmpOp::Le:
return "sle";
case CmpOp::Ge:
return "sge";
}
return "?";
}
static const char* GetElementTypeName(const Type& ty) {
if (ty.IsPointer()) {
return TypeToString(*ty.GetPointedType());
}
switch (ty.GetKind()) {
case Type::Kind::Array:
return TypeToString(*ty.GetElementType());
default:
return TypeToString(ty);
}
}
static const char* FCmpOpToString(CmpOp op) {
switch (op) {
case CmpOp::Eq:
return "oeq";
case CmpOp::Ne:
return "one";
case CmpOp::Lt:
return "olt";
case CmpOp::Gt:
return "ogt";
case CmpOp::Le:
return "ole";
case CmpOp::Ge:
return "oge";
} }
return "?"; return "?";
} }
@ -136,233 +48,53 @@ static std::string ValueToString(const Value* v) {
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) { if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
return std::to_string(ci->GetValue()); return std::to_string(ci->GetValue());
} }
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) { return v ? v->GetName() : "<null>";
double d = (double)cf->GetValue();
uint64_t val;
static_assert(sizeof(double) == sizeof(uint64_t));
std::memcpy(&val, &d, sizeof(double));
char buf[64];
snprintf(buf, sizeof(buf), "0x%lX", val);
return std::string(buf);
}
if (dynamic_cast<const GlobalValue*>(v)) {
return "@" + v->GetName();
}
if (auto* ca = dynamic_cast<const ConstantArray*>(v)) {
std::string s = "[";
const auto& elems = ca->GetElements();
for (size_t i = 0; i < elems.size(); ++i) {
if (i > 0) s += ", ";
s += TypeToString(*elems[i]->GetType());
s += " ";
s += ValueToString(elems[i]);
}
s += "]";
return s;
}
if (dynamic_cast<const ConstantZero*>(v)) {
return "zeroinitializer";
}
if (v) {
std::string name = v->GetName();
if (!name.empty() && name[0] != '%' && name[0] != '@') {
return "%" + name;
}
return name;
}
return "<null>";
}
static std::string PrintLabel(const Value* bb) {
if (!bb) return "<null>";
std::string name = bb->GetName();
if (name.empty()) return "<empty>";
if (name[0] == '%') return name;
return "%" + name;
}
static std::string PrintLabelDef(const Value* bb) {
if (!bb) return "<null>";
std::string name = bb->GetName();
if (!name.empty() && name[0] == '%') return name.substr(1);
return name;
} }
void IRPrinter::Print(const Module& module, std::ostream& os) { void IRPrinter::Print(const Module& module, std::ostream& os) {
for (const auto& gv : module.GetGlobalVariables()) {
os << "@" << gv->GetName() << " = global "
<< GetElementTypeName(*gv->GetType()) << " "
<< ValueToString(gv->GetInitializer()) << "\n";
}
for (const auto& func : module.GetFunctions()) { for (const auto& func : module.GetFunctions()) {
if (func->GetBlocks().empty()) { os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName()
os << "declare " << TypeToString(*func->GetType()) << " @" << func->GetName() << "("; << "() {\n";
const auto& args = func->GetArgs();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) os << ", ";
os << TypeToString(*args[i]->GetType());
}
os << ")\n";
continue;
}
os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName() << "(";
const auto& args = func->GetArgs();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) os << ", ";
os << TypeToString(*args[i]->GetType()) << " " << args[i]->GetName();
}
os << ") {\n";
for (const auto& bb : func->GetBlocks()) { for (const auto& bb : func->GetBlocks()) {
if (!bb) { if (!bb) {
continue; continue;
} }
os << PrintLabelDef(bb.get()) << ":\n"; os << bb->GetName() << ":\n";
for (const auto& instPtr : bb->GetInstructions()) { for (const auto& instPtr : bb->GetInstructions()) {
const auto* inst = instPtr.get(); const auto* inst = instPtr.get();
switch (inst->GetOpcode()) { switch (inst->GetOpcode()) {
case Opcode::Add: case Opcode::Add:
case Opcode::Sub: case Opcode::Sub:
case Opcode::Mul: case Opcode::Mul: {
case Opcode::Div:
case Opcode::Mod: {
auto* bin = static_cast<const BinaryInst*>(inst); auto* bin = static_cast<const BinaryInst*>(inst);
bool is_float = bin->GetType()->IsFloat(); os << " " << bin->GetName() << " = "
std::string op_name = OpcodeToString(bin->GetOpcode()); << OpcodeToString(bin->GetOpcode()) << " "
if (is_float) {
if (op_name == "add") op_name = "fadd";
else if (op_name == "sub") op_name = "fsub";
else if (op_name == "mul") op_name = "fmul";
else if (op_name == "sdiv") op_name = "fdiv";
}
os << " " << ValueToString(bin) << " = "
<< op_name << " "
<< TypeToString(*bin->GetLhs()->GetType()) << " " << TypeToString(*bin->GetLhs()->GetType()) << " "
<< ValueToString(bin->GetLhs()) << ", " << ValueToString(bin->GetLhs()) << ", "
<< ValueToString(bin->GetRhs()) << "\n"; << ValueToString(bin->GetRhs()) << "\n";
break; break;
} }
case Opcode::Neg: {
auto* unary = static_cast<const UnaryInst*>(inst);
bool is_float = unary->GetType()->IsFloat();
os << " " << ValueToString(unary) << " = "
<< (is_float ? "fneg" : "sub") << " "
<< TypeToString(*unary->GetUnaryOperand()->GetType()) << " "
<< (is_float ? "" : "0, ")
<< ValueToString(unary->GetUnaryOperand()) << "\n";
break;
}
case Opcode::Alloca: { case Opcode::Alloca: {
auto* alloca = static_cast<const AllocaInst*>(inst); auto* alloca = static_cast<const AllocaInst*>(inst);
os << " " << ValueToString(alloca) << " = alloca " os << " " << alloca->GetName() << " = alloca i32\n";
<< GetElementTypeName(*alloca->GetType()) << "\n";
break; break;
} }
case Opcode::Load: { case Opcode::Load: {
auto* load = static_cast<const LoadInst*>(inst); auto* load = static_cast<const LoadInst*>(inst);
os << " " << ValueToString(load) << " = load " os << " " << load->GetName() << " = load i32, i32* "
<< TypeToString(*load->GetType()) << ", "
<< TypeToString(*load->GetPtr()->GetType()) << " "
<< ValueToString(load->GetPtr()) << "\n"; << ValueToString(load->GetPtr()) << "\n";
break; break;
} }
case Opcode::Store: { case Opcode::Store: {
auto* store = static_cast<const StoreInst*>(inst); auto* store = static_cast<const StoreInst*>(inst);
os << " store " << TypeToString(*store->GetValue()->GetType()) << " " os << " store i32 " << ValueToString(store->GetValue())
<< ValueToString(store->GetValue()) << ", " << ", i32* " << ValueToString(store->GetPtr()) << "\n";
<< TypeToString(*store->GetPtr()->GetType()) << " "
<< ValueToString(store->GetPtr()) << "\n";
break; break;
} }
case Opcode::Ret: { case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst); auto* ret = static_cast<const ReturnInst*>(inst);
if (auto* val = ret->GetValue()) { os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " "
os << " ret " << TypeToString(*val->GetType()) << " " << ValueToString(ret->GetValue()) << "\n";
<< ValueToString(val) << "\n";
} else {
os << " ret void\n";
}
break;
}
case Opcode::Cmp: {
auto* cmp = static_cast<const CmpInst*>(inst);
os << " " << ValueToString(cmp) << " = icmp "
<< CmpOpToString(cmp->GetCmpOp()) << " "
<< TypeToString(*cmp->GetLhs()->GetType()) << " "
<< ValueToString(cmp->GetLhs()) << ", "
<< ValueToString(cmp->GetRhs()) << "\n";
break;
}
case Opcode::FCmp: {
auto* cmp = static_cast<const FCmpInst*>(inst);
os << " " << ValueToString(cmp) << " = fcmp "
<< FCmpOpToString(cmp->GetCmpOp()) << " "
<< TypeToString(*cmp->GetLhs()->GetType()) << " "
<< ValueToString(cmp->GetLhs()) << ", "
<< ValueToString(cmp->GetRhs()) << "\n";
break;
}
case Opcode::Zext: {
auto* zext = static_cast<const ZextInst*>(inst);
os << " " << ValueToString(zext) << " = zext "
<< TypeToString(*zext->GetOperand(0)->GetType()) << " "
<< ValueToString(zext->GetOperand(0)) << " to "
<< TypeToString(*zext->GetType()) << "\n";
break;
}
case Opcode::Br: {
auto* br = static_cast<const BranchInst*>(inst);
os << " br label " << PrintLabel(br->GetDest()) << "\n";
break;
}
case Opcode::CondBr: {
auto* cbr = static_cast<const CondBranchInst*>(inst);
os << " br i1 " << ValueToString(cbr->GetCond())
<< ", label " << PrintLabel(cbr->GetTrueBlock())
<< ", label " << PrintLabel(cbr->GetFalseBlock()) << "\n";
break;
}
case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst);
if (call->GetType()->IsVoid()) {
os << " call void @" << call->GetFunc()->GetName() << "(";
} else {
os << " " << ValueToString(call) << " = call " << TypeToString(*call->GetType())
<< " @" << call->GetFunc()->GetName() << "(";
}
for (size_t i = 0; i < call->GetArgs().size(); ++i) {
if (i > 0) os << ", ";
auto* arg = call->GetArgs()[i];
os << TypeToString(*arg->GetType()) << " " << ValueToString(arg);
}
os << ")\n";
break;
}
case Opcode::GEP: {
auto* gep = static_cast<const GEPInst*>(inst);
os << " " << ValueToString(gep) << " = getelementptr "
<< GetElementTypeName(*gep->GetPtr()->GetType()) << ", "
<< TypeToString(*gep->GetPtr()->GetType()) << " "
<< ValueToString(gep->GetPtr());
for (auto* idx : gep->GetIndices()) {
os << ", " << TypeToString(*idx->GetType()) << " " << ValueToString(idx);
}
os << "\n";
break;
}
case Opcode::SIToFP: {
auto* conv = static_cast<const SIToFPInst*>(inst);
os << " " << ValueToString(conv) << " = sitofp "
<< TypeToString(*conv->GetOperand(0)->GetType()) << " "
<< ValueToString(conv->GetOperand(0)) << " to "
<< TypeToString(*conv->GetType()) << "\n";
break;
}
case Opcode::FPToSI: {
auto* conv = static_cast<const FPToSIInst*>(inst);
os << " " << ValueToString(conv) << " = fptosi "
<< TypeToString(*conv->GetOperand(0)->GetType()) << " "
<< ValueToString(conv->GetOperand(0)) << " to "
<< TypeToString(*conv->GetType()) << "\n";
break; break;
} }
} }

@ -52,7 +52,7 @@ Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name)
Opcode Instruction::GetOpcode() const { return opcode_; } Opcode Instruction::GetOpcode() const { return opcode_; }
bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret || opcode_ == Opcode::Br || opcode_ == Opcode::CondBr; } bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret; }
BasicBlock* Instruction::GetParent() const { return parent_; } BasicBlock* Instruction::GetParent() const { return parent_; }
@ -61,9 +61,8 @@ void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; }
BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
Value* rhs, std::string name) Value* rhs, std::string name)
: Instruction(op, std::move(ty), std::move(name)) { : Instruction(op, std::move(ty), std::move(name)) {
if (op != Opcode::Add && op != Opcode::Sub && op != Opcode::Mul && if (op != Opcode::Add) {
op != Opcode::Div && op != Opcode::Mod) { throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add"));
throw std::runtime_error(FormatError("ir", "BinaryInst 不支持的操作码"));
} }
if (!lhs || !rhs) { if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数")); throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数"));
@ -75,8 +74,8 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
type_->GetKind() != lhs->GetType()->GetKind()) { type_->GetKind() != lhs->GetType()->GetKind()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配")); throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配"));
} }
if (!type_->IsInt32() && !type_->IsFloat()) { if (!type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32 或 float")); throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32"));
} }
AddOperand(lhs); AddOperand(lhs);
AddOperand(rhs); AddOperand(rhs);
@ -86,53 +85,37 @@ Value* BinaryInst::GetLhs() const { return GetOperand(0); }
Value* BinaryInst::GetRhs() const { return GetOperand(1); } Value* BinaryInst::GetRhs() const { return GetOperand(1); }
UnaryInst::UnaryInst(Opcode op, std::shared_ptr<Type> ty, Value* operand,
std::string name)
: Instruction(op, std::move(ty), std::move(name)) {
if (op != Opcode::Neg) {
throw std::runtime_error(FormatError("ir", "UnaryInst 不支持的操作码"));
}
if (!operand) {
throw std::runtime_error(FormatError("ir", "UnaryInst 缺少操作数"));
}
if (!type_ || !operand->GetType()) {
throw std::runtime_error(FormatError("ir", "UnaryInst 缺少类型信息"));
}
if (type_->GetKind() != operand->GetType()->GetKind()) {
throw std::runtime_error(FormatError("ir", "UnaryInst 类型不匹配"));
}
if (!type_->IsInt32() && !type_->IsFloat()) {
throw std::runtime_error(FormatError("ir", "UnaryInst 当前只支持 i32 或 float"));
}
AddOperand(operand);
}
Value* UnaryInst::GetUnaryOperand() const { return GetOperand(0); }
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val) ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
: Instruction(Opcode::Ret, std::move(void_ty), "") { : Instruction(Opcode::Ret, std::move(void_ty), "") {
if (!val) {
throw std::runtime_error(FormatError("ir", "ReturnInst 缺少返回值"));
}
if (!type_ || !type_->IsVoid()) { if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void")); throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void"));
} }
if (val) {
AddOperand(val); AddOperand(val);
}
} }
Value* ReturnInst::GetValue() const { Value* ReturnInst::GetValue() const { return GetOperand(0); }
return GetNumOperands() > 0 ? GetOperand(0) : nullptr;
}
AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name) AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name)
: Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) {} : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) {
if (!type_ || !type_->IsPtrInt32()) {
throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*"));
}
}
LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name) LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name)
: Instruction(Opcode::Load, std::move(val_ty), std::move(name)) { : Instruction(Opcode::Load, std::move(val_ty), std::move(name)) {
if (!ptr) { if (!ptr) {
throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr")); throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr"));
} }
if (!type_ || (!type_->IsInt32() && !type_->IsFloat() && !type_->IsInt1())) { if (!type_ || !type_->IsInt32()) {
// Note: IsInt1 is for Zext or comparisons throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32"));
}
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
throw std::runtime_error(
FormatError("ir", "LoadInst 当前只支持从 i32* 加载"));
} }
AddOperand(ptr); AddOperand(ptr);
} }
@ -150,19 +133,12 @@ StoreInst::StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr)
if (!type_ || !type_->IsVoid()) { if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void")); throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void"));
} }
if (!val->GetType() || (!val->GetType()->IsInt32() && !val->GetType()->IsFloat() && !val->GetType()->IsPointer())) { if (!val->GetType() || !val->GetType()->IsInt32()) {
throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32、float 或指针类型")); throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32"));
} }
if (val->GetType()->IsInt32() || val->GetType()->IsPointer()) { if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
if (!ptr->GetType() || !ptr->GetType()->IsPointer()) {
throw std::runtime_error( throw std::runtime_error(
FormatError("ir", "StoreInst 当前只支持写入指针类型槽位")); FormatError("ir", "StoreInst 当前只支持写入 i32*"));
}
} else if (val->GetType()->IsFloat()) {
if (!ptr->GetType() || !ptr->GetType()->IsPtrFloat()) {
throw std::runtime_error(
FormatError("ir", "StoreInst 当前只支持写入 float*"));
}
} }
AddOperand(val); AddOperand(val);
AddOperand(ptr); AddOperand(ptr);
@ -172,117 +148,4 @@ Value* StoreInst::GetValue() const { return GetOperand(0); }
Value* StoreInst::GetPtr() const { return GetOperand(1); } Value* StoreInst::GetPtr() const { return GetOperand(1); }
CmpInst::CmpInst(CmpOp cmp_op, Value* lhs, Value* rhs, std::string name)
: Instruction(Opcode::Cmp, Type::GetInt1Type(), std::move(name)), cmp_op_(cmp_op) {
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "CmpInst 缺少操作数"));
}
if (!lhs->GetType() || !rhs->GetType()) {
throw std::runtime_error(FormatError("ir", "CmpInst 缺少操作数类型信息"));
}
if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind()) {
throw std::runtime_error(FormatError("ir", "CmpInst 操作数类型不匹配"));
}
AddOperand(lhs);
AddOperand(rhs);
}
CmpOp CmpInst::GetCmpOp() const { return cmp_op_; }
Value* CmpInst::GetLhs() const { return GetOperand(0); }
Value* CmpInst::GetRhs() const { return GetOperand(1); }
FCmpInst::FCmpInst(CmpOp cmp_op, Value* lhs, Value* rhs, std::string name)
: Instruction(Opcode::FCmp, Type::GetInt1Type(), std::move(name)), cmp_op_(cmp_op) {
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "FCmpInst 缺少操作数"));
}
if (!lhs->GetType() || !rhs->GetType()) {
throw std::runtime_error(FormatError("ir", "FCmpInst 缺少操作数类型信息"));
}
if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind()) {
throw std::runtime_error(FormatError("ir", "FCmpInst 操作数类型不匹配"));
}
AddOperand(lhs);
AddOperand(rhs);
}
CmpOp FCmpInst::GetCmpOp() const { return cmp_op_; }
Value* FCmpInst::GetLhs() const { return GetOperand(0); }
Value* FCmpInst::GetRhs() const { return GetOperand(1); }
ZextInst::ZextInst(std::shared_ptr<Type> dest_ty, Value* val, std::string name)
: Instruction(Opcode::Zext, std::move(dest_ty), std::move(name)) {
if (!val) {
throw std::runtime_error(FormatError("ir", "ZextInst 缺少操作数"));
}
if (!type_->IsInt32() || !val->GetType()->IsInt1()) {
throw std::runtime_error(FormatError("ir", "ZextInst 当前只支持 i1 到 i32"));
}
AddOperand(val);
}
Value* ZextInst::GetValue() const { return GetOperand(0); }
BranchInst::BranchInst(BasicBlock* dest)
: Instruction(Opcode::Br, Type::GetVoidType(), "") {
if (!dest) {
throw std::runtime_error(FormatError("ir", "BranchInst 缺少目的块"));
}
AddOperand(dest);
}
BasicBlock* BranchInst::GetDest() const { return static_cast<BasicBlock*>(GetOperand(0)); }
CondBranchInst::CondBranchInst(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb)
: Instruction(Opcode::CondBr, Type::GetVoidType(), "") {
if (!cond || !true_bb || !false_bb) {
throw std::runtime_error(FormatError("ir", "CondBranchInst 缺少连边操作数"));
}
if (!cond->GetType()->IsInt1()) {
throw std::runtime_error(FormatError("ir", "CondBranchInst 必须使用 i1 作为条件"));
}
AddOperand(cond);
AddOperand(true_bb);
AddOperand(false_bb);
}
Value* CondBranchInst::GetCond() const { return GetOperand(0); }
BasicBlock* CondBranchInst::GetTrueBlock() const { return static_cast<BasicBlock*>(GetOperand(1)); }
BasicBlock* CondBranchInst::GetFalseBlock() const { return static_cast<BasicBlock*>(GetOperand(2)); }
CallInst::CallInst(Function* func, std::vector<Value*> args, std::string name)
: Instruction(Opcode::Call, func->GetType(), std::move(name)), func_(func), args_(std::move(args)) {
if (!func) {
throw std::runtime_error(FormatError("ir", "CallInst 缺少目标函数"));
}
AddOperand(func);
for (auto* arg : args_) {
AddOperand(arg);
}
}
Function* CallInst::GetFunc() const { return func_; }
const std::vector<Value*>& CallInst::GetArgs() const { return args_; }
GEPInst::GEPInst(std::shared_ptr<Type> ty, Value* ptr, std::vector<Value*> indices, std::string name)
: Instruction(Opcode::GEP, std::move(ty), std::move(name)), indices_(std::move(indices)) {
AddOperand(ptr);
for (auto* idx : indices_) {
AddOperand(idx);
}
}
Value* GEPInst::GetPtr() const { return GetOperand(0); }
const std::vector<Value*>& GEPInst::GetIndices() const { return indices_; }
SIToFPInst::SIToFPInst(std::shared_ptr<Type> ty, Value* val, std::string name)
: Instruction(Opcode::SIToFP, std::move(ty), std::move(name)) {
AddOperand(val);
}
FPToSIInst::FPToSIInst(std::shared_ptr<Type> ty, Value* val, std::string name)
: Instruction(Opcode::FPToSI, std::move(ty), std::move(name)) {
AddOperand(val);
}
} // namespace ir } // namespace ir

@ -18,13 +18,4 @@ const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const {
return functions_; return functions_;
} }
GlobalVariable* Module::CreateGlobalVariable(const std::string& name, std::shared_ptr<Type> type, ConstantValue* init) {
global_variables_.push_back(std::make_unique<GlobalVariable>(name, std::move(type), init));
return global_variables_.back().get();
}
const std::vector<std::unique_ptr<GlobalVariable>>& Module::GetGlobalVariables() const {
return global_variables_;
}
} // namespace ir } // namespace ir

@ -4,67 +4,28 @@
namespace ir { namespace ir {
Type::Type(Kind k) : kind_(k) {} Type::Type(Kind k) : kind_(k) {}
Type::Type(Kind k, std::shared_ptr<Type> elem_ty, int num_elems)
: kind_(k), elem_ty_(std::move(elem_ty)), num_elems_(num_elems) {}
const std::shared_ptr<Type>& Type::GetVoidType() { const std::shared_ptr<Type>& Type::GetVoidType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Void); static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Void);
return type; return type;
} }
const std::shared_ptr<Type>& Type::GetInt1Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int1);
return type;
}
const std::shared_ptr<Type>& Type::GetInt32Type() { const std::shared_ptr<Type>& Type::GetInt32Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int32); static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int32);
return type; return type;
} }
const std::shared_ptr<Type>& Type::GetPtrInt32Type() { const std::shared_ptr<Type>& Type::GetPtrInt32Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::PtrInt32, GetInt32Type(), 0); static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::PtrInt32);
return type; return type;
} }
const std::shared_ptr<Type>& Type::GetFloatType() {
static std::shared_ptr<Type> ty = std::make_shared<Type>(Kind::Float);
return ty;
}
const std::shared_ptr<Type>& Type::GetPtrFloatType() {
static std::shared_ptr<Type> ty = std::make_shared<Type>(Kind::PtrFloat, GetFloatType(), 0);
return ty;
}
std::shared_ptr<Type> Type::GetArrayType(std::shared_ptr<Type> elem_ty, int num_elems) {
return std::make_shared<Type>(Kind::Array, std::move(elem_ty), num_elems);
}
std::shared_ptr<Type> Type::GetPointerType(std::shared_ptr<Type> pointed_ty) {
return std::make_shared<Type>(Kind::Pointer, std::move(pointed_ty), 0);
}
Type::Kind Type::GetKind() const { return kind_; } Type::Kind Type::GetKind() const { return kind_; }
bool Type::IsVoid() const { return kind_ == Kind::Void; } bool Type::IsVoid() const { return kind_ == Kind::Void; }
bool Type::IsInt1() const { return kind_ == Kind::Int1; }
bool Type::IsInt32() const { return kind_ == Kind::Int32; } bool Type::IsInt32() const { return kind_ == Kind::Int32; }
bool Type::IsPtrInt32() const { bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; }
return kind_ == Kind::PtrInt32 || (kind_ == Kind::Pointer && GetPointedType() && GetPointedType()->IsInt32());
}
bool Type::IsFloat() const { return kind_ == Kind::Float; }
bool Type::IsPtrFloat() const {
return kind_ == Kind::PtrFloat || (kind_ == Kind::Pointer && GetPointedType() && GetPointedType()->IsFloat());
}
bool Type::IsArray() const { return kind_ == Kind::Array; }
bool Type::IsPointer() const { return kind_ == Kind::Pointer || kind_ == Kind::PtrInt32 || kind_ == Kind::PtrFloat; }
} // namespace ir } // namespace ir

@ -18,16 +18,10 @@ void Value::SetName(std::string n) { name_ = std::move(n); }
bool Value::IsVoid() const { return type_ && type_->IsVoid(); } bool Value::IsVoid() const { return type_ && type_->IsVoid(); }
bool Value::IsInt1() const { return type_ && type_->IsInt1(); }
bool Value::IsInt32() const { return type_ && type_->IsInt32(); } bool Value::IsInt32() const { return type_ && type_->IsInt32(); }
bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); } bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); }
bool Value::IsFloat() const { return type_ && type_->IsFloat(); }
bool Value::IsPtrFloat() const { return type_ && type_->IsPtrFloat(); }
bool Value::IsConstant() const { bool Value::IsConstant() const {
return dynamic_cast<const ConstantValue*>(this) != nullptr; return dynamic_cast<const ConstantValue*>(this) != nullptr;
} }
@ -44,10 +38,6 @@ bool Value::IsFunction() const {
return dynamic_cast<const Function*>(this) != nullptr; return dynamic_cast<const Function*>(this) != nullptr;
} }
bool Value::IsArgument() const {
return dynamic_cast<const Argument*>(this) != nullptr;
}
void Value::AddUse(User* user, size_t operand_index) { void Value::AddUse(User* user, size_t operand_index) {
if (!user) return; if (!user) return;
uses_.push_back(Use(this, user, operand_index)); uses_.push_back(Use(this, user, operand_index));
@ -84,29 +74,10 @@ void Value::ReplaceAllUsesWith(Value* new_value) {
} }
} }
Argument::Argument(std::shared_ptr<Type> ty, std::string name, Function* parent, size_t arg_no)
: Value(std::move(ty), std::move(name)), parent_(parent), arg_no_(arg_no) {}
Function* Argument::GetParent() const { return parent_; }
size_t Argument::GetArgNo() const { return arg_no_; }
ConstantValue::ConstantValue(std::shared_ptr<Type> ty, std::string name) ConstantValue::ConstantValue(std::shared_ptr<Type> ty, std::string name)
: Value(std::move(ty), std::move(name)) {} : Value(std::move(ty), std::move(name)) {}
ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int v) ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int v)
: ConstantValue(std::move(ty), ""), value_(v) {} : ConstantValue(std::move(ty), ""), value_(v) {}
ConstantFloat::ConstantFloat(std::shared_ptr<Type> ty, float v)
: ConstantValue(std::move(ty), ""), value_(v) {}
ConstantArray::ConstantArray(std::shared_ptr<Type> ty, std::vector<ConstantValue*> elements)
: ConstantValue(std::move(ty), ""), elements_(std::move(elements)) {}
ConstantZero::ConstantZero(std::shared_ptr<Type> ty)
: ConstantValue(std::move(ty), "") {}
GlobalVariable::GlobalVariable(std::string name, std::shared_ptr<Type> type, ConstantValue* init)
: GlobalValue(std::move(type), std::move(name)), init_(init) {}
} // namespace ir } // namespace ir

@ -7,44 +7,29 @@
#include "utils/Log.h" #include "utils/Log.h"
namespace { namespace {
ir::ConstantValue* BuildConstantArray(ir::Context& ctx, std::shared_ptr<ir::Type> type,
const std::vector<ir::ConstantValue*>& flattened, std::string GetLValueName(SysYParser::LValueContext& lvalue) {
size_t& pos) { if (!lvalue.ID()) {
if (!type->IsArray()) { throw std::runtime_error(FormatError("irgen", "非法左值"));
return flattened[pos++];
}
std::vector<ir::ConstantValue*> elements;
for (int i = 0; i < type->GetNumElements(); ++i) {
elements.push_back(BuildConstantArray(ctx, type->GetElementType(), flattened, pos));
} }
return ctx.GetConstArray(type, elements); return lvalue.ID()->getText();
}
} }
std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) { } // namespace
std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
if (!ctx) { if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句块")); throw std::runtime_error(FormatError("irgen", "缺少语句块"));
} }
// 压入局部作用域
storage_map_stack_.push_back({});
const_values_stack_.push_back({});
bool terminated = false;
for (auto* item : ctx->blockItem()) { for (auto* item : ctx->blockItem()) {
if (item) { if (item) {
if (VisitBlockItemResult(*item) == BlockFlow::Terminated) { if (VisitBlockItemResult(*item) == BlockFlow::Terminated) {
terminated = true; // 当前语法要求 return 为块内最后一条语句;命中后可停止生成。
break; break;
} }
} }
} }
return {};
// 弹出局部作用域
storage_map_stack_.pop_back();
const_values_stack_.pop_back();
return terminated ? BlockFlow::Terminated : BlockFlow::Continue;
} }
IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult( IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult(
@ -66,206 +51,27 @@ std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明")); throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明"));
} }
std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
if (!ctx) return BlockFlow::Continue;
if (!ctx->bType() || (!ctx->bType()->INT() && !ctx->bType()->FLOAT())) {
throw std::runtime_error(FormatError("irgen", "当前仅支持 int/float 常量声明"));
}
for (auto* def : ctx->constDef()) {
if (def) def->accept(this);
}
return BlockFlow::Continue;
}
void IRGenImpl::FlattenInitVal(SysYParser::InitValContext* ctx,
const std::vector<int>& dims,
const std::vector<int>& sub_sizes,
int dim_idx,
size_t& current_pos,
std::vector<ir::Value*>& results,
bool is_float) {
if (ctx->exp()) {
ir::Value* val = EvalExpr(*ctx->exp());
// Implicit conversion
if (is_float && !val->GetType()->IsFloat()) {
val = builder_.CreateSIToFP(val, module_.GetContext().NextTemp());
} else if (!is_float && val->GetType()->IsFloat()) {
val = builder_.CreateFPToSI(val, module_.GetContext().NextTemp());
}
results[current_pos++] = val;
} else {
// Nested { ... }
size_t start_pos = current_pos;
for (auto* item : ctx->initVal()) {
FlattenInitVal(item, dims, sub_sizes, dim_idx + 1, current_pos, results, is_float);
}
// Fill remaining with 0
size_t end_pos = start_pos + sub_sizes[dim_idx];
while (current_pos < end_pos) {
results[current_pos++] = is_float ? (ir::Value*)module_.GetContext().GetConstFloat(0.0f)
: (ir::Value*)module_.GetContext().GetConstInt(0);
}
}
}
void IRGenImpl::FlattenConstInitVal(SysYParser::ConstInitValContext* ctx,
const std::vector<int>& dims,
const std::vector<int>& sub_sizes,
int dim_idx,
size_t& current_pos,
std::vector<ir::ConstantValue*>& results,
bool is_float) {
if (ctx->constExp()) {
ir::Value* val = std::any_cast<ir::Value*>(ctx->constExp()->accept(this));
ir::ConstantValue* cval = dynamic_cast<ir::ConstantValue*>(val);
if (!cval) throw std::runtime_error("Not a constant expression");
// Constant conversion
if (is_float && dynamic_cast<ir::ConstantInt*>(cval)) {
cval = module_.GetContext().GetConstFloat((float)static_cast<ir::ConstantInt*>(cval)->GetValue());
} else if (!is_float && dynamic_cast<ir::ConstantFloat*>(cval)) {
cval = module_.GetContext().GetConstInt((int)static_cast<ir::ConstantFloat*>(cval)->GetValue());
}
results[current_pos++] = cval;
} else {
size_t start_pos = current_pos;
for (auto* item : ctx->constInitVal()) {
FlattenConstInitVal(item, dims, sub_sizes, dim_idx + 1, current_pos, results, is_float);
}
// Fill remaining with 0
size_t end_pos = start_pos + sub_sizes[dim_idx];
while (current_pos < end_pos) {
results[current_pos++] = is_float ? (ir::ConstantValue*)module_.GetContext().GetConstFloat(0.0f)
: (ir::ConstantValue*)module_.GetContext().GetConstInt(0);
}
}
}
std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "常量定义缺少名称"));
}
std::string var_name = ctx->ID()->getText();
// Get dimensions
std::vector<int> dims;
for (auto* idx : ctx->constIndex()) {
dims.push_back(EvaluateConstInt(idx->constExp()));
}
bool is_float = false;
auto* parent_decl = dynamic_cast<SysYParser::ConstDeclContext*>(ctx->parent);
if (parent_decl && parent_decl->bType() && parent_decl->bType()->FLOAT()) {
is_float = true;
}
auto base_ty = is_float ? ir::Type::GetFloatType() : ir::Type::GetInt32Type();
std::shared_ptr<ir::Type> var_ty = base_ty;
for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
var_ty = ir::Type::GetArrayType(var_ty, *it);
}
std::vector<int> sub_sizes(dims.size() + 1);
sub_sizes[dims.size()] = 1;
for (int i = (int)dims.size() - 1; i >= 0; --i) {
sub_sizes[i] = sub_sizes[i+1] * dims[i];
}
ir::ConstantValue* init_const = nullptr;
std::vector<ir::ConstantValue*> flattened;
if (dims.empty()) {
if (auto* init_val = ctx->constInitVal()) {
if (init_val->constExp()) {
ir::Value* val = std::any_cast<ir::Value*>(init_val->constExp()->accept(this));
init_const = dynamic_cast<ir::ConstantValue*>(val);
// Constant conversion
if (is_float && dynamic_cast<ir::ConstantInt*>(init_const)) {
init_const = module_.GetContext().GetConstFloat((float)static_cast<ir::ConstantInt*>(init_const)->GetValue());
} else if (!is_float && dynamic_cast<ir::ConstantFloat*>(init_const)) {
init_const = module_.GetContext().GetConstInt((int)static_cast<ir::ConstantFloat*>(init_const)->GetValue());
}
}
}
} else {
flattened.resize(sub_sizes[0]);
if (auto* init_val = ctx->constInitVal()) {
size_t pos = 0;
FlattenConstInitVal(init_val, dims, sub_sizes, 0, pos, flattened, is_float);
} else {
auto zero = is_float ? (ir::ConstantValue*)module_.GetContext().GetConstFloat(0.0f) : (ir::ConstantValue*)module_.GetContext().GetConstInt(0);
for (auto& v : flattened) v = zero;
}
size_t pos = 0;
init_const = BuildConstantArray(module_.GetContext(), var_ty, flattened, pos);
}
// 记录常量值供后续直接使用 (only for scalars for now)
if (dims.empty() && !const_values_stack_.empty()) {
const_values_stack_.back()[var_name] = init_const;
}
if (func_ == nullptr) {
auto gv_ptr_ty = ir::Type::GetPointerType(var_ty);
auto* gv = module_.CreateGlobalVariable(var_name, gv_ptr_ty, init_const);
if (!storage_map_stack_.empty()) {
storage_map_stack_.back()[var_name] = gv;
}
} else {
// 局部作用域 - 确保 alloca 在入口块
auto* current_bb = builder_.GetInsertBlock();
builder_.SetInsertPoint(func_->GetEntry());
ir::Value* slot = builder_.CreateAlloca(var_ty, module_.GetContext().NextTemp());
builder_.SetInsertPoint(current_bb);
if (!storage_map_stack_.empty()) {
storage_map_stack_.back()[var_name] = slot;
}
if (dims.empty()) {
if (init_const) builder_.CreateStore(init_const, slot);
} else {
for (size_t i = 0; i < flattened.size(); ++i) {
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
size_t temp = i;
for (size_t d = 0; d < dims.size(); ++d) {
indices.push_back(builder_.CreateConstInt(temp / sub_sizes[d+1]));
temp %= sub_sizes[d+1];
}
ir::Value* ptr = builder_.CreateGEP(ir::Type::GetPointerType(base_ty), slot, indices, module_.GetContext().NextTemp());
builder_.CreateStore(flattened[i], ptr);
}
}
}
return BlockFlow::Continue;
}
// 变量声明的 IR 生成目前也是最小实现: // 变量声明的 IR 生成目前也是最小实现:
// - 先检查声明的基础类型,支持 int 和 float // - 先检查声明的基础类型,当前仅支持局部 int
// - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。 // - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。
//
// 和更完整的版本相比,这里还没有:
// - 一个 Decl 中多个变量定义的顺序处理;
// - const、数组、全局变量等不同声明形态
// - 更丰富的类型系统。
std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
if (!ctx) { if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量声明")); throw std::runtime_error(FormatError("irgen", "缺少变量声明"));
} }
// 当前语法中 decl 包含 constDecl 或 varDecl if (!ctx->btype() || !ctx->btype()->INT()) {
if (auto* var_decl = ctx->varDecl()) { throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明"));
if (!var_decl->bType() || (!var_decl->bType()->INT() && !var_decl->bType()->FLOAT())) {
throw std::runtime_error(FormatError("irgen", "当前仅支持 int/float 变量声明"));
}
for (auto* var_def : var_decl->varDef()) {
if (var_def) {
var_def->accept(this);
} }
auto* var_def = ctx->varDef();
if (!var_def) {
throw std::runtime_error(FormatError("irgen", "非法变量声明"));
} }
} else if (auto* const_decl = ctx->constDecl()) { var_def->accept(this);
return const_decl->accept(this); return {};
} else {
throw std::runtime_error(FormatError("irgen", "当前仅支持变量声明"));
}
return BlockFlow::Continue;
} }
@ -274,145 +80,28 @@ std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
// - 标量初始化; // - 标量初始化;
// - 一个 VarDef 对应一个槽位。 // - 一个 VarDef 对应一个槽位。
std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
if (!ctx || !ctx->ID()) { if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量定义"));
}
if (!ctx->lValue()) {
throw std::runtime_error(FormatError("irgen", "变量声明缺少名称")); throw std::runtime_error(FormatError("irgen", "变量声明缺少名称"));
} }
std::string var_name = ctx->ID()->getText(); GetLValueName(*ctx->lValue());
if (!storage_map_stack_.empty() && storage_map_stack_.back().find(var_name) != storage_map_stack_.back().end()) { if (storage_map_.find(ctx) != storage_map_.end()) {
throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位"));
} }
auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
storage_map_[ctx] = slot;
// Get dimensions ir::Value* init = nullptr;
std::vector<int> dims; if (auto* init_value = ctx->initValue()) {
for (auto* idx : ctx->constIndex()) { if (!init_value->exp()) {
dims.push_back(EvaluateConstInt(idx->constExp())); throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化"));
}
// Determine base type
bool is_float = false;
auto* parent_decl = dynamic_cast<SysYParser::VarDeclContext*>(ctx->parent);
if (parent_decl && parent_decl->bType() && parent_decl->bType()->FLOAT()) {
is_float = true;
}
auto base_ty = is_float ? ir::Type::GetFloatType() : ir::Type::GetInt32Type();
std::shared_ptr<ir::Type> var_ty = base_ty;
for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
var_ty = ir::Type::GetArrayType(var_ty, *it);
}
std::vector<int> sub_sizes(dims.size() + 1);
sub_sizes[dims.size()] = 1;
for (int i = (int)dims.size() - 1; i >= 0; --i) {
sub_sizes[i] = sub_sizes[i+1] * dims[i];
}
if (func_ == nullptr) {
// 全局作用域
ir::ConstantValue* init_const = nullptr;
if (dims.empty()) {
if (auto* init_val = ctx->initVal()) {
if (init_val->exp()) {
auto* val = EvalExpr(*init_val->exp());
init_const = dynamic_cast<ir::ConstantValue*>(val);
// Constant conversion
if (is_float && dynamic_cast<ir::ConstantInt*>(init_const)) {
init_const = module_.GetContext().GetConstFloat((float)static_cast<ir::ConstantInt*>(init_const)->GetValue());
} else if (!is_float && dynamic_cast<ir::ConstantFloat*>(init_const)) {
init_const = module_.GetContext().GetConstInt((int)static_cast<ir::ConstantFloat*>(init_const)->GetValue());
}
}
} else {
init_const = is_float ? (ir::ConstantValue*)module_.GetContext().GetConstFloat(0.0f) : (ir::ConstantValue*)module_.GetContext().GetConstInt(0);
}
} else {
if (auto* init_val = ctx->initVal()) {
std::vector<ir::ConstantValue*> flattened(sub_sizes[0]);
// VarDef's InitVal can be an expression or { ... }
if (init_val->exp()) {
auto* val = EvalExpr(*init_val->exp());
auto* cval = dynamic_cast<ir::ConstantValue*>(val);
flattened[0] = cval;
auto zero = is_float ? (ir::ConstantValue*)module_.GetContext().GetConstFloat(0.0f) : (ir::ConstantValue*)module_.GetContext().GetConstInt(0);
for (size_t i = 1; i < flattened.size(); ++i) {
flattened[i] = zero;
}
size_t bpos = 0;
init_const = BuildConstantArray(module_.GetContext(), var_ty, flattened, bpos);
} else {
size_t fpos = 0;
std::vector<ir::Value*> flat_vals(sub_sizes[0]);
auto zero = is_float ? (ir::Value*)module_.GetContext().GetConstFloat(0.0f) : (ir::Value*)module_.GetContext().GetConstInt(0);
for (auto& v : flat_vals) v = zero;
FlattenInitVal(init_val, dims, sub_sizes, 0, fpos, flat_vals, is_float);
for (size_t i = 0; i < flat_vals.size(); ++i) {
flattened[i] = dynamic_cast<ir::ConstantValue*>(flat_vals[i]);
}
size_t bpos = 0;
init_const = BuildConstantArray(module_.GetContext(), var_ty, flattened, bpos);
}
} else {
init_const = module_.GetContext().GetConstZero(var_ty);
}
}
auto gv_ptr_ty = ir::Type::GetPointerType(var_ty);
auto* gv = module_.CreateGlobalVariable(var_name, gv_ptr_ty, init_const);
if (!storage_map_stack_.empty()) {
storage_map_stack_.back()[var_name] = gv;
} }
init = EvalExpr(*init_value->exp());
} else { } else {
// 局部作用域 - 确保 alloca 在入口块 init = builder_.CreateConstInt(0);
auto* current_bb = builder_.GetInsertBlock();
builder_.SetInsertPoint(func_->GetEntry());
ir::Value* slot = builder_.CreateAlloca(var_ty, module_.GetContext().NextTemp());
builder_.SetInsertPoint(current_bb);
if (!storage_map_stack_.empty()) {
storage_map_stack_.back()[var_name] = slot;
}
if (auto* init_val = ctx->initVal()) {
if (dims.empty()) {
if (init_val->exp()) {
ir::Value* init = EvalExpr(*init_val->exp());
if (is_float && !init->GetType()->IsFloat()) {
init = builder_.CreateSIToFP(init, module_.GetContext().NextTemp());
} else if (!is_float && init->GetType()->IsFloat()) {
init = builder_.CreateFPToSI(init, module_.GetContext().NextTemp());
} }
builder_.CreateStore(init, slot); builder_.CreateStore(init, slot);
} return {};
} else {
std::vector<ir::Value*> flattened(sub_sizes[0]);
auto zero = is_float ? (ir::Value*)module_.GetContext().GetConstFloat(0.0f) : (ir::Value*)module_.GetContext().GetConstInt(0);
for (auto& v : flattened) v = zero;
size_t pos = 0;
FlattenInitVal(init_val, dims, sub_sizes, 0, pos, flattened, is_float);
for (size_t i = 0; i < flattened.size(); ++i) {
// Optimization: only store non-zero?
// For now, store all to be safe.
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
size_t temp = i;
for (size_t d = 0; d < dims.size(); ++d) {
indices.push_back(builder_.CreateConstInt(temp / sub_sizes[d+1]));
temp %= sub_sizes[d+1];
}
ir::Value* ptr = builder_.CreateGEP(ir::Type::GetPointerType(base_ty), slot, indices, module_.GetContext().NextTemp());
builder_.CreateStore(flattened[i], ptr);
}
}
} else {
// Initialize scalar locals to 0
if (dims.empty()) {
ir::Value* zero = is_float ? (ir::Value*)module_.GetContext().GetConstFloat(0.0f) : (ir::Value*)module_.GetContext().GetConstInt(0);
builder_.CreateStore(zero, slot);
}
}
}
return BlockFlow::Continue;
} }

@ -24,75 +24,21 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
return std::any_cast<ir::Value*>(expr.accept(this)); return std::any_cast<ir::Value*>(expr.accept(this));
} }
ir::ConstantValue* IRGenImpl::EvaluateConst(antlr4::tree::ParseTree* tree) {
auto val = std::any_cast<ir::Value*>(tree->accept(this));
auto* cval = dynamic_cast<ir::ConstantValue*>(val);
if (!cval) throw std::runtime_error("Not a constant expression");
return cval;
}
int IRGenImpl::EvaluateConstInt(SysYParser::ConstExpContext* ctx) {
if (!ctx) return 0;
auto* val = EvaluateConst(ctx->addExp());
if (auto* ci = dynamic_cast<ir::ConstantInt*>(val)) return ci->GetValue();
if (auto* cf = dynamic_cast<ir::ConstantFloat*>(val)) return (int)cf->GetValue();
return 0;
}
int IRGenImpl::EvaluateConstInt(SysYParser::ExpContext* ctx) {
if (!ctx) return 0;
auto* val = EvaluateConst(ctx);
if (auto* ci = dynamic_cast<ir::ConstantInt*>(val)) return ci->GetValue();
if (auto* cf = dynamic_cast<ir::ConstantFloat*>(val)) return (int)cf->GetValue();
return 0;
}
std::shared_ptr<ir::Type> IRGenImpl::GetGEPResultType(ir::Value* ptr, const std::vector<ir::Value*>& indices) { std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) {
auto cur_ty = ptr->GetType()->GetPointedType(); if (!ctx || !ctx->exp()) {
for (size_t i = 1; i < indices.size(); ++i) { throw std::runtime_error(FormatError("irgen", "非法括号表达式"));
if (cur_ty->IsArray()) {
cur_ty = cur_ty->GetElementType();
} }
}
return ir::Type::GetPointerType(cur_ty);
}
std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法基本表达式"));
}
// 处理括号表达式LPAREN exp RPAREN
if (ctx->exp()) {
return EvalExpr(*ctx->exp()); return EvalExpr(*ctx->exp());
}
// 处理 lVal变量使用
if (ctx->lVal()) {
return ctx->lVal()->accept(this);
}
// 处理 number
if (ctx->number()) {
return ctx->number()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "不支持的基本表达式类型"));
} }
std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) { std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) {
if (!ctx) { if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
throw std::runtime_error(FormatError("irgen", "缺少字面量节点")); throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量"));
} }
if (ctx->intConst()) {
// 可能是 0x, 0X, 0 开头的八进制等,目前 std::stoi 会处理十进制,
// 为了支持 16 进制/8 进制建议使用 std::stoi(str, nullptr, 0)
std::string text = ctx->intConst()->getText();
return static_cast<ir::Value*>(
builder_.CreateConstInt(std::stoi(text, nullptr, 0)));
} else if (ctx->floatConst()) {
std::string text = ctx->floatConst()->getText();
return static_cast<ir::Value*>( return static_cast<ir::Value*>(
module_.GetContext().GetConstFloat(std::stof(text))); builder_.CreateConstInt(std::stoi(ctx->number()->getText())));
}
throw std::runtime_error(FormatError("irgen", "不支持的字面量"));
} }
// 变量使用的处理流程: // 变量使用的处理流程:
@ -101,482 +47,34 @@ std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) {
// 3. 最后生成 load把内存中的值读出来。 // 3. 最后生成 load把内存中的值读出来。
// //
// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。 // 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。
std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) {
if (!ctx || !ctx->ID()) { if (!ctx || !ctx->var() || !ctx->var()->ID()) {
throw std::runtime_error(FormatError("irgen", "非法左值")); throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量"));
}
std::string var_name = ctx->ID()->getText();
// 优先检查是否为已记录的常量
ir::ConstantValue* const_val = FindConst(var_name);
if (const_val && ctx->exp().empty()) {
return static_cast<ir::Value*>(const_val);
} }
auto* decl = sema_.ResolveVarUse(ctx->var());
const auto* binding = sema_.ResolveObjectUse(ctx); if (!decl) {
if (!binding) {
throw std::runtime_error( throw std::runtime_error(
FormatError("irgen", "变量使用缺少语义绑定:" + var_name)); FormatError("irgen",
"变量使用缺少语义绑定: " + ctx->var()->ID()->getText()));
} }
auto it = storage_map_.find(decl);
ir::Value* slot = FindStorage(var_name); if (it == storage_map_.end()) {
if (!slot) {
throw std::runtime_error( throw std::runtime_error(
FormatError("irgen", "变量声明缺少存储槽位:" + var_name)); FormatError("irgen",
} "变量声明缺少存储槽位: " + ctx->var()->ID()->getText()));
ir::Value* ptr = slot;
auto ptr_ty = ptr->GetType();
bool is_param = false;
// If it's a pointer to a pointer (function parameter case), load the pointer value first
if (ptr_ty->IsPointer() && ptr_ty->GetPointedType()->IsPointer()) {
ptr = builder_.CreateLoad(ptr, module_.GetContext().NextTemp());
is_param = true;
} else if (ptr->IsArgument()) {
is_param = true;
}
// Determine if the result of this LVal is a scalar or an array
bool result_is_scalar = (ctx->exp().size() == binding->dimensions.size());
if (!ctx->exp().empty()) {
std::vector<ir::Value*> indices;
// If it's a local array, we need leading 0
if (ptr->GetType()->IsPointer() && ptr->GetType()->GetPointedType()->IsArray()) {
if (!is_param) {
indices.push_back(builder_.CreateConstInt(0));
}
}
for (auto* exp_ctx : ctx->exp()) {
indices.push_back(EvalExpr(*exp_ctx));
}
auto res_ptr_ty = GetGEPResultType(ptr, indices);
ptr = builder_.CreateGEP(res_ptr_ty, ptr, indices, module_.GetContext().NextTemp());
}
if (result_is_scalar) {
return static_cast<ir::Value*>(builder_.CreateLoad(ptr, module_.GetContext().NextTemp()));
} else {
// Decay ptr to the first element of the sub-array
while (ptr->GetType()->GetPointedType()->IsArray()) {
std::vector<ir::Value*> d_indices;
d_indices.push_back(builder_.CreateConstInt(0));
d_indices.push_back(builder_.CreateConstInt(0));
auto d_res_ty = GetGEPResultType(ptr, d_indices);
ptr = builder_.CreateGEP(d_res_ty, ptr, d_indices, module_.GetContext().NextTemp());
}
return ptr;
}
}
std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法加减法表达式"));
}
// 如果是 mulExp 直接返回addExp : mulExp
if (ctx->mulExp() && ctx->addExp() == nullptr) {
return ctx->mulExp()->accept(this);
}
// 处理 addExp op mulExp 的递归形式
if (!ctx->addExp() || !ctx->mulExp()) {
throw std::runtime_error(FormatError("irgen", "非法加减法表达式结构"));
} }
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->addExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->mulExp()->accept(this));
if (lhs->IsConstant() && rhs->IsConstant()) {
auto* cl = static_cast<ir::ConstantValue*>(lhs);
auto* cr = static_cast<ir::ConstantValue*>(rhs);
if (auto* cil = dynamic_cast<ir::ConstantInt*>(cl)) {
if (auto* cir = dynamic_cast<ir::ConstantInt*>(cr)) {
if (ctx->ADD()) return static_cast<ir::Value*>(module_.GetContext().GetConstInt(cil->GetValue() + cir->GetValue()));
if (ctx->SUB()) return static_cast<ir::Value*>(module_.GetContext().GetConstInt(cil->GetValue() - cir->GetValue()));
}
}
}
// Implicit conversion
if (lhs->GetType()->IsFloat() && !rhs->GetType()->IsFloat()) {
if (rhs->IsConstant()) {
rhs = module_.GetContext().GetConstFloat((float)static_cast<ir::ConstantInt*>(rhs)->GetValue());
} else {
rhs = builder_.CreateSIToFP(rhs, module_.GetContext().NextTemp());
}
} else if (!lhs->GetType()->IsFloat() && rhs->GetType()->IsFloat()) {
if (lhs->IsConstant()) {
lhs = module_.GetContext().GetConstFloat((float)static_cast<ir::ConstantInt*>(lhs)->GetValue());
} else {
lhs = builder_.CreateSIToFP(lhs, module_.GetContext().NextTemp());
}
}
if (lhs->IsConstant() && rhs->IsConstant()) {
auto* cl = static_cast<ir::ConstantValue*>(lhs);
auto* cr = static_cast<ir::ConstantValue*>(rhs);
if (auto* cfl = dynamic_cast<ir::ConstantFloat*>(cl)) {
if (auto* cfr = dynamic_cast<ir::ConstantFloat*>(cr)) {
if (ctx->ADD()) return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(cfl->GetValue() + cfr->GetValue()));
if (ctx->SUB()) return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(cfl->GetValue() - cfr->GetValue()));
}
}
}
ir::Opcode op = ir::Opcode::Add;
if (ctx->ADD()) {
op = ir::Opcode::Add;
} else if (ctx->SUB()) {
op = ir::Opcode::Sub;
} else {
throw std::runtime_error(FormatError("irgen", "未知的加减运算符"));
}
return static_cast<ir::Value*>( return static_cast<ir::Value*>(
builder_.CreateBinary(op, lhs, rhs, module_.GetContext().NextTemp())); builder_.CreateLoad(it->second, module_.GetContext().NextTemp()));
} }
std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) {
if (!ctx) { if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
throw std::runtime_error(FormatError("irgen", "非法一元表达式")); throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
}
// 如果是 primaryExp 直接返回unaryExp : primaryExp
if (ctx->primaryExp()) {
return ctx->primaryExp()->accept(this);
}
// 处理函数调用unaryExp : ID LPAREN funcRParams? RPAREN
if (ctx->ID()) {
std::string func_name = ctx->ID()->getText();
// 从 Sema 或 Module 中查找函数
// 目前简化处理,直接从 Module 中查找(如果是当前文件内定义的)
// 或者依赖 Sema 给出解析结果
const FunctionBinding* func_binding = sema_.ResolveFunctionCall(ctx);
if (!func_binding) {
throw std::runtime_error(FormatError("irgen", "未找到函数声明:" + func_name));
}
// 假设 func_binding 能够找到对应的 ir::Function*
// 这里如果 sema 不提供直接拿 ir::Function 的方式,需要遍历 module_.GetFunctions() 查找
ir::Function* target_func = nullptr;
for (const auto& f : module_.GetFunctions()) {
if (f->GetName() == func_name) {
target_func = f.get();
break;
}
}
if (!target_func) {
// 可能是外部函数如 putint, getint 等
// 如果没有在 module_ 中,则需要创建一个只有声明的 Function
std::shared_ptr<ir::Type> ret_ty;
if (func_binding->return_type == SemanticType::Int) {
ret_ty = ir::Type::GetInt32Type();
} else if (func_binding->return_type == SemanticType::Float) {
ret_ty = ir::Type::GetFloatType();
} else {
ret_ty = ir::Type::GetVoidType();
}
target_func = module_.CreateFunction(func_name, ret_ty);
// 对于外部函数,需要传递参数,可能还需要在 target_func 中 AddArgument
for (const auto& param : func_binding->params) {
std::shared_ptr<ir::Type> p_ty;
if (param.type == SemanticType::Int) {
p_ty = param.dimensions.empty() && !param.is_array_param ? ir::Type::GetInt32Type() : ir::Type::GetPtrInt32Type();
} else {
p_ty = param.dimensions.empty() && !param.is_array_param ? ir::Type::GetFloatType() : ir::Type::GetPtrFloatType();
}
target_func->AddArgument(p_ty, param.name);
}
}
std::vector<ir::Value*> args;
if (ctx->funcRParams()) {
args = std::any_cast<std::vector<ir::Value*>>(ctx->funcRParams()->accept(this));
}
// Implicit conversion for function arguments
const auto& formal_args = target_func->GetArgs();
for (size_t i = 0; i < std::min(args.size(), formal_args.size()); ++i) {
if (formal_args[i]->GetType()->IsFloat() && !args[i]->GetType()->IsFloat()) {
args[i] = builder_.CreateSIToFP(args[i], module_.GetContext().NextTemp());
} else if (formal_args[i]->GetType()->IsInt32() && args[i]->GetType()->IsFloat()) {
args[i] = builder_.CreateFPToSI(args[i], module_.GetContext().NextTemp());
}
}
return static_cast<ir::Value*>(builder_.CreateCall(target_func, args, module_.GetContext().NextTemp()));
}
// 处理一元运算符unaryExp : addUnaryOp unaryExp
if (ctx->addUnaryOp() && ctx->unaryExp()) {
ir::Value* operand = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
// Constant folding for unary op
if (operand->IsConstant()) {
if (ctx->addUnaryOp()->SUB()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(operand)) {
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(-ci->GetValue()));
} else if (auto* cf = dynamic_cast<ir::ConstantFloat*>(operand)) {
return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(-cf->GetValue()));
} }
} else { ir::Value* lhs = EvalExpr(*ctx->exp(0));
return operand; ir::Value* rhs = EvalExpr(*ctx->exp(1));
}
}
// 判断是正号还是负号
if (ctx->addUnaryOp()->SUB()) {
// 负号:如果是整数生成 sub 0, operand浮点数生成 fsub 0.0, operand
if (operand->GetType()->IsFloat()) {
ir::Value* zero = module_.GetContext().GetConstFloat(0.0f);
// 此处暂且假设 CreateSub 可以处理浮点数(如果底层有 fsub 则更好)
return static_cast<ir::Value*>( return static_cast<ir::Value*>(
builder_.CreateSub(zero, operand, module_.GetContext().NextTemp())); builder_.CreateBinary(ir::Opcode::Add, lhs, rhs,
} else { module_.GetContext().NextTemp()));
ir::Value* zero = builder_.CreateConstInt(0);
return static_cast<ir::Value*>(
builder_.CreateSub(zero, operand, module_.GetContext().NextTemp()));
}
} else if (ctx->addUnaryOp()->ADD()) {
// 正号:直接返回操作数(+x 等价于 x
return operand;
} else {
throw std::runtime_error(FormatError("irgen", "未知的一元运算符"));
}
}
throw std::runtime_error(FormatError("irgen", "不支持的一元表达式类型"));
}
std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法乘除法表达式"));
}
// 如果是 unaryExp 直接返回mulExp : unaryExp
if (ctx->unaryExp() && ctx->mulExp() == nullptr) {
return ctx->unaryExp()->accept(this);
}
// 处理 mulExp op unaryExp 的递归形式
if (!ctx->mulExp() || !ctx->unaryExp()) {
throw std::runtime_error(FormatError("irgen", "非法乘除法表达式结构"));
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->mulExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
// Constant folding
if (lhs->IsConstant() && rhs->IsConstant()) {
auto* cl = static_cast<ir::ConstantValue*>(lhs);
auto* cr = static_cast<ir::ConstantValue*>(rhs);
if (auto* cil = dynamic_cast<ir::ConstantInt*>(cl)) {
if (auto* cir = dynamic_cast<ir::ConstantInt*>(cr)) {
if (ctx->MUL()) return static_cast<ir::Value*>(module_.GetContext().GetConstInt(cil->GetValue() * cir->GetValue()));
if (ctx->DIV()) return static_cast<ir::Value*>(module_.GetContext().GetConstInt(cil->GetValue() / cir->GetValue()));
if (ctx->MOD()) return static_cast<ir::Value*>(module_.GetContext().GetConstInt(cil->GetValue() % cir->GetValue()));
}
}
}
// Implicit conversion
if (lhs->GetType()->IsFloat() && !rhs->GetType()->IsFloat()) {
if (rhs->IsConstant()) {
rhs = module_.GetContext().GetConstFloat((float)static_cast<ir::ConstantInt*>(rhs)->GetValue());
} else {
rhs = builder_.CreateSIToFP(rhs, module_.GetContext().NextTemp());
}
} else if (!lhs->GetType()->IsFloat() && rhs->GetType()->IsFloat()) {
if (lhs->IsConstant()) {
lhs = module_.GetContext().GetConstFloat((float)static_cast<ir::ConstantInt*>(lhs)->GetValue());
} else {
lhs = builder_.CreateSIToFP(lhs, module_.GetContext().NextTemp());
}
}
if (lhs->IsConstant() && rhs->IsConstant()) {
auto* cl = static_cast<ir::ConstantValue*>(lhs);
auto* cr = static_cast<ir::ConstantValue*>(rhs);
if (auto* cfl = dynamic_cast<ir::ConstantFloat*>(cl)) {
if (auto* cfr = dynamic_cast<ir::ConstantFloat*>(cr)) {
if (ctx->MUL()) return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(cfl->GetValue() * cfr->GetValue()));
if (ctx->DIV()) return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(cfl->GetValue() / cfr->GetValue()));
}
}
}
ir::Opcode op = ir::Opcode::Mul;
if (ctx->MUL()) {
op = ir::Opcode::Mul;
} else if (ctx->DIV()) {
op = ir::Opcode::Div;
} else if (ctx->MOD()) {
op = ir::Opcode::Mod;
} else {
throw std::runtime_error(FormatError("irgen", "未知的乘除运算符"));
}
return static_cast<ir::Value*>(
builder_.CreateBinary(op, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
if (ctx->addExp() && ctx->relExp() == nullptr) {
return ctx->addExp()->accept(this);
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->relExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->addExp()->accept(this));
if (lhs->GetType()->IsInt1()) lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp());
if (rhs->GetType()->IsInt1()) rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp());
// Implicit conversion
if (lhs->GetType()->IsFloat() && !rhs->GetType()->IsFloat()) {
rhs = builder_.CreateSIToFP(rhs, module_.GetContext().NextTemp());
} else if (!lhs->GetType()->IsFloat() && rhs->GetType()->IsFloat()) {
lhs = builder_.CreateSIToFP(lhs, module_.GetContext().NextTemp());
}
ir::CmpOp op;
if (ctx->LT()) op = ir::CmpOp::Lt;
else if (ctx->GT()) op = ir::CmpOp::Gt;
else if (ctx->LE()) op = ir::CmpOp::Le;
else if (ctx->GE()) op = ir::CmpOp::Ge;
else throw std::runtime_error(FormatError("irgen", "未知的关系运算符"));
return static_cast<ir::Value*>(builder_.CreateCmp(op, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
if (ctx->relExp() && ctx->eqExp() == nullptr) {
return ctx->relExp()->accept(this);
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->eqExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->relExp()->accept(this));
if (lhs->GetType()->IsInt1()) lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp());
if (rhs->GetType()->IsInt1()) rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp());
// Implicit conversion
if (lhs->GetType()->IsFloat() && !rhs->GetType()->IsFloat()) {
rhs = builder_.CreateSIToFP(rhs, module_.GetContext().NextTemp());
} else if (!lhs->GetType()->IsFloat() && rhs->GetType()->IsFloat()) {
lhs = builder_.CreateSIToFP(lhs, module_.GetContext().NextTemp());
}
ir::CmpOp op;
if (ctx->EQ()) op = ir::CmpOp::Eq;
else if (ctx->NE()) op = ir::CmpOp::Ne;
else throw std::runtime_error(FormatError("irgen", "未知的相等运算符"));
return static_cast<ir::Value*>(builder_.CreateCmp(op, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitCondUnaryExp(SysYParser::CondUnaryExpContext* ctx) {
if (ctx->eqExp()) {
return ctx->eqExp()->accept(this);
}
if (ctx->NOT()) {
ir::Value* operand = std::any_cast<ir::Value*>(ctx->condUnaryExp()->accept(this));
if (operand->GetType()->IsInt1()) {
operand = builder_.CreateZext(operand, module_.GetContext().NextTemp());
}
if (operand->GetType()->IsFloat()) {
ir::Value* zero = module_.GetContext().GetConstFloat(0.0f);
return static_cast<ir::Value*>(builder_.CreateCmp(ir::CmpOp::Eq, operand, zero, module_.GetContext().NextTemp()));
} else {
ir::Value* zero = builder_.CreateConstInt(0);
return static_cast<ir::Value*>(builder_.CreateCmp(ir::CmpOp::Eq, operand, zero, module_.GetContext().NextTemp()));
}
}
throw std::runtime_error(FormatError("irgen", "非法条件一元表达式"));
}
std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
if (ctx->condUnaryExp() && ctx->lAndExp() == nullptr) {
return ctx->condUnaryExp()->accept(this);
}
ir::AllocaInst* res_ptr = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
ir::Value* zero = builder_.CreateConstInt(0);
builder_.CreateStore(zero, res_ptr);
ir::BasicBlock* rhs_bb = func_->CreateBlock(NextBlockName("land_rhs"));
ir::BasicBlock* end_bb = func_->CreateBlock(NextBlockName("land_end"));
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->lAndExp()->accept(this));
if (lhs->GetType()->IsInt1()) {
lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp());
}
ir::Value* lhs_cond = builder_.CreateCmp(ir::CmpOp::Ne, lhs, zero, module_.GetContext().NextTemp());
builder_.CreateCondBr(lhs_cond, rhs_bb, end_bb);
builder_.SetInsertPoint(rhs_bb);
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->condUnaryExp()->accept(this));
if (rhs->GetType()->IsInt1()) {
rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp());
}
ir::Value* rhs_cond = builder_.CreateCmp(ir::CmpOp::Ne, rhs, zero, module_.GetContext().NextTemp());
ir::Value* rhs_res = builder_.CreateZext(rhs_cond, module_.GetContext().NextTemp());
builder_.CreateStore(rhs_res, res_ptr);
builder_.CreateBr(end_bb);
builder_.SetInsertPoint(end_bb);
return static_cast<ir::Value*>(builder_.CreateLoad(res_ptr, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
if (ctx->lAndExp() && ctx->lOrExp() == nullptr) {
return ctx->lAndExp()->accept(this);
}
ir::AllocaInst* res_ptr = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
ir::Value* one = builder_.CreateConstInt(1);
builder_.CreateStore(one, res_ptr);
ir::BasicBlock* rhs_bb = func_->CreateBlock(NextBlockName("lor_rhs"));
ir::BasicBlock* end_bb = func_->CreateBlock(NextBlockName("lor_end"));
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->lOrExp()->accept(this));
ir::Value* zero = builder_.CreateConstInt(0);
if (lhs->GetType()->IsInt1()) {
lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp());
}
ir::Value* lhs_cond = builder_.CreateCmp(ir::CmpOp::Eq, lhs, zero, module_.GetContext().NextTemp());
builder_.CreateCondBr(lhs_cond, rhs_bb, end_bb);
builder_.SetInsertPoint(rhs_bb);
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->lAndExp()->accept(this));
if (rhs->GetType()->IsInt1()) {
rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp());
}
ir::Value* rhs_cond = builder_.CreateCmp(ir::CmpOp::Ne, rhs, zero, module_.GetContext().NextTemp());
ir::Value* rhs_res = builder_.CreateZext(rhs_cond, module_.GetContext().NextTemp());
builder_.CreateStore(rhs_res, res_ptr);
builder_.CreateBr(end_bb);
builder_.SetInsertPoint(end_bb);
return static_cast<ir::Value*>(builder_.CreateLoad(res_ptr, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) {
if (!ctx || !ctx->lOrExp()) {
throw std::runtime_error(FormatError("irgen", "非法条件表达式"));
}
return ctx->lOrExp()->accept(this);
}
std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) {
std::vector<ir::Value*> args;
for (auto* exp : ctx->exp()) {
args.push_back(EvalExpr(*exp));
}
return args;
} }

@ -29,7 +29,7 @@ IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema)
// 编译单元的 IR 生成当前只实现了最小功能: // 编译单元的 IR 生成当前只实现了最小功能:
// - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容; // - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容;
// - 当前会读取编译单元中的 topLevelItem找到 funcDef 后生成函数 IR // - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR
// //
// 当前还没有实现: // 当前还没有实现:
// - 多个函数定义的遍历与生成; // - 多个函数定义的遍历与生成;
@ -38,24 +38,11 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
if (!ctx) { if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少编译单元")); throw std::runtime_error(FormatError("irgen", "缺少编译单元"));
} }
// 初始化全局作用域 auto* func = ctx->funcDef();
storage_map_stack_.push_back({}); if (!func) {
const_values_stack_.push_back({}); throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
// 遍历所有 topLevelItem
for (auto* item : ctx->topLevelItem()) {
if (!item) continue;
if (item->funcDef()) {
item->funcDef()->accept(this);
} else if (item->decl()) {
item->decl()->accept(this);
}
} }
func->accept(this);
// 退出全局作用域
storage_map_stack_.pop_back();
const_values_stack_.pop_back();
return {}; return {};
} }
@ -74,98 +61,27 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
// - 入口块中的参数初始化逻辑。 // - 入口块中的参数初始化逻辑。
// ... // ...
// 因此这里目前只支持最小的“无参 int 函数”生成。
std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
if (!ctx) { if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义")); throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
} }
if (!ctx->block()) { if (!ctx->blockStmt()) {
throw std::runtime_error(FormatError("irgen", "函数体为空")); throw std::runtime_error(FormatError("irgen", "函数体为空"));
} }
if (!ctx->ID()) { if (!ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "缺少函数名")); throw std::runtime_error(FormatError("irgen", "缺少函数名"));
} }
if (!ctx->funcType() || !ctx->funcType()->INT()) {
std::shared_ptr<ir::Type> ret_type; throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数"));
if (ctx->funcType()->INT()) {
ret_type = ir::Type::GetInt32Type();
} else if (ctx->funcType()->FLOAT()) {
ret_type = ir::Type::GetFloatType();
} else if (ctx->funcType()->VOID()) {
ret_type = ir::Type::GetVoidType();
} else {
throw std::runtime_error(FormatError("irgen", "未知的函数返回类型"));
}
func_ = module_.CreateFunction(ctx->ID()->getText(), ret_type);
ir::BasicBlock* alloca_bb = func_->CreateBlock("alloca");
ir::BasicBlock* entry_bb = func_->CreateBlock("entry");
builder_.SetInsertPoint(entry_bb);
// 进入函数作用域,压入一个新的 map
storage_map_stack_.push_back({});
const_values_stack_.push_back({});
if (ctx->funcFParams()) {
for (auto* paramCtx : ctx->funcFParams()->funcFParam()) {
std::shared_ptr<ir::Type> param_type;
bool is_array = !paramCtx->LBRACK().empty();
auto base_sema_ty = paramCtx->bType()->INT() ? SemanticType::Int : SemanticType::Float;
auto base_ir_ty = (base_sema_ty == SemanticType::Int) ? ir::Type::GetInt32Type() : ir::Type::GetFloatType();
if (is_array) {
std::shared_ptr<ir::Type> elem_ty = base_ir_ty;
auto exps = paramCtx->exp();
for (auto it = exps.rbegin(); it != exps.rend(); ++it) {
int dim = EvaluateConstInt(*it);
elem_ty = ir::Type::GetArrayType(elem_ty, dim);
}
param_type = ir::Type::GetPointerType(elem_ty);
} else {
param_type = base_ir_ty;
} }
std::string arg_name = paramCtx->ID()->getText(); func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type());
auto* arg = func_->AddArgument(param_type, "%arg" + std::to_string(func_->GetArgs().size())); builder_.SetInsertPoint(func_->GetEntry());
storage_map_.clear();
// Ensure param alloca is in alloca_bb
auto* current_bb = builder_.GetInsertBlock();
builder_.SetInsertPoint(alloca_bb);
ir::Instruction* alloca_inst = builder_.CreateAlloca(param_type, module_.GetContext().NextTemp());
builder_.SetInsertPoint(current_bb);
builder_.CreateStore(arg, alloca_inst);
storage_map_stack_.back()[arg_name] = alloca_inst;
}
}
ctx->block()->accept(this);
// Implicit return for void functions or main
if (!builder_.GetInsertBlock()->HasTerminator()) {
if (func_->GetType()->IsVoid()) {
builder_.CreateRet(nullptr);
} else if (func_->GetName() == "main") {
builder_.CreateRet(builder_.CreateConstInt(0));
} else {
if (func_->GetType()->IsFloat()) {
builder_.CreateRet(module_.GetContext().GetConstFloat(0.0f));
} else {
builder_.CreateRet(builder_.CreateConstInt(0));
}
}
}
// Branch from alloca_bb to entry_bb
builder_.SetInsertPoint(alloca_bb);
builder_.CreateBr(entry_bb);
ctx->blockStmt()->accept(this);
// 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。
VerifyFunctionStructure(*func_); VerifyFunctionStructure(*func_);
func_ = nullptr;
// 退出函数作用域,弹出 map
storage_map_stack_.pop_back();
const_values_stack_.pop_back();
return {}; return {};
} }

@ -9,9 +9,9 @@
// 语句生成当前只实现了最小子集。 // 语句生成当前只实现了最小子集。
// 目前支持: // 目前支持:
// - return <exp>; // - return <exp>;
// - 赋值语句lVal = exp;
// //
// 还未支持: // 还未支持:
// - 赋值语句
// - if / while 等控制流 // - if / while 等控制流
// - 空语句、块语句嵌套分发之外的更多语句形态 // - 空语句、块语句嵌套分发之外的更多语句形态
@ -19,178 +19,21 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
if (!ctx) { if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句")); throw std::runtime_error(FormatError("irgen", "缺少语句"));
} }
if (ctx->returnStmt()) {
if (ctx->lVal() && ctx->ASSIGN()) { return ctx->returnStmt()->accept(this);
if (!ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "赋值语句缺少表达式"));
}
ir::Value* rhs = EvalExpr(*ctx->exp());
auto* lval_ctx = ctx->lVal();
if (!lval_ctx || !lval_ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量赋值"));
}
const auto* decl = sema_.ResolveObjectUse(lval_ctx);
if (!decl) {
throw std::runtime_error(
FormatError("irgen", "变量使用缺少语义绑定:" + lval_ctx->ID()->getText()));
}
std::string var_name = lval_ctx->ID()->getText();
ir::Value* slot = FindStorage(var_name);
if (!slot) {
throw std::runtime_error(
FormatError("irgen", "变量声明缺少存储槽位:" + var_name));
}
ir::Value* ptr = slot;
auto ptr_ty = ptr->GetType();
bool is_param = false;
// If it's a pointer to a pointer (function parameter case), load the pointer value first
if (ptr_ty->IsPointer() && ptr_ty->GetPointedType()->IsPointer()) {
ptr = builder_.CreateLoad(ptr, module_.GetContext().NextTemp());
is_param = true;
}
if (ptr->IsArgument()) is_param = true;
if (!lval_ctx->exp().empty()) {
std::vector<ir::Value*> indices;
if (ptr->GetType()->IsPointer() && ptr->GetType()->GetPointedType()->IsArray()) {
if (!is_param) {
indices.push_back(builder_.CreateConstInt(0));
}
}
for (auto* exp_ctx : lval_ctx->exp()) {
indices.push_back(EvalExpr(*exp_ctx));
}
auto res_ptr_ty = GetGEPResultType(ptr, indices);
ptr = builder_.CreateGEP(res_ptr_ty, ptr, indices, module_.GetContext().NextTemp());
}
// Implicit conversion for assignment
if ((ptr->GetType()->IsPtrFloat() || (ptr->GetType()->IsArray() && ptr->GetType()->GetElementType()->IsFloat())) && !rhs->GetType()->IsFloat()) {
rhs = builder_.CreateSIToFP(rhs, module_.GetContext().NextTemp());
} else if ((ptr->GetType()->IsPtrInt32() || (ptr->GetType()->IsArray() && ptr->GetType()->GetElementType()->IsInt32())) && rhs->GetType()->IsFloat()) {
rhs = builder_.CreateFPToSI(rhs, module_.GetContext().NextTemp());
}
builder_.CreateStore(rhs, ptr);
return BlockFlow::Continue;
}
if (ctx->IF()) {
ir::Value* cond_val = std::any_cast<ir::Value*>(ctx->cond()->accept(this));
// cond_val must be i1, if it's not we need to check if it's != 0
if (cond_val->GetType()->IsInt32()) {
ir::Value* zero = builder_.CreateConstInt(0);
cond_val = builder_.CreateCmp(ir::CmpOp::Ne, cond_val, zero, module_.GetContext().NextTemp());
} else if (cond_val->GetType()->IsFloat()) {
ir::Value* zero = module_.GetContext().GetConstFloat(0.0f);
cond_val = builder_.CreateCmp(ir::CmpOp::Ne, cond_val, zero, module_.GetContext().NextTemp());
}
ir::BasicBlock* then_bb = func_->CreateBlock(NextBlockName("if_then"));
ir::BasicBlock* else_bb = ctx->ELSE() ? func_->CreateBlock(NextBlockName("if_else")) : nullptr;
ir::BasicBlock* merge_bb = func_->CreateBlock(NextBlockName("if_merge"));
builder_.CreateCondBr(cond_val, then_bb, else_bb ? else_bb : merge_bb);
builder_.SetInsertPoint(then_bb);
auto then_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
if (then_flow == BlockFlow::Continue) {
builder_.CreateBr(merge_bb);
}
if (ctx->ELSE()) {
builder_.SetInsertPoint(else_bb);
auto else_flow = std::any_cast<BlockFlow>(ctx->stmt(1)->accept(this));
if (else_flow == BlockFlow::Continue) {
builder_.CreateBr(merge_bb);
}
}
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
}
if (ctx->WHILE()) {
ir::BasicBlock* cond_bb = func_->CreateBlock(NextBlockName("while_cond"));
ir::BasicBlock* body_bb = func_->CreateBlock(NextBlockName("while_body"));
ir::BasicBlock* exit_bb = func_->CreateBlock(NextBlockName("while_exit"));
builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(cond_bb);
ir::Value* cond_val = std::any_cast<ir::Value*>(ctx->cond()->accept(this));
if (cond_val->GetType()->IsInt32()) {
ir::Value* zero = builder_.CreateConstInt(0);
cond_val = builder_.CreateCmp(ir::CmpOp::Ne, cond_val, zero, module_.GetContext().NextTemp());
} else if (cond_val->GetType()->IsFloat()) {
ir::Value* zero = module_.GetContext().GetConstFloat(0.0f);
cond_val = builder_.CreateCmp(ir::CmpOp::Ne, cond_val, zero, module_.GetContext().NextTemp());
}
builder_.CreateCondBr(cond_val, body_bb, exit_bb);
builder_.SetInsertPoint(body_bb);
ir::BasicBlock* old_cond = current_loop_cond_bb_;
ir::BasicBlock* old_exit = current_loop_exit_bb_;
current_loop_cond_bb_ = cond_bb;
current_loop_exit_bb_ = exit_bb;
auto body_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
if (body_flow == BlockFlow::Continue) {
builder_.CreateBr(cond_bb);
}
current_loop_cond_bb_ = old_cond;
current_loop_exit_bb_ = old_exit;
builder_.SetInsertPoint(exit_bb);
return BlockFlow::Continue;
} }
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型"));
}
if (ctx->BREAK()) {
if (!current_loop_exit_bb_) {
throw std::runtime_error(FormatError("irgen", "break 必须在循环内"));
}
builder_.CreateBr(current_loop_exit_bb_);
return BlockFlow::Terminated;
}
if (ctx->CONTINUE()) { std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) {
if (!current_loop_cond_bb_) { if (!ctx) {
throw std::runtime_error(FormatError("irgen", "continue 必须在循环内")); throw std::runtime_error(FormatError("irgen", "缺少 return 语句"));
} }
builder_.CreateBr(current_loop_cond_bb_); if (!ctx->exp()) {
return BlockFlow::Terminated; throw std::runtime_error(FormatError("irgen", "return 缺少表达式"));
} }
if (ctx->RETURN()) {
if (ctx->exp()) {
ir::Value* v = EvalExpr(*ctx->exp()); ir::Value* v = EvalExpr(*ctx->exp());
// Handle return type conversion if necessary
if (func_->GetType()->IsFloat() && !v->GetType()->IsFloat()) {
v = builder_.CreateSIToFP(v, module_.GetContext().NextTemp());
} else if (func_->GetType()->IsInt32() && v->GetType()->IsFloat()) {
v = builder_.CreateFPToSI(v, module_.GetContext().NextTemp());
}
builder_.CreateRet(v); builder_.CreateRet(v);
} else {
builder_.CreateRet(nullptr); // nullptr for void ret
}
return BlockFlow::Terminated; return BlockFlow::Terminated;
}
if (ctx->block()) {
return ctx->block()->accept(this);
}
if (ctx->exp()) {
EvalExpr(*ctx->exp());
return BlockFlow::Continue;
}
if (ctx->SEMICOLON()) {
return BlockFlow::Continue;
}
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型"));
} }

@ -46,15 +46,13 @@ int main(int argc, char** argv) {
} }
if (opts.emit_asm) { if (opts.emit_asm) {
auto machine_module = mir::LowerToMIR(*module); auto machine_func = mir::LowerToMIR(*module);
for (auto& func : machine_module->GetFunctions()) { mir::RunRegAlloc(*machine_func);
mir::RunRegAlloc(*func); mir::RunFrameLowering(*machine_func);
mir::RunFrameLowering(*func);
}
if (need_blank_line) { if (need_blank_line) {
std::cout << "\n"; std::cout << "\n";
} }
mir::PrintAsm(*machine_module, std::cout); mir::PrintAsm(*machine_func, std::cout);
} }
#else #else
if (opts.emit_ir || opts.emit_asm) { if (opts.emit_ir || opts.emit_asm) {

@ -16,290 +16,63 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function,
return function.GetFrameSlot(operand.GetFrameIndex()); return function.GetFrameSlot(operand.GetFrameIndex());
} }
void PrintMovImm(std::ostream& os, PhysReg reg, int imm) {
const char* reg_name = PhysRegName(reg);
if (imm >= -32768 && imm <= 65535) {
os << " mov " << reg_name << ", #" << imm << "\n";
} else {
uint32_t uimm = static_cast<uint32_t>(imm);
os << " mov " << reg_name << ", #" << (uimm & 0xFFFF) << "\n";
os << " movk " << reg_name << ", #" << ((uimm >> 16) & 0xFFFF) << ", lsl #16\n";
}
}
void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg, void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
int offset) { int offset) {
if (offset >= -256 && offset <= 255) {
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset
<< "]\n"; << "]\n";
} else {
// Offset out of range for ldur/stur
if (offset < 0) {
PrintMovImm(os, PhysReg::X16, -offset);
os << " sub x16, x29, x16\n";
} else {
PrintMovImm(os, PhysReg::X16, offset);
os << " add x16, x29, x16\n";
}
if (mnemonic[0] == 'l') { // load
os << " ldr " << PhysRegName(reg) << ", [x16]\n";
} else { // store
os << " str " << PhysRegName(reg) << ", [x16]\n";
}
}
}
const char* CondCodeName(CondCode cc) {
switch (cc) {
case CondCode::EQ: return "eq";
case CondCode::NE: return "ne";
case CondCode::LT: return "lt";
case CondCode::LE: return "le";
case CondCode::GT: return "gt";
case CondCode::GE: return "ge";
}
return "??";
} }
} // namespace } // namespace
void PrintAsm(const MachineModule& module, std::ostream& os) { void PrintAsm(const MachineFunction& function, std::ostream& os) {
// Print global variables
if (!module.GetGlobals().empty()) {
os << ".data\n";
for (const auto& gv : module.GetGlobals()) {
os << ".global " << gv.name << "\n";
os << ".align 4\n";
os << gv.name << ":\n";
if (gv.size > 4 || gv.init_value == 0) {
os << " .zero " << gv.size << "\n";
} else {
os << " .word " << gv.init_value << "\n";
}
}
os << "\n";
}
os << ".text\n"; os << ".text\n";
for (const auto& function : module.GetFunctions()) { os << ".global " << function.GetName() << "\n";
os << ".global " << function->GetName() << "\n"; os << ".type " << function.GetName() << ", %function\n";
os << ".type " << function->GetName() << ", %function\n"; os << function.GetName() << ":\n";
os << function->GetName() << ":\n";
for (const auto& block : function->GetBlocks()) {
os << ".L" << function->GetName() << "_" << block->GetName() << ":\n";
for (const auto& inst : block->GetInstructions()) { for (const auto& inst : function.GetEntry().GetInstructions()) {
const auto& ops = inst.GetOperands(); const auto& ops = inst.GetOperands();
switch (inst.GetOpcode()) { switch (inst.GetOpcode()) {
case Opcode::Prologue: case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n"; os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n"; os << " mov x29, sp\n";
if (function->GetFrameSize() > 0) { if (function.GetFrameSize() > 0) {
if (function->GetFrameSize() <= 4095) { os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
os << " sub sp, sp, #" << function->GetFrameSize() << "\n";
} else {
PrintMovImm(os, PhysReg::X11, function->GetFrameSize());
os << " sub sp, sp, x11\n";
}
} }
break; break;
case Opcode::Epilogue: case Opcode::Epilogue:
if (function->GetFrameSize() > 0) { if (function.GetFrameSize() > 0) {
if (function->GetFrameSize() <= 4095) { os << " add sp, sp, #" << function.GetFrameSize() << "\n";
os << " add sp, sp, #" << function->GetFrameSize() << "\n";
} else {
PrintMovImm(os, PhysReg::X11, function->GetFrameSize());
os << " add sp, sp, x11\n";
}
} }
os << " ldp x29, x30, [sp], #16\n"; os << " ldp x29, x30, [sp], #16\n";
break; break;
case Opcode::MovImm: case Opcode::MovImm:
if (ops.at(1).GetKind() == Operand::Kind::Global) { os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
os << " adrp " << PhysRegName(ops.at(0).GetReg()) << ", " << ops.at(1).GetGlobal() << "\n"; << ops.at(1).GetImm() << "\n";
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(0).GetReg())
<< ", :lo12:" << ops.at(1).GetGlobal() << "\n";
} else {
PrintMovImm(os, ops.at(0).GetReg(), ops.at(1).GetImm());
}
break; break;
case Opcode::MovRR: {
const char* dst = PhysRegName(ops.at(0).GetReg());
const char* src = PhysRegName(ops.at(1).GetReg());
if (dst[0] == 's' && src[0] == 'w') {
os << " fmov " << dst << ", " << src << "\n";
} else if (dst[0] == 'w' && src[0] == 's') {
os << " fmov " << dst << ", " << src << "\n";
} else if (dst[0] == 's' && src[0] == 's') {
os << " fmov " << dst << ", " << src << "\n";
} else {
os << " mov " << dst << ", " << src << "\n";
}
break;
}
case Opcode::LoadStack: { case Opcode::LoadStack: {
const auto& slot = GetFrameSlot(*function, ops.at(1)); const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset); PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
break; break;
} }
case Opcode::StoreStack: { case Opcode::StoreStack: {
const auto& slot = GetFrameSlot(*function, ops.at(1)); const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset); PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
break; break;
} }
case Opcode::AddrStack: {
const auto& slot = GetFrameSlot(*function, ops.at(1));
int offset = slot.offset;
if (offset >= 0) {
if (offset <= 4095) {
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", x29, #" << offset << "\n";
} else {
PrintMovImm(os, PhysReg::X16, offset);
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", x29, x16\n";
}
} else {
int abs_offset = -offset;
if (abs_offset <= 4095) {
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", x29, #" << abs_offset << "\n";
} else {
PrintMovImm(os, PhysReg::X16, abs_offset);
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", x29, x16\n";
}
}
break;
}
case Opcode::LoadGlobal:
os << " adrp x16, " << ops.at(1).GetGlobal() << "\n";
os << " add x16, x16, :lo12:" << ops.at(1).GetGlobal() << "\n";
os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", [x16]\n";
break;
case Opcode::StoreGlobal:
os << " adrp x16, " << ops.at(1).GetGlobal() << "\n";
os << " add x16, x16, :lo12:" << ops.at(1).GetGlobal() << "\n";
os << " str " << PhysRegName(ops.at(0).GetReg()) << ", [x16]\n";
break;
case Opcode::AddRR: case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n"; << PhysRegName(ops.at(2).GetReg()) << "\n";
break; break;
case Opcode::AddRRI:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #" << ops.at(2).GetImm() << "\n";
break;
case Opcode::AddRRR_LSL: {
const char* reg2_name = PhysRegName(ops.at(2).GetReg());
std::string reg2_str = reg2_name;
std::string extension = "lsl";
if (reg2_name[0] == 'w') {
extension = "sxtw";
}
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< reg2_str << ", " << extension << " #" << ops.at(3).GetImm() << "\n";
break;
}
case Opcode::SubRR:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::MulRR:
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::SDivRR:
os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::MSubRRR:
os << " msub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << ", "
<< PhysRegName(ops.at(3).GetReg()) << "\n";
break;
case Opcode::Sxtw:
os << " sxtw " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::NegR:
os << " neg " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::CmpRR:
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::CSet:
os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", "
<< CondCodeName(ops.at(1).GetCond()) << "\n";
break;
case Opcode::FAdd:
os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FSub:
os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FMUL:
os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FDiv:
os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FNeg:
os << " fneg " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FCmp:
os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FCvtSI2FP:
os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FCvtFP2SI:
os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::LoadR:
os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", ["
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
break;
case Opcode::StoreR:
os << " str " << PhysRegName(ops.at(0).GetReg()) << ", ["
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
break;
case Opcode::Call:
os << " bl " << ops.at(0).GetLabel() << "\n";
break;
case Opcode::B:
os << " b .L" << function->GetName() << "_" << ops.at(0).GetLabel() << "\n";
break;
case Opcode::BCond:
os << " cmp " << PhysRegName(ops.at(1).GetReg()) << ", #0\n";
os << " b." << CondCodeName(ops.at(0).GetCond()) << " .L" << function->GetName() << "_" << ops.at(2).GetLabel() << "\n";
break;
case Opcode::Ret: case Opcode::Ret:
os << " ret\n"; os << " ret\n";
break; break;
} }
} }
}
os << ".size " << function->GetName() << ", .-" << function->GetName() << "\n\n"; os << ".size " << function.GetName() << ", .-" << function.GetName()
} << "\n";
} }
} // namespace mir } // namespace mir

@ -19,8 +19,7 @@ void RunFrameLowering(MachineFunction& function) {
for (const auto& slot : function.GetFrameSlots()) { for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size; cursor += slot.size;
if (-cursor < -256) { if (-cursor < -256) {
// For now, keep the 256-byte limit for simplicity (ldur/stur range) throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
// throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
} }
} }
@ -31,16 +30,9 @@ void RunFrameLowering(MachineFunction& function) {
} }
function.SetFrameSize(AlignTo(cursor, 16)); function.SetFrameSize(AlignTo(cursor, 16));
// Add Prologue to the first block auto& insts = function.GetEntry().GetInstructions();
if (!function.GetBlocks().empty()) {
auto& entry_insts = function.GetBlocks().front()->GetInstructions();
entry_insts.insert(entry_insts.begin(), MachineInstr(Opcode::Prologue));
}
// Add Epilogue before every Ret
for (auto& block : function.GetBlocks()) {
auto& insts = block->GetInstructions();
std::vector<MachineInstr> lowered; std::vector<MachineInstr> lowered;
lowered.emplace_back(Opcode::Prologue);
for (const auto& inst : insts) { for (const auto& inst : insts) {
if (inst.GetOpcode() == Opcode::Ret) { if (inst.GetOpcode() == Opcode::Ret) {
lowered.emplace_back(Opcode::Epilogue); lowered.emplace_back(Opcode::Epilogue);
@ -48,7 +40,6 @@ void RunFrameLowering(MachineFunction& function) {
lowered.push_back(inst); lowered.push_back(inst);
} }
insts = std::move(lowered); insts = std::move(lowered);
}
} }
} // namespace mir } // namespace mir

@ -1,6 +1,5 @@
#include "mir/MIR.h" #include "mir/MIR.h"
#include <cstring>
#include <stdexcept> #include <stdexcept>
#include <unordered_map> #include <unordered_map>
@ -12,600 +11,113 @@ namespace {
using ValueSlotMap = std::unordered_map<const ir::Value*, int>; using ValueSlotMap = std::unordered_map<const ir::Value*, int>;
int AlignTo(int value, int align) {
return ((value + align - 1) / align) * align;
}
bool IsPointerLike(const ir::Type& ty) {
return ty.IsPointer() || ty.IsPtrInt32() || ty.IsPtrFloat();
}
bool IsFloatLike(const ir::Type& ty) { return ty.IsFloat(); }
PhysReg ToXReg(PhysReg reg) {
if ((int)reg >= (int)PhysReg::W0 && (int)reg <= (int)PhysReg::W15) {
return static_cast<PhysReg>((int)reg - (int)PhysReg::W0 + (int)PhysReg::X0);
}
return reg;
}
PhysReg ToSReg(PhysReg reg) {
if ((int)reg >= (int)PhysReg::W0 && (int)reg <= (int)PhysReg::W15) {
return static_cast<PhysReg>((int)reg - (int)PhysReg::W0 + (int)PhysReg::S0);
}
return reg;
}
struct ArgLoc {
bool in_reg = false;
PhysReg reg = PhysReg::W0;
int stack_offset = 0; // bytes from stack-args base
};
ArgLoc GetFunctionArgLoc(const ir::Function& func, size_t arg_no) {
int gpr_idx = 0;
int fpr_idx = 0;
int stack_slots = 0;
const auto& args = func.GetArgs();
for (size_t i = 0; i < args.size(); ++i) {
const auto& ty = *args[i]->GetType();
const bool is_float = IsFloatLike(ty);
const bool is_ptr = IsPointerLike(ty);
ArgLoc loc;
if (is_float && fpr_idx < 8) {
loc.in_reg = true;
loc.reg = static_cast<PhysReg>((int)PhysReg::S0 + fpr_idx);
++fpr_idx;
} else if (!is_float && gpr_idx < 8) {
loc.in_reg = true;
loc.reg = is_ptr ? static_cast<PhysReg>((int)PhysReg::X0 + gpr_idx)
: static_cast<PhysReg>((int)PhysReg::W0 + gpr_idx);
++gpr_idx;
} else {
loc.in_reg = false;
loc.stack_offset = stack_slots * 8;
++stack_slots;
}
if (i == arg_no) return loc;
}
throw std::runtime_error(
FormatError("mir", "函数参数索引越界: " + std::to_string(arg_no)));
}
void EmitValueToReg(const ir::Value* value, PhysReg target, void EmitValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) { const ValueSlotMap& slots, MachineBasicBlock& block) {
bool is_ptr = IsPointerLike(*value->GetType());
bool is_float = IsFloatLike(*value->GetType());
if (is_ptr) {
target = ToXReg(target);
} else if (is_float) {
target = ToSReg(target);
}
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) { if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
block.Append(Opcode::MovImm, block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(constant->GetValue())}); {Operand::Reg(target), Operand::Imm(constant->GetValue())});
return; return;
} }
if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(value)) {
float f = cf->GetValue();
uint32_t bits;
std::memcpy(&bits, &f, 4);
// mov w10, #bits; fmov target, w10
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W10), Operand::Imm((int)bits)});
block.Append(Opcode::MovRR, {Operand::Reg(target), Operand::Reg(PhysReg::W10)});
return;
}
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(value)) {
// This loads the VALUE of the global, not its address
block.Append(Opcode::LoadGlobal,
{Operand::Reg(target), Operand::Global(gv->GetName())});
return;
}
if (auto* arg = dynamic_cast<const ir::Argument*>(value)) {
const auto* parent = arg->GetParent();
if (!parent) {
throw std::runtime_error(FormatError("mir", "参数未绑定到函数"));
}
const ArgLoc loc = GetFunctionArgLoc(*parent, arg->GetArgNo());
if (loc.in_reg) {
block.Append(Opcode::MovRR, {Operand::Reg(target), Operand::Reg(loc.reg)});
} else {
// Incoming stack args are at [old_sp + offset]. After prologue:
// x29 = old_sp - 16, so address is [x29 + 16 + offset].
const int fp_offset = 16 + loc.stack_offset;
if (fp_offset <= 4095) {
block.Append(Opcode::AddRRI, {Operand::Reg(PhysReg::X10),
Operand::Reg(PhysReg::X29),
Operand::Imm(fp_offset)});
} else {
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X11),
Operand::Imm(fp_offset)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X10),
Operand::Reg(PhysReg::X29),
Operand::Reg(PhysReg::X11)});
}
block.Append(Opcode::LoadR, {Operand::Reg(target), Operand::Reg(PhysReg::X10)});
}
return;
}
auto it = slots.find(value); auto it = slots.find(value);
if (it == slots.end()) { if (it == slots.end()) {
throw std::runtime_error( throw std::runtime_error(
FormatError("mir", "找不到值对应的栈槽: " + value->GetName())); FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
} }
block.Append(Opcode::LoadStack, {Operand::Reg(target), Operand::FrameIndex(it->second)}); block.Append(Opcode::LoadStack,
} {Operand::Reg(target), Operand::FrameIndex(it->second)});
void EmitAddrToReg(const ir::Value* value, PhysReg target,
const MachineFunction& function,
const ValueSlotMap& slots, MachineBasicBlock& block) {
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(value)) {
// adrp x10, gv; add x10, x10, :lo12:gv
block.Append(Opcode::MovImm, {Operand::Reg(target), Operand::Global(gv->GetName())}); // Special case for address
return;
}
if (auto* arg = dynamic_cast<const ir::Argument*>(value)) {
// Argument is already an address (pointer)
EmitValueToReg(arg, target, slots, block);
return;
}
auto it = slots.find(value);
if (it != slots.end()) {
// Check if it's an alloca (frame index) or a stored address
// For alloca, we want the address: add x10, x29, #offset
// For stored address, we want to load it: ldr x10, [x29, #offset]
// In our simple lowering, alloca's value in 'slots' is the frame index.
// If 'value' is an AllocaInst, we compute its address.
if (dynamic_cast<const ir::AllocaInst*>(value)) {
block.Append(Opcode::AddrStack, {Operand::Reg(target), Operand::FrameIndex(it->second)});
return;
}
// Otherwise it's a stored address (from a GEP)
block.Append(Opcode::LoadStack, {Operand::Reg(target), Operand::FrameIndex(it->second)});
return;
}
throw std::runtime_error(FormatError("mir", "无法获取地址: " + value->GetName()));
}
size_t GetTypeSize(const ir::Type& ty) {
if (ty.IsInt32() || ty.IsFloat()) return 4;
if (ty.IsPointer() || ty.IsPtrInt32() || ty.IsPtrFloat()) return 8;
if (ty.IsArray()) {
return ty.GetNumElements() * GetTypeSize(*ty.GetElementType());
}
return 0;
} }
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
MachineBasicBlock& block, ValueSlotMap& slots) { ValueSlotMap& slots) {
auto& block = function.GetEntry();
switch (inst.GetOpcode()) { switch (inst.GetOpcode()) {
case ir::Opcode::Alloca: { case ir::Opcode::Alloca: {
auto& alloca = static_cast<const ir::AllocaInst&>(inst); slots.emplace(&inst, function.CreateFrameIndex());
// AllocaInst's type is PointerType. We want the size of the pointed type.
size_t size = GetTypeSize(*alloca.GetType()->GetPointedType());
slots.emplace(&inst, function.CreateFrameIndex(static_cast<int>(size)));
return; return;
} }
case ir::Opcode::Store: { case ir::Opcode::Store: {
auto& store = static_cast<const ir::StoreInst&>(inst); auto& store = static_cast<const ir::StoreInst&>(inst);
PhysReg val_reg = PhysReg::W8; auto dst = slots.find(store.GetPtr());
EmitValueToReg(store.GetValue(), val_reg, slots, block); if (dst == slots.end()) {
if (IsPointerLike(*store.GetValue()->GetType())) { throw std::runtime_error(
val_reg = ToXReg(val_reg); FormatError("mir", "暂不支持对非栈变量地址进行写入"));
} else if (IsFloatLike(*store.GetValue()->GetType())) {
val_reg = ToSReg(val_reg);
}
// If ptr is a global or stored address (GEP result), we use LoadR/StoreR logic
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(store.GetPtr())) {
block.Append(Opcode::StoreGlobal, {Operand::Reg(val_reg), Operand::Global(gv->GetName())});
} else if (auto* alloca = dynamic_cast<const ir::AllocaInst*>(store.GetPtr())) {
auto it = slots.find(alloca);
if (it == slots.end()) throw std::runtime_error("Alloca not found");
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)});
} else {
// Pointer is in a register (from GEP)
EmitAddrToReg(store.GetPtr(), PhysReg::X10, function, slots, block);
block.Append(Opcode::StoreR, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X10)});
} }
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block);
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)});
return; return;
} }
case ir::Opcode::Load: { case ir::Opcode::Load: {
auto& load = static_cast<const ir::LoadInst&>(inst); auto& load = static_cast<const ir::LoadInst&>(inst);
int dst_slot = function.CreateFrameIndex(static_cast<int>(GetTypeSize(*load.GetType()))); auto src = slots.find(load.GetPtr());
PhysReg dst_reg = PhysReg::W8; if (src == slots.end()) {
if (IsPointerLike(*load.GetType())) { throw std::runtime_error(
dst_reg = ToXReg(dst_reg); FormatError("mir", "暂不支持对非栈变量地址进行读取"));
} else if (IsFloatLike(*load.GetType())) {
dst_reg = ToSReg(dst_reg);
}
if (auto* gv = dynamic_cast<const ir::GlobalVariable*>(load.GetPtr())) {
block.Append(Opcode::LoadGlobal, {Operand::Reg(dst_reg), Operand::Global(gv->GetName())});
} else if (auto* alloca = dynamic_cast<const ir::AllocaInst*>(load.GetPtr())) {
auto it = slots.find(alloca);
if (it == slots.end()) throw std::runtime_error("Alloca not found");
block.Append(Opcode::LoadStack, {Operand::Reg(dst_reg), Operand::FrameIndex(it->second)});
} else {
// Pointer is in a register (from GEP)
EmitAddrToReg(load.GetPtr(), PhysReg::X10, function, slots, block);
block.Append(Opcode::LoadR, {Operand::Reg(dst_reg), Operand::Reg(PhysReg::X10)});
}
block.Append(Opcode::StoreStack, {Operand::Reg(dst_reg), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::GEP: {
auto& gep = static_cast<const ir::GEPInst&>(inst);
int dst_slot = function.CreateFrameIndex(8); // Address is 8 bytes
EmitAddrToReg(gep.GetPtr(), PhysReg::X10, function, slots, block);
// Initial type is the pointed type of the base pointer
std::shared_ptr<ir::Type> cur_ty = gep.GetPtr()->GetType()->GetPointedType();
for (size_t i = 0; i < gep.GetIndices().size(); ++i) {
ir::Value* index_val = gep.GetIndices()[i];
// Skip index 0 if it's the first index and we're starting from a pointer
if (i == 0) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(index_val)) {
if (ci->GetValue() == 0) {
continue;
}
}
EmitValueToReg(index_val, PhysReg::W8, slots, block);
size_t element_size = GetTypeSize(*cur_ty);
// Use X8 for 64-bit multiplication if element_size is large,
// but for simple cases we can use AddRRR_LSL with W8 for auto sxtw
if (element_size == 4) {
block.Append(Opcode::AddRRR_LSL, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::W8), Operand::Imm(2)});
} else if (element_size == 8) {
block.Append(Opcode::AddRRR_LSL, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::W8), Operand::Imm(3)});
} else {
block.Append(Opcode::Sxtw, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::W8)});
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X9), Operand::Imm(static_cast<int>(element_size))});
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X9)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X8)});
}
continue;
}
if (cur_ty->IsArray()) {
size_t element_size = GetTypeSize(*cur_ty->GetElementType());
EmitValueToReg(index_val, PhysReg::W8, slots, block);
if (element_size == 4) {
block.Append(Opcode::AddRRR_LSL, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::W8), Operand::Imm(2)});
} else if (element_size == 8) {
block.Append(Opcode::AddRRR_LSL, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::W8), Operand::Imm(3)});
} else {
block.Append(Opcode::Sxtw, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::W8)});
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X9), Operand::Imm(static_cast<int>(element_size))});
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X9)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::X8)});
}
cur_ty = cur_ty->GetElementType();
} else {
throw std::runtime_error(FormatError("mir", "GEP 索引超出范围或类型不是数组"));
}
}
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::X10), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Call: {
auto& call = static_cast<const ir::CallInst&>(inst);
const auto& args = call.GetArgs();
std::vector<ArgLoc> arg_locs(args.size());
int gpr_idx = 0;
int fpr_idx = 0;
int stack_slots = 0;
for (size_t i = 0; i < args.size(); ++i) {
const auto& ty = *args[i]->GetType();
const bool is_float = IsFloatLike(ty);
const bool is_ptr = IsPointerLike(ty);
if (is_float && fpr_idx < 8) {
arg_locs[i] = ArgLoc{true, static_cast<PhysReg>((int)PhysReg::S0 + fpr_idx), 0};
++fpr_idx;
} else if (!is_float && gpr_idx < 8) {
arg_locs[i] = ArgLoc{
true,
is_ptr ? static_cast<PhysReg>((int)PhysReg::X0 + gpr_idx)
: static_cast<PhysReg>((int)PhysReg::W0 + gpr_idx),
0};
++gpr_idx;
} else {
arg_locs[i] = ArgLoc{false, PhysReg::W0, stack_slots * 8};
++stack_slots;
}
}
int stack_arg_size = 0;
if (stack_slots > 0) {
stack_arg_size = AlignTo(stack_slots * 8, 16);
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::X11), Operand::Imm(stack_arg_size)});
block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::X11)});
}
for (size_t i = 0; i < args.size(); ++i) {
const ArgLoc& loc = arg_locs[i];
if (loc.in_reg) {
EmitValueToReg(args[i], loc.reg, slots, block);
continue;
}
PhysReg val_reg = PhysReg::W8;
if (IsPointerLike(*args[i]->GetType())) {
val_reg = ToXReg(val_reg);
} else if (IsFloatLike(*args[i]->GetType())) {
val_reg = ToSReg(val_reg);
}
EmitValueToReg(args[i], val_reg, slots, block);
if (loc.stack_offset == 0) {
block.Append(Opcode::MovRR,
{Operand::Reg(PhysReg::X10), Operand::Reg(PhysReg::SP)});
} else if (loc.stack_offset <= 4095) {
block.Append(Opcode::AddRRI, {Operand::Reg(PhysReg::X10),
Operand::Reg(PhysReg::SP),
Operand::Imm(loc.stack_offset)});
} else {
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::X11), Operand::Imm(loc.stack_offset)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X10),
Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::X11)});
}
block.Append(Opcode::StoreR,
{Operand::Reg(val_reg), Operand::Reg(PhysReg::X10)});
}
block.Append(Opcode::Call, {Operand::Label(call.GetFunc()->GetName())});
if (stack_arg_size > 0) {
block.Append(Opcode::MovImm,
{Operand::Reg(PhysReg::X11), Operand::Imm(stack_arg_size)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::X11)});
}
if (!call.GetType()->IsVoid()) {
int dst_slot = function.CreateFrameIndex(static_cast<int>(GetTypeSize(*call.GetType())));
PhysReg ret_reg = PhysReg::W0;
if (IsFloatLike(*call.GetType())) {
ret_reg = ToSReg(ret_reg);
} else if (IsPointerLike(*call.GetType())) {
ret_reg = ToXReg(ret_reg);
}
block.Append(Opcode::StoreStack, {Operand::Reg(ret_reg), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
}
return;
}
case ir::Opcode::Add:
case ir::Opcode::Sub:
case ir::Opcode::Mul:
case ir::Opcode::Div:
case ir::Opcode::Mod: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
if (bin.GetType()->IsFloat()) {
PhysReg lhs_reg = PhysReg::W8;
PhysReg rhs_reg = PhysReg::W9;
EmitValueToReg(bin.GetLhs(), lhs_reg, slots, block);
EmitValueToReg(bin.GetRhs(), rhs_reg, slots, block);
lhs_reg = ToSReg(lhs_reg);
rhs_reg = ToSReg(rhs_reg);
Opcode op;
if (inst.GetOpcode() == ir::Opcode::Add) op = Opcode::FAdd;
else if (inst.GetOpcode() == ir::Opcode::Sub) op = Opcode::FSub;
else if (inst.GetOpcode() == ir::Opcode::Mul) op = Opcode::FMUL;
else if (inst.GetOpcode() == ir::Opcode::Div) op = Opcode::FDiv;
else throw std::runtime_error("Float mod not supported");
block.Append(op, {Operand::Reg(PhysReg::S0), Operand::Reg(lhs_reg), Operand::Reg(rhs_reg)});
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
} else {
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
if (inst.GetOpcode() == ir::Opcode::Add) {
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
} else if (inst.GetOpcode() == ir::Opcode::Sub) {
block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
} else if (inst.GetOpcode() == ir::Opcode::Mul) {
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
} else if (inst.GetOpcode() == ir::Opcode::Div) {
block.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
} else if (inst.GetOpcode() == ir::Opcode::Mod) {
// srem w10, w8, w9 => sdiv w10, w8, w9; msub w8, w10, w9, w8
block.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
block.Append(Opcode::MSubRRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::W8)});
}
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
}
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::SIToFP: {
auto& fcvt = static_cast<const ir::UnaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(fcvt.GetUnaryOperand(), PhysReg::W8, slots, block);
block.Append(Opcode::FCvtSI2FP, {Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::W8)});
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::FPToSI: {
auto& fcvt = static_cast<const ir::UnaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(fcvt.GetUnaryOperand(), PhysReg::W8, slots, block);
block.Append(Opcode::FCvtFP2SI, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::S8)});
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
} }
case ir::Opcode::Cmp:
case ir::Opcode::FCmp: {
int dst_slot = function.CreateFrameIndex(); int dst_slot = function.CreateFrameIndex();
ir::CmpOp ir_cc; block.Append(Opcode::LoadStack,
if (inst.GetOpcode() == ir::Opcode::Cmp) { {Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)});
auto& cmp = static_cast<const ir::CmpInst&>(inst);
EmitValueToReg(cmp.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(cmp.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
ir_cc = cmp.GetCmpOp();
} else {
auto& cmp = static_cast<const ir::FCmpInst&>(inst);
EmitValueToReg(cmp.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(cmp.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::FCmp, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)});
ir_cc = cmp.GetCmpOp();
}
CondCode cc = CondCode::EQ;
switch (ir_cc) {
case ir::CmpOp::Eq: cc = CondCode::EQ; break;
case ir::CmpOp::Ne: cc = CondCode::NE; break;
case ir::CmpOp::Lt: cc = CondCode::LT; break;
case ir::CmpOp::Le: cc = CondCode::LE; break;
case ir::CmpOp::Gt: cc = CondCode::GT; break;
case ir::CmpOp::Ge: cc = CondCode::GE; break;
}
block.Append(Opcode::CSet, {Operand::Reg(PhysReg::W8), Operand::Cond(cc)});
block.Append(Opcode::StoreStack, block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot); slots.emplace(&inst, dst_slot);
return; return;
} }
case ir::Opcode::Zext: { case ir::Opcode::Add: {
auto& zext = static_cast<const ir::ZextInst&>(inst); auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(zext.GetValue(), PhysReg::W8, slots, block);
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Neg: {
auto& unary = static_cast<const ir::UnaryInst&>(inst);
int dst_slot = function.CreateFrameIndex(); int dst_slot = function.CreateFrameIndex();
if (unary.GetType()->IsFloat()) { EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(unary.GetUnaryOperand(), PhysReg::W8, slots, block); EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::FNeg, {Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::S8)}); block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); Operand::Reg(PhysReg::W8),
} else { Operand::Reg(PhysReg::W9)});
EmitValueToReg(unary.GetUnaryOperand(), PhysReg::W8, slots, block);
block.Append(Opcode::NegR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8)});
block.Append(Opcode::StoreStack, block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
}
slots.emplace(&inst, dst_slot); slots.emplace(&inst, dst_slot);
return; return;
} }
case ir::Opcode::Br: {
auto& br = static_cast<const ir::BranchInst&>(inst);
block.Append(Opcode::B, {Operand::Label(br.GetDest()->GetName())});
return;
}
case ir::Opcode::CondBr: {
auto& cbr = static_cast<const ir::CondBranchInst&>(inst);
EmitValueToReg(cbr.GetCond(), PhysReg::W8, slots, block);
// SysY IR CondBr uses i1. In MIR, we compare with 0.
block.Append(Opcode::BCond, {Operand::Cond(CondCode::NE),
Operand::Reg(PhysReg::W8),
Operand::Label(cbr.GetTrueBlock()->GetName())});
block.Append(Opcode::B, {Operand::Label(cbr.GetFalseBlock()->GetName())});
return;
}
case ir::Opcode::Ret: { case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst); auto& ret = static_cast<const ir::ReturnInst&>(inst);
if (auto* val = ret.GetValue()) { EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block);
EmitValueToReg(val, PhysReg::W0, slots, block);
}
block.Append(Opcode::Ret); block.Append(Opcode::Ret);
return; return;
} }
default: case ir::Opcode::Sub:
throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令: " + std::to_string((int)inst.GetOpcode()))); case ir::Opcode::Mul:
throw std::runtime_error(FormatError("mir", "暂不支持该二元运算"));
} }
throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令"));
} }
} // namespace } // namespace
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) { std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module) {
DefaultContext(); DefaultContext();
auto machine_module = std::make_unique<MachineModule>();
// Lower global variables if (module.GetFunctions().size() != 1) {
for (const auto& gv : module.GetGlobalVariables()) { throw std::runtime_error(FormatError("mir", "暂不支持多个函数"));
GlobalVariable mir_gv;
mir_gv.name = gv->GetName();
mir_gv.size = GetTypeSize(*gv->GetType()->GetPointedType());
if (auto* init = gv->GetInitializer()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(init)) {
mir_gv.init_value = ci->GetValue();
} else if (auto* cf = dynamic_cast<ir::ConstantFloat*>(init)) {
float f = cf->GetValue();
uint32_t bits;
std::memcpy(&bits, &f, 4);
mir_gv.init_value = static_cast<int>(bits);
}
}
machine_module->GetGlobals().push_back(mir_gv);
} }
// Lower functions const auto& func = *module.GetFunctions().front();
for (const auto& ir_func : module.GetFunctions()) { if (func.GetName() != "main") {
if (ir_func->GetBlocks().empty()) continue; // Skip declarations throw std::runtime_error(FormatError("mir", "暂不支持非 main 函数"));
auto machine_func = std::make_unique<MachineFunction>(ir_func->GetName());
ValueSlotMap slots;
// Create all blocks first to handle forward references in branches
std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> block_map;
for (const auto& ir_bb : ir_func->GetBlocks()) {
block_map[ir_bb.get()] = &machine_func->CreateBlock(ir_bb->GetName());
} }
// Lower instructions in each block auto machine_func = std::make_unique<MachineFunction>(func.GetName());
for (const auto& ir_bb : ir_func->GetBlocks()) { ValueSlotMap slots;
auto& machine_bb = *block_map.at(ir_bb.get()); const auto* entry = func.GetEntry();
for (const auto& inst : ir_bb->GetInstructions()) { if (!entry) {
LowerInstruction(*inst, *machine_func, machine_bb, slots); throw std::runtime_error(FormatError("mir", "IR 函数缺少入口基本块"));
}
} }
machine_module->GetFunctions().push_back(std::move(machine_func)); for (const auto& inst : entry->GetInstructions()) {
LowerInstruction(*inst, *machine_func, slots);
} }
return machine_module; return machine_func;
} }
} // namespace mir } // namespace mir

@ -8,12 +8,7 @@
namespace mir { namespace mir {
MachineFunction::MachineFunction(std::string name) MachineFunction::MachineFunction(std::string name)
: name_(std::move(name)) {} : name_(std::move(name)), entry_("entry") {}
MachineBasicBlock& MachineFunction::CreateBlock(const std::string& name) {
blocks_.push_back(std::make_unique<MachineBasicBlock>(name));
return *blocks_.back();
}
int MachineFunction::CreateFrameIndex(int size) { int MachineFunction::CreateFrameIndex(int size) {
int index = static_cast<int>(frame_slots_.size()); int index = static_cast<int>(frame_slots_.size());

@ -4,29 +4,17 @@
namespace mir { namespace mir {
Operand::Operand(Kind kind, PhysReg reg, int imm, std::string label) Operand::Operand(Kind kind, PhysReg reg, int imm)
: kind_(kind), reg_(reg), imm_(imm), label_(std::move(label)) {} : kind_(kind), reg_(reg), imm_(imm) {}
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); } Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); }
Operand Operand::Imm(int value) { Operand Operand::Imm(int value) {
return Operand(Kind::Imm, PhysReg::WZR, value); return Operand(Kind::Imm, PhysReg::W0, value);
} }
Operand Operand::FrameIndex(int index) { Operand Operand::FrameIndex(int index) {
return Operand(Kind::FrameIndex, PhysReg::WZR, index); return Operand(Kind::FrameIndex, PhysReg::W0, index);
}
Operand Operand::Label(const std::string& name) {
return Operand(Kind::Label, PhysReg::WZR, 0, name);
}
Operand Operand::Global(const std::string& name) {
return Operand(Kind::Global, PhysReg::WZR, 0, name);
}
Operand Operand::Cond(CondCode cc) {
return Operand(Kind::Cond, PhysReg::WZR, static_cast<int>(cc));
} }
MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands) MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands)

@ -8,14 +8,22 @@ namespace mir {
namespace { namespace {
bool IsAllowedReg(PhysReg reg) { bool IsAllowedReg(PhysReg reg) {
return true; // All registers are allowed for now as we are not doing allocation switch (reg) {
case PhysReg::W0:
case PhysReg::W8:
case PhysReg::W9:
case PhysReg::X29:
case PhysReg::X30:
case PhysReg::SP:
return true;
}
return false;
} }
} // namespace } // namespace
void RunRegAlloc(MachineFunction& function) { void RunRegAlloc(MachineFunction& function) {
for (auto& block : function.GetBlocks()) { for (const auto& inst : function.GetEntry().GetInstructions()) {
for (const auto& inst : block->GetInstructions()) {
for (const auto& operand : inst.GetOperands()) { for (const auto& operand : inst.GetOperands()) {
if (operand.GetKind() == Operand::Kind::Reg && if (operand.GetKind() == Operand::Kind::Reg &&
!IsAllowedReg(operand.GetReg())) { !IsAllowedReg(operand.GetReg())) {
@ -23,7 +31,6 @@ void RunRegAlloc(MachineFunction& function) {
} }
} }
} }
}
} }
} // namespace mir } // namespace mir

@ -8,61 +8,18 @@ namespace mir {
const char* PhysRegName(PhysReg reg) { const char* PhysRegName(PhysReg reg) {
switch (reg) { switch (reg) {
case PhysReg::W0: return "w0"; case PhysReg::W0:
case PhysReg::W1: return "w1"; return "w0";
case PhysReg::W2: return "w2"; case PhysReg::W8:
case PhysReg::W3: return "w3"; return "w8";
case PhysReg::W4: return "w4"; case PhysReg::W9:
case PhysReg::W5: return "w5"; return "w9";
case PhysReg::W6: return "w6"; case PhysReg::X29:
case PhysReg::W7: return "w7"; return "x29";
case PhysReg::W8: return "w8"; case PhysReg::X30:
case PhysReg::W9: return "w9"; return "x30";
case PhysReg::W10: return "w10"; case PhysReg::SP:
case PhysReg::W11: return "w11"; return "sp";
case PhysReg::W12: return "w12";
case PhysReg::W13: return "w13";
case PhysReg::W14: return "w14";
case PhysReg::W15: return "w15";
case PhysReg::X0: return "x0";
case PhysReg::X1: return "x1";
case PhysReg::X2: return "x2";
case PhysReg::X3: return "x3";
case PhysReg::X4: return "x4";
case PhysReg::X5: return "x5";
case PhysReg::X6: return "x6";
case PhysReg::X7: return "x7";
case PhysReg::X8: return "x8";
case PhysReg::X9: return "x9";
case PhysReg::X10: return "x10";
case PhysReg::X11: return "x11";
case PhysReg::X12: return "x12";
case PhysReg::X13: return "x13";
case PhysReg::X14: return "x14";
case PhysReg::X15: return "x15";
case PhysReg::X16: return "x16";
case PhysReg::X17: return "x17";
case PhysReg::S0: return "s0";
case PhysReg::S1: return "s1";
case PhysReg::S2: return "s2";
case PhysReg::S3: return "s3";
case PhysReg::S4: return "s4";
case PhysReg::S5: return "s5";
case PhysReg::S6: return "s6";
case PhysReg::S7: return "s7";
case PhysReg::S8: return "s8";
case PhysReg::S9: return "s9";
case PhysReg::S10: return "s10";
case PhysReg::S11: return "s11";
case PhysReg::S12: return "s12";
case PhysReg::S13: return "s13";
case PhysReg::S14: return "s14";
case PhysReg::S15: return "s15";
case PhysReg::X29: return "x29";
case PhysReg::X30: return "x30";
case PhysReg::SP: return "sp";
case PhysReg::WZR: return "wzr";
case PhysReg::XZR: return "xzr";
} }
throw std::runtime_error(FormatError("mir", "未知物理寄存器")); throw std::runtime_error(FormatError("mir", "未知物理寄存器"));
} }

File diff suppressed because it is too large Load Diff

@ -1,39 +1,17 @@
// 维护对象符号的注册与按作用域查找。 // 维护局部变量声明的注册与查找。
#include "sem/SymbolTable.h" #include "sem/SymbolTable.h"
#include <stdexcept> void SymbolTable::Add(const std::string& name,
SysYParser::VarDefContext* decl) {
SymbolTable::SymbolTable() : scopes_(1) {} table_[name] = decl;
void SymbolTable::EnterScope() { scopes_.emplace_back(); }
void SymbolTable::ExitScope() {
if (scopes_.size() <= 1) {
throw std::runtime_error("symbol table scope underflow");
}
scopes_.pop_back();
} }
bool SymbolTable::Add(const ObjectBinding& symbol) { bool SymbolTable::Contains(const std::string& name) const {
auto& scope = scopes_.back(); return table_.find(name) != table_.end();
return scope.emplace(symbol.name, symbol).second;
} }
bool SymbolTable::ContainsInCurrentScope(std::string_view name) const { SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const {
const auto& scope = scopes_.back(); auto it = table_.find(name);
return scope.find(std::string(name)) != scope.end(); return it == table_.end() ? nullptr : it->second;
} }
const ObjectBinding* SymbolTable::Lookup(std::string_view name) const {
const std::string key(name);
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
auto found = it->find(key);
if (found != it->end()) {
return &found->second;
}
}
return nullptr;
}
size_t SymbolTable::Depth() const { return scopes_.size(); }

@ -1,49 +1,4 @@
#include <stdio.h> // SysY 运行库实现:
#include <stdarg.h> // - 按实验/评测规范提供 I/O 等函数实现
#include <sys/time.h> // - 与编译器生成的目标代码链接,支撑运行时行为
/* Input functions */
int getint() { int t; scanf("%d", &t); return t; }
int getch() { char t; scanf("%c", &t); return (int)t; }
float getfloat() { float t; scanf("%f", &t); return t; }
int getarray(int a[]) {
int n;
scanf("%d", &n);
for (int i = 0; i < n; i++) scanf("%d", &a[i]);
return n;
}
int getfarray(float a[]) {
int n;
scanf("%d", &n);
for (int i = 0; i < n; i++) scanf("%f", &a[i]);
return n;
}
/* Output functions */
void putint(int a) { printf("%d", a); }
void putch(int a) { printf("%c", (char)a); }
void putfloat(float a) { printf("%a", a); }
void putarray(int n, int a[]) {
printf("%d:", n);
for (int i = 0; i < n; i++) printf(" %d", a[i]);
printf("\n");
}
void putfarray(int n, float a[]) {
printf("%d:", n);
for (int i = 0; i < n; i++) printf(" %a", a[i]);
printf("\n");
}
/* Timing functions */
struct timeval _sysy_start, _sysy_end;
void starttime() { gettimeofday(&_sysy_start, NULL); }
void stoptime() {
gettimeofday(&_sysy_end, NULL);
int millis = (_sysy_end.tv_sec - _sysy_start.tv_sec) * 1000 +
(_sysy_end.tv_usec - _sysy_start.tv_usec) / 1000;
fprintf(stderr, "Timer: %d ms\n", millis);
}

Binary file not shown.

@ -1,16 +0,0 @@
int a[5];
int main() {
int i = 0;
while (i < 5) {
a[i] = i * i;
i = i + 1;
}
i = 0;
while (i < 5) {
putint(a[i]);
putch(32);
i = i + 1;
}
putch(10);
return 0;
}

@ -1,10 +0,0 @@
int main() {
int a = 10;
int b = 3;
putint(a + b); putch(32);
putint(a - b); putch(32);
putint(a * b); putch(32);
putint(a / b); putch(32);
putint(a % b); putch(10);
return 0;
}

@ -1,2 +0,0 @@
0x1.cp+1 -0x1p-1 0x1.8p+1 0x1.8p-1
0

@ -1,9 +0,0 @@
int main() {
float a = 1.5;
float b = 2.0;
putfloat(a + b); putch(32);
putfloat(a - b); putch(32);
putfloat(a * b); putch(32);
putfloat(a / b); putch(10);
return 0;
}

@ -1,19 +0,0 @@
int main() {
int a = 5;
int b = 10;
if (a > b) {
putint(1);
} else {
if (a == 5) {
if (b != 10) {
putint(2);
} else {
putint(3);
}
} else {
putint(4);
}
}
putch(10);
return 0;
}

@ -1,10 +0,0 @@
int fib(int n) {
if (n <= 1) return n;
return fib(n-1) + fib(n-2);
}
int main() {
int n = 6;
putint(fib(n));
putch(10);
return 0;
}

@ -1,6 +0,0 @@
// 测试:简单加法
int main() {
int a = 1;
int b = 2;
return a + b;
}

@ -1,8 +0,0 @@
// 测试:减法和乘法
int main() {
int a = 10;
int b = 3;
int c = a - b;
int d = a * b;
return c + d;
}

@ -1,8 +0,0 @@
// 测试:除法和取模
int main() {
int a = 20;
int b = 6;
int c = a / b;
int d = a % b;
return c + d;
}

@ -1,7 +0,0 @@
// 测试:一元运算符(正负号)
int main() {
int a = 5;
int b = -a;
int c = +10;
return b + c;
}

@ -1,7 +0,0 @@
// 测试:赋值表达式
int main() {
int a = 10;
int b = 20;
a = b;
return a;
}

@ -1,5 +0,0 @@
// 测试:逗号分隔的多变量声明
int main() {
int a = 1, b = 2, c = 3;
return a + b + c;
}

@ -1,14 +0,0 @@
// 测试:综合测试(所有功能)
int main() {
int a = 10, b = 5;
int c = a + b;
int d = a - b;
int e = a * 2;
int f = a / b;
int g = a % b;
int h = -c;
int i = +d;
a = b + c;
b = d + e;
return a + b + f + g + h + i;
}

@ -1 +0,0 @@
int main(){ break; return 0; }

@ -1,2 +0,0 @@
int f(int x){ return x; }
int main(){ return f(); }

@ -1,2 +0,0 @@
void f(){ return 1; }
int main(){ return 0; }

@ -1 +0,0 @@
int main(){ return a; }

@ -1,47 +0,0 @@
#!/bin/bash
# 批量测试脚本
# 遍历 test/test_case 目录下所有的 .sy 文件,并验证解析是否成功
if [ ! -f "./build/bin/compiler" ]; then
echo "Compiler executable not found at ./build/bin/compiler. Please build the project first."
exit 1
fi
FAIL_COUNT=0
PASS_COUNT=0
FAILED_FILES=()
echo "开始批量测试解析..."
echo "========================================="
# 查找所有 .sy 文件并进行测试
while IFS= read -r file; do
# 运行解析器,将正常输出重定向到 /dev/null保留错误输出用于判断
./build/bin/compiler --emit-parse-tree "$file" > /dev/null 2>&1
if [ $? -ne 0 ]; then
echo "❌ 解析失败: $file"
FAIL_COUNT=$((FAIL_COUNT+1))
FAILED_FILES+=("$file")
else
echo "✅ 解析成功: $file"
PASS_COUNT=$((PASS_COUNT+1))
fi
done < <(find test/test_case -type f -name "*.sy" | sort)
echo "========================================="
echo "测试完成!"
echo "成功: $PASS_COUNT"
echo "失败: $FAIL_COUNT"
if [ $FAIL_COUNT -ne 0 ]; then
echo "失败的文件列表:"
for f in "${FAILED_FILES[@]}"; do
echo " - $f"
done
exit 1
else
echo "🎉 所有测试用例均解析成功!"
exit 0
fi

@ -1,34 +0,0 @@
#include <exception>
#include <iostream>
#include <string>
#include "frontend/AntlrDriver.h"
#include "sem/Sema.h"
#include "utils/Log.h"
int main(int argc, char** argv) {
if (argc < 2) {
std::cerr << "usage: sema_check <input.sy> [more.sy...]\n";
return 2;
}
bool failed = false;
for (int i = 1; i < argc; ++i) {
const std::string path = argv[i];
try {
auto antlr = ParseFileWithAntlr(path);
auto* comp_unit = dynamic_cast<SysYParser::CompUnitContext*>(antlr.tree);
if (!comp_unit) {
throw std::runtime_error(FormatError("sema_check", "语法树根节点不是 compUnit"));
}
(void)RunSema(*comp_unit);
std::cout << "OK " << path << "\n";
} catch (const std::exception& ex) {
failed = true;
std::cout << "ERR " << path << "\n";
PrintException(std::cout, ex);
}
}
return failed ? 1 : 0;
}

@ -1,19 +0,0 @@
(base) root@HP:/home/hp/nudt-compiler-cpp/build# make -j$(nproc)
[ 2%] Built target utils
[ 2%] Building CXX object src/ir/CMakeFiles/ir_core.dir/Type.cpp.o
[ 3%] Building CXX object src/ir/CMakeFiles/ir_core.dir/Value.cpp.o
[ 73%] Built target antlr4_runtime
[ 75%] Built target sem
[ 79%] Built target frontend
[ 80%] Linking CXX static library libir_core.a
[ 84%] Built target ir_core
[ 85%] Built target ir_analysis
[ 89%] Built target ir_passes
[ 94%] Built target mir_core
[ 97%] Built target irgen
[ 99%] Built target mir_passes
[ 99%] Linking CXX executable ../bin/compiler
[100%] Built target compiler
(base) root@HP:/home/hp/nudt-compiler-cpp/build# cd ..
(base) root@HP:/home/hp/nudt-compiler-cpp# ./scripts/verify_ir.sh test/test_case/functional/09_func_defn.sy --run
[error] [irgen] 变量声明缺少存储槽位a
Loading…
Cancel
Save