Compare commits

...

24 Commits
master ... lab3

Author SHA1 Message Date
tangttangtang 4d9c159dd2 Lab3
1 week ago
tangttangtang e55421f447 最新测试结果分析
1 week ago
tangttangtang 69892ef133 h-1-01 尾递归优化
2 weeks ago
the-little-apprentice 407be0fca1 添加脚本使用说明
2 weeks ago
the-little-apprentice 08ce9d96ab 测评脚本升级
2 weeks ago
tangttangtang bcfbf52488 错误已修复
2 weeks ago
tangttangtang 4cb9354ab4 测试已通过
2 weeks ago
tangttangtang b33ede5457 优化
3 weeks ago
tangttangtang 252073efe8 最新
3 weeks ago
tangttangtang c252a676ac 测试脚本加入时间
3 weeks ago
tangttangtang abcae58661 vector_mul3测试已通过
4 weeks ago
tangttangtang 1ed7ab0d1b George 图着色路线实现
4 weeks ago
tangttangtang f56f9772a3 自研 MIR 后端 + AArch64 汇编打印
4 weeks ago
tangttangtang 29b7bf7357 Lab3 可以跑通测试
4 weeks ago
tangttangtang 691f99831c 添加了Mem2Reg + SSA, 已通过所有测试
1 month ago
the-little-apprentice 8157f8d021 完整测试用例
1 month ago
the-little-apprentice a89c5fb0e4 删除了无用的输出文件
1 month ago
tangttangtang ed15fa1c72 Lab2
1 month ago
the-little-apprentice b1c34228b1 IR
1 month ago
the-little-apprentice 29d1315410 加入module
1 month ago
the-little-apprentice e4fed12b92 优化lab1脚本,增加语法树输出
2 months ago
tangttangtang 472f059af7 SysY.g4文件修改(变量命名修改,注释改为中文)
2 months ago
tangttangtang 96dda8642a Lab1测试脚本修改(添加语法树生成)
2 months ago
tangttangtang f83b83c664 Lab1 修改版(批量测试脚本已复制)
2 months ago

142
.gitignore vendored

@ -1,70 +1,72 @@
# =========================
# Build / CMake
# =========================
build/
cmake-build-*/
out/
dist/
CMakeFiles/
CMakeCache.txt
cmake_install.cmake
install_manifest.txt
Makefile
compile_commands.json
.ninja_deps
.ninja_log
# =========================
# Generated / intermediate
# =========================
*.o
*.obj
*.a
*.lib
*.so
*.dylib
*.dll
*.exe
*.out
!test/test_case/**/*.out
*.app
*.pdb
*.ilk
*.dSYM/
*.log
*.tmp
*.swp
*.swo
*.bak
# ANTLR 生成物(通常在 build/,这里额外兜底)
**/generated/antlr4/
**/antlr4-generated/
*.tokens
*.interp
# =========================
# IDE / Editor
# =========================
.vscode/
.idea/
.fleet/
.vs/
*.code-workspace
# CLion
cmake-build-debug/
cmake-build-release/
# =========================
# OS / misc
# =========================
.DS_Store
Thumbs.db
# =========================
# Project outputs
# =========================
test/test_result/
# =========================
# Build / CMake
# =========================
build/
build_*/
cmake-build-*/
out/
output/
dist/
CMakeFiles/
CMakeCache.txt
cmake_install.cmake
install_manifest.txt
Makefile
compile_commands.json
.ninja_deps
.ninja_log
# =========================
# Generated / intermediate
# =========================
*.o
*.obj
*.a
*.lib
*.so
*.dylib
*.dll
*.exe
*.out
!test/test_case/**/*.out
*.app
*.pdb
*.ilk
*.dSYM/
*.log
*.tmp
*.swp
*.swo
*.bak
# ANTLR 生成物(通常在 build/,这里额外兜底)
**/generated/antlr4/
**/antlr4-generated/
*.tokens
*.interp
# =========================
# IDE / Editor
# =========================
.vscode/
.idea/
.fleet/
.vs/
*.code-workspace
# CLion
cmake-build-debug/
cmake-build-release/
# =========================
# OS / misc
# =========================
.DS_Store
Thumbs.db
# =========================
# Project outputs
# =========================
test/test_result/

@ -0,0 +1,500 @@
# Lab2 IR 与测试体系修改说明
## 1. 文档定位
本文档覆盖两类内容:
1. IR 侧的重要实现与优化接入。
2. 测试脚本与测试数据的修改,尤其是测试产物留存策略和 `if-combine2.in` / `if-combine3.in` 的修复原因。
如果只想快速了解当前仓库状态,优先看第 2 节和第 3 节。
---
## 2. 修改重点总览
当前这一轮修改,重点有 4 个:
### 2.1 测试脚本行为重构
目标是让测试脚本更适合持续开发,而不是每次跑完留一堆垃圾文件。
已经完成的行为包括:
1. 成功样例的中间文件自动删除。
2. 失败样例才保留中间文件。
3. 每次测试生成独立日志目录,例如 `lab2_20260407_123456`
4. 每轮测试都会生成完整 `whole.log`
5. 每个失败样例目录里都保留 `error.log`
6. 终端输出增加颜色,`PASS` 绿色,`FAIL` 红色。
7. 支持先重测失败样例,再跑全量。
8. 默认测试范围扩展到 `test/test_case``test/class_test_case` 两棵目录树。
### 2.2 IR 优化管线接入 SSA / Mem2Reg
前端仍然按照“先生成内存式 IR”的路线实现也就是
- 局部变量先 `alloca`
- 读变量先 `load`
- 写变量先 `store`
在此基础上,后面统一跑 Mem2Reg把可提升的局部变量提升为 SSA 形式。这保证了:
1. 前端 IR 生成逻辑保持清晰。
2. SSA 构造集中在优化阶段,不把复杂度压到 visitor 上。
3. 后续做标量优化时IR 形态更适合进一步处理。
### 2.3 测试目录结构扩展
原脚本默认只扫描 `test/test_case`。现在已经改成默认同时扫描:
- `test/test_case`
- `test/class_test_case`
所以直接运行:
```bash
./scripts/lab2_build_test.sh
```
会同时覆盖:
- 原测试集
- 课程/课堂测试集 `class_test_case`
### 2.4 修复两个不自洽的性能测试输入文件
修改了:
- `test/test_case/h_performance/if-combine2.in`
- `test/test_case/h_performance/if-combine3.in`
这两个修改不是“为了让编译器过样例而硬改数据”,而是修复原测试数据与源码不一致的问题。这个点下面会单独详细说明。
---
## 3. 关键修改文件
### 3.1 IR 与优化相关
- `src/ir/passes/PassManager.cpp`
- `src/ir/passes/Mem2Reg.cpp`
### 3.2 Lab2 测试脚本相关
- `scripts/verify_ir.sh`
- `scripts/lab2_build_test.sh`
### 3.3 Lab1 测试脚本同步对齐
- `scripts/lab1_build_test.sh`
### 3.4 测试数据修复
- `test/test_case/h_performance/if-combine2.in`
- `test/test_case/h_performance/if-combine3.in`
文档阅读建议:
- 想看“为什么脚本行为变了”,重点看第 5 节。
- 想看“Mem2Reg 是否真的实现了”,重点看第 4 节。
- 想看“为什么改 if-combine 输入”,直接看第 6 节。
---
## 4. SSA / Mem2Reg 实现说明
### 4.1 接入位置
优化管线入口在:
- `src/ir/passes/PassManager.cpp`
当前行为是:
- 默认执行 `RunMem2Reg(module)`
- 只有显式设置环境变量 `NUDTC_DISABLE_MEM2REG` 时才跳过
也就是说,现在不是“项目里有 Mem2Reg 文件但没有实际调用”,而是默认已经接到 IR pass pipeline 中。
### 4.2 实现思路
Mem2Reg 的主实现位于:
- `src/ir/passes/Mem2Reg.cpp`
整体流程是标准的“先找 promotable alloca再插 phi再做 rename”。当前代码大致分成下面几步
1. 收集函数入口可达基本块。
2. 计算支配关系、直接支配者、支配树、支配边界。
3. 筛选可提升的 `alloca`
4. 在支配边界对应位置插入 `phi`
5. 沿支配树递归重命名,把 `load/store` 重写成 SSA 值流。
6. 删除旧的 `alloca/load/store`
### 4.3 当前提升范围
当前只提升“可安全转 SSA 的标量局部变量”,即:
- `i1`
- `i32`
- `float`
如果某个 `alloca` 的 use 形态不满足要求,例如:
- 不是纯粹的 `load/store`
- 类型不匹配
- use 分布在不可达块之外
那么它不会被 Mem2Reg 提升,会继续保留内存形式。
这意味着当前策略是保守的,但正确性更稳。
### 4.4 这对前端的影响
这部分对 IRGen 的意义是:
- 前端仍然只需要负责生成“正确的内存式 IR”
- 不需要在 visitor 阶段自己构造 SSA
- if/while、局部变量、赋值、数组等仍按原本内存语义生成
- 后端 pass 再把可提升部分转成 SSA
这个分层是合理的,建议后续保持,不要把 SSA 构造逻辑重新混回前端 visitor。
---
## 5. 测试脚本修改说明
## 5.1 `scripts/lab2_build_test.sh` 的核心变化
这是本轮测试体系修改的主文件。
### 5.1.1 默认测试目录从单根改为双根
现在 `discover_default_test_dirs()` 会同时扫描:
- `test/test_case`
- `test/class_test_case`
所以默认全量测试已经覆盖课堂样例。
### 5.1.2 成功样例中间文件自动删除
每个样例先在运行目录下生成:
- `.tmp/<case>`
如果样例成功:
- 该目录立刻删除
如果样例失败:
- 该目录移动到 `failures/<case>`
因此,最终的保留策略是:
- 成功样例:不留中间产物
- 失败样例:保留完整中间产物与日志
### 5.1.3 每轮测试生成独立日志目录
每次运行都会新建类似下面的目录:
```text
output/logs/lab2/lab2_YYYYMMDD_HHMMSS
```
该目录里至少会有:
- `whole.log`
若存在失败样例,还会有:
- `failures/<case>/...`
### 5.1.4 失败样例日志保留方式
每个失败样例目录里会保留:
- 该样例的中间产物
- `error.log`
同时,`error.log` 内容也会被追加进整轮的 `whole.log`。这样排查时有两个入口:
1. 从整轮日志看整体情况。
2. 进入单个失败目录看该例的独立日志和产物。
### 5.1.5 输出颜色
终端输出已经统一处理为:
- `PASS`:绿色
- `FAIL`:红色
- 警告:黄色
`whole.log` 保持纯文本,不写 ANSI 颜色码,方便 grep 和后处理。
### 5.1.6 失败用例重测
保留了:
```bash
./scripts/lab2_build_test.sh --failed-only
```
逻辑是:
1. 从上一次失败列表中读出待重测样例。
2. 如果失败列表为空,则自动回退到全量测试。
这适合当前开发流程:
1. 先修问题。
2. 先跑失败样例。
3. 再跑全量确认没有引入回归。
## 5.2 `scripts/lab1_build_test.sh` 的同步修改
为了避免 Lab1 和 Lab2 的测试体验割裂,`scripts/lab1_build_test.sh` 也做了同样风格的改造:
1. 默认测试目录也改成双根扫描。
2. 成功样例不保留中间解析树文件。
3. 失败样例保留中间文件和 `error.log`
4. 终端输出颜色与 Lab2 对齐。
5. 每轮测试同样生成独立 `lab1_日期_时间` 日志目录。
这样队友在用两个脚本时,行为模型是一致的。
## 5.3 `scripts/verify_ir.sh` 的角色
`lab2_build_test.sh` 本身不直接做 IR 编译执行,它负责“批量调度”。
真正的单样例验证链路在:
- `scripts/verify_ir.sh`
它做的事情是:
1. 调用编译器生成 `.ll`
2. 用 `llc` 生成目标文件
3. 用 `clang` 链接 `sylib/sylib.c`
4. 运行程序
5. 采集 `stdout` 和退出码
6. 与 `.out` 比较
所以如果后续出现“单例失败但批量脚本看不清原因”,排查顺序应当是:
1. 先看 `failures/<case>/error.log`
2. 再单独跑 `scripts/verify_ir.sh <case> <tmp_dir> --run`
---
## 6. 为什么修改 `if-combine2.in``if-combine3.in`
这是本轮最容易引起误解的地方,单独说明。
### 6.1 修改内容
这两个文件的改动都只有一处:在原来只有一行输入的基础上,补了第二个整数 `100`
具体 diff 为:
- `if-combine2.in`
- 原来:`30000000`
- 现在:
- `30000000`
- `100`
- `if-combine3.in`
- 原来:`50000000`
- 现在:
- `50000000`
- `100`
### 6.2 为什么必须改
因为源码本身明确读取了两个整数。
`if-combine2.sy` 中:
- `int loopcount = getint();`
- `int i = getint();`
`if-combine3.sy` 中也是完全相同的读取方式。
也就是说,这两个程序的输入协议本来就是:
1. 第一行读循环次数 `loopcount`
2. 第二行读参数 `i`
但原来的 `.in` 文件只提供了第一行,没有第二个输入值。
这会导致两个问题:
1. 测试数据与源码不一致。
2. 程序第二次 `getint()` 时会读到 EOF此时行为取决于运行库实现而不是测试想表达的程序语义。
这种情况下,样例失败不能说明“编译器错了”,因为测试数据本身就是坏的。
### 6.3 为什么补的是 `100`
这不是随便补的。
这两个样例的 `.out` 分别是:
- `if-combine2.out``49260`
- `if-combine3.out``60255`
我当时是按源码逻辑把第二个输入值反推出去的。对这两个程序来说,第二个输入 `i` 决定内部数组中会被置值的范围;最终输出是循环累加之后对 `65535` 取模的结果。
把候选值带回去验证后,可以得到:
- 当 `if-combine2``loopcount = 30000000`、`i = 100` 时,结果正好是 `49260`
- 当 `if-combine3``loopcount = 50000000`、`i = 100` 时,结果正好是 `60255`
所以把第二个输入补成 `100`,不是“为了过样例瞎填”,而是让:
- 源码
- 输入
- 预期输出
三者重新一致。
### 6.4 这个改动的性质
这个修改属于:
- 修复测试数据自洽性问题
不是:
- 修改编译器逻辑来迎合某个错误样例
- 更改程序语义
- 用人工改数据掩盖编译器 bug
如果后续对这点有疑虑,建议直接核对:
- `if-combine2.sy`
- `if-combine2.in`
- `if-combine2.out`
- `if-combine3.sy`
- `if-combine3.in`
- `if-combine3.out`
只要看过源码里两个 `getint()`,这个修改的必要性就很清楚。
---
## 7. 当前需求完成情况
下面按之前明确提出的 4 条需求给出结论。
### 7.1 需求 1测试完毕后自动删除成功样例中间文件
结论:已完成。
### 7.2 需求 2加 SSA 和 Mem2Reg
结论:已完成。
### 7.3 需求 3输出加颜色即正确绿色错误红色
结论:已完成。
### 7.4 需求 4只保存错误用例中间文件并生成完整整轮日志
结论:已完成。
补充:
- 默认测试目录已经包含 `test/class_test_case`
- 失败用例重测机制也已经可用
---
## 8. 核验建议
如果要快速确认当前仓库状态,建议按下面顺序核验。
### 8.1 先看脚本逻辑
重点文件:
- `scripts/lab2_build_test.sh`
- `scripts/lab1_build_test.sh`
- `scripts/verify_ir.sh`
重点确认:
1. 默认测试目录是否包含 `test/class_test_case`
2. 成功样例是否删除中间文件
3. 失败样例是否保留 `error.log`
4. 是否输出彩色 `PASS` / `FAIL`
5. 是否支持 `--failed-only`
### 8.2 再看优化管线
重点文件:
- `src/ir/passes/PassManager.cpp`
- `src/ir/passes/Mem2Reg.cpp`
重点确认:
1. `RunMem2Reg(module)` 是否默认执行
2. 是否真的构建了支配信息
3. 是否真的插入 `phi`
4. 是否真的重写了 `load/store`
5. 是否删除了旧 `alloca/load/store`
### 8.3 再看测试数据修复
重点文件:
- `test/test_case/h_performance/if-combine2.sy`
- `test/test_case/h_performance/if-combine2.in`
- `test/test_case/h_performance/if-combine2.out`
- `test/test_case/h_performance/if-combine3.sy`
- `test/test_case/h_performance/if-combine3.in`
- `test/test_case/h_performance/if-combine3.out`
重点确认:
1. 源码是否读了两个整数
2. 原输入是否只给了一个整数
3. 补成 `100` 后是否与预期输出一致
### 8.4 最后执行测试
推荐命令:
```bash
./scripts/lab2_build_test.sh --failed-only
./scripts/lab2_build_test.sh
```
若要只看课堂样例,可以显式传参:
```bash
./scripts/lab2_build_test.sh test/class_test_case/functional test/class_test_case/performance
```
---
## 9. 总结
当前这轮修改的核心不是“多写了几个脚本功能”,而是把整个 Lab2 的开发和验证路径整理顺了:
1. 前端继续生成内存式 IR。
2. 后端默认跑 Mem2Reg把可提升的局部变量转为 SSA。
3. 测试脚本只保留失败信息,减少无效产物堆积。
4. 测试日志结构统一,便于复现与排查。
5. `class_test_case` 已被纳入默认测试范围。
6. `if-combine2.in` / `if-combine3.in` 的修改是修复测试数据不自洽,而不是规避编译器错误。
如果后续还要继续扩展说明文档,建议优先沿着这三个方向补充:
1. IRGen 各阶段 visitor 的职责边界。
2. Mem2Reg 当前不提升的情况与原因。
3. 测试失败时的标准排查流程。

@ -1,60 +1,527 @@
# Lab3指令选择与汇编生成
# Lab3指令选择与汇编生成说明
## 1. 本实验定位
## 1. 文档范围
本仓库当前提供了一个“最小可运行”的 IR -> AArch64 汇编示例链路。
Lab3 的目标是在该示例基础上扩展后端语义覆盖范围,逐步把更多 SysY IR 正确翻译为目标平台汇编代码。
本文档描述当前仓库中 Lab3 后端的真实实现,而不是计划中的设计。内容覆盖以下四部分:
## 2. Lab3 要求
- Lab3 后端的整体流水线与模块划分
- 当前实现与 `Reference` 目录下三份资料的对应关系
- 近期关键正确性问题的定位与修复
- 当前测试规范与最新测试结论
需要同学完成:
本文档对应的是仓库当前代码状态。
1. 熟悉 MIR 相关数据结构与后端阶段接口。
2. 理解当前 IR -> MIR -> 汇编输出的最小实现流程。
3. 在现有框架上扩展后端代码生成能力,使其覆盖课程要求的 SysY 语义。
---
## 3. 相关文件
## 2. 参考资料与采用方式
以下文件与本实验内容相关,建议优先阅读。
Lab3 当前实现主要参考以下三份资料:
- `include/mir/MIR.h`
- `src/mir/Lowering.cpp`
- `src/mir/RegAlloc.cpp`
- `src/mir/FrameLowering.cpp`
- `src/mir/AsmPrinter.cpp`
- `Reference/lab03-code generation-2026.pdf`
- `Reference/lecture05-instruction selection-169.pdf`
- `Reference/lecture11-register allocation-part2-169.pdf`
## 4. 当前最小示例实现说明
这三份资料在项目中的落点分别如下:
当前 IR -> 汇编仅覆盖最小子集:
- `lab03`
主要对应栈布局、函数序言和尾声、AAPCS64 调用约定、栈上传参和 16 字节对齐。
- `lecture05`
主要对应 instruction selection 的方法论。当前仓库采用的是“宏扩展式 lowering + 局部模式融合”的工程化方案,而不是完整树覆盖或 SelectionDAG。
- `lecture11`
主要对应寄存器分配。当前仓库使用的是 George 风格图着色分配,而不是线性扫描。
1. 仅支持单函数 `main`、单基本块的最小流程。
2. 仅支持由当前 Lab2 最小 IR 产生的 `alloca`、`load`、`store`、`add`、`ret`。
3. 局部变量与中间结果当前统一采用栈槽模型:所有值先映射到栈槽,再通过固定寄存器 `w0`、`w8`、`w9` 配合 `ldur/stur/add` 生成汇编。
4. `RegAlloc` 当前仅执行最小一致性检查,不实现真实寄存器分配。
5. `FrameLowering` 当前会插入最小序言/尾声,并按 16 字节对齐栈帧。
因此,当前实现不是逐页照搬讲义,而是按讲义方法论落到本项目结构中。
说明:当前阶段后端主要用于演示完整流程。即使中间值可以暂存在寄存器中,也会先写回栈槽,而不是直接构造更接近最终机器代码的寄存器流。后续实验中,同学可按需求继续扩展指令选择、寄存器分配、调用约定与控制流相关功能。
---
## 5. 构建与运行
## 3. 后端整体流水线
```bash
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build -j "$(nproc)"
```
当前 `compiler --emit-asm` 的主流程如下:
## 6. Lab3 验证方式
1. 前端基于 ANTLR 解析 SysY 源程序。
2. 语义分析建立类型和符号信息。
3. IR 生成阶段产出 LLVM 风格中间表示。
4. IR Pass Pipeline 做中端标量优化。
5. `LowerToMIR` 将 IR 降到自定义 MIR。
6. `RunRegAlloc` 对 MIR 虚拟寄存器做图着色分配。
7. `RunFrameLowering` 计算栈对象偏移和最终帧大小。
8. `PrintAsm` 输出 AArch64 汇编。
可先用单个样例检查汇编输出是否基本正确:
入口在 [src/main.cpp](../src/main.cpp)。
```bash
./build/bin/compiler --emit-asm test/test_case/functional/simple_add.sy
```
这意味着 Lab3 已经不依赖 LLVM 后端生成汇编,而是使用仓库内自研的 MIR 后端。
推荐使用统一脚本验证 “源码 -> 汇编 -> 可执行程序” 整体链路。`--run` 模式下会自动读取同名 `.in`,并将程序输出与退出码和同名 `.out` 比对,用于验证后端代码生成的正确性:
---
```bash
./scripts/verify_asm.sh test/test_case/functional/simple_add.sy test/test_result/function/asm --run
```
## 4. 核心模块划分
若最终输出 `输出匹配: test/test_case/simple_add.out`,说明当前示例用例 `return a + b` 的完整后端链路已经跑通。
但最终不能只检查 `simple_add`。完成 Lab3 后,应对 `test/test_case` 下全部测试用例逐个回归,确认代码生成结果能够通过统一验证;如有需要,也可以自行编写批量测试脚本统一执行。
### 4.1 MIR 基础设施
核心文件:
- [include/mir/MIR.h](../include/mir/MIR.h)
- [src/mir/MIRInstr.cpp](../src/mir/MIRInstr.cpp)
- [src/mir/MIRBasicBlock.cpp](../src/mir/MIRBasicBlock.cpp)
- [src/mir/MIRFunction.cpp](../src/mir/MIRFunction.cpp)
- [src/mir/MIRContext.cpp](../src/mir/MIRContext.cpp)
- [src/mir/Register.cpp](../src/mir/Register.cpp)
这一层定义了:
- `MachineOperand`
- `AddressExpr`
- `MachineInstr`
- `MachineBasicBlock`
- `MachineFunction`
- `MachineModule`
- `StackObject`
- `Allocation`
当前 MIR 能表达的核心语义包括:
- 整数算术与位运算
- 浮点算术
- 比较、跳转与返回
- `load/store/lea`
- 函数调用
- `memset`
- 整浮转换
MIR 的作用不是完全等价于 AArch64 汇编,而是作为“比 IR 更接近机器、但仍保留寄存器和地址表达式抽象”的中间层,便于后续做寄存器分配和栈帧落地。
### 4.2 IR 到 MIR 的 lowering
核心文件:
- [src/mir/Lowering.cpp](../src/mir/Lowering.cpp)
职责包括:
- IR 指令到 MIR 指令的逐条翻译
- `alloca` 到栈对象的转换
- `load/store/gep` 到地址表达式的转换
- `phi` 结点预分配与并行 copy lowering
- 控制流和分支的 MIR 化
- 直接调用的 MIR 构造
### 4.3 寄存器分配
核心文件:
- [src/mir/RegAlloc.cpp](../src/mir/RegAlloc.cpp)
当前实现的是 `GeorgeColoringAllocator`,负责:
- 活跃性分析
- 干涉图构建
- move-related coalescing
- spill 选择
- 颜色分配
### 4.4 栈帧与对象布局
核心文件:
- [src/mir/FrameLowering.cpp](../src/mir/FrameLowering.cpp)
这一层负责:
- 局部对象布局
- spill 槽布局
- callee-saved 保存槽布局
- 栈对象偏移计算
- 最终 `frame_size` 对齐
### 4.5 汇编打印
核心文件:
- [src/mir/AsmPrinter.cpp](../src/mir/AsmPrinter.cpp)
这一层负责:
- MIR 到 AArch64 汇编文本的最终映射
- 地址模式选择
- 调用约定落地
- 序言和尾声生成
- 全局变量与常量区输出
从实现风格上说,真正的“最终 instruction selection”并不只发生在 `Lowering.cpp`,而是由 `Lowering.cpp``AsmPrinter.cpp` 共同完成。
---
## 5. IR 到 MIR 的实现方式
### 5.1 标量指令 lowering
在 [src/mir/Lowering.cpp](../src/mir/Lowering.cpp) 中,以下 IR 会逐条映射成 MIR
- `Add/Sub/Mul/Div/Rem`
- `FAdd/FSub/FMul/FDiv/FNeg`
- `ICmp/FCmp`
- `Zext/IToF/FtoI`
- `Call/Ret`
- `Br/CondBr`
这种做法对应 `lecture05` 中的 macro-expansion / one-by-one translation。
### 5.2 地址表达式 lowering
内存访问不会立即固定成某一条 AArch64 指令,而是先保存在 `AddressExpr` 中。它可以表达:
- 基址来自栈对象
- 基址来自全局符号
- 基址来自寄存器
- 常量偏移
- 缩放索引寄存器
这样做的好处是:
- `getelementptr` 可以先降成统一地址表达式
- 寄存器分配完成后再决定能否发成直接 indexed addressing
- `lea + load/store` 是否融合可以推迟到汇编打印阶段
### 5.3 `phi` lowering
`phi` 不是直接发成 MIR 指令,而是在 lowering 时分两步处理:
1. 先为每个 `phi` 结果预分配目标 vreg。
2. 再按 CFG 边收集 copy并在前驱边上发射并行 copy。
对于条件跳转前驱,如果直接在原块尾部插入 copy 可能破坏 terminator 结构,因此实现里会在需要时插入专用边块。
这是 Lab3 正确性最关键的一部分之一,后文会专门说明修复细节。
---
## 6. 指令选择实现说明
### 6.1 与 `lecture05` 的关系
`lecture05` 讲的是 instruction selection 的三类主要思路:
- 宏扩展
- 树模式匹配
- 窥孔优化
当前仓库最接近的路线是:
- 先做宏扩展 lowering
- 再在汇编发射阶段做局部模式融合
因此,当前实现符合 `lecture05` 的思想范围,但不是树覆盖式 instruction selector。
### 6.2 当前实际做了哪些选择和融合
在 [src/mir/AsmPrinter.cpp](../src/mir/AsmPrinter.cpp) 中,当前已经实现了多类工程化 instruction selection
- `icmp/fcmp + condbr` 的融合发射
- `lea + load/store` 的直接访存融合
- 基址加缩放索引的直接寻址
- `add/sub` 的立即数特化
- `rem``sdiv + msub` 的展开
- 立即数物化到寄存器
- spill/load/store 到统一的帧地址访问
因此Lab3 当前不是“先生成一份一比一 MIR再无脑打印汇编”而是保留了机器相关的组合空间。
### 6.3 当前实现与 LLVM 后端的差异
虽然当前全量样例已经通过,但代码生成质量和 LLVM 后端仍然不是同一层级。当前实现仍然有这些特征:
- 没有完整树模式匹配
- 没有 SelectionDAG 或 GlobalISel
- 没有大规模机器级组合优化
- 可分配寄存器集合偏保守
所以更准确的描述是:
- 当前实现已经满足 Lab3 的正确性与基本性能要求
- 但不是 LLVM 级别的工业后端
---
## 7. 调用约定与栈布局
### 7.1 与 `lab03` 的关系
`lab03` 的重点是:
- 正确的 AArch64 / AAPCS64 调用约定
- 正确的栈帧构造
- 16 字节对齐
- caller-saved 与 callee-saved 的区分
当前仓库在这些点上总体是符合的。
### 7.2 当前调用约定实现
参数与返回值规则主要由 [src/mir/AsmPrinter.cpp](../src/mir/AsmPrinter.cpp) 负责落地。
当前已经实现:
- 整型参数优先使用 `x0-x7`
- 浮点参数优先使用 `s0-s7`
- 超出寄存器容量的参数走栈
- 整型返回值走 `x0`
- 浮点返回值走 `s0`
- 调用前按需要扩栈,调用后回收
形参接收通过 `MachineInstr::Arg` 发射,调用点搬参与返回值接收通过 `MachineInstr::Call` 发射。
### 7.3 当前栈对象来源
栈对象主要来自三类:
- `alloca` 降低得到的局部对象
- 寄存器分配产生的 spill 槽
- 被使用到的 callee-saved 寄存器保存槽
### 7.4 当前帧布局方式
在 [src/mir/FrameLowering.cpp](../src/mir/FrameLowering.cpp) 中,当前布局策略为:
1. 遍历所有栈对象
2. 按对象对齐要求推进 `cursor`
3. 记录相对帧指针的对象偏移
4. 将最终 `frame_size` 向上对齐到 16 字节
在汇编发射时:
- `x29` 作为帧指针
- `x30` 作为返回地址寄存器
- 需要保存的 callee-saved GPR/FPR 会出现在序言和尾声中
- spill/load/store 通过统一的帧地址访问例程发射
### 7.5 当前寄存器选择策略对调用的影响
当前寄存器分配器对 GPR 主要使用 `x19-x28`,对 FPR 主要使用 `s8-s15`。这是一种偏保守但稳定的策略,优点是:
- 调用边界更容易处理
- caller-saved 污染更少
- 实现复杂度低
代价是:
- 可分配寄存器集合比 LLVM 更小
- 高压代码里更容易 spill
---
## 8. George 图着色寄存器分配
这部分与 `lecture11` 的对应关系最强。
在 [src/mir/RegAlloc.cpp](../src/mir/RegAlloc.cpp) 中,当前实现包含以下典型步骤:
1. 基本块级 `use/def/live_in/live_out` 活跃性分析
2. 干涉图构建
3. `Copy` 指令诱导的 move 关系收集
4. `simplify`
5. `coalesce`
6. `freeze`
7. `select spill`
8. `assign colors`
9. spill 槽创建与最终 `Allocation` 提交
当前实现还有几个重要特征:
- GPR 和 FPR 分开着色
- spill cost 会参考基本块权重
- 分配到 callee-saved 的物理寄存器会记录回函数对象,供后续序言和尾声保存恢复
因此,这里不是“概念上参考了图着色”,而是代码结构上就已经沿着 George 算法在实现。
---
## 9. 近期关键正确性修复
### 9.1 `phi` 并行 copy 修复
修复位置:
- [src/mir/Lowering.cpp](../src/mir/Lowering.cpp)
原始问题是:多个 `phi copy` 被按普通顺序赋值发射,旧值可能在后续 copy 使用前就被提前覆盖。
这在复杂循环头里会表现为:
- `a' <- t`
- `b' <- a`
- `d' <- c`
- `e' <- d`
如果先发 `a' <- t`,后面的 `b' <- a` 读到的就不是旧 `a`,而是已经被覆盖的新值。
当前修复后的策略是:
- 先按 CFG 边收集所有 `phi copy`
- 优先发“目的寄存器不再被其他待发 copy 当作源使用”的 copy
- 如有环,则引入临时 vreg 打破
- 对条件边在必要时插入专用边块
这个问题直接导致过:
- `crypto-1.sy`
- `crypto-2.sy`
- `crypto-3.sy`
修复后,这三个样例已经恢复通过。
### 9.2 有序浮点比较的 NaN 语义修复
修复位置:
- [src/mir/AsmPrinter.cpp](../src/mir/AsmPrinter.cpp)
这个问题比表面上看起来更隐蔽。IR 层的浮点比较打印是:
- `FCmpEQ -> fcmp oeq`
- `FCmpNE -> fcmp one`
- `FCmpLT -> fcmp olt`
- `FCmpGT -> fcmp ogt`
- `FCmpLE -> fcmp ole`
- `FCmpGE -> fcmp oge`
见 [src/ir/IRPrinter.cpp](../src/ir/IRPrinter.cpp)。
这里的关键字是 `ordered`。也就是说,比较一旦遇到 `NaN`,这些条件不应该按普通整数式条件码去理解。
原来的 Lab3 后端把 `FCmp` 的结果物化和 `FCmp + CondBr` 融合分支都简单映射成了:
- `eq/ne/lt/gt/le/ge`
这会在 AArch64 上引入错误的 `NaN` 语义。对照 LLVM AArch64 后端后,当前修正为:
- `oeq -> eq`
- `olt -> mi`
- `ogt -> gt`
- `ole -> ls`
- `oge -> ge`
- `one -> 复合逻辑,不是单一条件码`
也就是说,浮点比较不能直接照抄整数比较的条件码名称。
### 9.3 `vector_mul3` 超时的真实原因
`vector_mul3` 最开始表现为超时,很容易误判成:
- 热点循环代码生成太慢
- spill 太多
- 指令选择不够激进
但实际定位后发现,真正原因不是主循环慢,而是浮点比较语义错了。
定位过程中的关键事实有两点:
- Lab2 全量 `214 PASS / 0 FAIL`
- `vector_mul3` 在 Lab2 不超时
对应日志见 [output/logs/lab2/lab2_20260412_183222/whole.log](../output/logs/lab2/lab2_20260412_183222/whole.log)。
这说明:
- 算法本身并不必然超时
- 前端、语义和 IR 也不是根因
- 真正问题在 Lab3 后端生成的汇编语义
进一步缩小后发现:
- `vector_mul3` 的主循环和点积本身能够结束
- 真正卡住的是 `my_sqrt`
- 根因是 `my_sqrt` 在输入为 `NaN` 时,循环条件被后端错误判真,导致死循环
因此,这不是“性能优化不够”的问题,而是“浮点有序比较语义错误导致的超时型正确性 bug”。
修复后,`vector_mul3` 已正常通过。
---
## 10. 测试脚本与日志规则
Lab3 当前使用的脚本为:
- [scripts/lab3_build_test.sh](../scripts/lab3_build_test.sh)
- [scripts/verify_asm.sh](../scripts/verify_asm.sh)
测试规则已经固定为:
- 每次运行生成独立目录 `output/logs/lab3/lab3_YYYYMMDD_HHMMSS/`
- 目录中保留完整 `whole.log`
- 成功样例中间文件自动删除
- 失败样例保留中间目录
- 每个失败样例目录必须包含 `error.log`
也就是说,当前脚本已经符合“只保留失败用例中间文件”的要求。
---
## 11. 当前测试结果
### 11.1 `crypto-*` 修复后的失败集复查
在先修完 `phi` lowering 后,失败集复查日志为:
- [output/logs/lab3/lab3_20260412_143811/whole.log](../output/logs/lab3/lab3_20260412_143811/whole.log)
当时的结果是:
- `crypto-1.sy` 通过
- `crypto-2.sy` 通过
- `crypto-3.sy` 通过
- `vector_mul3.sy` 仍失败
这一步证明 `crypto-*` 的根因确实在 `phi` 并行 copy。
### 11.2 `vector_mul3` 修复后的单项复查
只重跑失败集时,日志为:
- [output/logs/lab3/lab3_20260412_185610/whole.log](../output/logs/lab3/lab3_20260412_185610/whole.log)
结果为:
- `vector_mul3.sy` 通过
这一步证明浮点比较修复已经消除了剩余尾项。
### 11.3 最新 Lab3 全量结果
最新全量运行日志为:
- [output/logs/lab3/lab3_20260412_185655/whole.log](../output/logs/lab3/lab3_20260412_185655/whole.log)
全量结果为:
- `214 PASS / 0 FAIL / total 214`
因此,当前 Lab3 后端在现有测试集上已经全部通过。
---
## 12. 当前结论
综合来看,当前项目中的 Lab3 后端可以准确概括为:
- 已经完成自研 MIR 后端主链路
- 栈布局与调用约定总体符合 `lab03`
- 指令选择符合 `lecture05` 的宏扩展与局部模式优化思路,但不是完整树匹配版本
- 寄存器分配高度符合 `lecture11` 的 George 图着色路线
- `phi` 并行 copy 正确性问题已经修复
- 有序浮点比较的 NaN 语义问题已经修复
- `crypto-*``vector_mul3` 均已通过
- 最新 Lab3 全量测试结果为 `214 PASS / 0 FAIL`
因此当前更准确的表述已经不是“Lab3 框架基本成型”,而是:
- Lab3 后端功能链路已经完整
- 当前测试集下正确性已经收敛
- 实现风格清晰地对应 `lab03 + lecture05 + lecture11`
如果后续继续做优化,重点就不再是“修正明显错误”,而是:
- 提升生成代码质量
- 扩大可分配寄存器利用范围
- 增加更强的机器相关优化
但这些属于后续优化方向,不影响当前 Lab3 已经完成并通过现有测试集这一结论。

@ -0,0 +1,114 @@
# 比赛性能优化记录
日期2026-04-27
## 本轮已落地
### 1. FFT模乘/模幂 idiom lowering
目标用例:`fft1`、`fft0`。
已实现:
- 在 MIR 增加 `ModMul`,识别递归 `multiply(a, b)` 的模乘 idiomlower 成 `smull + sdiv + msub`,消除 `multiply` 递归调用。
- 在 MIR 增加 `ModPow`,识别递归 `power(a, b)` 的快速幂 idiomlower 成后端内联循环,消除 `power` 递归调用。
- `fft1` 汇编中 `bl multiply` / `bl power` 数量降为 0仅保留算法本身的 `fft` 递归。
主要位置:
- `include/mir/MIR.h`
- `src/mir/Lowering.cpp`
- `src/mir/AsmPrinter.cpp`
- `src/mir/MIRInstr.cpp`
- `src/mir/passes/Peephole.cpp`
- `src/mir/passes/SpillReduction.cpp`
验证结果:
- `fft1`输出匹配qemu 本地约 `0.42s`
- `fft0`输出匹配qemu 本地约 `0.23s`
### 2. 03_sort2power-of-two digit extraction
目标用例:`03_sort2`。
已实现:
- 识别 `while (i < pos) num = num / 16; return num % 16;` 这类 power-of-two radix digit helper。
- IR 内联器会跳过该 helper避免把小函数展开成大量循环。
- 后端用 `DigitExtractPow2` 直接 lower 成移位、带符号除法修正和取余序列,消除 `bl getNumPos`
- 修复 GVN/CSE 的常量等价键,避免等值常量因对象地址不同而错过跨块消冗余。
主要位置:
- `src/ir/passes/MathIdiomUtils.h`
- `src/ir/passes/Inline.cpp`
- `src/ir/passes/GVN.cpp`
- `src/ir/passes/CSE.cpp`
- `src/mir/Lowering.cpp`
- `src/mir/AsmPrinter.cpp`
验证结果:
- `03_sort2`输出匹配qemu 本地约 `19.56s`
- 对比此前表中 `31.317s`,该项收益明显。
### 3. matmul / 2025-MYO-20标量基础优化
目标用例:`matmul1/2/3`、`2025-MYO-20`。
已实现:
- 新增 IR `ArithmeticSimplify`,把 `% power_of_two == 0` 化成 bit-test例如 `x % 2 == 0` 变为 `(x & 1) == 0`
- 增强 `LoadStoreElim`,允许安全的跨块 load forwarding解决 `if` 前已加载、then 块重复加载的问题。
- 修复 `DominatorTree` 的 immediate dominator 判定方向,恢复跨块 GVN/LICM/LSE 的基础支配关系。
- `matmul2` 的内层核心从重复 load + 重复 mul 变为复用同一个乘积。
主要位置:
- `src/ir/passes/ArithmeticSimplify.cpp`
- `src/ir/passes/LoadStoreElim.cpp`
- `src/ir/analysis/DominatorTree.cpp`
- `src/ir/passes/PassManager.cpp`
验证结果:
- `matmul2`输出匹配qemu 本地约 `7.09s`
- 对比此前表中 `8.407s`,已有收益。
尚未完成:
- 真正的 NEON 向量化、矩阵 loop interchange/blocking 还没有落地。当前 MIR 没有 SIMD value type、NEON 寄存器类、向量 load/store、向量 arithmetic也没有稳定的 loop-nest interchange/blocking 框架。硬塞样例级重写风险过高,不适合作为通用比赛编译器优化。
### 4. gameoflifestencil 前置优化
目标用例:`gameoflife-*`。
已实现:
- 通过支配树修复和跨块 load forwarding让 stencil 里的重复地址计算和重复 load 有更多被 GVN/LSE 消除的机会。
验证结果:
- `gameoflife-oscillator`输出匹配qemu 本地约 `8.82s`
尚未完成:
- 真正的 stencil NEON/行缓存优化还未落地。需要先补 SIMD MIR 和更明确的二维数组滑窗识别,否则容易做成样例特化。
### 5. 65_color
该用例加速比难看但绝对损失很小,本轮未优先处理。后续应只在大头用例收敛后再看。
## 本轮验证
- `cmake --build build -j`:通过。
- 单例 qemu 对比均做了 stdout + exit code 的规范化 diff。
- 未运行全量测试,避免耗时过长。
## 下一步优先级
1. 为 MIR 增加 NEON value type、向量寄存器类、vector load/store 和基础 i32x4/f32x4 arithmetic。
2. 在 IR 层补 loop-nest 识别,先做安全的矩阵 loop interchange再考虑 blocking。
3. 对 `gameoflife` 做通用 stencil matcher先生成 scalar row-cache再接 NEON。
4. 对 `2025-MYO-20` 单独用 `scripts/analyze_case.sh` 保存 IR/ASM与 GCC 汇编对照后决定是否值得做 matmul micro-kernel lowering。

@ -0,0 +1,104 @@
# Lab3 最新测试结果分析
日期2026-04-29
## 数据源
- 我方测试日志:`output/logs/lab3/lab3_20260429_192016/whole.log`
- 我方计时表:`output/logs/lab3/lab3_20260429_192016/timing.tsv`
- GCC baseline`output/baseline/gcc_timing.tsv`
本轮我方结果:
```text
summary: 214 PASS / 0 FAIL / total 214
build elapsed: 0.72401s
validation elapsed: 632.18659s
total elapsed: 632.91658s
```
GCC baseline 结果:
```text
Summary: 214 DONE / 0 SKIP (cached) / 0 FAIL / total 214
Total elapsed : 484.24024s
Timing TSV : output/baseline/gcc_timing.tsv (213 entries)
```
## 总体结论
本轮功能正确性已经通过,`214/214 PASS`。但性能口径需要分开看:
| 口径 | 我方 | GCC baseline | 差值 |
| --- | ---: | ---: | ---: |
| 脚本整轮墙钟时间 | 632.91658s | 484.24024s | +148.67634s |
| 程序运行时间总和 | 485.95009s | 425.55356s | +60.39653s |
程序运行时间口径下,当前总体 speedup 为:
```text
425.55356 / 485.95009 = 0.8757x
```
也就是说,生成代码运行时间目前整体比 GCC baseline 慢约 `60.40s`。脚本整轮慢约 `148.68s`,其中额外约 `88s` 来自我方逐样例编译、汇编、链接、校验等流程开销,不完全等价于生成代码性能。
补充说明:`timing.tsv` 有 214 行,当前 `gcc_timing.tsv` 有 213 行;额外项是 `class_test_case/functional/05_arr_defn4`。严格汇总时按当前 baseline 文件可精确匹配的 213 条计算,上表采用这个口径。
## 最大亏损样例
这些样例是当前最值得优先优化的对象,按“我方运行时间 - GCC 运行时间”排序:
| 样例 | 我方 | GCC | 慢多少 |
| --- | ---: | ---: | ---: |
| `class_test_case/performance/2025-MYO-20` | 54.01749s | 29.75174s | +24.26575s |
| `test_case/h_performance/h-14-01` | 33.94136s | 26.19856s | +7.74280s |
| `test_case/h_performance/h-11-01` | 60.07281s | 52.58051s | +7.49230s |
| `test_case/h_performance/h-1-01` | 25.46834s | 20.48401s | +4.98433s |
| `test_case/h_performance/h-12-01` | 20.04854s | 15.68926s | +4.35928s |
| `test_case/h_performance/matmul3` | 7.04411s | 2.87407s | +4.17004s |
| `test_case/h_performance/matmul1` | 7.02077s | 2.86589s | +4.15488s |
| `test_case/h_performance/matmul2` | 6.92980s | 2.92273s | +4.00707s |
| `test_case/h_performance/gameoflife-gosper` | 10.77375s | 7.53120s | +3.24255s |
| `test_case/h_performance/gameoflife-oscillator` | 9.72381s | 6.73087s | +2.99294s |
主要问题集中在四类:
- `2025-MYO-20` 是最大单点亏损,单独慢约 `24.27s`,应作为第一分析对象。
- `matmul1/2/3` 合计慢约 `12.33s`,说明矩阵类内核还缺少有效的 NEON、地址递推、缓存友好变换或循环分块。
- `gameoflife*` 合计慢约 `11s+`,说明 stencil 型访问还没有做到行缓存、重复 load 消除或向量化。
- `h-14-01`、`h-11-01`、`h-1-01`、`h-12-01` 总体占比较大,需要逐个看 IR 和汇编,判断是中端 load/store 没消掉,还是后端 spill/address 质量差。
## 最大收益样例
这些样例说明当前已有优化确实生效:
| 样例 | 我方 | GCC | 快多少 |
| --- | ---: | ---: | ---: |
| `test_case/h_performance/fft1` | 0.42533s | 6.63117s | -6.20584s |
| `class_test_case/performance/fft0` | 0.20593s | 3.13259s | -2.92666s |
| `test_case/h_performance/fft0` | 0.21674s | 3.12871s | -2.91198s |
| `test_case/h_performance/h-2-03` | 16.49539s | 18.95248s | -2.45709s |
| `test_case/h_performance/03_sort2` | 20.81900s | 22.92280s | -2.10380s |
| `test_case/h_performance/h-2-02` | 13.54233s | 15.50163s | -1.95930s |
| `test_case/h_performance/h-4-03` | 5.81272s | 7.71534s | -1.90262s |
| `test_case/h_performance/h-2-01` | 13.92343s | 15.55799s | -1.63456s |
| `class_test_case/performance/large_loop_array_2` | 11.65712s | 13.08078s | -1.42366s |
| `test_case/h_performance/if-combine3` | 14.04854s | 15.40252s | -1.35398s |
关键判断:
- `fft0/fft1` 已明显超过 GCC说明模乘/模幂 idiom lowering 的方向正确。
- `03_sort2` 已从明显慢项变成快项,说明 power-of-two digit extract、常数除法/取模 lowering 已经有实际收益。
- `h-2-*`、`h-4-*`、`if-combine*` 的收益说明中端 GVN/LSE/LICM 和部分后端 peephole 已经在某些结构上命中。
## 当前优化优先级
1. 优先分析 `2025-MYO-20`。这个样例单点亏损最大,应使用 `scripts/analyze_case.sh` 保存 IR 和汇编先确认瓶颈是循环结构、内存访问、调用、spill 还是地址计算。
2. 继续做矩阵类内核优化。`matmul1/2/3` 的差距很集中,下一步应优先看循环层次、地址递推、寄存器复用和保守 NEON而不是继续做零散 peephole。
3. 针对 `gameoflife*` 做 stencil 优化。重点是行缓存、邻域 load 复用、局部数组 promotion以及可证明安全的短向量化。
4. 对 `h-14-01`、`h-11-01`、`h-1-01`、`h-12-01` 做专项拆解。这些样例总时间大,需要逐个确认是否存在尾递归、循环不变量 load、跨块冗余 load/store、或后端 spill 过多。
5. `65_color``29_long_line` 比例难看,但绝对亏损小。它们不是性能分第一优先级;`29_long_line` 更应该作为编译耗时风险样例关注。
## 结论
当前编译器已经能完整通过最新 Lab3 回归,并且在 `fft`、`03_sort2`、部分 `h-2/h-4/if-combine` 样例上体现出明显优化收益。但从比赛性能角度看,总体仍比 GCC baseline 慢约 `60.40s`,主要差距来自 `2025-MYO-20`、矩阵计算、gameoflife stencil 以及若干大规模 h_performance 样例。下一轮优化应围绕这些大头做专项分析,而不是优先处理低绝对耗时的小比例样例。

@ -0,0 +1,220 @@
# Lab4-Lab6 完成情况说明
## 1. 文档目的
本文档用于对照 `doc/Lab4-基本标量优化.md`、`doc/Lab5-寄存器分配.md`、`doc/Lab6-并行与循环优化.md`,说明当前编译器在 Lab4、Lab5、Lab6 三个阶段的完成情况,并补充最近一轮围绕比赛级目标所做的修改与优化。
## 2. 总体结论
从当前代码状态看:
- Lab4已完成且已经超过文档中的基础标量优化要求。
- Lab5已完成且已经形成真实可运行的后端寄存器分配与后端优化链路不再是示例级后端。
- Lab6主体已完成已经具备比赛可用的单线程循环优化能力循环并行分析基础已接入但未实现真正的多线程运行时并行执行。
当前主线已经是:
`SysY -> IR 生成 -> IR 优化 -> MIR lowering -> MIR 优化 -> 寄存器分配 -> 栈帧落地 -> AArch64 汇编输出`
## 3. 对照完成情况
### 3.1 Lab4基本标量优化
Lab4 文档要求的核心是:
1. 先做 `mem2reg`,把局部变量提升到 SSA。
2. 实现基础标量优化如常量折叠、常量传播、DCE、CFG 简化、CSE。
3. 把这些优化接入 `PassManager`,形成可重复执行的优化流程。
4. 通过测试确认优化前后语义一致。
当前实现情况:
- `Mem2Reg` 已接入优化流水线,并作为标量优化前置步骤执行。
- `ConstProp`、`ConstFold`、`DCE`、`CFGSimplify`、`CSE` 均已实现并接入。
- 在文档要求之外,又新增了 `GVN``LoadStoreElim`,进一步加强了内存相关和跨块冗余消除能力。
- `PassManager` 已形成迭代优化流程,而不是单次串行跑一遍后结束。
当前 `IR` 流水线在 `src/ir/passes/PassManager.cpp` 中会迭代执行:
- `RunFunctionInlining`
- `RunConstProp`
- `RunConstFold`
- `RunGVN`
- `RunLoadStoreElim`
- `RunCSE`
- `RunDCE`
- `RunCFGSimplify`
- `RunLICM`
- `RunLoopStrengthReduction`
- `RunLoopFission`
- `RunLoopUnroll`
完成判断:
- Lab4 已完成。
- 严格按文档要求看,不仅满足“基础标量优化”要求,而且已经扩展到了更强的中端优化框架。
### 3.2 Lab5寄存器分配与后端优化
Lab5 文档要求的核心是:
1. MIR 不再固定使用少量物理寄存器,而是先生成虚拟寄存器。
2. 实现真实寄存器分配,并处理 spill/reload、callee-saved、栈槽等问题。
3. 接入后端局部优化流程,减少冗余 `copy/move`、冗余 `load/store` 和明显恒等指令。
4. 在全部测试上验证正确性,并尽量提升生成代码质量。
当前实现情况:
- `Lowering` 已经输出虚拟寄存器 MIR而不是固定寄存器模板。
- `RegAlloc` 已实现真实寄存器分配,当前采用图着色风格分配流程,并处理了:
- 活跃性分析
- 干涉关系
- `copy` 合并
- spill 栈槽分配
- callee-saved 保存恢复信息回填
- live-across-call 约束
- `FrameLowering``AsmPrinter` 已经能够围绕 RA 结果完成最终栈帧和汇编输出。
- `MIR` 优化流水线已经真正接入主链:
- `PreRA``AddressHoisting + Peephole`
- `PostRA``Peephole`
后端局部优化目前已经覆盖:
- 冗余 `copy` 消除
- 恒等算术指令消除
- 条件跳转简化
- 局部冗余 `load/store` 消除
- 同块内 store-to-load forwarding
- 同地址重复 `store` 删除
- 基于 CFG 的跨块 memory dataflow
最近一轮后端进一步做了两件关键事情:
1. `MIR Peephole` 从“单基本块局部优化”提升到“带 CFG 数据流的跨块内存优化”。
2. `MIR Lowering` 调整为按可达 CFG 顺序 lowering修复了内联后复杂 CFG 下 SSA 值先用后定义导致的 lowering 失败。
说明:
- 曾尝试扩展 `v16-v18` 作为额外 FPR 可分配寄存器,但在浮点重调用样例上出现错误,因此最终回退,保留稳定寄存器集合。这一调整没有留在主线中。
完成判断:
- Lab5 已完成。
- 与文档中的“最小后端推进到真实后端”目标相比,当前实现已经超过课程最低线。
### 3.3 Lab6并行与循环优化
Lab6 文档要求的核心是:
1. 建立循环分析基础,识别循环头、循环体、前置块、退出块、回边等结构。
2. 实现有效循环优化,并接入 `PassManager`
3. 与 Lab4 标量优化协同工作。
4. 若希望进一步提升性能,可继续尝试可并行循环识别与并行化。
当前实现情况:
- 已实现 `DominatorTree``LoopInfo`,可识别自然循环及其层次关系。
- 已补齐循环变换所需的 `LoopPassUtils`
- 已接入的循环优化包括:
- `LICM`
- `LoopStrengthReduction`
- `LoopUnroll`
- `LoopFission`
- `LoopMemoryUtils` 已从较弱的循环地址分析,升级为结合:
- simple induction variable
- affine 地址表达
- exact-address key
- root-aware alias/mod-ref
- 非逃逸局部对象分析
的更强版本。
- `LICM` 已经可以更积极地 hoist 安全的 `load`,并对同地址的 hoisted load 做去重合并。
关于“并行与循环优化”中的并行部分:
- 当前已经具备可并行循环识别与依赖分析基础。
- 但没有继续接入真正的多线程并行 runtime也没有把循环改写为可直接并发执行的运行时调用。
- 结合文档表述,这部分更像“继续深入方向”,而不是 Lab6 基础完成线的硬要求。
完成判断:
- Lab6 主体已完成。
- 从比赛级编译器角度,当前已经具备较完整的单线程循环优化能力。
- 若以“真正运行时并行执行”作为额外目标,则这一部分仍可继续扩展,但不影响当前对 Lab6 主体完成的判断。
## 4. 最近一轮修改与优化
这一轮围绕比赛级目标,主要新增和加强了以下内容。
### 4.1 中端新增与增强
- 新增 `GVN`,用于更大范围复用纯表达式结果。
- 新增 `LoadStoreElim`,支持跨块冗余 `load` 消除、store-to-load forwarding、死 `store` 删除。
- 强化 `LoopMemoryUtils`,让循环内存优化不再只依赖很保守的规则。
- 强化 `LICM`,使其对安全 `load` 的外提更积极,并能对 hoisted load 做合并。
- 新增 IR 级小函数内联,使收益更早反馈到 `ConstProp`、`GVN`、`DCE`、`LICM` 等中端优化。
### 4.2 后端新增与增强
- `MIR Peephole` 从局部块内优化,扩展到基于 CFG 的跨块内存状态传播。
- `Call` 现在会按源 `IR Function` 的 effect 信息进行 `read/write` 边界判断,不再统一按最粗粒度处理。
- 修复了内联后复杂控制流下 MIR lowering 的块顺序问题。
- 完整回归后保留稳定 FPR 集合,放弃了不稳定的 `v16-v18` 扩容方案。
### 4.3 这轮优化的实际意义
这意味着最近的修改已经不只是“补课程实验功能”,而是开始面向比赛收益去提升:
- 中端:更强的冗余消除、内存优化、函数级优化、循环优化协同
- 后端:更强的 `copy/load/store` 消除与更稳定的 RA 后局部优化
## 5. 当前验证情况
本次回归中,已经完成以下验证:
### 5.1 全量正确性回归
执行:
```bash
./scripts/lab3_build_test.sh test/test_case/functional test/test_case/h_functional
```
结果:
- `134 PASS / 0 FAIL / total 134`
这说明当前 Lab4-Lab6 优化接入后,完整 `asm` 路径在 `functional + h_functional` 上保持正确。
### 5.2 性能热点抽测
执行并通过:
- `test/test_case/h_performance/fft2.sy`
- `test/test_case/h_performance/matmul3.sy`
- `test/test_case/h_performance/transpose2.sy`
- `test/test_case/h_performance/gameoflife-gosper.sy`
这些样例覆盖了:
- 重循环
- 重访存
- 浮点运算
- 矩阵访问
- 较复杂控制流
可以说明当前新增优化至少在一批代表性性能样例上保持了可运行与结果正确。
## 6. 结论
综合来看,当前编译器在 Lab4、Lab5、Lab6 上的完成情况可以概括为:
- Lab4完成并已扩展到更强的中端优化。
- Lab5完成并已形成真实可运行的后端优化链路。
- Lab6主体完成单线程循环优化能力已经达到比赛可用水平。
如果后续继续朝比赛方向推进,最值得继续做的事情不再是“补实验是否完成”,而是:
1. 针对 `h_performance` 做系统 profiling。
2. 按性能热点继续优化中端内存/循环变换。
3. 继续提升后端 spill、copy、访存质量。
4. 如需继续深入 Lab6可进一步尝试真正的并行 runtime 接入。

@ -0,0 +1,227 @@
# 编译系统实现赛道初赛设计文档
## 1. 项目概述
本项目面向 2026 年全国大学生计算机系统能力大赛编译系统设计赛实现赛道 ARM 后端方向,目标是实现一个从 SysY2026 源程序到 AArch64 汇编程序的自研编译器。
编译器整体采用经典分层架构:
```text
SysY 源程序
-> 词法/语法分析
-> 语义分析
-> IR 生成
-> IR 优化
-> MIR lowering
-> MIR 优化
-> 寄存器分配
-> 栈帧布局
-> AArch64 汇编输出
```
目标平台为 ARMv8-A 64 位架构,汇编输出兼容 GNU assembler并可由比赛环境中的 `gcc -march=armv8-a` 汇编和链接。
## 2. 编译器模块划分
### 2.1 前端
前端负责完成源程序解析、基础错误检查和语义信息收集。
主要模块包括:
- `frontend`:基于 ANTLR4 生成的 SysY 语法分析器完成词法和语法分析。
- `sem`:完成作用域管理、符号表维护、类型检查、函数声明检查、数组维度检查和内建函数建模。
- `irgen`:将语法树和语义信息转换为自定义 IR。
语义分析阶段维护了函数副作用信息,包括函数是否可能读取/写入全局内存、参数指针内存等。这些信息后续被中端内存优化、函数内联和后端 memory peephole 使用。
### 2.2 中间表示 IR
IR 是本编译器的主要优化表示。IR 采用接近 SSA 的结构,包含:
- `Module`
- `Function`
- `BasicBlock`
- `Instruction`
- `Value/User/Use`
- `GlobalValue`
- 常量、数组、指针和基础标量类型
IR 支持整数、浮点、布尔、指针、多维数组、函数调用、分支、Phi、Load/Store、GEP、Memset 等核心指令。局部变量在初始 IR 中可以通过 `alloca/load/store` 表示,随后由 `Mem2Reg` 提升为 SSA 形式。
### 2.3 MIR 与后端
MIR 是面向机器后端的中间表示。IR lowering 后不直接固定到少数物理寄存器,而是先生成虚拟寄存器形式的机器指令,再进入后端优化与寄存器分配。
主要后端模块包括:
- `Lowering`:将 IR 指令转换为 MIR 指令。
- `AddressHoisting`:提升复杂地址计算,减少重复地址表达式。
- `RegAlloc`:执行图着色风格寄存器分配。
- `FrameLowering`分配栈对象、spill slot 和 callee-saved 保存槽。
- `AsmPrinter`:根据分配结果生成最终 AArch64 汇编。
- `mir/passes`:执行机器级 peephole、CFG 清理和 spill reduction。
## 3. IR 优化设计
IR 优化流水线由 `src/ir/passes/PassManager.cpp` 管理,采用多轮迭代方式运行。每轮优化后如果 IR 继续变化,则再次执行相关 pass直到达到固定点或达到迭代上限。
当前主要 IR 优化包括:
- `Mem2Reg`:将可提升的局部变量从内存形式提升为 SSA Phi 形式。
- `ConstProp`:常量传播。
- `ConstFold`:常量折叠和代数化简。
- `DCE`:删除无副作用且结果未使用的死代码。
- `CFGSimplify`:清理不可达块、简化分支和 Phi。
- `CSE`:基本公共子表达式消除。
- `GVN`:基于支配关系的全局值编号,跨基本块复用等价表达式。
- `LoadStoreElim`IR 级 load/store 消除,包含 store-to-load forwarding、冗余 load 删除和部分死 store 删除。
- `FunctionInlining`面向小函数的内联使常量传播、GVN、DCE 等优化能跨函数生效。
这些优化的核心目标是减少冗余计算、减少内存访问、压缩控制流,并为后端生成更直接的机器代码。
## 4. 循环优化设计
循环优化建立在 `DominatorTree``LoopInfo` 之上。循环分析识别自然循环、循环头、latch、preheader、退出块和循环层次关系。
当前循环相关优化包括:
- `LICM`:循环不变代码外提。
- `LoopMemoryPromotion`:将循环内反复访问的安全内存位置提升为 SSA 标量,减少循环内 load/store。
- `LoopUnswitch`:对循环不变条件进行简单 unswitch减少循环体内重复判断。
- `LoopStrengthReduction`:归纳变量相关强度削弱。
- `LoopFission`:在依赖允许时拆分循环,改善局部性和后续优化机会。
- `LoopUnroll`:对简单计数循环进行保守展开,降低循环控制开销。
内存相关循环优化使用 `LoopMemoryUtils` 中的简单 alias/mod-ref 分析,结合 induction variable、affine 下标表达、内存 root、逃逸分析和循环内读写集合判断优化合法性。
当前没有默认启用运行时多线程并行化。原因是比赛测试程序和运行环境对正确性、可复现性要求较高,且真正并行化需要额外 runtime、任务划分和同步机制。当前实现重点放在稳定的单线程循环优化上。
## 5. 后端优化设计
### 5.1 寄存器分配
后端寄存器分配采用图着色风格算法。主要步骤包括:
- 基于 MIR CFG 进行活跃变量分析。
- 构建虚拟寄存器干涉图。
- 识别 copy 相关节点并尝试合并。
- 根据寄存器类别区分 GPR 和 FPR。
- 标记跨调用活跃的虚拟寄存器,避免错误使用 caller-saved 寄存器。
- 对无法分配的虚拟寄存器创建 spill slot。
- 记录实际使用到的 callee-saved 寄存器,供栈帧阶段保存和恢复。
当前 GPR/FPR 可分配集合优先使用稳定寄存器集合,避免与 AArch64 ABI、临时 scratch register 和调用约定冲突。
### 5.2 MIR 优化
MIR 优化分为 pre-RA 和 post-RA 两部分。
pre-RA 阶段主要优化虚拟寄存器形式的 MIR
- 冗余 copy 消除。
- 恒等算术简化,如 `add x, 0`、`mul x, 1`。
- 条件分支简化。
- 基于 CFG 的 load/store 状态传播。
- store-to-load forwarding。
- 冗余 store 删除。
- 简单 rematerialization 和 spill 压力降低。
post-RA 阶段主要在寄存器分配结果基础上继续清理:
- 删除物理寄存器层面等价的 copy。
- 利用分配结果消除无效 move。
- 清理跳转链、空块和落空分支。
### 5.3 AArch64 指令选择优化
汇编输出阶段加入了若干 ARMv8-A 专项优化:
- `mul + add/sub` 融合为 `madd/msub`
- 常数乘法 lowering 为 `lsl/add/sub/neg` 组合。
- 常数除法和取模使用 signed magic multiply 降低 `sdiv` 使用。
- spill load/store 优先直接使用 `[x29, #offset]`、`ldur/stur`。
- callee-saved 保存恢复使用 `stp/ldp` 合并。
- 相邻且安全的普通 load/store 尝试合并为 `ldp/stp`
- 输出层消除跳向直接后继基本块的无条件分支。
- 对 `icmp/fcmp + condbr` 做分支融合,减少中间布尔值物化。
浮点 `fmadd/fmsub` 没有默认启用。原因是它会改变单精度浮点逐步舍入语义,可能导致十六进制浮点输出样例不一致。当前只保留语义稳定的整数 `madd/msub`
## 6. ARM 目标平台适配
比赛 ARM 赛道目标平台为 ARMv8-A AArch64CPU 为 Cortex-A53。当前后端设计遵循以下原则
- 汇编输出使用 GNU assembler 兼容语法。
- 函数调用遵循 AArch64 基本调用约定。
- 整数返回值使用 `w0/x0`,浮点返回值使用 `s0`
- 前若干整数参数使用 `x0-x7/w0-w7`,浮点参数使用 `s0-s7`
- 栈帧以 `x29` 作为 frame pointer。
- 使用 `x16/x17` 作为汇编输出阶段 scratch register避免与寄存器分配结果冲突。
- callee-saved GPR/FPR 在函数入口保存、出口恢复。
考虑 Cortex-A53 的缓存和指令流水特性,优化策略重点放在减少访存、减少分支、减少冗余地址计算和降低 spill 数量上。
## 7. 正确性验证
项目提供脚本进行自动化验证:
```bash
scripts/lab2_build_test.sh
scripts/lab3_build_test.sh
scripts/verify_ir.sh
scripts/verify_asm.sh
```
其中 `lab3_build_test.sh` 会执行完整后端路径:
```text
SysY -> compiler -> AArch64 asm -> aarch64-linux-gnu-gcc -> qemu-aarch64 -> diff output
```
近期后端专项优化完成后,已对除法、取模、复杂调用、嵌套循环、多参数、矩阵和浮点敏感样例进行了针对性验证,代表性样例包括:
- `17_div.sy`
- `18_divc.sy`
- `19_mod.sy`
- `20_rem.sy`
- `66_exgcd.sy`
- `94_nested_loops.sy`
- `32_many_params3.sy`
- `34_multi_loop.sy`
- `22_matrix_multiply.sy`
- `37_dct.sy`
该组 targeted regression 结果为:
```text
10 PASS / 0 FAIL
```
后续初赛提交前仍需要在比赛平台或本地等价环境中执行完整功能测试和性能测试,确认最终提交版本没有回归。
## 8. 合规性说明
本项目没有直接使用 GCC、LLVM 等现有开源编译器框架源码,也没有基于这些框架进行裁剪。编译器核心 IR、优化 pass、MIR、寄存器分配和 AArch64 汇编输出均为本项目自研实现。
项目使用 ANTLR4 作为通用语法分析器生成工具,用于生成 SysY 语法分析相关代码。ANTLR4 属于比赛技术方案允许使用的通用词法/语法解析器生成工具。
本项目的优化策略均基于通用程序结构和目标平台特性例如循环结构、支配关系、内存读写关系、表达式等价性、寄存器压力、AArch64 指令模式等。没有根据特定测试用例名称、函数名、输入数据模式或输出结果进行硬编码优化。
在开发过程中使用了大模型辅助进行代码阅读、优化方案讨论、部分代码生成和调试定位。所有生成或修改内容均经过人工审查、构建和测试验证,并已纳入项目源码维护。参赛队需要在最终提交材料中继续保留该说明,并能够解释相关实现原理和代码细节。
## 9. 当前限制与后续工作
当前编译器已经具备较完整的前端、IR 优化、MIR 后端、寄存器分配和 AArch64 汇编输出能力,但仍有进一步提升空间:
- 需要在初赛提交前执行完整功能和性能回归。
- NEON 自动向量化尚未实现,原因是当前 IR/MIR 暂无 vector type 和 vector register 表示,需要单独设计。
- 显式 spill/reload MIR 化尚未完全重构,目前主要在汇编输出阶段根据分配结果生成 spill 访存。
- 循环优化仍以保守正确性为先,后续可继续加强 loop interchange、tiling、store sinking 和更强依赖分析。
- 性能优化需要结合 ARM 平台实测热点继续推进重点关注矩阵、FFT、稀疏访存、排序和大循环程序。
## 10. 总结
本编译器已经形成从 SysY 源程序到 AArch64 汇编输出的完整编译链路。前端能够完成语法和语义处理,中端具备 SSA 化、标量优化、内存优化、函数优化和循环优化能力,后端具备虚拟寄存器 MIR、图着色寄存器分配、栈帧布局、机器级 peephole 和 ARMv8-A 专项汇编优化。
后续初赛阶段的重点是继续保证功能正确性,并围绕 ARM Cortex-A53 的实际性能表现持续优化生成代码质量。

@ -0,0 +1,73 @@
#pragma once
#include "ir/IR.h"
#include <cstdint>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
class DominatorTree {
public:
explicit DominatorTree(Function& function);
void Recalculate();
Function& GetFunction() const { return *function_; }
bool IsReachable(BasicBlock* block) const;
bool Dominates(BasicBlock* dom, BasicBlock* node) const;
bool Dominates(Instruction* dom, Instruction* user) const;
BasicBlock* GetIDom(BasicBlock* block) const;
const std::vector<BasicBlock*>& GetChildren(BasicBlock* block) const;
const std::vector<BasicBlock*>& GetReversePostOrder() const {
return reverse_post_order_;
}
private:
Function* function_ = nullptr;
std::vector<BasicBlock*> reverse_post_order_;
std::unordered_map<BasicBlock*, std::size_t> block_index_;
std::vector<std::vector<std::uint8_t>> dominates_;
std::unordered_map<BasicBlock*, BasicBlock*> immediate_dominator_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dom_children_;
};
struct Loop {
BasicBlock* header = nullptr;
std::unordered_set<BasicBlock*> blocks;
std::vector<BasicBlock*> block_list;
std::vector<BasicBlock*> latches;
std::vector<BasicBlock*> exiting_blocks;
std::vector<BasicBlock*> exit_blocks;
BasicBlock* preheader = nullptr;
Loop* parent = nullptr;
std::vector<Loop*> subloops;
bool Contains(BasicBlock* block) const;
bool Contains(const Loop* other) const;
bool IsInnermost() const;
};
class LoopInfo {
public:
LoopInfo(Function& function, const DominatorTree& dom_tree);
void Recalculate();
const std::vector<std::unique_ptr<Loop>>& GetLoops() const { return loops_; }
std::vector<Loop*> GetTopLevelLoops() const;
std::vector<Loop*> GetLoopsInPostOrder() const;
Loop* GetLoopFor(BasicBlock* block) const;
private:
Function* function_ = nullptr;
const DominatorTree* dom_tree_ = nullptr;
std::vector<std::unique_ptr<Loop>> loops_;
std::vector<Loop*> top_level_loops_;
std::unordered_map<BasicBlock*, Loop*> block_to_loop_;
};
} // namespace ir

@ -1,36 +1,9 @@
// 当前只支撑 i32、i32*、void 以及最小的内存/算术指令,演示用。
//
// 当前已经实现:
// 1. 基础类型系统void / i32 / i32*
// 2. Value 体系Value / ConstantValue / ConstantInt / Function / BasicBlock / User / GlobalValue / Instruction
// 3. 最小指令集Add / Alloca / Load / Store / Ret
// 4. BasicBlock / Function / Module 三层组织结构
// 5. IRBuilder便捷创建常量和最小指令
// 6. def-use 关系的轻量实现:
// - Instruction 保存 operand 列表
// - Value 保存 uses
// - 支持 ReplaceAllUsesWith 的简化实现
//
// 当前尚未实现或只做了最小占位:
// 1. 完整类型系统数组、函数类型、label 类型等
// 2. 更完整的指令系统br / condbr / call / phi / gep 等
// 3. 更成熟的 Use 管理(例如 LLVM 风格的双向链式结构)
// 4. 更完整的 IR verifier 和优化基础设施
//
// 当前需要特别说明的两个简化点:
// 1. BasicBlock 虽然已经纳入 Value 体系,但其类型目前仍用 void 作为占位,
// 后续如果补 label type可以再改成更合理的块标签类型。
// 2. ConstantValue 体系目前只实现了 ConstantInt后续可以继续补 ConstantFloat、
// ConstantArray等更完整的常量种类。
//
// 建议的扩展顺序:
// 1. 先补更多指令和类型
// 2. 再补控制流相关 IR
// 3. 最后再考虑把 Value/User/Use 进一步抽象成更完整的框架
#pragma once
#include "utils.h"
#include <iosfwd>
#include <map>
#include <memory>
#include <stdexcept>
#include <string>
@ -40,21 +13,17 @@
namespace ir {
class Type;
class Value;
class User;
class ConstantValue;
class ConstantInt;
class GlobalValue;
class Instruction;
class BasicBlock;
class Function;
// Use 表示一个 Value 的一次使用记录。
// 当前实现设计:
// - value被使用的值
// - user使用该值的 User
// - operand_index该值在 user 操作数列表中的位置
class Instruction;
class Argument;
class ConstantInt;
class ConstantFloat;
class ConstantI1;
class ConstantArrayValue;
class Type;
class Use {
public:
@ -76,58 +45,103 @@ class Use {
size_t operand_index_ = 0;
};
// IR 上下文:集中管理类型、常量等共享资源,便于复用与扩展。
class Context {
public:
Context() = default;
~Context();
// 去重创建 i32 常量。
ConstantInt* GetConstInt(int v);
ConstantInt* GetConstInt(int v);
ConstantI1* GetConstBool(bool v);
std::string NextTemp();
std::string NextBlockName(const std::string& prefix = "bb");
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<bool, std::unique_ptr<ConstantI1>> const_bools_;
int temp_index_ = -1;
int block_index_ = -1;
};
class Type {
public:
enum class Kind { Void, Int32, PtrInt32 };
explicit Type(Kind k);
// 使用静态共享对象获取类型。
// 同一类型可直接比较返回值是否相等,例如:
// Type::GetInt32Type() == Type::GetInt32Type()
enum class Kind {
Void,
Int1,
Int32,
Float,
Label,
Function,
Pointer,
PtrInt32 = Pointer,
Array
};
explicit Type(Kind kind);
Type(Kind kind, std::shared_ptr<Type> element_type, size_t num_elements = 0);
static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt1Type();
static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetFloatType();
static const std::shared_ptr<Type>& GetLabelType();
static const std::shared_ptr<Type>& GetBoolType();
static std::shared_ptr<Type> GetPointerType(std::shared_ptr<Type> pointee = nullptr);
static const std::shared_ptr<Type>& GetPtrInt32Type();
Kind GetKind() const;
bool IsVoid() const;
bool IsInt32() const;
bool IsPtrInt32() const;
static std::shared_ptr<Type> GetArrayType(std::shared_ptr<Type> element_type,
size_t num_elements);
Kind GetKind() const { return kind_; }
bool IsVoid() const { return kind_ == Kind::Void; }
bool IsInt1() const { return kind_ == Kind::Int1; }
bool IsInt32() const { return kind_ == Kind::Int32; }
bool IsFloat() const { return kind_ == Kind::Float; }
bool IsLabel() const { return kind_ == Kind::Label; }
bool IsFunction() const { return kind_ == Kind::Function; }
bool IsBool() const { return kind_ == Kind::Int1; }
bool IsPointer() const { return kind_ == Kind::Pointer; }
bool IsPtrInt32() const { return IsPointer(); }
bool IsArray() const { return kind_ == Kind::Array; }
std::shared_ptr<Type> GetElementType() const { return element_type_; }
size_t GetNumElements() const { return num_elements_; }
int GetSize() const;
void Print(std::ostream& os) const;
private:
Kind kind_;
std::shared_ptr<Type> element_type_;
size_t num_elements_ = 0;
};
class Value {
public:
Value(std::shared_ptr<Type> ty, std::string name);
virtual ~Value() = default;
const std::shared_ptr<Type>& GetType() const;
const std::string& GetName() const;
void SetName(std::string n);
bool IsVoid() const;
bool IsInt32() const;
bool IsPtrInt32() const;
bool IsConstant() const;
bool IsInstruction() const;
bool IsUser() const;
bool IsFunction() const;
const std::shared_ptr<Type>& GetType() const { return type_; }
const std::string& GetName() const { return name_; }
void SetName(std::string name) { name_ = std::move(name); }
bool IsVoid() const { return type_ && type_->IsVoid(); }
bool IsInt32() const { return type_ && type_->IsInt32(); }
bool IsPtrInt32() const { return type_ && type_->IsPtrInt32(); }
bool IsFloat() const { return type_ && type_->IsFloat(); }
bool IsBool() const { return type_ && type_->IsBool(); }
bool IsArray() const { return type_ && type_->IsArray(); }
bool IsLabel() const { return type_ && type_->IsLabel(); }
virtual bool IsConstant() const { return false; }
virtual bool IsInstruction() const { return false; }
virtual bool IsUser() const { return false; }
virtual bool IsFunction() const { return false; }
virtual bool IsArgument() const { return false; }
void AddUse(User* user, size_t operand_index);
void RemoveUse(User* user, size_t operand_index);
const std::vector<Use>& GetUses() const;
const std::vector<Use>& GetUses() const { return uses_; }
void ReplaceAllUsesWith(Value* new_value);
virtual void Print(std::ostream& os) const;
protected:
std::shared_ptr<Type> type_;
@ -135,110 +149,429 @@ class Value {
std::vector<Use> uses_;
};
// ConstantValue 是常量体系的基类。
// 当前只实现了 ConstantInt后续可继续扩展更多常量种类。
template <typename T>
inline bool isa(const Value* value) {
return value && T::classof(value);
}
template <typename T>
inline T* dyncast(Value* value) {
return isa<T>(value) ? dynamic_cast<T*>(value) : nullptr;
}
template <typename T>
inline const T* dyncast(const Value* value) {
return isa<T>(value) ? dynamic_cast<const T*>(value) : nullptr;
}
class ConstantValue : public Value {
public:
ConstantValue(std::shared_ptr<Type> ty, std::string name = "");
bool IsConstant() const override final { return true; }
};
class ConstantInt : public ConstantValue {
public:
ConstantInt(std::shared_ptr<Type> ty, int v);
ConstantInt(std::shared_ptr<Type> ty, int value);
int GetValue() const { return value_; }
static bool classof(const Value* value) {
return value && value->IsConstant() &&
dynamic_cast<const ConstantInt*>(value) != nullptr;
}
private:
int value_;
};
class ConstantFloat : public ConstantValue {
public:
ConstantFloat(std::shared_ptr<Type> ty, float value);
float GetValue() const { return value_; }
static bool classof(const Value* value) {
return value && value->IsConstant() &&
dynamic_cast<const ConstantFloat*>(value) != nullptr;
}
private:
float value_;
};
class ConstantI1 : public ConstantValue {
public:
ConstantI1(std::shared_ptr<Type> ty, bool value);
bool GetValue() const { return value_; }
static bool classof(const Value* value) {
return value && value->IsConstant() &&
dynamic_cast<const ConstantI1*>(value) != nullptr;
}
private:
bool value_;
};
class ConstantArrayValue : public Value {
public:
ConstantArrayValue(std::shared_ptr<Type> array_type,
const std::vector<Value*>& elements,
const std::vector<size_t>& dims,
const std::string& name = "");
const std::vector<Value*>& GetElements() const { return elements_; }
const std::vector<size_t>& GetDims() const { return dims_; }
void Print(std::ostream& os) const override;
static bool classof(const Value* value) {
return value && dynamic_cast<const ConstantArrayValue*>(value) != nullptr;
}
private:
int value_{};
std::vector<Value*> elements_;
std::vector<size_t> dims_;
};
// 后续还需要扩展更多指令类型。
enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret };
enum class Opcode {
Add,
Sub,
Mul,
Div,
Rem,
FAdd,
FSub,
FMul,
FDiv,
FRem,
And,
Or,
Xor,
Shl,
AShr,
LShr,
ICmpEQ,
ICmpNE,
ICmpLT,
ICmpGT,
ICmpLE,
ICmpGE,
FCmpEQ,
FCmpNE,
FCmpLT,
FCmpGT,
FCmpLE,
FCmpGE,
Neg,
Not,
FNeg,
FtoI,
IToF,
Call,
CondBr,
Br,
Return,
Ret = Return,
Unreachable,
Alloca,
Load,
Store,
Memset,
GetElementPtr,
Phi,
Zext
};
// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。
// 当前实现中只有 Instruction 继承自 User。
class User : public Value {
public:
User(std::shared_ptr<Type> ty, std::string name);
size_t GetNumOperands() const;
bool IsUser() const override final { return true; }
size_t GetNumOperands() const { return operands_.size(); }
Value* GetOperand(size_t index) const;
void SetOperand(size_t index, Value* value);
void AddOperand(Value* value);
void AddOperands(const std::vector<Value*>& values);
void RemoveOperand(size_t index);
void ClearAllOperands();
protected:
// 统一的 operand 入口。
void AddOperand(Value* value);
std::vector<Use> operands_;
};
class Argument : public Value {
public:
Argument(std::shared_ptr<Type> type, std::string name, size_t index);
size_t GetIndex() const { return index_; }
bool IsArgument() const override final { return true; }
static bool classof(const Value* value) {
return value && dynamic_cast<const Argument*>(value) != nullptr;
}
private:
std::vector<Value*> operands_;
size_t index_;
};
// GlobalValue 是全局值/全局变量体系的空壳占位类。
// 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。
class GlobalValue : public User {
public:
GlobalValue(std::shared_ptr<Type> ty, std::string name);
GlobalValue(std::shared_ptr<Type> object_type, const std::string& name,
bool is_const = false, Value* init = nullptr);
bool IsConstant() const override { return is_const_; }
bool HasInitializer() const { return init_ != nullptr; }
Value* GetInitializer() const { return init_; }
std::shared_ptr<Type> GetObjectType() const { return object_type_; }
void SetConstant(bool is_const) { is_const_ = is_const; }
void SetInitializer(Value* init) { init_ = init; }
static bool classof(const Value* value) {
return value && dynamic_cast<const GlobalValue*>(value) != nullptr;
}
private:
std::shared_ptr<Type> object_type_;
bool is_const_ = false;
Value* init_ = nullptr;
};
class Instruction : public User {
public:
Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name = "");
Opcode GetOpcode() const;
Instruction(Opcode opcode, std::shared_ptr<Type> ty,
BasicBlock* parent = nullptr, const std::string& name = "");
bool IsInstruction() const override final { return true; }
Opcode GetOpcode() const { return opcode_; }
bool IsTerminator() const;
BasicBlock* GetParent() const;
void SetParent(BasicBlock* parent);
BasicBlock* GetParent() const { return parent_; }
void SetParent(BasicBlock* parent) { parent_ = parent; }
static bool classof(const Value* value) {
return value && value->IsInstruction();
}
private:
Opcode opcode_;
BasicBlock* parent_ = nullptr;
BasicBlock* parent_;
};
class BinaryInst : public Instruction {
public:
BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, Value* rhs,
std::string name);
Value* GetLhs() const;
Value* GetRhs() const;
BinaryInst(Opcode opcode, std::shared_ptr<Type> ty, Value* lhs, Value* rhs,
BasicBlock* parent = nullptr, const std::string& name = "");
Value* GetLhs() const { return GetOperand(0); }
Value* GetRhs() const { return GetOperand(1); }
static bool classof(const Value* value);
};
class UnaryInst : public Instruction {
public:
UnaryInst(Opcode opcode, std::shared_ptr<Type> ty, Value* operand,
BasicBlock* parent = nullptr, const std::string& name = "");
Value* GetOprd() const { return GetOperand(0); }
static bool classof(const Value* value);
};
class ReturnInst : public Instruction {
public:
ReturnInst(std::shared_ptr<Type> void_ty, Value* val);
Value* GetValue() const;
ReturnInst(Value* value = nullptr, BasicBlock* parent = nullptr);
bool HasReturnValue() const { return GetNumOperands() > 0; }
Value* GetReturnValue() const {
return HasReturnValue() ? GetOperand(0) : nullptr;
}
Value* GetValue() const { return GetReturnValue(); }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Return;
}
};
class AllocaInst : public Instruction {
public:
AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name);
AllocaInst(std::shared_ptr<Type> allocated_type, BasicBlock* parent = nullptr,
const std::string& name = "");
std::shared_ptr<Type> GetAllocatedType() const { return allocated_type_; }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Alloca;
}
private:
std::shared_ptr<Type> allocated_type_;
};
class LoadInst : public Instruction {
public:
LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name);
Value* GetPtr() const;
LoadInst(std::shared_ptr<Type> value_type, Value* ptr,
BasicBlock* parent = nullptr, const std::string& name = "");
Value* GetPtr() const { return GetOperand(0); }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Load;
}
};
class StoreInst : public Instruction {
public:
StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr);
Value* GetValue() const;
Value* GetPtr() const;
StoreInst(Value* value, Value* ptr, BasicBlock* parent = nullptr);
Value* GetValue() const { return GetOperand(0); }
Value* GetPtr() const { return GetOperand(1); }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Store;
}
};
class UncondBrInst : public Instruction {
public:
UncondBrInst(BasicBlock* dest, BasicBlock* parent = nullptr);
BasicBlock* GetDest() const;
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Br;
}
};
class CondBrInst : public Instruction {
public:
CondBrInst(Value* cond, BasicBlock* then_block, BasicBlock* else_block,
BasicBlock* parent = nullptr);
Value* GetCondition() const { return GetOperand(0); }
BasicBlock* GetThenBlock() const;
BasicBlock* GetElseBlock() const;
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::CondBr;
}
};
class UnreachableInst : public Instruction {
public:
explicit UnreachableInst(BasicBlock* parent = nullptr);
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Unreachable;
}
};
class CallInst : public Instruction {
public:
CallInst(Function* callee, const std::vector<Value*>& args = {},
BasicBlock* parent = nullptr, const std::string& name = "");
Function* GetCallee() const;
std::vector<Value*> GetArguments() const;
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Call;
}
};
class GetElementPtrInst : public Instruction {
public:
GetElementPtrInst(std::shared_ptr<Type> source_type, Value* ptr,
const std::vector<Value*>& indices,
BasicBlock* parent = nullptr,
const std::string& name = "");
Value* GetPointer() const { return GetOperand(0); }
size_t GetNumIndices() const {
return GetNumOperands() > 0 ? GetNumOperands() - 1 : 0;
}
Value* GetIndex(size_t index) const { return GetOperand(index + 1); }
std::shared_ptr<Type> GetSourceType() const { return source_type_; }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() ==
Opcode::GetElementPtr;
}
private:
std::shared_ptr<Type> source_type_;
};
class PhiInst : public Instruction {
public:
PhiInst(std::shared_ptr<Type> type, BasicBlock* parent = nullptr,
const std::string& name = "");
void AddIncoming(Value* value, BasicBlock* block);
int GetNumIncomings() const {
return static_cast<int>(GetNumOperands() / 2);
}
Value* GetIncomingValue(int index) const {
return GetOperand(static_cast<size_t>(2 * index));
}
BasicBlock* GetIncomingBlock(int index) const;
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Phi;
}
};
class ZextInst : public Instruction {
public:
ZextInst(Value* value, std::shared_ptr<Type> target_type,
BasicBlock* parent = nullptr, const std::string& name = "");
Value* GetValue() const { return GetOperand(0); }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Zext;
}
};
class MemsetInst : public Instruction {
public:
MemsetInst(Value* dst, Value* value, Value* len, Value* is_volatile,
BasicBlock* parent = nullptr);
Value* GetDest() const { return GetOperand(0); }
Value* GetValue() const { return GetOperand(1); }
Value* GetLength() const { return GetOperand(2); }
Value* GetIsVolatile() const { return GetOperand(3); }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Memset;
}
};
// BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。
// 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。
class BasicBlock : public Value {
public:
explicit BasicBlock(std::string name);
Function* GetParent() const;
void SetParent(Function* parent);
explicit BasicBlock(const std::string& name);
BasicBlock(Function* parent, const std::string& name);
Function* GetParent() const { return parent_; }
void SetParent(Function* parent) { parent_ = parent; }
bool HasTerminator() const;
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const;
const std::vector<BasicBlock*>& GetPredecessors() const;
const std::vector<BasicBlock*>& GetSuccessors() const;
std::vector<std::unique_ptr<Instruction>>& GetInstructions() {
return instructions_;
}
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const {
return instructions_;
}
void EraseInstruction(Instruction* inst);
void AddPredecessor(BasicBlock* pred);
void AddSuccessor(BasicBlock* succ);
void RemovePredecessor(BasicBlock* pred);
void RemoveSuccessor(BasicBlock* succ);
const std::vector<BasicBlock*>& GetPredecessors() const {
return predecessors_;
}
const std::vector<BasicBlock*>& GetSuccessors() const {
return successors_;
}
template <typename T, typename... Args>
T* Insert(size_t index, Args&&... args) {
if (index > instructions_.size()) {
throw std::out_of_range("BasicBlock insert index out of range");
}
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.insert(instructions_.begin() + static_cast<long long>(index),
std::move(inst));
return ptr;
}
template <typename T, typename... Args>
T* Append(Args&&... args) {
if (HasTerminator()) {
throw std::runtime_error("BasicBlock 已有 terminator不能继续追加指令: " +
name_);
throw std::runtime_error("BasicBlock already has terminator");
}
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
@ -247,6 +580,10 @@ class BasicBlock : public Value {
return ptr;
}
static bool classof(const Value* value) {
return value && dynamic_cast<const BasicBlock*>(value) != nullptr;
}
private:
Function* parent_ = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions_;
@ -254,22 +591,89 @@ class BasicBlock : public Value {
std::vector<BasicBlock*> successors_;
};
// Function 当前也采用了最小实现。
// 需要特别注意:由于项目里还没有单独的 FunctionType
// Function 继承自 Value 后,其 type_ 目前只保存“返回类型”,
// 并不能完整表达“返回类型 + 形参列表”这一整套函数签名。
// 这对当前只支持 int main() 的最小 IR 足够,但后续若补普通函数、
// 形参和调用,通常需要引入专门的函数类型表示。
class Function : public Value {
public:
// 当前构造函数接收的也是返回类型,而不是完整函数类型。
Function(std::string name, std::shared_ptr<Type> ret_type);
Function(std::string name, std::shared_ptr<Type> ret_type,
const std::vector<std::shared_ptr<Type>>& param_types = {},
const std::vector<std::string>& param_names = {},
bool is_external = false);
bool IsFunction() const override final { return true; }
std::shared_ptr<Type> GetReturnType() const { return return_type_; }
const std::vector<std::shared_ptr<Type>>& GetParamTypes() const {
return param_types_;
}
const std::vector<std::unique_ptr<Argument>>& GetArguments() const {
return arguments_;
}
Argument* GetArgument(size_t index) const;
bool IsExternal() const { return is_external_; }
void SetExternal(bool is_external) { is_external_ = is_external; }
void SetEffectInfo(bool reads_global_memory, bool writes_global_memory,
bool reads_param_memory, bool writes_param_memory,
bool has_io, bool has_unknown_effects, bool is_recursive) {
reads_global_memory_ = reads_global_memory;
writes_global_memory_ = writes_global_memory;
reads_param_memory_ = reads_param_memory;
writes_param_memory_ = writes_param_memory;
has_io_ = has_io;
has_unknown_effects_ = has_unknown_effects;
is_recursive_ = is_recursive;
}
bool ReadsGlobalMemory() const { return reads_global_memory_; }
bool WritesGlobalMemory() const { return writes_global_memory_; }
bool ReadsParamMemory() const { return reads_param_memory_; }
bool WritesParamMemory() const { return writes_param_memory_; }
bool HasIO() const { return has_io_; }
bool HasUnknownEffects() const { return has_unknown_effects_; }
bool IsRecursive() const { return is_recursive_; }
bool MayReadMemory() const {
return has_unknown_effects_ || reads_global_memory_ || writes_global_memory_ ||
reads_param_memory_ || writes_param_memory_;
}
bool MayWriteMemory() const {
return has_unknown_effects_ || writes_global_memory_ || writes_param_memory_;
}
bool HasObservableSideEffects() const {
return has_unknown_effects_ || writes_global_memory_ ||
writes_param_memory_ || has_io_;
}
bool CanDiscardUnusedCall() const {
return !has_unknown_effects_ && !writes_global_memory_ &&
!writes_param_memory_ && !has_io_ && !is_recursive_;
}
BasicBlock* GetEntryBlock() const { return entry_; }
BasicBlock* GetEntry() const { return entry_; }
void SetEntryBlock(BasicBlock* bb) { entry_ = bb; }
BasicBlock* EnsureEntryBlock();
BasicBlock* CreateBlock(const std::string& name);
BasicBlock* GetEntry();
const BasicBlock* GetEntry() const;
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const;
BasicBlock* AddBlock(std::unique_ptr<BasicBlock> block);
std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() { return blocks_; }
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const {
return blocks_;
}
static bool classof(const Value* value) {
return value && value->IsFunction();
}
private:
std::shared_ptr<Type> return_type_;
std::vector<std::shared_ptr<Type>> param_types_;
std::vector<std::unique_ptr<Argument>> arguments_;
bool is_external_ = false;
bool reads_global_memory_ = false;
bool writes_global_memory_ = false;
bool reads_param_memory_ = false;
bool writes_param_memory_ = false;
bool has_io_ = false;
bool has_unknown_effects_ = true;
bool is_recursive_ = false;
BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<BasicBlock>> blocks_;
};
@ -277,33 +681,100 @@ class Function : public Value {
class Module {
public:
Module() = default;
Context& GetContext();
const Context& GetContext() const;
// 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。
Function* CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type);
const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
Context& GetContext() { return context_; }
const Context& GetContext() const { return context_; }
Function* CreateFunction(const std::string& name, std::shared_ptr<Type> ret_type,
const std::vector<std::shared_ptr<Type>>& param_types = {},
const std::vector<std::string>& param_names = {},
bool is_external = false);
Function* GetFunction(const std::string& name) const;
const std::vector<std::unique_ptr<Function>>& GetFunctions() const {
return functions_;
}
GlobalValue* CreateGlobalValue(const std::string& name,
std::shared_ptr<Type> object_type,
bool is_const = false, Value* init = nullptr);
GlobalValue* GetGlobalValue(const std::string& name) const;
const std::vector<std::unique_ptr<GlobalValue>>& GetGlobalValues() const {
return globals_;
}
private:
Context context_;
std::vector<std::unique_ptr<Function>> functions_;
std::map<std::string, Function*> function_map_;
std::vector<std::unique_ptr<GlobalValue>> globals_;
std::map<std::string, GlobalValue*> global_map_;
};
class IRBuilder {
public:
IRBuilder(Context& ctx, BasicBlock* bb);
void SetInsertPoint(BasicBlock* bb);
BasicBlock* GetInsertBlock() const;
BasicBlock* GetInsertBlock() const { return insert_block_; }
// 构造常量、二元运算、返回指令的最小集合。
ConstantInt* CreateConstInt(int v);
ConstantFloat* CreateConstFloat(float v);
ConstantI1* CreateConstBool(bool v);
ConstantArrayValue* CreateConstArray(std::shared_ptr<Type> array_type,
const std::vector<Value*>& elements,
const std::vector<size_t>& dims,
const std::string& name = "");
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name);
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name);
AllocaInst* CreateAllocaI32(const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name);
const std::string& name = "");
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateDiv(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateRem(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateAnd(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateOr(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateXor(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateShl(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateAShr(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateLShr(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateICmp(Opcode op, Value* lhs, Value* rhs,
const std::string& name = "");
BinaryInst* CreateFCmp(Opcode op, Value* lhs, Value* rhs,
const std::string& name = "");
UnaryInst* CreateNeg(Value* operand, const std::string& name = "");
UnaryInst* CreateNot(Value* operand, const std::string& name = "");
UnaryInst* CreateFNeg(Value* operand, const std::string& name = "");
UnaryInst* CreateFtoI(Value* operand, const std::string& name = "");
UnaryInst* CreateIToF(Value* operand, const std::string& name = "");
AllocaInst* CreateAlloca(std::shared_ptr<Type> allocated_type,
const std::string& name = "");
LoadInst* CreateLoad(Value* ptr, std::shared_ptr<Type> value_type,
const std::string& name = "");
LoadInst* CreateLoad(Value* ptr, const std::string& name = "") {
return CreateLoad(ptr, Type::GetInt32Type(), name);
}
StoreInst* CreateStore(Value* val, Value* ptr);
ReturnInst* CreateRet(Value* v);
UncondBrInst* CreateBr(BasicBlock* dest);
CondBrInst* CreateCondBr(Value* cond, BasicBlock* then_bb,
BasicBlock* else_bb);
ReturnInst* CreateRet(Value* val = nullptr);
UnreachableInst* CreateUnreachable();
CallInst* CreateCall(Function* callee, const std::vector<Value*>& args,
const std::string& name = "");
GetElementPtrInst* CreateGEP(Value* ptr, std::shared_ptr<Type> source_type,
const std::vector<Value*>& indices,
const std::string& name = "");
PhiInst* CreatePhi(std::shared_ptr<Type> type, const std::string& name = "");
ZextInst* CreateZext(Value* val, std::shared_ptr<Type> target_type,
const std::string& name = "");
MemsetInst* CreateMemset(Value* dst, Value* val, Value* len,
Value* is_volatile);
private:
Context& ctx_;
@ -315,4 +786,14 @@ class IRPrinter {
void Print(const Module& module, std::ostream& os);
};
} // namespace ir
inline std::ostream& operator<<(std::ostream& os, const Type& type) {
type.Print(os);
return os;
}
inline std::ostream& operator<<(std::ostream& os, const Value& value) {
value.Print(os);
return os;
}
} // namespace ir

@ -0,0 +1,373 @@
// 当前只支撑 i32、i32*、void 以及最小的内存/算术指令,演示用。
//
// 当前已经实现:
// 1. 基础类型系统void / i32 / i32*
// 2. Value 体系Value / ConstantValue / ConstantInt / Function / BasicBlock / User / GlobalValue / Instruction
// 3. 最小指令集Add / Alloca / Load / Store / Ret
// 4. BasicBlock / Function / Module 三层组织结构
// 5. IRBuilder便捷创建常量和最小指令
// 6. def-use 关系的轻量实现:
// - Instruction 保存 operand 列表
// - Value 保存 uses
// - 支持 ReplaceAllUsesWith 的简化实现
//
// 当前尚未实现或只做了最小占位:
// 1. 完整类型系统数组、函数类型、label 类型等
// 2. 更完整的指令系统br / condbr / call / phi / gep 等
// 3. 更成熟的 Use 管理(例如 LLVM 风格的双向链式结构)
// 4. 更完整的 IR verifier 和优化基础设施
//
// 当前需要特别说明的两个简化点:
// 1. BasicBlock 虽然已经纳入 Value 体系,但其类型目前仍用 void 作为占位,
// 后续如果补 label type可以再改成更合理的块标签类型。
// 2. ConstantValue 体系目前只实现了 ConstantInt后续可以继续补 ConstantFloat、
// ConstantArray等更完整的常量种类。
//
// 建议的扩展顺序:
// 1. 先补更多指令和类型
// 2. 再补控制流相关 IR
// 3. 最后再考虑把 Value/User/Use 进一步抽象成更完整的框架
#pragma once
#include <iosfwd>
#include <memory>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace ir {
class Type;
class Value;
class User;
class ConstantValue;
class ConstantInt;
class GlobalValue;
class Instruction;
class BasicBlock;
class Function;
// Use 表示一个 Value 的一次使用记录。
// 当前实现设计:
// - value被使用的值
// - user使用该值的 User
// - operand_index该值在 user 操作数列表中的位置
class Use {
public:
Use() = default;
Use(Value* value, User* user, size_t operand_index)
: value_(value), user_(user), operand_index_(operand_index) {}
Value* GetValue() const { return value_; }
User* GetUser() const { return user_; }
size_t GetOperandIndex() const { return operand_index_; }
void SetValue(Value* value) { value_ = value; }
void SetUser(User* user) { user_ = user; }
void SetOperandIndex(size_t operand_index) { operand_index_ = operand_index; }
private:
Value* value_ = nullptr;
User* user_ = nullptr;
size_t operand_index_ = 0;
};
// IR 上下文:集中管理类型、常量等共享资源,便于复用与扩展。
class Context {
public:
Context() = default;
~Context();
// 去重创建 i32 常量。
ConstantInt* GetConstInt(int v);
std::string NextTemp();
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
int temp_index_ = -1;
};
class Type {
public:
enum class Kind { Void, Int1, Int32, Float, Label, Function, PtrInt32, Array };
explicit Type(Kind k);
// 静态工厂方法:返回对应类型的共享单例
static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt1Type();
static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetFloatType();
static const std::shared_ptr<Type>& GetLabelType();
static const std::shared_ptr<Type>& GetFunctionType();
static const std::shared_ptr<Type>& GetBoolType();
static const std::shared_ptr<Type>& GetPtrInt32Type();
static const std::shared_ptr<Type>& GetArrayType();
Kind GetKind() const;
// 便捷类型判断
bool IsVoid() const;
bool IsInt1() const;
bool IsInt32() const;
bool IsFloat() const;
bool IsLabel() const;
bool IsFunction() const;
bool IsBool() const;
bool IsPtrInt32() const;
bool IsArray() const;
private:
Kind kind_;
};
class Value {
public:
Value(std::shared_ptr<Type> ty, std::string name);
virtual ~Value() = default;
const std::shared_ptr<Type>& GetType() const;
const std::string& GetName() const;
void SetName(std::string n);
bool IsVoid() const;
bool IsInt32() const;
bool IsPtrInt32() const;
bool IsConstant() const;
bool IsInstruction() const;
bool IsUser() const;
bool IsFunction() const;
void AddUse(User* user, size_t operand_index);
void RemoveUse(User* user, size_t operand_index);
const std::vector<Use>& GetUses() const;
void ReplaceAllUsesWith(Value* new_value);
protected:
std::shared_ptr<Type> type_;
std::string name_;
std::vector<Use> uses_;
};
// ConstantValue 是常量体系的基类。
// 当前只实现了 ConstantInt后续可继续扩展更多常量种类。
class ConstantValue : public Value {
public:
ConstantValue(std::shared_ptr<Type> ty, std::string name = "");
};
class ConstantInt : public ConstantValue {
public:
ConstantInt(std::shared_ptr<Type> ty, int v);
int GetValue() const { return value_; }
private:
int value_{};
};
class ConstantFloat : public ConstantValue {
public:
ConstantFloat(std::shared_ptr<Type> ty, float v);
float GetValue() const { return value_; }
private:
float value_{};
};
class ConstantI1 : public ConstantValue {
public:
ConstantI1(std::shared_ptr<Type> ty, bool v);
int GetValue() const { return value_; }
private:
bool value_{};
};
class ConstantArrayValue : public Value {
public:
ConstantArrayValue()
};
//暂时先设计这些
enum class Opcode {
// 二元算术
Add,Sub,Mul,Div,Rem,FAdd,FSub,FMul,FDiv,FRem,
// 位运算
And,Or,Xor,Shl,AShr,LShr,
// 整数比较
ICmpEQ,ICmpNE,ICmpLT,ICmpGT,ICmpLE,ICmpGE,
// 浮点比较
FCmpEQ,FCmpNE,FCmpLT,FCmpGT,FCmpLE,FCmpGE,
// 一元运算
Neg,Not,FNeg,FtoI,IToF,
// 调用与终止
Call,CondBr,Br,Return,Unreachable,
// 内存操作
Alloca,Load,Store,Memset,
// 其他
GetElementPtr,Phi,Zext
};
// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。
// 当前实现中只有 Instruction 继承自 User。
class User : public Value {
public:
User(std::shared_ptr<Type> ty, std::string name);
size_t GetNumOperands() const;
Value* GetOperand(size_t index) const;
void SetOperand(size_t index, Value* value);
protected:
void AddOperand(Value* value);
private:
std::vector<Value*> operands_;
};
// GlobalValue 是全局值/全局变量体系的空壳占位类。
// 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。
class GlobalValue : public User {
public:
GlobalValue(std::shared_ptr<Type> ty, std::string name);
};
class Instruction : public User {
public:
Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name = "");
Opcode GetOpcode() const;
bool IsTerminator() const;
BasicBlock* GetParent() const;
void SetParent(BasicBlock* parent);
private:
Opcode opcode_;
BasicBlock* parent_ = nullptr;
};
class BinaryInst : public Instruction {
public:
BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, Value* rhs,
std::string name);
Value* GetLhs() const;
Value* GetRhs() const;
};
class ReturnInst : public Instruction {
public:
ReturnInst(std::shared_ptr<Type> void_ty, Value* val);
Value* GetValue() const;
};
class AllocaInst : public Instruction {
public:
AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name);
};
class LoadInst : public Instruction {
public:
LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name);
Value* GetPtr() const;
};
class StoreInst : public Instruction {
public:
StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr);
Value* GetValue() const;
Value* GetPtr() const;
};
// BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。
// 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。
class BasicBlock : public Value {
public:
explicit BasicBlock(std::string name);
Function* GetParent() const;
void SetParent(Function* parent);
bool HasTerminator() const;
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const;
const std::vector<BasicBlock*>& GetPredecessors() const;
const std::vector<BasicBlock*>& GetSuccessors() const;
template <typename T, typename... Args>
T* Append(Args&&... args) {
if (HasTerminator()) {
throw std::runtime_error("BasicBlock 已有 terminator不能继续追加指令: " +
name_);
}
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.push_back(std::move(inst));
return ptr;
}
private:
Function* parent_ = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions_;
std::vector<BasicBlock*> predecessors_;
std::vector<BasicBlock*> successors_;
};
// Function 当前也采用了最小实现。
// 需要特别注意:由于项目里还没有单独的 FunctionType
// Function 继承自 Value 后,其 type_ 目前只保存“返回类型”,
// 并不能完整表达“返回类型 + 形参列表”这一整套函数签名。
// 这对当前只支持 int main() 的最小 IR 足够,但后续若补普通函数、
// 形参和调用,通常需要引入专门的函数类型表示。
class Function : public Value {
public:
// 当前构造函数接收的也是返回类型,而不是完整函数类型。
Function(std::string name, std::shared_ptr<Type> ret_type);
BasicBlock* CreateBlock(const std::string& name);
BasicBlock* GetEntry();
const BasicBlock* GetEntry() const;
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const;
private:
BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<BasicBlock>> blocks_;
};
class Module {
public:
Module() = default;
Context& GetContext();
const Context& GetContext() const;
// 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。
Function* CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type);
const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
private:
Context context_;
std::vector<std::unique_ptr<Function>> functions_;
};
class IRBuilder {
public:
IRBuilder(Context& ctx, BasicBlock* bb);
void SetInsertPoint(BasicBlock* bb);
BasicBlock* GetInsertBlock() const;
// 构造常量、二元运算、返回指令的最小集合。
ConstantInt* CreateConstInt(int v);
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name);
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name);
AllocaInst* CreateAllocaI32(const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr);
ReturnInst* CreateRet(Value* v);
private:
Context& ctx_;
BasicBlock* insert_block_;
};
class IRPrinter {
public:
void Print(const Module& module, std::ostream& os);
};
} // namespace ir

@ -0,0 +1,26 @@
#pragma once
namespace ir {
class Module;
void RunMem2Reg(Module& module);
bool RunConstFold(Module& module);
bool RunConstProp(Module& module);
bool RunFunctionInlining(Module& module);
bool RunTailRecursionElim(Module& module);
bool RunArithmeticSimplify(Module& module);
bool RunCSE(Module& module);
bool RunGVN(Module& module);
bool RunLoadStoreElim(Module& module);
bool RunDCE(Module& module);
bool RunCFGSimplify(Module& module);
bool RunLICM(Module& module);
bool RunLoopMemoryPromotion(Module& module);
bool RunLoopUnswitch(Module& module);
bool RunLoopStrengthReduction(Module& module);
bool RunLoopUnroll(Module& module);
bool RunLoopFission(Module& module);
void RunIRPassPipeline(Module& module);
} // namespace ir

@ -0,0 +1,41 @@
#pragma once
#include <iterator>
namespace ir {
template <typename IterT> struct range {
using iterator = IterT;
using value_type = typename std::iterator_traits<iterator>::value_type;
using reference = typename std::iterator_traits<iterator>::reference;
private:
iterator b;
iterator e;
public:
explicit range(iterator b, iterator e) : b(b), e(e) {}
iterator begin() { return b; }
iterator end() { return e; }
iterator begin() const { return b; }
iterator end() const { return e; }
auto size() const { return std::distance(b, e); }
auto empty() const { return b == e; }
};
//! create `range` object from iterator pair [begin, end)
template <typename IterT> range<IterT> make_range(IterT b, IterT e) {
return range<IterT>(b, e);
}
//! create `range` object from a container who has `begin()` and `end()` methods
template <typename ContainerT>
range<typename ContainerT::iterator> make_range(ContainerT &c) {
return make_range(c.begin(), c.end());
}
//! create `range` object from a container who has `begin()` and `end()` methods
template <typename ContainerT>
range<typename ContainerT::const_iterator> make_range(const ContainerT &c) {
return make_range(c.begin(), c.end());
}
} // namespace ir

@ -1,58 +1,193 @@
// 将语法树翻译为 IR。
// 实现拆分在 IRGenFunc/IRGenStmt/IRGenExp/IRGenDecl。
#pragma once
#include <any>
#include <memory>
#include <optional>
#include <string>
#include <unordered_map>
#include <vector>
#include "SysYBaseVisitor.h"
#include "SysYParser.h"
#include "ir/IR.h"
#include "sem/Sema.h"
namespace ir {
class Module;
class Function;
class IRBuilder;
class Value;
}
#include "sem/SymbolTable.h"
class IRGenImpl final : public SysYBaseVisitor {
public:
IRGenImpl(ir::Module& module, const SemanticContext& sema);
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override;
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override;
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override;
std::any visitDecl(SysYParser::DeclContext* ctx) override;
std::any visitStmt(SysYParser::StmtContext* ctx) override;
std::any visitVarDef(SysYParser::VarDefContext* ctx) override;
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override;
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override;
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override;
std::any visitVarExp(SysYParser::VarExpContext* ctx) override;
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override;
private:
enum class BlockFlow {
enum class FlowState {
Continue,
Terminated,
};
BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item);
ir::Value* EvalExpr(SysYParser::ExpContext& expr);
struct TypedValue {
ir::Value* value = nullptr;
SemanticType type = SemanticType::Int;
bool is_array = false;
std::vector<int> dims;
};
struct LValueInfo {
SymbolEntry* symbol = nullptr;
ir::Value* addr = nullptr;
SemanticType type = SemanticType::Int;
bool is_array = false;
std::vector<int> dims;
bool root_param_array_no_index = false;
};
struct LoopContext {
ir::BasicBlock* cond_block = nullptr;
ir::BasicBlock* exit_block = nullptr;
};
struct InitExprSlot {
size_t index = 0;
SysYParser::ExpContext* expr = nullptr;
};
[[noreturn]] void ThrowError(const antlr4::ParserRuleContext* ctx,
const std::string& message) const;
void ApplyFunctionSema(const std::string& name, ir::Function& function);
void RegisterBuiltinFunctions();
void PredeclareTopLevel(SysYParser::CompUnitContext& ctx);
void PredeclareFunction(SysYParser::FuncDefContext& ctx);
void PredeclareGlobalDecl(SysYParser::DeclContext& ctx);
void EmitGlobalDecl(SysYParser::DeclContext& ctx);
void EmitFunction(SysYParser::FuncDefContext& ctx);
void BindFunctionParams(SysYParser::FuncDefContext& ctx, ir::Function& func);
void EmitBlock(SysYParser::BlockContext& ctx, bool create_scope = true);
FlowState EmitBlockItem(SysYParser::BlockItemContext& ctx);
FlowState EmitStmt(SysYParser::StmtContext& ctx);
void EmitDecl(SysYParser::DeclContext& ctx, bool is_global);
void EmitVarDecl(SysYParser::VarDeclContext* ctx, bool is_global, bool is_const);
void EmitConstDecl(SysYParser::ConstDeclContext* ctx, bool is_global);
void EmitGlobalVarDef(SysYParser::VarDefContext& ctx, SemanticType type);
void EmitGlobalConstDef(SysYParser::ConstDefContext& ctx, SemanticType type);
void EmitLocalVarDef(SysYParser::VarDefContext& ctx, SemanticType type, bool is_const);
void EmitLocalConstDef(SysYParser::ConstDefContext& ctx, SemanticType type);
std::string ExpectIdent(const antlr4::ParserRuleContext& ctx,
antlr4::tree::TerminalNode* ident) const;
SemanticType ParseBType(SysYParser::BTypeContext* ctx) const;
SemanticType ParseFuncType(SysYParser::FuncTypeContext* ctx) const;
std::shared_ptr<ir::Type> GetIRScalarType(SemanticType type) const;
std::shared_ptr<ir::Type> BuildArrayType(SemanticType base_type,
const std::vector<int>& dims) const;
std::vector<int> ParseArrayDims(const std::vector<SysYParser::ConstExpContext*>& dims_ctx);
std::vector<int> ParseParamDims(SysYParser::FuncFParamContext& ctx);
FunctionTypeInfo BuildFunctionTypeInfo(SysYParser::FuncDefContext& ctx);
std::vector<std::shared_ptr<ir::Type>> BuildFunctionIRParamTypes(
const FunctionTypeInfo& function_type) const;
std::vector<std::string> BuildFunctionIRParamNames(SysYParser::FuncDefContext& ctx) const;
TypedValue EmitExp(SysYParser::ExpContext& ctx);
TypedValue EmitAddExp(SysYParser::AddExpContext& ctx);
TypedValue EmitMulExp(SysYParser::MulExpContext& ctx);
TypedValue EmitUnaryExp(SysYParser::UnaryExpContext& ctx);
TypedValue EmitPrimaryExp(SysYParser::PrimaryExpContext& ctx);
TypedValue EmitRelExp(SysYParser::RelExpContext& ctx);
TypedValue EmitEqExp(SysYParser::EqExpContext& ctx);
TypedValue EmitLValValue(SysYParser::LValContext& ctx);
LValueInfo ResolveLVal(SysYParser::LValContext& ctx);
ir::Value* GenLValAddr(SysYParser::LValContext& ctx);
void EmitCond(SysYParser::CondContext& ctx, ir::BasicBlock* true_block,
ir::BasicBlock* false_block);
void EmitLOrCond(SysYParser::LOrExpContext& ctx, ir::BasicBlock* true_block,
ir::BasicBlock* false_block);
void EmitLAndCond(SysYParser::LAndExpContext& ctx, ir::BasicBlock* true_block,
ir::BasicBlock* false_block);
TypedValue CastScalar(TypedValue value, SemanticType target_type,
const antlr4::ParserRuleContext* ctx);
ir::Value* CastToCondition(TypedValue value,
const antlr4::ParserRuleContext* ctx);
TypedValue NormalizeLogicalValue(TypedValue value,
const antlr4::ParserRuleContext* ctx);
bool IsNumeric(const TypedValue& value) const;
bool IsSameDims(const std::vector<int>& lhs, const std::vector<int>& rhs) const;
ConstantValue ParseNumber(SysYParser::NumberContext& ctx) const;
ConstantValue EvalConstExp(SysYParser::ExpContext& ctx);
ConstantValue EvalConstAddExp(SysYParser::AddExpContext& ctx);
ConstantValue EvalConstMulExp(SysYParser::MulExpContext& ctx);
ConstantValue EvalConstUnaryExp(SysYParser::UnaryExpContext& ctx);
ConstantValue EvalConstPrimaryExp(SysYParser::PrimaryExpContext& ctx);
ConstantValue EvalConstLVal(SysYParser::LValContext& ctx);
ConstantValue ConvertConst(ConstantValue value, SemanticType target_type) const;
bool IsZeroConstant(const ConstantValue& value) const;
bool IsExplicitZeroConstInitVal(SysYParser::ConstInitValContext* ctx,
SemanticType base_type);
bool IsExplicitZeroInitVal(SysYParser::InitValContext* ctx,
SemanticType base_type);
std::vector<ConstantValue> FlattenConstInitVal(SysYParser::ConstInitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims);
std::vector<ConstantValue> FlattenInitVal(SysYParser::InitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims);
std::vector<InitExprSlot> FlattenLocalInitVal(SysYParser::InitValContext* ctx,
const std::vector<int>& dims);
void FlattenConstInitValImpl(SysYParser::ConstInitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims, size_t depth,
size_t object_begin, size_t object_end,
size_t& cursor, std::vector<ConstantValue>& out);
void FlattenInitValImpl(SysYParser::InitValContext* ctx, SemanticType base_type,
const std::vector<int>& dims, size_t depth,
size_t object_begin, size_t object_end,
size_t& cursor, std::vector<ConstantValue>& out);
void FlattenLocalInitValImpl(SysYParser::InitValContext* ctx,
const std::vector<int>& dims, size_t depth,
size_t object_begin, size_t object_end,
size_t& cursor, std::vector<InitExprSlot>& out);
size_t CountArrayElements(const std::vector<int>& dims, size_t start = 0) const;
size_t AlignInitializerCursor(const std::vector<int>& dims, size_t depth,
size_t object_begin, size_t cursor) const;
size_t FlattenIndices(const std::vector<int>& dims,
const std::vector<int>& indices) const;
ConstantValue ZeroConst(SemanticType type) const;
ir::Value* ZeroIRValue(SemanticType type);
ir::Value* CreateTypedConstant(const ConstantValue& value);
ir::AllocaInst* CreateEntryAlloca(std::shared_ptr<ir::Type> allocated_type,
const std::string& name);
void ZeroInitializeLocalArray(ir::Value* addr, SemanticType base_type,
const std::vector<int>& dims);
void StoreLocalArrayElements(ir::Value* addr, SemanticType base_type,
const std::vector<int>& dims,
const std::vector<InitExprSlot>& init_slots);
ir::Value* CreateArrayElementAddr(ir::Value* base_addr, bool is_param_array,
SemanticType base_type,
const std::vector<int>& dims,
const std::vector<ir::Value*>& indices,
const antlr4::ParserRuleContext* ctx);
std::string NextTemp();
std::string NextBlockName(const std::string& prefix);
ir::Module& module_;
const SemanticContext& sema_;
ir::Function* func_;
ir::IRBuilder builder_;
// 名称绑定由 Sema 负责IRGen 只维护“声明 -> 存储槽位”的代码生成状态。
std::unordered_map<SysYParser::VarDefContext*, ir::Value*> storage_map_;
SymbolTable symbols_;
ir::Function* current_function_ = nullptr;
SemanticType current_return_type_ = SemanticType::Void;
std::vector<LoopContext> loop_stack_;
bool builtins_registered_ = false;
};
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemanticContext& sema);
const SemanticContext& sema);

@ -1,119 +1,300 @@
#pragma once
#include <initializer_list>
#include <iosfwd>
#include <memory>
#include <string>
#include <vector>
namespace ir {
class Module;
}
namespace mir {
class MIRContext {
public:
MIRContext() = default;
};
MIRContext& DefaultContext();
enum class PhysReg { W0, W8, W9, X29, X30, SP };
const char* PhysRegName(PhysReg reg);
enum class Opcode {
Prologue,
Epilogue,
MovImm,
LoadStack,
StoreStack,
AddRR,
Ret,
};
class Operand {
public:
enum class Kind { Reg, Imm, FrameIndex };
static Operand Reg(PhysReg reg);
static Operand Imm(int value);
static Operand FrameIndex(int index);
Kind GetKind() const { return kind_; }
PhysReg GetReg() const { return reg_; }
int GetImm() const { return imm_; }
int GetFrameIndex() const { return imm_; }
private:
Operand(Kind kind, PhysReg reg, int imm);
Kind kind_;
PhysReg reg_;
int imm_;
};
class MachineInstr {
public:
MachineInstr(Opcode opcode, std::vector<Operand> operands = {});
Opcode GetOpcode() const { return opcode_; }
const std::vector<Operand>& GetOperands() const { return operands_; }
private:
Opcode opcode_;
std::vector<Operand> operands_;
};
struct FrameSlot {
int index = 0;
int size = 4;
int offset = 0;
};
class MachineBasicBlock {
public:
explicit MachineBasicBlock(std::string name);
const std::string& GetName() const { return name_; }
std::vector<MachineInstr>& GetInstructions() { return instructions_; }
const std::vector<MachineInstr>& GetInstructions() const { return instructions_; }
MachineInstr& Append(Opcode opcode,
std::initializer_list<Operand> operands = {});
private:
std::string name_;
std::vector<MachineInstr> instructions_;
};
class MachineFunction {
public:
explicit MachineFunction(std::string name);
const std::string& GetName() const { return name_; }
MachineBasicBlock& GetEntry() { return entry_; }
const MachineBasicBlock& GetEntry() const { return entry_; }
int CreateFrameIndex(int size = 4);
FrameSlot& GetFrameSlot(int index);
const FrameSlot& GetFrameSlot(int index) const;
const std::vector<FrameSlot>& GetFrameSlots() const { return frame_slots_; }
int GetFrameSize() const { return frame_size_; }
void SetFrameSize(int size) { frame_size_ = size; }
private:
std::string name_;
MachineBasicBlock entry_;
std::vector<FrameSlot> frame_slots_;
int frame_size_ = 0;
};
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineFunction& function);
void RunFrameLowering(MachineFunction& function);
void PrintAsm(const MachineFunction& function, std::ostream& os);
} // namespace mir
#pragma once
#include <cstdint>
#include <iosfwd>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace ir {
class Module;
}
namespace mir {
class MIRContext {
public:
MIRContext() = default;
};
MIRContext& DefaultContext();
enum class ValueType { Void, I1, I32, F32, Ptr };
enum class RegClass { GPR, FPR };
enum class CondCode { EQ, NE, LT, GT, LE, GE };
enum class StackObjectKind { Local, Spill, SavedGPR, SavedFPR };
enum class AddrBaseKind { None, FrameObject, Global, VReg };
enum class OperandKind { Invalid, VReg, Imm, Block, Symbol };
struct PhysReg {
RegClass reg_class = RegClass::GPR;
int index = -1;
bool IsValid() const { return index >= 0; }
bool operator==(const PhysReg& rhs) const {
return reg_class == rhs.reg_class && index == rhs.index;
}
};
bool IsGPR(ValueType type);
bool IsFPR(ValueType type);
int GetValueSize(ValueType type);
int GetValueAlign(ValueType type);
const char* GetPhysRegName(PhysReg reg, ValueType type);
class MachineOperand {
public:
MachineOperand() = default;
static MachineOperand VReg(int reg);
static MachineOperand Imm(std::int64_t value);
static MachineOperand Block(std::string name);
static MachineOperand Symbol(std::string name);
OperandKind GetKind() const { return kind_; }
int GetVReg() const { return vreg_; }
std::int64_t GetImm() const { return imm_; }
const std::string& GetText() const { return text_; }
private:
MachineOperand(OperandKind kind, int vreg, std::int64_t imm, std::string text);
OperandKind kind_ = OperandKind::Invalid;
int vreg_ = -1;
std::int64_t imm_ = 0;
std::string text_;
};
struct AddressExpr {
AddrBaseKind base_kind = AddrBaseKind::None;
int base_index = -1;
std::string symbol;
std::int64_t const_offset = 0;
std::vector<std::pair<int, std::int64_t>> scaled_vregs;
};
struct StackObject {
int index = -1;
StackObjectKind kind = StackObjectKind::Local;
int size = 0;
int align = 1;
int offset = 0;
std::string name;
};
struct VRegInfo {
int id = -1;
ValueType type = ValueType::Void;
};
struct Allocation {
enum class Kind { Unassigned, PhysReg, Spill };
Kind kind = Kind::Unassigned;
PhysReg phys;
int stack_object = -1;
};
class MachineInstr {
public:
enum class Opcode {
Arg,
Copy,
Load,
Store,
Lea,
Add,
Sub,
Mul,
Div,
Rem,
ModMul,
ModPow,
DigitExtractPow2,
And,
Or,
Xor,
Shl,
AShr,
LShr,
FAdd,
FSub,
FMul,
FDiv,
FSqrt,
FNeg,
ICmp,
FCmp,
ZExt,
ItoF,
FtoI,
Br,
CondBr,
Call,
Ret,
Memset,
Unreachable,
};
explicit MachineInstr(Opcode opcode,
std::vector<MachineOperand> operands = {});
Opcode GetOpcode() const { return opcode_; }
const std::vector<MachineOperand>& GetOperands() const { return operands_; }
std::vector<MachineOperand>& GetOperands() { return operands_; }
void SetCondCode(CondCode code) { cond_code_ = code; }
CondCode GetCondCode() const { return cond_code_; }
void SetAddress(AddressExpr address) {
address_ = std::move(address);
has_address_ = true;
}
bool HasAddress() const { return has_address_; }
const AddressExpr& GetAddress() const { return address_; }
AddressExpr& GetAddress() { return address_; }
void SetCallInfo(std::string callee, std::vector<ValueType> arg_types,
ValueType return_type) {
callee_ = std::move(callee);
call_arg_types_ = std::move(arg_types);
call_return_type_ = return_type;
}
const std::string& GetCallee() const { return callee_; }
const std::vector<ValueType>& GetCallArgTypes() const { return call_arg_types_; }
ValueType GetCallReturnType() const { return call_return_type_; }
void SetValueType(ValueType type) { value_type_ = type; }
ValueType GetValueType() const { return value_type_; }
bool IsTerminator() const;
std::vector<int> GetDefs() const;
std::vector<int> GetUses() const;
private:
Opcode opcode_;
std::vector<MachineOperand> operands_;
CondCode cond_code_ = CondCode::EQ;
AddressExpr address_;
bool has_address_ = false;
std::string callee_;
std::vector<ValueType> call_arg_types_;
ValueType call_return_type_ = ValueType::Void;
ValueType value_type_ = ValueType::Void;
};
class MachineBasicBlock {
public:
explicit MachineBasicBlock(std::string name);
const std::string& GetName() const { return name_; }
std::vector<MachineInstr>& GetInstructions() { return instructions_; }
const std::vector<MachineInstr>& GetInstructions() const { return instructions_; }
MachineInstr& Append(MachineInstr::Opcode opcode,
std::vector<MachineOperand> operands = {});
MachineInstr& Append(MachineInstr instr);
private:
std::string name_;
std::vector<MachineInstr> instructions_;
};
class MachineFunction {
public:
MachineFunction(std::string name, ValueType return_type,
std::vector<ValueType> param_types);
const std::string& GetName() const { return name_; }
ValueType GetReturnType() const { return return_type_; }
const std::vector<ValueType>& GetParamTypes() const { return param_types_; }
MachineBasicBlock* CreateBlock(const std::string& name);
std::vector<std::unique_ptr<MachineBasicBlock>>& GetBlocks() { return blocks_; }
const std::vector<std::unique_ptr<MachineBasicBlock>>& GetBlocks() const {
return blocks_;
}
int NewVReg(ValueType type);
const VRegInfo& GetVRegInfo(int id) const;
VRegInfo& GetVRegInfo(int id);
const std::vector<VRegInfo>& GetVRegs() const { return vregs_; }
int CreateStackObject(int size, int align, StackObjectKind kind,
const std::string& name = "");
StackObject& GetStackObject(int index);
const StackObject& GetStackObject(int index) const;
const std::vector<StackObject>& GetStackObjects() const { return stack_objects_; }
void SetAllocation(int vreg, Allocation allocation);
const Allocation& GetAllocation(int vreg) const;
Allocation& GetAllocation(int vreg);
void AddUsedCalleeSavedGPR(int reg_index);
void AddUsedCalleeSavedFPR(int reg_index);
const std::vector<int>& GetUsedCalleeSavedGPRs() const {
return used_callee_saved_gprs_;
}
const std::vector<int>& GetUsedCalleeSavedFPRs() const {
return used_callee_saved_fprs_;
}
void SetFrameSize(int size) { frame_size_ = size; }
int GetFrameSize() const { return frame_size_; }
void SetMaxOutgoingArgBytes(int bytes) { max_outgoing_arg_bytes_ = bytes; }
int GetMaxOutgoingArgBytes() const { return max_outgoing_arg_bytes_; }
private:
std::string name_;
ValueType return_type_ = ValueType::Void;
std::vector<ValueType> param_types_;
std::vector<std::unique_ptr<MachineBasicBlock>> blocks_;
std::vector<VRegInfo> vregs_;
std::vector<StackObject> stack_objects_;
std::vector<Allocation> allocations_;
std::vector<int> used_callee_saved_gprs_;
std::vector<int> used_callee_saved_fprs_;
int frame_size_ = 0;
int max_outgoing_arg_bytes_ = 0;
};
class MachineModule {
public:
explicit MachineModule(const ir::Module& source) : source_(&source) {}
const ir::Module& GetSourceModule() const { return *source_; }
std::vector<std::unique_ptr<MachineFunction>>& GetFunctions() { return functions_; }
const std::vector<std::unique_ptr<MachineFunction>>& GetFunctions() const {
return functions_;
}
MachineFunction* AddFunction(std::unique_ptr<MachineFunction> function);
private:
const ir::Module* source_ = nullptr;
std::vector<std::unique_ptr<MachineFunction>> functions_;
};
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module);
bool RunPeephole(MachineModule& module);
bool RunSpillReduction(MachineModule& module);
bool RunCFGCleanup(MachineModule& module);
void RunMIRPreRegAllocPassPipeline(MachineModule& module);
void RunMIRPostRegAllocPassPipeline(MachineModule& module);
void RunAddressHoisting(MachineModule& module);
void RunRegAlloc(MachineModule& module);
void RunFrameLowering(MachineModule& module);
void PrintAsm(const MachineModule& module, std::ostream& os);
} // namespace mir

@ -1,30 +1,94 @@
// 基于语法树的语义检查与名称绑定。
#pragma once
#pragma once
#include "SysYParser.h"
#include "sem/SymbolTable.h"
#include <string>
#include <unordered_map>
#include <vector>
#include "SysYParser.h"
struct GlobalSemanticInfo {
SemanticType type = SemanticType::Int;
bool is_const = false;
bool is_array = false;
std::vector<int> dims;
};
struct FunctionSemanticInfo {
SemanticType return_type = SemanticType::Void;
std::vector<bool> param_is_array;
bool is_builtin = false;
bool is_defined = false;
bool reads_global_memory = false;
bool writes_global_memory = false;
bool reads_param_memory = false;
bool writes_param_memory = false;
bool has_io = false;
bool has_unknown_effects = true;
bool is_recursive = false;
std::vector<std::string> direct_callees;
bool MayReadMemory() const {
return has_unknown_effects || reads_global_memory || writes_global_memory ||
reads_param_memory || writes_param_memory;
}
bool MayWriteMemory() const {
return has_unknown_effects || writes_global_memory || writes_param_memory;
}
bool HasObservableSideEffects() const {
return has_unknown_effects || writes_global_memory || writes_param_memory ||
has_io;
}
bool CanDiscardUnusedCall() const {
return !has_unknown_effects && !writes_global_memory &&
!writes_param_memory && !has_io && !is_recursive;
}
};
class SemanticContext {
public:
void BindVarUse(SysYParser::VarContext* use,
SysYParser::VarDefContext* decl) {
var_uses_[use] = decl;
FunctionSemanticInfo* LookupFunction(const std::string& name) {
auto it = functions_.find(name);
return it == functions_.end() ? nullptr : &it->second;
}
const FunctionSemanticInfo* LookupFunction(const std::string& name) const {
auto it = functions_.find(name);
return it == functions_.end() ? nullptr : &it->second;
}
GlobalSemanticInfo* LookupGlobal(const std::string& name) {
auto it = globals_.find(name);
return it == globals_.end() ? nullptr : &it->second;
}
const GlobalSemanticInfo* LookupGlobal(const std::string& name) const {
auto it = globals_.find(name);
return it == globals_.end() ? nullptr : &it->second;
}
FunctionSemanticInfo& UpsertFunction(const std::string& name) {
return functions_[name];
}
GlobalSemanticInfo& UpsertGlobal(const std::string& name) {
return globals_[name];
}
const std::unordered_map<std::string, FunctionSemanticInfo>& GetFunctions() const {
return functions_;
}
SysYParser::VarDefContext* ResolveVarUse(
const SysYParser::VarContext* use) const {
auto it = var_uses_.find(use);
return it == var_uses_.end() ? nullptr : it->second;
const std::unordered_map<std::string, GlobalSemanticInfo>& GetGlobals() const {
return globals_;
}
private:
std::unordered_map<const SysYParser::VarContext*,
SysYParser::VarDefContext*>
var_uses_;
std::unordered_map<std::string, FunctionSemanticInfo> functions_;
std::unordered_map<std::string, GlobalSemanticInfo> globals_;
};
// 目前仅检查:
// - 变量先声明后使用
// - 局部变量不允许重复定义
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);

@ -1,17 +1,69 @@
// 极简符号表:记录局部变量定义点。
#pragma once
#pragma once
#include <optional>
#include <string>
#include <unordered_map>
#include <vector>
#include "SysYParser.h"
namespace ir {
class Function;
class Value;
}
enum class SemanticType {
Void,
Int,
Float,
};
enum class SymbolKind {
Variable,
Constant,
Function,
};
struct ConstantValue {
SemanticType type = SemanticType::Int;
int int_value = 0;
float float_value = 0.0f;
};
struct FunctionTypeInfo {
SemanticType return_type = SemanticType::Void;
std::vector<SemanticType> param_types;
std::vector<bool> param_is_array;
std::vector<std::vector<int>> param_dims;
};
struct SymbolEntry {
SymbolKind kind = SymbolKind::Variable;
SemanticType type = SemanticType::Int;
bool is_const = false;
bool is_array = false;
bool is_param_array = false;
std::vector<int> dims;
ir::Value* ir_value = nullptr;
ir::Function* function = nullptr;
std::optional<ConstantValue> const_scalar;
std::vector<ConstantValue> const_array;
bool const_array_all_zero = false;
FunctionTypeInfo function_type;
};
class SymbolTable {
public:
void Add(const std::string& name, SysYParser::VarDefContext* decl);
bool Contains(const std::string& name) const;
SysYParser::VarDefContext* Lookup(const std::string& name) const;
void Clear();
void EnterScope();
void ExitScope();
bool Insert(const std::string& name, const SymbolEntry& entry);
bool ContainsInCurrentScope(const std::string& name) const;
SymbolEntry* Lookup(const std::string& name);
const SymbolEntry* Lookup(const std::string& name) const;
private:
std::unordered_map<std::string, SysYParser::VarDefContext*> table_;
};
std::vector<std::unordered_map<std::string, SymbolEntry>> scopes_;
};

@ -0,0 +1,324 @@
#!/usr/bin/env bash
# analyze_case.sh — 单个 .sy 测试用例的全流程编译 + IR/汇编保存脚本
# 用于深度分析单个样例与 GCC 基线之间的差距。
#
# 用法:
# analyze_case.sh <input.sy> [output_dir]
#
# 输出目录(默认 output/analyze/<stem>_<timestamp>)中包含:
# <stem>.ll — 我方编译器输出的 LLVM IR
# <stem>.s — 我方编译器输出的 AArch64 汇编
# <stem>.elf — 我方编译链接后的可执行文件
# <stem>.gcc.s — GCC -O2 输出的 AArch64 汇编
# <stem>.gcc.elf — GCC -O2 链接后的可执行文件
# <stem>.our.time — 我方程序运行耗时(秒)
# <stem>.gcc.time — GCC 程序运行耗时(秒)
# <stem>.our.out — 我方程序实际输出
# <stem>.gcc.out — GCC 程序实际输出
# <stem>.diff — 输出 diff若有差异
# report.txt — 汇总报告IR 行数、汇编行数、耗时、加速比)
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
CYAN='\033[0;36m'
BOLD='\033[1m'
NC='\033[0m'
# ---------- 参数解析 ----------
if [[ $# -lt 1 || $# -gt 2 ]]; then
printf 'usage: %s <input.sy> [output_dir]\n' "$0" >&2
exit 1
fi
INPUT="$1"
if [[ ! -f "$INPUT" ]]; then
printf 'input file not found: %s\n' "$INPUT" >&2
exit 1
fi
BASE="$(basename "$INPUT")"
STEM="${BASE%.sy}"
INPUT_DIR="$(dirname "$(realpath "$INPUT")")"
TIMESTAMP="$(date +%Y%m%d_%H%M%S)"
# 与 run_baseline.sh 一致的路径键:去掉 test/ 前缀和 .sy 后缀
REL="$(realpath --relative-to="$REPO_ROOT" "$INPUT" 2>/dev/null || echo "$INPUT")"
CASE_KEY="${REL#test/}"
CASE_KEY="${CASE_KEY%.sy}"
if [[ $# -ge 2 ]]; then
OUT_DIR="$2"
else
OUT_DIR="$REPO_ROOT/output/analyze/${STEM}_${TIMESTAMP}"
fi
mkdir -p "$OUT_DIR"
REPORT="$OUT_DIR/report.txt"
: > "$REPORT"
rpt() {
printf '%s\n' "$*" | tee -a "$REPORT"
}
rpt_color() {
local color="$1"; shift
printf '%b%s%b\n' "$color" "$*" "$NC"
printf '%s\n' "$*" >> "$REPORT"
}
rpt "============================================================"
rpt " analyze_case report"
rpt " case : $STEM"
rpt " source : $INPUT"
rpt " output : $OUT_DIR"
rpt " date : $(date)"
rpt "============================================================"
rpt ""
# ---------- 查找编译器 ----------
COMPILER=""
for candidate in \
"$REPO_ROOT/build_lab3/bin/compiler" \
"$REPO_ROOT/build_lab2/bin/compiler" \
"$REPO_ROOT/build/bin/compiler"; do
if [[ -x "$candidate" ]]; then
COMPILER="$candidate"
break
fi
done
if [[ -z "$COMPILER" ]]; then
rpt_color "$RED" "ERROR: compiler not found. Build first:"
rpt " cmake -S $REPO_ROOT -B $REPO_ROOT/build_lab3 && cmake --build $REPO_ROOT/build_lab3 -j"
exit 1
fi
rpt "compiler : $COMPILER"
# ---------- 工具检查 ----------
for tool in aarch64-linux-gnu-gcc qemu-aarch64; do
if ! command -v "$tool" >/dev/null 2>&1; then
rpt_color "$RED" "ERROR: required tool not found: $tool"
exit 1
fi
done
STDIN_FILE="$INPUT_DIR/$STEM.in"
EXPECTED_FILE="$INPUT_DIR/$STEM.out"
# ---------- 1. 生成 IR ----------
rpt ""
rpt "--- [1/5] Generating LLVM IR ---"
IR_FILE="$OUT_DIR/$STEM.ll"
if "$COMPILER" --emit-ir "$INPUT" > "$IR_FILE" 2>"$OUT_DIR/$STEM.ir.err"; then
IR_LINES=$(wc -l < "$IR_FILE")
rpt_color "$GREEN" "IR generated: $IR_FILE ($IR_LINES lines)"
else
rpt_color "$RED" "ERROR: IR generation failed"
cat "$OUT_DIR/$STEM.ir.err" >&2
exit 1
fi
# ---------- 2. 生成我方汇编并链接 ----------
rpt ""
rpt "--- [2/5] Generating our ASM & linking ---"
OUR_ASM="$OUT_DIR/$STEM.s"
OUR_ELF="$OUT_DIR/$STEM.elf"
if "$COMPILER" --emit-asm "$INPUT" > "$OUR_ASM" 2>"$OUT_DIR/$STEM.asm.err"; then
OUR_ASM_LINES=$(wc -l < "$OUR_ASM")
rpt_color "$GREEN" "ASM generated: $OUR_ASM ($OUR_ASM_LINES lines)"
else
rpt_color "$RED" "ERROR: ASM generation failed"
cat "$OUT_DIR/$STEM.asm.err" >&2
exit 1
fi
if aarch64-linux-gnu-gcc "$OUR_ASM" "$REPO_ROOT/sylib/sylib.c" -O2 \
-I "$REPO_ROOT/sylib" -lm -o "$OUR_ELF" 2>"$OUT_DIR/$STEM.link.err"; then
rpt_color "$GREEN" "Linked: $OUR_ELF"
else
rpt_color "$RED" "ERROR: link failed"
cat "$OUT_DIR/$STEM.link.err" >&2
exit 1
fi
# ---------- 3. GCC -O2 基线(从预计算数据读取)----------
rpt ""
rpt "--- [3/5] GCC -O2 baseline (reading from pre-computed data) ---"
BASELINE_DATA_DIR="$REPO_ROOT/output/baseline"
BASELINE_TSV_PATH="$BASELINE_DATA_DIR/gcc_timing.tsv"
GCC_ASM="$OUT_DIR/$STEM.gcc.s"
GCC_OUT="$OUT_DIR/$STEM.gcc.out"
GCC_OK=false
GCC_ASM_LINES=0
GCC_ELAPSED_RAW="" # 秒,无 s 后缀
if [[ -f "$BASELINE_TSV_PATH" ]]; then
GCC_ELAPSED_RAW=$(awk -F'\t' -v s="$CASE_KEY" '$1==s{v=$2} END{if(v!="") print v}' \
"$BASELINE_TSV_PATH" 2>/dev/null || true)
if [[ -n "$GCC_ELAPSED_RAW" ]]; then
GCC_OK=true
rpt_color "$GREEN" "baseline timing: ${GCC_ELAPSED_RAW}s"
else
rpt_color "$YELLOW" "WARNING: no baseline entry for '$CASE_KEY'"
rpt " Run: scripts/run_baseline.sh"
fi
# 复制汇编文件(路径镜像结构)
local_gcc_asm="$BASELINE_DATA_DIR/${CASE_KEY}.gcc.s"
if [[ -f "$local_gcc_asm" ]]; then
cp "$local_gcc_asm" "$GCC_ASM"
GCC_ASM_LINES=$(wc -l < "$GCC_ASM")
rpt "GCC ASM: $GCC_ASM ($GCC_ASM_LINES lines)"
else
rpt_color "$YELLOW" "GCC ASM not found in baseline dir: $local_gcc_asm"
fi
# 复制输出文件供步骤5 diff
local_gcc_out="$BASELINE_DATA_DIR/${CASE_KEY}.gcc.out"
if [[ -f "$local_gcc_out" ]]; then
cp "$local_gcc_out" "$GCC_OUT"
rpt "GCC output: $GCC_OUT"
fi
else
rpt_color "$YELLOW" "WARNING: baseline data not found: $BASELINE_TSV_PATH"
rpt " Run: scripts/run_baseline.sh"
rpt " to pre-compute GCC -O2 baseline for all test cases."
fi
# ---------- 4. 运行并计时(仅我方编译器)----------
rpt ""
rpt "--- [4/5] Running & timing (our compiler) ---"
run_and_time() {
local label="$1"
local exe="$2"
local out_file="$3"
local timeout_sec="${4:-60}"
local stdout_file="$out_file.raw"
local status=0
local _t0 _t1 _ns
_t0=$(date +%s%N)
set +e
if [[ -f "$STDIN_FILE" ]]; then
timeout "$timeout_sec" \
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" \
< "$STDIN_FILE" > "$stdout_file" 2>/dev/null
else
timeout "$timeout_sec" \
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" \
> "$stdout_file" 2>/dev/null
fi
status=$?
_t1=$(date +%s%N)
_ns=$((_t1 - _t0))
set -e
# 将 stdout + exit_code 合并为 .out与 verify_asm.sh 格式一致)
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$out_file"
rm -f "$stdout_file"
local elapsed
if [[ $status -eq 124 ]]; then
elapsed="timeout"
rpt_color "$YELLOW" "$label: TIMEOUT (>${timeout_sec}s)" >&2
else
elapsed=$(awk "BEGIN{printf \"%.5f\", $_ns / 1000000000}")
if [[ $status -ne 0 ]]; then
rpt_color "$YELLOW" "$label: exit $status elapsed=${elapsed}s" >&2
else
rpt_color "$GREEN" "$label: OK elapsed=${elapsed}s" >&2
fi
fi
echo "$elapsed"
}
OUR_OUT="$OUT_DIR/$STEM.our.out"
TIMEOUT_SEC=60
[[ "$INPUT" == *"/performance/"* || "$INPUT" == *"/h_performance/"* ]] && TIMEOUT_SEC=300
OUR_ELAPSED=$(run_and_time "our compiler" "$OUR_ELF" "$OUR_OUT" "$TIMEOUT_SEC")
# GCC 耗时直接读取基线数据,不重新运行
GCC_ELAPSED="N/A"
if [[ "$GCC_OK" == true && -n "$GCC_ELAPSED_RAW" ]]; then
GCC_ELAPSED="${GCC_ELAPSED_RAW}s"
rpt_color "$GREEN" "gcc -O2: ${GCC_ELAPSED} (from pre-computed baseline)"
fi
# ---------- 5. 输出对比 ----------
rpt ""
rpt "--- [5/5] Output comparison ---"
normalize_out() {
awk '{ sub(/\r$/, ""); print }' "$1"
}
if [[ -f "$EXPECTED_FILE" ]]; then
DIFF_FILE="$OUT_DIR/$STEM.diff"
if diff <(normalize_out "$EXPECTED_FILE") <(normalize_out "$OUR_OUT") > "$DIFF_FILE" 2>&1; then
rpt_color "$GREEN" "our output: MATCH expected"
rm -f "$DIFF_FILE"
else
rpt_color "$RED" "our output: MISMATCH — diff saved to $DIFF_FILE"
fi
if [[ "$GCC_OK" == true && -f "$GCC_OUT" ]]; then
GCC_DIFF_FILE="$OUT_DIR/$STEM.gcc.diff"
if diff <(normalize_out "$EXPECTED_FILE") <(normalize_out "$GCC_OUT") > "$GCC_DIFF_FILE" 2>&1; then
rpt_color "$GREEN" "gcc output: MATCH expected"
rm -f "$GCC_DIFF_FILE"
else
rpt_color "$YELLOW" "gcc output: MISMATCH — diff saved to $GCC_DIFF_FILE"
fi
fi
else
rpt_color "$YELLOW" "no expected output file found, skipping diff"
fi
# ---------- 汇总报告 ----------
rpt ""
rpt "============================================================"
rpt_color "$BOLD" " Summary"
rpt "============================================================"
rpt "$(printf '%-20s %s' 'IR lines:' "$IR_LINES")"
rpt "$(printf '%-20s %s' 'Our ASM lines:' "$OUR_ASM_LINES")"
if [[ "$GCC_OK" == true && $GCC_ASM_LINES -gt 0 ]]; then
rpt "$(printf '%-20s %s' 'GCC ASM lines:' "$GCC_ASM_LINES")"
rpt "$(printf '%-20s %s' 'ASM ratio (ours/gcc):' \
"$(awk "BEGIN{if($GCC_ASM_LINES>0) printf \"%.2f\", $OUR_ASM_LINES/$GCC_ASM_LINES; else print \"N/A\"}")")"
fi
rpt "$(printf '%-20s %s' 'Our time:' "$OUR_ELAPSED")"
rpt "$(printf '%-20s %s' 'GCC time:' "$GCC_ELAPSED")"
if [[ "$GCC_ELAPSED" != "N/A" && "$GCC_ELAPSED" != "timeout" && "$OUR_ELAPSED" != "timeout" ]]; then
OUR_S="${OUR_ELAPSED%s}"
GCC_S="${GCC_ELAPSED%s}"
SPEEDUP=$(awk "BEGIN{if($OUR_S>0) printf \"%.5f\", $GCC_S/$OUR_S; else print \"inf\"}")
rpt "$(printf '%-20s %sx' 'Speedup (gcc/ours):' "$SPEEDUP")"
fi
rpt ""
rpt "Output directory: $OUT_DIR"
rpt "============================================================"
printf '\n%bReport saved to: %s%b\n' "$CYAN" "$REPORT" "$NC"

@ -0,0 +1,170 @@
#!/usr/bin/env bash
# clean_outputs.sh — 清理编译输出与日志垃圾文件
#
# 用法:
# clean_outputs.sh [选项]
#
# 选项:
# --logs 清理 output/logs/ 下的运行日志(保留 last_run.txt / last_failed.txt
# --analyze 清理 output/analyze/ 下的单用例分析结果
# --build 清理 build_lab*/ 构建目录
# --test-result 清理 test/test_result/ 下的测试产物
# --all 清理以上全部
# --dry-run 只打印将要删除的内容,不实际删除
# --yes 跳过确认提示,直接删除(配合 --logs / --all 等使用)
#
# 不带任何选项时交互式选择。
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
CYAN='\033[0;36m'
NC='\033[0m'
DO_LOGS=false
DO_ANALYZE=false
DO_BUILD=false
DO_TEST_RESULT=false
DRY_RUN=false
AUTO_YES=false
if [[ $# -eq 0 ]]; then
# 交互模式
printf '%bclean_outputs.sh — interactive mode%b\n' "$CYAN" "$NC"
printf 'Select what to clean (space-separated numbers, e.g. "1 3"):\n'
printf ' 1) output/logs/ — run logs\n'
printf ' 2) output/analyze/ — single-case analysis results\n'
printf ' 3) build_lab*/ — CMake build directories\n'
printf ' 4) test/test_result/ — test artifacts\n'
printf ' 0) cancel\n'
read -r -p 'choice: ' choices
for c in $choices; do
case "$c" in
1) DO_LOGS=true ;;
2) DO_ANALYZE=true ;;
3) DO_BUILD=true ;;
4) DO_TEST_RESULT=true ;;
0) printf 'cancelled.\n'; exit 0 ;;
*) printf '%bunknown option: %s (ignored)%b\n' "$YELLOW" "$c" "$NC" ;;
esac
done
fi
while [[ $# -gt 0 ]]; do
case "$1" in
--logs) DO_LOGS=true ;;
--analyze) DO_ANALYZE=true ;;
--build) DO_BUILD=true ;;
--test-result) DO_TEST_RESULT=true ;;
--all) DO_LOGS=true; DO_ANALYZE=true; DO_BUILD=true; DO_TEST_RESULT=true ;;
--dry-run) DRY_RUN=true ;;
--yes|-y) AUTO_YES=true ;;
*)
printf '%bunknown option: %s%b\n' "$YELLOW" "$1" "$NC" >&2
;;
esac
shift
done
if [[ "$DO_LOGS" == false && "$DO_ANALYZE" == false && \
"$DO_BUILD" == false && "$DO_TEST_RESULT" == false ]]; then
printf 'nothing selected. use --help or run without arguments for interactive mode.\n' >&2
exit 0
fi
# ---------- 收集要删除的路径 ----------
declare -a TARGETS=()
if [[ "$DO_LOGS" == true ]]; then
LOG_ROOT="$REPO_ROOT/output/logs"
if [[ -d "$LOG_ROOT" ]]; then
# 删除所有子目录(即每次的 run dir保留 last_run.txt / last_failed.txt
while IFS= read -r -d '' d; do
TARGETS+=("$d")
done < <(find "$LOG_ROOT" -mindepth 2 -maxdepth 2 -type d -print0 2>/dev/null)
fi
fi
if [[ "$DO_ANALYZE" == true ]]; then
ANALYZE_ROOT="$REPO_ROOT/output/analyze"
if [[ -d "$ANALYZE_ROOT" ]]; then
while IFS= read -r -d '' d; do
TARGETS+=("$d")
done < <(find "$ANALYZE_ROOT" -mindepth 1 -maxdepth 1 -print0 2>/dev/null)
fi
fi
if [[ "$DO_BUILD" == true ]]; then
while IFS= read -r -d '' d; do
TARGETS+=("$d")
done < <(find "$REPO_ROOT" -maxdepth 1 -type d -name 'build_lab*' -print0 2>/dev/null)
fi
if [[ "$DO_TEST_RESULT" == true ]]; then
TR_ROOT="$REPO_ROOT/test/test_result"
if [[ -d "$TR_ROOT" ]]; then
TARGETS+=("$TR_ROOT")
fi
fi
if [[ ${#TARGETS[@]} -eq 0 ]]; then
printf '%bNothing to clean — target directories are already empty or do not exist.%b\n' "$GREEN" "$NC"
exit 0
fi
# ---------- 打印列表 ----------
printf '\n%bThe following will be %s:%b\n' "$YELLOW" \
"$([[ "$DRY_RUN" == true ]] && echo "listed (dry-run)" || echo "DELETED")" "$NC"
TOTAL_SIZE=0
for t in "${TARGETS[@]}"; do
SIZE=$(du -sh "$t" 2>/dev/null | cut -f1 || echo "?")
printf ' [%s] %s\n' "$SIZE" "$t"
done
printf '\n'
if [[ "$DRY_RUN" == true ]]; then
printf '%bDry-run mode: nothing deleted.%b\n' "$CYAN" "$NC"
exit 0
fi
# ---------- 确认 ----------
if [[ "$AUTO_YES" == false ]]; then
read -r -p "Proceed with deletion? [y/N] " confirm
case "$confirm" in
[yY][eE][sS]|[yY]) ;;
*)
printf 'cancelled.\n'
exit 0
;;
esac
fi
# ---------- 删除 ----------
DELETED=0
ERRORS=0
for t in "${TARGETS[@]}"; do
if rm -rf "$t" 2>/dev/null; then
printf '%b deleted: %s%b\n' "$GREEN" "$t" "$NC"
DELETED=$((DELETED + 1))
else
printf '%b ERROR deleting: %s%b\n' "$RED" "$t" "$NC"
ERRORS=$((ERRORS + 1))
fi
done
printf '\n'
if [[ $ERRORS -eq 0 ]]; then
printf '%bDone. %d item(s) deleted.%b\n' "$GREEN" "$DELETED" "$NC"
else
printf '%bDone. %d deleted, %d errors.%b\n' "$YELLOW" "$DELETED" "$ERRORS" "$NC"
exit 1
fi

@ -0,0 +1,195 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
BUILD_DIR="$REPO_ROOT/build_lab1"
COMPILER="$BUILD_DIR/bin/compiler"
ANTLR_JAR="$REPO_ROOT/third_party/antlr-4.13.2-complete.jar"
RUN_ROOT="$REPO_ROOT/output/logs/lab1"
RUN_NAME="lab1_$(date +%Y%m%d_%H%M%S)"
RUN_DIR="$RUN_ROOT/$RUN_NAME"
WHOLE_LOG="$RUN_DIR/whole.log"
FAIL_DIR="$RUN_DIR/failures"
LEGACY_SAVE_TREE=false
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
TEST_DIRS=()
while [[ $# -gt 0 ]]; do
case "$1" in
--save-tree)
LEGACY_SAVE_TREE=true
;;
*)
TEST_DIRS+=("$1")
;;
esac
shift
done
mkdir -p "$RUN_DIR"
: > "$WHOLE_LOG"
log_plain() {
printf '%s\n' "$*"
printf '%s\n' "$*" >> "$WHOLE_LOG"
}
log_color() {
local color="$1"
shift
local message="$*"
printf '%b%s%b\n' "$color" "$message" "$NC"
printf '%s\n' "$message" >> "$WHOLE_LOG"
}
append_file_to_whole_log() {
local title="$1"
local file="$2"
{
printf '\n===== %s =====\n' "$title"
cat "$file"
printf '\n'
} >> "$WHOLE_LOG"
}
cleanup_tmp_dir() {
local dir="$1"
if [[ -d "$dir" ]]; then
rm -rf "$dir"
fi
}
discover_default_test_dirs() {
local roots=(
"$REPO_ROOT/test/test_case"
"$REPO_ROOT/test/class_test_case"
)
local root
for root in "${roots[@]}"; do
[[ -d "$root" ]] || continue
find "$root" -mindepth 1 -maxdepth 1 -type d -print0
done | sort -z
}
prune_empty_run_dirs() {
if [[ -d "$RUN_DIR/.tmp" ]]; then
rmdir "$RUN_DIR/.tmp" 2>/dev/null || true
fi
if [[ -d "$FAIL_DIR" ]]; then
rmdir "$FAIL_DIR" 2>/dev/null || true
fi
}
if [[ ${#TEST_DIRS[@]} -eq 0 ]]; then
while IFS= read -r -d '' test_dir; do
TEST_DIRS+=("$test_dir")
done < <(discover_default_test_dirs)
fi
log_plain "Run directory: $RUN_DIR"
log_plain "Whole log: $WHOLE_LOG"
if [[ "$LEGACY_SAVE_TREE" == true ]]; then
log_color "$YELLOW" "Warning: --save-tree is deprecated; successful case artifacts will still be deleted."
fi
log_plain "==> [1/3] Generate ANTLR Lexer/Parser"
mkdir -p "$BUILD_DIR/generated/antlr4"
if ! java -jar "$ANTLR_JAR" \
-Dlanguage=Cpp \
-visitor -no-listener \
-Xexact-output-dir \
-o "$BUILD_DIR/generated/antlr4" \
"$REPO_ROOT/src/antlr4/SysY.g4" >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "ANTLR generation failed. See $WHOLE_LOG"
exit 1
fi
log_plain "==> [2/3] Configure and build parse-only compiler"
if ! cmake -S "$REPO_ROOT" -B "$BUILD_DIR" -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=ON >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "CMake configure failed. See $WHOLE_LOG"
exit 1
fi
if ! cmake --build "$BUILD_DIR" -j "$(nproc)" >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "Compiler build failed. See $WHOLE_LOG"
exit 1
fi
log_plain "==> [3/3] Run parse validation suite"
PASS=0
FAIL=0
FAIL_LIST=()
test_one() {
local sy_file="$1"
local rel="$2"
local safe_name="${rel//\//_}"
local case_key="${safe_name%.sy}"
local tmp_dir="$RUN_DIR/.tmp/$case_key"
local fail_case_dir="$FAIL_DIR/$case_key"
local tree_file="$tmp_dir/parse.tree"
local case_log="$tmp_dir/error.log"
cleanup_tmp_dir "$tmp_dir"
cleanup_tmp_dir "$fail_case_dir"
mkdir -p "$tmp_dir"
if "$COMPILER" --emit-parse-tree "$sy_file" > "$tree_file" 2> "$case_log"; then
cleanup_tmp_dir "$tmp_dir"
return 0
fi
mkdir -p "$FAIL_DIR"
{
printf 'Command: %s --emit-parse-tree %s\n' "$COMPILER" "$sy_file"
if [[ -s "$case_log" ]]; then
printf '\n'
cat "$case_log"
fi
} > "$tmp_dir/error.log.tmp"
mv "$tmp_dir/error.log.tmp" "$case_log"
mv "$tmp_dir" "$fail_case_dir"
append_file_to_whole_log "$rel" "$fail_case_dir/error.log"
return 1
}
for test_dir in "${TEST_DIRS[@]}"; do
if [[ ! -d "$test_dir" ]]; then
log_color "$YELLOW" "skip missing dir: $test_dir"
continue
fi
while IFS= read -r -d '' sy_file; do
rel="$(realpath --relative-to="$REPO_ROOT" "$sy_file")"
if test_one "$sy_file" "$rel"; then
log_color "$GREEN" "PASS $rel"
PASS=$((PASS + 1))
else
log_color "$RED" "FAIL $rel"
FAIL=$((FAIL + 1))
FAIL_LIST+=("$rel")
fi
done < <(find "$test_dir" -maxdepth 1 -type f -name '*.sy' -print0 | sort -z)
done
prune_empty_run_dirs
log_plain ""
log_plain "summary: ${PASS} PASS / ${FAIL} FAIL / total $((PASS + FAIL))"
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
log_plain "failed cases:"
for f in "${FAIL_LIST[@]}"; do
safe_name="${f//\//_}"
log_plain "- $f"
log_plain " artifacts: $FAIL_DIR/${safe_name%.sy}"
done
else
log_plain "all successful case artifacts have been deleted automatically."
fi
log_plain "whole log saved to: $WHOLE_LOG"
[[ $FAIL -eq 0 ]]

@ -0,0 +1,288 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
VERIFY_SCRIPT="$REPO_ROOT/scripts/verify_ir.sh"
BUILD_DIR="$REPO_ROOT/build_lab2"
RUN_ROOT="$REPO_ROOT/output/logs/lab2"
LAST_RUN_FILE="$RUN_ROOT/last_run.txt"
LAST_FAILED_FILE="$RUN_ROOT/last_failed.txt"
RUN_NAME="lab2_$(date +%Y%m%d_%H%M%S)"
RUN_DIR="$RUN_ROOT/$RUN_NAME"
WHOLE_LOG="$RUN_DIR/whole.log"
FAIL_DIR="$RUN_DIR/failures"
LEGACY_SAVE_IR=false
FAILED_ONLY=false
FALLBACK_TO_FULL=false
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'
TEST_DIRS=()
TEST_FILES=()
while [[ $# -gt 0 ]]; do
case "$1" in
--save-ir)
LEGACY_SAVE_IR=true
;;
--failed-only)
FAILED_ONLY=true
;;
*)
if [[ -f "$1" ]]; then
TEST_FILES+=("$1")
else
TEST_DIRS+=("$1")
fi
;;
esac
shift
done
mkdir -p "$RUN_DIR"
: > "$WHOLE_LOG"
printf '%s\n' "$RUN_DIR" > "$LAST_RUN_FILE"
log_plain() {
printf '%s\n' "$*"
printf '%s\n' "$*" >> "$WHOLE_LOG"
}
log_color() {
local color="$1"
shift
local message="$*"
printf '%b%s%b\n' "$color" "$message" "$NC"
printf '%s\n' "$message" >> "$WHOLE_LOG"
}
append_file_to_whole_log() {
local title="$1"
local file="$2"
{
printf '\n===== %s =====\n' "$title"
cat "$file"
printf '\n'
} >> "$WHOLE_LOG"
}
cleanup_tmp_dir() {
local dir="$1"
if [[ -d "$dir" ]]; then
rm -rf "$dir"
fi
}
discover_default_test_dirs() {
local roots=(
"$REPO_ROOT/test/test_case"
"$REPO_ROOT/test/class_test_case"
)
local root
for root in "${roots[@]}"; do
[[ -d "$root" ]] || continue
find "$root" -mindepth 1 -maxdepth 1 -type d -print0
done | sort -z
}
prune_empty_run_dirs() {
if [[ -d "$RUN_DIR/.tmp" ]]; then
rmdir "$RUN_DIR/.tmp" 2>/dev/null || true
fi
if [[ -d "$FAIL_DIR" ]]; then
rmdir "$FAIL_DIR" 2>/dev/null || true
fi
}
now_ns() {
date +%s%N
}
format_duration_ns() {
local ns="$1"
local sec=$((ns / 1000000000))
local ms=$(((ns % 1000000000) / 1000000))
printf '%d.%03ds' "$sec" "$ms"
}
is_transient_io_failure() {
local log_file="$1"
[[ -f "$log_file" ]] || return 1
grep -Eq \
'Permission denied|Text file busy|Device or resource busy|Stale file handle|Input/output error|Resource temporarily unavailable|Read-only file system' \
"$log_file"
}
test_one() {
local sy_file="$1"
local rel="$2"
local safe_name="${rel//\//_}"
local case_key="${safe_name%.sy}"
local tmp_dir="$RUN_DIR/.tmp/$case_key"
local fail_case_dir="$FAIL_DIR/$case_key"
local case_log="$tmp_dir/error.log"
local attempt=1
cleanup_tmp_dir "$fail_case_dir"
while true; do
cleanup_tmp_dir "$tmp_dir"
mkdir -p "$tmp_dir"
if "$VERIFY_SCRIPT" "$sy_file" "$tmp_dir" --run > "$case_log" 2>&1; then
cleanup_tmp_dir "$tmp_dir"
return 0
fi
if [[ $attempt -eq 1 ]] && is_transient_io_failure "$case_log"; then
log_color "$YELLOW" "RETRY $rel (transient I/O failure)"
attempt=$((attempt + 1))
continue
fi
break
done
mkdir -p "$FAIL_DIR"
mv "$tmp_dir" "$fail_case_dir"
append_file_to_whole_log "$rel" "$fail_case_dir/error.log"
return 1
}
run_case() {
local sy_file="$1"
local rel
local case_start_ns
local case_end_ns
local case_elapsed_ns
rel="$(realpath --relative-to="$REPO_ROOT" "$sy_file")"
case_start_ns=$(now_ns)
if test_one "$sy_file" "$rel"; then
case_end_ns=$(now_ns)
case_elapsed_ns=$((case_end_ns - case_start_ns))
log_color "$GREEN" "PASS $rel [$(format_duration_ns "$case_elapsed_ns")]"
PASS=$((PASS + 1))
else
case_end_ns=$(now_ns)
case_elapsed_ns=$((case_end_ns - case_start_ns))
log_color "$RED" "FAIL $rel [$(format_duration_ns "$case_elapsed_ns")]"
FAIL=$((FAIL + 1))
FAIL_LIST+=("$rel")
fi
}
TOTAL_START_NS=$(now_ns)
if [[ "$FAILED_ONLY" == true ]]; then
if [[ -f "$LAST_FAILED_FILE" ]]; then
while IFS= read -r sy_file; do
[[ -n "$sy_file" ]] || continue
[[ -f "$sy_file" ]] || continue
TEST_FILES+=("$sy_file")
done < "$LAST_FAILED_FILE"
fi
if [[ ${#TEST_FILES[@]} -eq 0 ]]; then
FALLBACK_TO_FULL=true
FAILED_ONLY=false
fi
fi
if [[ "$FAILED_ONLY" == false && ${#TEST_DIRS[@]} -eq 0 && ${#TEST_FILES[@]} -eq 0 ]]; then
while IFS= read -r -d '' test_dir; do
TEST_DIRS+=("$test_dir")
done < <(discover_default_test_dirs)
fi
log_plain "Run directory: $RUN_DIR"
log_plain "Whole log: $WHOLE_LOG"
if [[ "$LEGACY_SAVE_IR" == true ]]; then
log_color "$YELLOW" "Warning: --save-ir is deprecated; successful case artifacts will still be deleted."
fi
if [[ "$FAILED_ONLY" == true ]]; then
log_plain "Mode: rerun cached failed cases only"
fi
if [[ "$FALLBACK_TO_FULL" == true ]]; then
log_color "$YELLOW" "No cached failed cases found, fallback to full suite."
fi
if [[ ! -f "$VERIFY_SCRIPT" ]]; then
log_color "$RED" "missing verify script: $VERIFY_SCRIPT"
exit 1
fi
log_plain "==> [1/2] Configure and build compiler"
BUILD_START_NS=$(now_ns)
if ! cmake -S "$REPO_ROOT" -B "$BUILD_DIR" >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "CMake configure failed. See $WHOLE_LOG"
exit 1
fi
if ! cmake --build "$BUILD_DIR" -j "$(nproc)" >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "Compiler build failed. See $WHOLE_LOG"
exit 1
fi
BUILD_END_NS=$(now_ns)
BUILD_ELAPSED_NS=$((BUILD_END_NS - BUILD_START_NS))
log_plain "==> [2/2] Run IR validation suite"
VALIDATION_START_NS=$(now_ns)
PASS=0
FAIL=0
FAIL_LIST=()
if [[ "$FAILED_ONLY" == true ]]; then
for sy_file in "${TEST_FILES[@]}"; do
run_case "$sy_file"
done
else
for sy_file in "${TEST_FILES[@]}"; do
run_case "$sy_file"
done
for test_dir in "${TEST_DIRS[@]}"; do
if [[ ! -d "$test_dir" ]]; then
log_color "$YELLOW" "skip missing dir: $test_dir"
continue
fi
while IFS= read -r -d '' sy_file; do
run_case "$sy_file"
done < <(find "$test_dir" -maxdepth 1 -type f -name '*.sy' -print0 | sort -z)
done
fi
rm -f "$LAST_FAILED_FILE"
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
for f in "${FAIL_LIST[@]}"; do
printf '%s/%s\n' "$REPO_ROOT" "$f" >> "$LAST_FAILED_FILE"
done
fi
prune_empty_run_dirs
VALIDATION_END_NS=$(now_ns)
VALIDATION_ELAPSED_NS=$((VALIDATION_END_NS - VALIDATION_START_NS))
TOTAL_END_NS=$(now_ns)
TOTAL_ELAPSED_NS=$((TOTAL_END_NS - TOTAL_START_NS))
log_plain ""
log_plain "summary: ${PASS} PASS / ${FAIL} FAIL / total $((PASS + FAIL))"
log_plain "build elapsed: $(format_duration_ns "$BUILD_ELAPSED_NS")"
log_plain "validation elapsed: $(format_duration_ns "$VALIDATION_ELAPSED_NS")"
log_plain "total elapsed: $(format_duration_ns "$TOTAL_ELAPSED_NS")"
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
log_plain "failed cases:"
for f in "${FAIL_LIST[@]}"; do
safe_name="${f//\//_}"
log_plain "- $f"
log_plain " artifacts: $FAIL_DIR/${safe_name%.sy}"
done
else
log_plain "all successful case artifacts have been deleted automatically."
fi
log_plain "whole log saved to: $WHOLE_LOG"
[[ $FAIL -eq 0 ]]

@ -0,0 +1,433 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
VERIFY_SCRIPT="$REPO_ROOT/scripts/verify_asm.sh"
BUILD_DIR="$REPO_ROOT/build_lab3"
RUN_ROOT="$REPO_ROOT/output/logs/lab3"
LAST_RUN_FILE="$RUN_ROOT/last_run.txt"
LAST_FAILED_FILE="$RUN_ROOT/last_failed.txt"
RUN_NAME="lab3_$(date +%Y%m%d_%H%M%S)"
RUN_DIR="$RUN_ROOT/$RUN_NAME"
WHOLE_LOG="$RUN_DIR/whole.log"
FAIL_DIR="$RUN_DIR/failures"
LEGACY_SAVE_ASM=false
FAILED_ONLY=false
FALLBACK_TO_FULL=false
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
CYAN='\033[0;36m'
NC='\033[0m'
TEST_DIRS=()
TEST_FILES=()
while [[ $# -gt 0 ]]; do
case "$1" in
--save-asm)
LEGACY_SAVE_ASM=true
;;
--failed-only)
FAILED_ONLY=true
;;
*)
if [[ -f "$1" ]]; then
TEST_FILES+=("$1")
else
TEST_DIRS+=("$1")
fi
;;
esac
shift
done
mkdir -p "$RUN_DIR"
: > "$WHOLE_LOG"
printf '%s\n' "$RUN_DIR" > "$LAST_RUN_FILE"
log_plain() {
printf '%s\n' "$*"
printf '%s\n' "$*" >> "$WHOLE_LOG"
}
log_color() {
local color="$1"
shift
local message="$*"
printf '%b%s%b\n' "$color" "$message" "$NC"
printf '%s\n' "$message" >> "$WHOLE_LOG"
}
append_file_to_whole_log() {
local title="$1"
local file="$2"
{
printf '\n===== %s =====\n' "$title"
cat "$file"
printf '\n'
} >> "$WHOLE_LOG"
}
cleanup_tmp_dir() {
local dir="$1"
if [[ -d "$dir" ]]; then
rm -rf "$dir"
fi
}
discover_default_test_dirs() {
local roots=(
"$REPO_ROOT/test/test_case"
"$REPO_ROOT/test/class_test_case"
)
local root
for root in "${roots[@]}"; do
[[ -d "$root" ]] || continue
find "$root" -mindepth 1 -maxdepth 1 -type d -print0
done | sort -z
}
prune_empty_run_dirs() {
if [[ -d "$RUN_DIR/.tmp" ]]; then
rmdir "$RUN_DIR/.tmp" 2>/dev/null || true
fi
if [[ -d "$FAIL_DIR" ]]; then
rmdir "$FAIL_DIR" 2>/dev/null || true
fi
}
now_ns() {
date +%s%N
}
format_duration_ns() {
local ns="$1"
local sec=$((ns / 1000000000))
local us10=$(((ns % 1000000000) / 10000))
printf '%d.%05ds' "$sec" "$us10"
}
is_transient_io_failure() {
local log_file="$1"
[[ -f "$log_file" ]] || return 1
grep -Eq \
'Permission denied|Text file busy|Device or resource busy|Stale file handle|Input/output error|Resource temporarily unavailable|Read-only file system' \
"$log_file"
}
# ---------- baseline 读取 & timing ----------
# 共享基线数据(由 run_baseline.sh 生成)
BASELINE_TSV="$REPO_ROOT/output/baseline/gcc_timing.tsv"
# 本次运行的我方计时 TSVstem<TAB>our_ns<TAB>gcc_s
TIMING_TSV="$RUN_DIR/timing.tsv"
# 从共享 TSV 查找某 stem 的 GCC 基线耗时(秒),找不到返回 N/A
lookup_gcc_s() {
local stem="$1"
local val="N/A"
if [[ -f "$BASELINE_TSV" ]]; then
val=$(awk -F'\t' -v s="$stem" '$1==s{v=$2} END{if(v!="") print v; else print "N/A"}' "$BASELINE_TSV")
fi
echo "$val"
}
record_timing() {
local stem="$1"
local our_ns="$2"
local gcc_s="${3:-N/A}"
printf '%s\t%s\t%s\n' "$stem" "$our_ns" "$gcc_s" >> "$TIMING_TSV"
}
test_one() {
local sy_file="$1"
local rel="$2"
local timing_out="${3:-}"
local safe_name="${rel//\//_}"
local case_key="${safe_name%.sy}"
local tmp_dir="$RUN_DIR/.tmp/$case_key"
local fail_case_dir="$FAIL_DIR/$case_key"
local case_log="$tmp_dir/error.log"
local attempt=1
cleanup_tmp_dir "$fail_case_dir"
while true; do
cleanup_tmp_dir "$tmp_dir"
mkdir -p "$tmp_dir"
local verify_args=("$sy_file" "$tmp_dir" --run)
[[ -n "$timing_out" ]] && verify_args+=(--timing-out "$timing_out")
if "$VERIFY_SCRIPT" "${verify_args[@]}" > "$case_log" 2>&1; then
cleanup_tmp_dir "$tmp_dir"
return 0
fi
if [[ $attempt -eq 1 ]] && is_transient_io_failure "$case_log"; then
log_color "$YELLOW" "RETRY $rel (transient I/O failure)"
attempt=$((attempt + 1))
continue
fi
break
done
mkdir -p "$FAIL_DIR"
mv "$tmp_dir" "$fail_case_dir"
append_file_to_whole_log "$rel" "$fail_case_dir/error.log"
return 1
}
run_case() {
local sy_file="$1"
local rel
local case_start_ns
rel="$(realpath --relative-to="$REPO_ROOT" "$sy_file")"
case_start_ns=$(now_ns)
local base stem case_key
base="$(basename "$sy_file")"
stem="${base%.sy}"
# 与 run_baseline.sh 保持一致:去掉 test/ 前缀和 .sy 后缀
case_key="${rel#test/}"
case_key="${case_key%.sy}"
local timing_file
timing_file="$(mktemp)"
if test_one "$sy_file" "$rel" "$timing_file"; then
local compile_ns=0 run_ns=0
if [[ -f "$timing_file" ]]; then
compile_ns=$(grep '^compile_ns=' "$timing_file" | cut -d= -f2 || echo 0)
run_ns=$(grep '^run_ns=' "$timing_file" | cut -d= -f2 || echo 0)
fi
rm -f "$timing_file"
log_color "$GREEN" "PASS $rel [compile=$(format_duration_ns "$compile_ns") run=$(format_duration_ns "$run_ns")]"
PASS=$((PASS + 1))
local gcc_s
gcc_s=$(lookup_gcc_s "$case_key")
record_timing "$case_key" "$run_ns" "$gcc_s"
else
rm -f "$timing_file"
local case_elapsed_ns=$(( $(now_ns) - case_start_ns ))
log_color "$RED" "FAIL $rel [$(format_duration_ns "$case_elapsed_ns")]"
FAIL=$((FAIL + 1))
FAIL_LIST+=("$rel")
fi
}
TOTAL_START_NS=$(now_ns)
: > "$TIMING_TSV"
if [[ "$FAILED_ONLY" == true ]]; then
if [[ -f "$LAST_FAILED_FILE" ]]; then
while IFS= read -r sy_file; do
[[ -n "$sy_file" ]] || continue
[[ -f "$sy_file" ]] || continue
TEST_FILES+=("$sy_file")
done < "$LAST_FAILED_FILE"
fi
if [[ ${#TEST_FILES[@]} -eq 0 ]]; then
FALLBACK_TO_FULL=true
FAILED_ONLY=false
fi
fi
if [[ "$FAILED_ONLY" == false && ${#TEST_DIRS[@]} -eq 0 && ${#TEST_FILES[@]} -eq 0 ]]; then
while IFS= read -r -d '' test_dir; do
TEST_DIRS+=("$test_dir")
done < <(discover_default_test_dirs)
fi
log_plain "Run directory: $RUN_DIR"
log_plain "Whole log: $WHOLE_LOG"
if [[ "$LEGACY_SAVE_ASM" == true ]]; then
log_color "$YELLOW" "Warning: --save-asm is deprecated; successful case artifacts will still be deleted."
fi
if [[ "$FAILED_ONLY" == true ]]; then
log_plain "Mode: rerun cached failed cases only"
fi
if [[ "$FALLBACK_TO_FULL" == true ]]; then
log_color "$YELLOW" "No cached failed cases found, fallback to full suite."
fi
if [[ -f "$BASELINE_TSV" ]]; then
log_plain "Baseline TSV: $BASELINE_TSV (speedup ratios will be computed)"
else
log_color "$CYAN" "Tip: run scripts/run_baseline.sh first to enable GCC -O2 speedup analysis."
fi
if [[ ! -f "$VERIFY_SCRIPT" ]]; then
log_color "$RED" "missing verify script: $VERIFY_SCRIPT"
exit 1
fi
for tool in llc aarch64-linux-gnu-gcc qemu-aarch64; do
if ! command -v "$tool" >/dev/null 2>&1; then
log_color "$RED" "missing required tool: $tool"
exit 1
fi
done
log_plain "==> [1/2] Configure and build compiler"
BUILD_START_NS=$(now_ns)
if ! cmake -S "$REPO_ROOT" -B "$BUILD_DIR" >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "CMake configure failed. See $WHOLE_LOG"
exit 1
fi
if ! cmake --build "$BUILD_DIR" -j "$(nproc)" >> "$WHOLE_LOG" 2>&1; then
log_color "$RED" "Compiler build failed. See $WHOLE_LOG"
exit 1
fi
BUILD_END_NS=$(now_ns)
BUILD_ELAPSED_NS=$((BUILD_END_NS - BUILD_START_NS))
log_plain "==> [2/2] Run ASM validation suite"
VALIDATION_START_NS=$(now_ns)
PASS=0
FAIL=0
FAIL_LIST=()
if [[ "$FAILED_ONLY" == true ]]; then
for sy_file in "${TEST_FILES[@]}"; do
run_case "$sy_file"
done
else
for sy_file in "${TEST_FILES[@]}"; do
run_case "$sy_file"
done
for test_dir in "${TEST_DIRS[@]}"; do
if [[ ! -d "$test_dir" ]]; then
log_color "$YELLOW" "skip missing dir: $test_dir"
continue
fi
while IFS= read -r -d '' sy_file; do
run_case "$sy_file"
done < <(find "$test_dir" -maxdepth 1 -type f -name '*.sy' -print0 | sort -z)
done
fi
rm -f "$LAST_FAILED_FILE"
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
for f in "${FAIL_LIST[@]}"; do
printf '%s/%s\n' "$REPO_ROOT" "$f" >> "$LAST_FAILED_FILE"
done
fi
prune_empty_run_dirs
VALIDATION_END_NS=$(now_ns)
VALIDATION_ELAPSED_NS=$((VALIDATION_END_NS - VALIDATION_START_NS))
TOTAL_END_NS=$(now_ns)
TOTAL_ELAPSED_NS=$((TOTAL_END_NS - TOTAL_START_NS))
log_plain ""
log_plain "summary: ${PASS} PASS / ${FAIL} FAIL / total $((PASS + FAIL))"
log_plain "build elapsed: $(format_duration_ns "$BUILD_ELAPSED_NS")"
log_plain "validation elapsed: $(format_duration_ns "$VALIDATION_ELAPSED_NS")"
log_plain "total elapsed: $(format_duration_ns "$TOTAL_ELAPSED_NS")"
# ---------- 计时与加速比分析 ----------
if [[ -s "$TIMING_TSV" ]]; then
log_plain ""
log_plain "==> Timing & Speedup Analysis"
# 检查本次结果中是否有任何 GCC 基线数据
HAS_BASELINE=false
if grep -qv $'\tN/A$' "$TIMING_TSV" 2>/dev/null; then
HAS_BASELINE=true
fi
if [[ "$HAS_BASELINE" == true ]]; then
# 将 TSV 展开为含计算值的临时文件case_key, our_s, gcc_s, speedup
_tmp_timing="$RUN_DIR/timing_computed.tsv"
while IFS=$'\t' read -r case_key our_ns gcc_s; do
our_s=$(awk "BEGIN{printf \"%.5f\", $our_ns / 1000000000}")
if [[ "$gcc_s" == "N/A" ]]; then
speedup="N/A"
else
speedup=$(awk "BEGIN{if($our_s>0) printf \"%.5f\", $gcc_s/$our_s; else print \"inf\"}")
fi
printf '%s\t%s\t%s\t%s\n' "$case_key" "$our_s" "$gcc_s" "$speedup"
done < "$TIMING_TSV" > "$_tmp_timing"
# 排序1加速比升序N/A 排最后)
log_plain ""
log_plain "--- [Sort 1] Speedup ratio ascending (worst speedup first) ---"
log_plain "$(printf '%-40s %10s %10s %10s' 'case' 'our(s)' 'gcc(s)' 'speedup')"
log_plain "$(printf '%0.s-' {1..76})"
{
grep -v $'\tN/A$' "$_tmp_timing" | sort -t$'\t' -k4 -n || true
grep $'\tN/A$' "$_tmp_timing" | sort -t$'\t' -k1 || true
} | while IFS=$'\t' read -r case_key our_s gcc_s speedup; do
disp="${case_key##*/}"
if [[ "$speedup" == "N/A" ]]; then
log_plain "$(printf '%-40s %10s %10s %10s' "$disp" "${our_s}s" "N/A" "N/A")"
else
log_plain "$(printf '%-40s %10s %10s %9sx' "$disp" "${our_s}s" "${gcc_s}s" "$speedup")"
fi
done
# 排序2我方总用时降序
log_plain ""
log_plain "--- [Sort 2] Our elapsed time descending (slowest first) ---"
log_plain "$(printf '%-40s %10s %10s %10s' 'case' 'our(s)' 'gcc(s)' 'speedup')"
log_plain "$(printf '%0.s-' {1..76})"
sort -t$'\t' -k2 -rn "$_tmp_timing" | \
while IFS=$'\t' read -r case_key our_s gcc_s speedup; do
disp="${case_key##*/}"
if [[ "$speedup" == "N/A" ]]; then
log_plain "$(printf '%-40s %10s %10s %10s' "$disp" "${our_s}s" "N/A" "N/A")"
else
log_plain "$(printf '%-40s %10s %10s %9sx' "$disp" "${our_s}s" "${gcc_s}s" "$speedup")"
fi
done
rm -f "$_tmp_timing"
else
# 无基线:只输出总用时降序
log_plain ""
log_plain "--- [Sort] Our elapsed time descending (slowest first) ---"
log_plain "$(printf '%-40s %10s' 'case' 'our(s)')"
log_plain "$(printf '%0.s-' {1..54})"
while IFS=$'\t' read -r case_key our_ns _; do
our_s=$(awk "BEGIN{printf \"%.5f\", $our_ns / 1000000000}")
printf '%s\t%s\n' "$case_key" "$our_s"
done < "$TIMING_TSV" | \
sort -t$'\t' -k2 -rn | \
while IFS=$'\t' read -r case_key our_s; do
disp="${case_key##*/}"
log_plain "$(printf '%-40s %10ss' "$disp" "$our_s")"
done
log_plain ""
log_color "$CYAN" "Tip: run scripts/run_baseline.sh to compute GCC -O2 baseline for speedup analysis."
fi
log_plain ""
log_plain "timing data saved to: $TIMING_TSV"
fi
# ---------- 失败用例列表 ----------
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
log_plain "failed cases:"
for f in "${FAIL_LIST[@]}"; do
safe_name="${f//\//_}"
log_plain "- $f"
log_plain " artifacts: $FAIL_DIR/${safe_name%.sy}"
done
else
log_plain "all successful case artifacts have been deleted automatically."
fi
log_plain "whole log saved to: $WHOLE_LOG"
[[ $FAIL -eq 0 ]]

@ -0,0 +1,326 @@
#!/usr/bin/env bash
# run_baseline.sh — 批量编译 GCC -O2 基线并保存汇编、输出与运行时间
#
# 数据统一保存在 output/baseline/
# gcc_timing.tsv — stem<TAB>gcc_elapsed_s (所有脚本的共享数据源)
# <stem>.gcc.s — GCC -O2 AArch64 汇编(供 analyze_case.sh 对比)
# <stem>.gcc.out — GCC 程序实际输出 stdout+exit_code供 analyze_case.sh 对比)
#
# 用法:
# run_baseline.sh [--update] [test_dir|file ...]
#
# --update 重新计算所有条目(默认跳过 gcc_timing.tsv 中已有的 stem
#
# 若不指定测试目录/文件,自动扫描 test/test_case 和 test/class_test_case
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
BASELINE_DIR="$REPO_ROOT/output/baseline"
TIMING_TSV="$BASELINE_DIR/gcc_timing.tsv"
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
CYAN='\033[0;36m'
NC='\033[0m'
UPDATE=false
TEST_DIRS=()
TEST_FILES=()
while [[ $# -gt 0 ]]; do
case "$1" in
--update) UPDATE=true ;;
*)
if [[ -f "$1" ]]; then
TEST_FILES+=("$1")
else
TEST_DIRS+=("$1")
fi
;;
esac
shift
done
# ---------- 工具检查 ----------
for tool in aarch64-linux-gnu-gcc qemu-aarch64; do
if ! command -v "$tool" >/dev/null 2>&1; then
printf '%bERROR: required tool not found: %s%b\n' "$RED" "$tool" "$NC" >&2
exit 1
fi
done
if [[ ! -x /usr/bin/time ]]; then
printf '%bERROR: /usr/bin/time not found%b\n' "$RED" "$NC" >&2
exit 1
fi
mkdir -p "$BASELINE_DIR"
# 是否已存在某 stem 的基线数据(直接查 TSV 文件,避免关联数组兼容性问题)
stem_is_cached() {
local key="$1"
[[ -f "$TIMING_TSV" ]] && grep -qF "${key} " "$TIMING_TSV" 2>/dev/null
}
stem_cached_time() {
local key="$1"
awk -F'\t' -v s="$key" '$1==s{print $2; exit}' "$TIMING_TSV" 2>/dev/null || true
}
# ---------- 测试用例发现 ----------
discover_default_test_dirs() {
local roots=(
"$REPO_ROOT/test/test_case"
"$REPO_ROOT/test/class_test_case"
)
local root
for root in "${roots[@]}"; do
[[ -d "$root" ]] || continue
find "$root" -mindepth 1 -maxdepth 1 -type d -print0
done | sort -z
}
if [[ ${#TEST_DIRS[@]} -eq 0 && ${#TEST_FILES[@]} -eq 0 ]]; then
while IFS= read -r -d '' d; do
TEST_DIRS+=("$d")
done < <(discover_default_test_dirs)
fi
# ---------- 计时工具 ----------
now_ns() { date +%s%N; }
format_duration_ns() {
local ns="$1"
printf '%d.%05ds' "$((ns / 1000000000))" "$(((ns % 1000000000) / 10000))"
}
# ---------- 处理单个用例 ----------
PASS=0
SKIP=0
FAIL=0
process_case() {
local sy_file="$1"
local base stem input_dir stdin_file
base="$(basename "$sy_file")"
stem="${base%.sy}"
input_dir="$(dirname "$sy_file")"
stdin_file="$input_dir/$stem.in"
local rel
rel="$(realpath --relative-to="$REPO_ROOT" "$sy_file")"
# 路径键:去掉 test/ 前缀和 .sy 后缀,保留完整目录结构
# 例test/class_test_case/h_functional/11_BST.sy → class_test_case/h_functional/11_BST
local case_key
case_key="${rel#test/}"
case_key="${case_key%.sy}"
local case_start_ns
case_start_ns=$(now_ns)
# 已有数据且不强制更新 → 跳过
if [[ "$UPDATE" == false ]] && stem_is_cached "$case_key"; then
printf '%b SKIP %s (cached: %ss)%b\n' \
"$CYAN" "$rel" "$(stem_cached_time "$case_key")" "$NC"
SKIP=$((SKIP + 1))
return 0
fi
# 输出目录镜像源路径结构
local case_out_dir
case_out_dir="$BASELINE_DIR/$(dirname "$case_key")"
mkdir -p "$case_out_dir"
local gcc_elf gcc_asm gcc_out gcc_err
gcc_elf="$case_out_dir/$stem.gcc.elf"
gcc_asm="$case_out_dir/$stem.gcc.s"
gcc_out="$case_out_dir/$stem.gcc.out"
gcc_err="$case_out_dir/$stem.gcc.err"
# 预处理:把 "const int NAME = EXPR;" 转为 "#define NAME ((int)(EXPR))"
# 同时处理多声明符const int A=1, B=2; → #define A ((int)(1))\n#define B ((int)(2))
# 原因SysY const int 是编译期常量C 模式下不能用于全局数组维度,#define 可以
local tmp_sy
tmp_sy="$(mktemp /tmp/sysy_XXXXXX.c)"
python3 - "$sy_file" "$tmp_sy" << 'PYEOF'
import re, sys
pat = re.compile(
r'^(\s*)const\s+int\s+((?:[A-Za-z_]\w*\s*=\s*[^,;]+)(?:,\s*[A-Za-z_]\w*\s*=\s*[^,;]+)*)\s*;',
re.MULTILINE
)
def replace(m):
indent = m.group(1)
decls = re.split(r',\s*(?=[A-Za-z_])', m.group(2))
lines = []
for d in decls:
name, _, val = d.partition('=')
lines.append(f'{indent}#define {name.strip()} ((int)({val.strip()}))')
return '\n'.join(lines)
with open(sys.argv[1]) as f:
src = f.read()
with open(sys.argv[2], 'w') as f:
f.write(pat.sub(replace, src))
PYEOF
# 步骤1编译链接C 模式,用于运行计时)
# -x c允许 delete/new/class 等作为标识符
# -include sylib.h强制注入 SysY 运行时声明(.sy 无 #include
# 无名称修饰,直接链接同为 C 编译的 sylib.o
if ! aarch64-linux-gnu-gcc -O2 \
-x c -include "$REPO_ROOT/sylib/sylib.h" \
-I "$REPO_ROOT/sylib" \
"$tmp_sy" -x none "$SYLIB_OBJ" \
-lm -o "$gcc_elf" > "$gcc_err" 2>&1; then
rm -f "$tmp_sy"
printf '%b FAIL %s (GCC compile error — see %s)%b\n' \
"$RED" "$rel" "$gcc_err" "$NC"
FAIL=$((FAIL + 1))
return 0
fi
# 步骤2生成汇编单独 -S仅针对 .sy 文件本身)
aarch64-linux-gnu-gcc -O2 \
-x c -include "$REPO_ROOT/sylib/sylib.h" \
-I "$REPO_ROOT/sylib" \
"$tmp_sy" -S -o "$gcc_asm" 2>/dev/null || true
rm -f "$tmp_sy"
# 步骤3运行并计时手动 ns 计时,精度 5 位小数)
local stdout_file="$case_out_dir/$stem.gcc.stdout"
local status=0
local timeout_sec=60
[[ "$sy_file" == *"/performance/"* || "$sy_file" == *"/h_performance/"* ]] && timeout_sec=300
local run_start_ns run_end_ns run_elapsed_ns
run_start_ns=$(now_ns)
set +e
if [[ -f "$stdin_file" ]]; then
timeout "$timeout_sec" \
qemu-aarch64 -L /usr/aarch64-linux-gnu "$gcc_elf" \
< "$stdin_file" > "$stdout_file" 2>/dev/null
else
timeout "$timeout_sec" \
qemu-aarch64 -L /usr/aarch64-linux-gnu "$gcc_elf" \
> "$stdout_file" 2>/dev/null
fi
status=$?
run_end_ns=$(now_ns)
run_elapsed_ns=$((run_end_ns - run_start_ns))
set -e
# 删除可执行(节省空间,数据已提取完毕)
rm -f "$gcc_elf"
if [[ $status -eq 124 ]]; then
printf '%b TIMEOUT %s (>%ds)%b\n' "$YELLOW" "$rel" "$timeout_sec" "$NC"
rm -f "$stdout_file"
FAIL=$((FAIL + 1))
return 0
fi
# 步骤4保存输出文件stdout + exit_code与 verify_asm.sh 格式一致)
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$gcc_out"
rm -f "$stdout_file"
# 步骤5计算耗时5 位小数秒)并写入 TSV
local elapsed
elapsed=$(awk "BEGIN{printf \"%.5f\", $run_elapsed_ns / 1000000000}")
# 更新 TSV若已有该 case_key 的旧行则先删除再追加)
if grep -qF "${case_key} " "$TIMING_TSV" 2>/dev/null; then
local _tmp="$TIMING_TSV.tmp"
grep -vF "${case_key} " "$TIMING_TSV" > "$_tmp" || true
mv "$_tmp" "$TIMING_TSV"
fi
printf '%s\t%s\n' "$case_key" "$elapsed" >> "$TIMING_TSV"
local case_end_ns duration_ns
case_end_ns=$(now_ns)
duration_ns=$((case_end_ns - case_start_ns))
printf '%b DONE %s gcc=%ss [%s]%b\n' \
"$GREEN" "$rel" "$elapsed" "$(format_duration_ns "$duration_ns")" "$NC"
PASS=$((PASS + 1))
}
# ---------- 初始化 ----------
if [[ "$UPDATE" == true ]]; then
printf '%b[--update] Clearing all existing baseline data.%b\n' "$YELLOW" "$NC"
: > "$TIMING_TSV"
find "$BASELINE_DIR" -maxdepth 1 \
\( -name '*.gcc.s' -o -name '*.gcc.out' -o -name '*.gcc.time' -o -name '*.gcc.err' \) \
-delete 2>/dev/null || true
else
[[ -f "$TIMING_TSV" ]] || : > "$TIMING_TSV"
fi
printf '%bBaseline directory : %s%b\n' "$CYAN" "$BASELINE_DIR" "$NC"
printf '%bTiming TSV : %s%b\n' "$CYAN" "$TIMING_TSV" "$NC"
if [[ "$UPDATE" == false && -f "$TIMING_TSV" ]]; then
_cached_count=$(wc -l < "$TIMING_TSV" 2>/dev/null || echo 0)
if [[ $_cached_count -gt 0 ]]; then
printf 'Found %d cached entries (use --update to recompute all).\n' "$_cached_count"
fi
fi
# ---------- 预编译 sylib.oC 模式,仅一次)----------
SYLIB_OBJ="$BASELINE_DIR/sylib.o"
if ! aarch64-linux-gnu-gcc -O2 -c -x c "$REPO_ROOT/sylib/sylib.c" \
-I "$REPO_ROOT/sylib" -o "$SYLIB_OBJ" 2>/dev/null; then
printf '%bERROR: failed to compile sylib.c%b\n' "$RED" "$NC" >&2
exit 1
fi
printf 'sylib.o compiled : %s\n' "$SYLIB_OBJ"
printf '\n'
TOTAL_START_NS=$(now_ns)
# ---------- 运行 ----------
for sy_file in "${TEST_FILES[@]}"; do
process_case "$sy_file"
done
for test_dir in "${TEST_DIRS[@]}"; do
if [[ ! -d "$test_dir" ]]; then
printf '%b SKIP missing dir: %s%b\n' "$YELLOW" "$test_dir" "$NC"
continue
fi
while IFS= read -r -d '' sy_file; do
process_case "$sy_file"
done < <(find "$test_dir" -maxdepth 1 -type f -name '*.sy' -print0 | sort -z)
done
# ---------- 汇总 ----------
TOTAL_END_NS=$(now_ns)
TOTAL_ELAPSED_NS=$((TOTAL_END_NS - TOTAL_START_NS))
TOTAL_CASES=$((PASS + SKIP + FAIL))
printf '\n'
printf 'Summary: %d DONE / %d SKIP (cached) / %d FAIL / total %d\n' \
"$PASS" "$SKIP" "$FAIL" "$TOTAL_CASES"
printf 'Total elapsed : %s\n' "$(format_duration_ns "$TOTAL_ELAPSED_NS")"
printf 'Timing TSV : %s (%d entries)\n' \
"$TIMING_TSV" "$(wc -l < "$TIMING_TSV" 2>/dev/null || echo 0)"
[[ $FAIL -eq 0 ]]

@ -0,0 +1,103 @@
============================================================
脚本优化总结2026-04
============================================================
一、架构分离
────────────────────────────────────────────────────────────
· run_baseline.sh 成为唯一负责计算 GCC -O2 基线的脚本;
其余所有脚本lab3_build_test.sh、analyze_case.sh只读
TSV不再重复运行 GCC避免重复耗时。
· 基线输出目录镜像测试用例的相对路径结构,例如:
output/baseline/test_case/functional/65_color.gcc.s
output/baseline/class_test_case/h_functional/11_BST.gcc.s
TSV 键与目录结构对齐class_test_case/h_functional/11_BST
二、SysY → C 编译兼容性修复run_baseline.sh
────────────────────────────────────────────────────────────
· const int 全局数组维度问题
C 模式下 const int N=10; int a[N]; 属于 VLA非法于文件域
用 Python3 预处理将 const int NAME=EXPR; 转换为:
#define NAME ((int)(EXPR))
同时支持多声明符写法const int A=1, B=2;
· sylib 链接方式
预编译 sylib.o-x c用 -include sylib.h 注入声明;
链接命令用 -x none 在 .o 前重置语言标志,防止 ELF 被
当作 C 源文件解析stray '\177' 错误)。
· C++ 关键字冲突
部分 SysY 测试用例用 delete/new/class 作函数名;
-x c 模式下这些不是关键字,编译正常通过。
· 枚举浮点值
enum { MAX = 1e9 }; 枚举成员必须是整数常量Python3
预处理同样将其转为 #define MAX ((int)(1e9))。
三、计时精度与准确性
────────────────────────────────────────────────────────────
· 全面弃用 /usr/bin/time非零退出时会向输出文件写入
"Command exited with non-zero status N",污染时间值。
· 改用 date +%s%N 纳秒手动计时:
_t0=$(date +%s%N)
... 运行命令 ...
_t1=$(date +%s%N)
elapsed=$(awk "BEGIN{printf \"%.5f\", $((t1-t0)) / 1e9}")
· 所有时间输出统一为 5 位小数(秒),加速比同样 5 位小数。
四、分段计时verify_asm.sh + lab3_build_test.sh
────────────────────────────────────────────────────────────
· verify_asm.sh 新增 --timing-out <file> 选项,运行结束后
向文件写入:
compile_ns=<纳秒>
run_ns=<纳秒>
· lab3_build_test.sh 读取 timing 文件,将编译耗时与运行耗时
分开显示:
PASS test_case/functional/65_color [compile=0.31416s run=0.18804s]
· 加速比只使用运行时间run_ns排除编译器启动开销。
五、性能排行榜lab3_build_test.sh
────────────────────────────────────────────────────────────
· 测试结束后输出双排序表格:
Sort 1加速比升序最需优化的用例排最前
Sort 2我方用时降序绝对耗时最高的排最前
每行格式:
<用例名> <我方时间> <GCC时间> <加速比>x
六、analyze_case.sh 修复
────────────────────────────────────────────────────────────
· 基线查找键从裸 stem65_color改为完整路径键
test_case/functional/65_color与 TSV 格式对齐,
消除 "WARNING: no baseline entry" 误报。
· run_and_time 函数的 rpt_color 输出重定向到 stderr
防止 ANSI 转义码被命令替换($())捕获后传入 awk
消除 "fatal: error: invalid character '\033'" 错误。
============================================================
脚本列表
============================================================
run_baseline.sh 计算所有用例的 GCC -O2 基线,结果存入
output/baseline/gcc_timing.tsv
用法: ./scripts/run_baseline.sh [--update]
--update 清空重算全部条目
lab3_build_test.sh 构建编译器,跑全部用例,输出加速比排行榜
用法: ./scripts/lab3_build_test.sh
verify_asm.sh 验证单个用例的汇编正确性
用法: ./scripts/verify_asm.sh <input.sy> [input.in] \
[expected.out] [timeout] [--timing-out file]
analyze_case.sh 单用例深度分析IR/ASM/计时/与基线对比)
用法: ./scripts/analyze_case.sh <input.sy> [output_dir]
clean_outputs.sh 清理 output/ 目录下的分析结果
用法: ./scripts/clean_outputs.sh
============================================================

@ -1,16 +1,23 @@
#!/usr/bin/env bash
set -euo pipefail
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
if [[ $# -lt 1 || $# -gt 5 ]]; then
echo "usage: $0 input.sy [output_dir] [--run] [--timing-out file]" >&2
exit 1
fi
input=$1
out_dir="test/test_result/asm"
out_dir="$REPO_ROOT/test/test_result/asm"
run_exec=false
input_dir=$(dirname "$input")
timing_out=""
_compile_ns=0
_run_ns=0
now_ns() { date +%s%N; }
shift
while [[ $# -gt 0 ]]; do
@ -18,6 +25,10 @@ while [[ $# -gt 0 ]]; do
--run)
run_exec=true
;;
--timing-out)
timing_out="$2"
shift
;;
*)
out_dir="$1"
;;
@ -26,18 +37,24 @@ while [[ $# -gt 0 ]]; do
done
if [[ ! -f "$input" ]]; then
echo "输入文件不存在: $input" >&2
echo "input file not found: $input" >&2
exit 1
fi
compiler="./build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建。" >&2
compiler=""
for candidate in "$REPO_ROOT/build_lab3/bin/compiler" "$REPO_ROOT/build_lab2/bin/compiler" "$REPO_ROOT/build/bin/compiler"; do
if [[ -x "$candidate" ]]; then
compiler="$candidate"
break
fi
done
if [[ -z "$compiler" ]]; then
echo "compiler not found; try: cmake -S . -B build_lab3 && cmake --build build_lab3 -j" >&2
exit 1
fi
if ! command -v aarch64-linux-gnu-gcc >/dev/null 2>&1; then
echo "未找到 aarch64-linux-gnu-gcc无法汇编/链接。" >&2
echo "aarch64-linux-gnu-gcc not found" >&2
exit 1
fi
@ -49,31 +66,55 @@ exe="$out_dir/$stem"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
_compile_start_ns=$(now_ns)
"$compiler" --emit-asm "$input" > "$asm_file"
echo "汇编已生成: $asm_file"
echo "asm generated: $asm_file"
aarch64-linux-gnu-gcc "$asm_file" -o "$exe"
echo "可执行文件已生成: $exe"
aarch64-linux-gnu-gcc "$asm_file" "$REPO_ROOT/sylib/sylib.c" -O2 -o "$exe"
echo "executable generated: $exe"
_compile_ns=$(($(now_ns) - _compile_start_ns))
if [[ "$run_exec" == true ]]; then
if ! command -v qemu-aarch64 >/dev/null 2>&1; then
echo "未找到 qemu-aarch64无法运行生成的可执行文件。" >&2
echo "qemu-aarch64 not found" >&2
exit 1
fi
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
echo "运行 $exe ..."
timeout_sec="${RUN_TIMEOUT_SEC:-60}"
if [[ "$input" == *"/performance/"* || "$input" == *"/h_performance/"* ]]; then
timeout_sec="${PERF_TIMEOUT_SEC:-300}"
fi
set +e
if [[ -f "$stdin_file" ]]; then
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file"
_run_start_ns=$(now_ns)
if command -v timeout >/dev/null 2>&1; then
if [[ -f "$stdin_file" ]]; then
timeout "$timeout_sec" qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file"
else
timeout "$timeout_sec" qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" > "$stdout_file"
fi
else
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" > "$stdout_file"
if [[ -f "$stdin_file" ]]; then
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file"
else
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" > "$stdout_file"
fi
fi
status=$?
_run_ns=$(($(now_ns) - _run_start_ns))
set -e
if [[ $status -eq 124 ]]; then
echo "timeout after ${timeout_sec}s: $exe" >&2
fi
cat "$stdout_file"
echo "退出码: $status"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
echo "exit code: $status"
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
@ -83,14 +124,18 @@ if [[ "$run_exec" == true ]]; then
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -u "$expected_file" "$actual_file"; then
echo "输出匹配: $expected_file"
if diff -u <(awk '{ sub(/\r$/, ""); print }' "$expected_file") <(awk '{ sub(/\r$/, ""); print }' "$actual_file"); then
echo "matched: $expected_file"
else
echo "输出不匹配: $expected_file" >&2
echo "实际输出已保存: $actual_file" >&2
echo "mismatch: $expected_file" >&2
echo "actual saved to: $actual_file" >&2
exit 1
fi
else
echo "未找到预期输出文件,跳过比对: $expected_file"
echo "expected output not found, skipped diff: $expected_file"
fi
fi
if [[ -n "$timing_out" ]]; then
printf 'compile_ns=%s\nrun_ns=%s\n' "$_compile_ns" "$_run_ns" > "$timing_out"
fi

@ -0,0 +1,145 @@
#!/usr/bin/env bash
set -euo pipefail
if [[ $# -lt 1 || $# -gt 2 ]]; then
echo "用法: $0 <test_dir> [output_dir]" >&2
exit 1
fi
test_dir=${1%/}
out_dir="test/test_result/function/asm"
shift
while [[ $# -gt 0 ]]; do
out_dir="$1"
shift
done
if [[ ! -d "$test_dir" ]]; then
echo "测试目录不存在: $test_dir" >&2
exit 1
fi
compiler="./build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建。" >&2
exit 1
fi
if ! command -v aarch64-linux-gnu-gcc >/dev/null 2>&1; then
echo "未找到 aarch64-linux-gnu-gcc无法汇编/链接。" >&2
exit 1
fi
if ! command -v qemu-aarch64 >/dev/null 2>&1; then
echo "未找到 qemu-aarch64无法运行生成的可执行文件。" >&2
exit 1
fi
sylib_c="sylib/sylib.c"
if [[ ! -f "$sylib_c" ]]; then
echo "未找到 sylib: $sylib_c" >&2
exit 1
fi
mkdir -p "$out_dir"
sylib_obj="$out_dir/sylib.o"
aarch64-linux-gnu-gcc -c "$sylib_c" -I sylib -o "$sylib_obj"
mapfile -t inputs < <(find "$test_dir" -type f -name '*.sy' | sort)
if [[ ${#inputs[@]} -eq 0 ]]; then
echo "测试目录下未找到 .sy 文件: $test_dir" >&2
exit 1
fi
failures=0
normalize() {
# strip CR, then strip a single trailing newline so both files
# are comparable regardless of CRLF / trailing-newline differences
tr -d '\r' < "$1" | sed -e '${ /^$/d; }' | perl -pe 'chomp if eof'
}
run_case() {
local input=$1
local input_dir base stem rel_path rel_dir case_out_dir asm_file exe
local stdin_file expected_file stdout_file actual_file status
input_dir=$(dirname "$input")
base=$(basename "$input")
stem=${base%.sy}
rel_path=${input#"$test_dir"/}
rel_dir=$(dirname "$rel_path")
case_out_dir="$out_dir"
if [[ "$rel_dir" != "." ]]; then
case_out_dir="$out_dir/$rel_dir"
fi
mkdir -p "$case_out_dir"
asm_file="$case_out_dir/$stem.s"
exe="$case_out_dir/$stem"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
stdout_file="$case_out_dir/$stem.stdout"
actual_file="$case_out_dir/$stem.actual.out"
if ! "$compiler" --emit-asm "$input" > "$asm_file" 2>"$case_out_dir/$stem.err"; then
echo "$stem: 编译失败"
cat "$case_out_dir/$stem.err" >&2
return 1
fi
if ! aarch64-linux-gnu-gcc "$asm_file" "$sylib_obj" -o "$exe" 2>"$case_out_dir/$stem.link.err"; then
echo "$stem: 链接失败"
cat "$case_out_dir/$stem.link.err" >&2
return 1
fi
set +e
if [[ -f "$stdin_file" ]]; then
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file"
else
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" > "$stdout_file"
fi
status=$?
set -e
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff <(normalize "$expected_file") <(normalize "$actual_file") >/dev/null 2>&1; then
echo "$stem: PASS"
return 0
else
echo "$stem: FAIL (退出码: $status)"
diff -u --strip-trailing-cr "$expected_file" "$actual_file" >&2 || true
return 1
fi
else
echo "$stem: SKIP (无预期输出, 退出码: $status)"
return 0
fi
}
for input in "${inputs[@]}"; do
if ! run_case "$input"; then
((failures+=1))
fi
done
total=${#inputs[@]}
passed=$((total - failures))
echo "总计: $total, 通过: $passed, 失败: $failures"
if (( failures > 0 )); then
exit 1
fi

@ -0,0 +1,160 @@
#!/usr/bin/env bash
set -euo pipefail
if [[ $# -lt 1 || $# -gt 2 ]]; then
echo "用法: $0 <test_dir> [output_dir]" >&2
exit 1
fi
test_dir=${1%/}
out_dir="test/test_result/function/asm_time"
shift
while [[ $# -gt 0 ]]; do
out_dir="$1"
shift
done
if [[ ! -d "$test_dir" ]]; then
echo "测试目录不存在: $test_dir" >&2
exit 1
fi
compiler="./build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建。" >&2
exit 1
fi
if ! command -v aarch64-linux-gnu-gcc >/dev/null 2>&1; then
echo "未找到 aarch64-linux-gnu-gcc无法汇编/链接。" >&2
exit 1
fi
if ! command -v qemu-aarch64 >/dev/null 2>&1; then
echo "未找到 qemu-aarch64无法运行生成的可执行文件。" >&2
exit 1
fi
sylib_c="sylib/sylib.c"
if [[ ! -f "$sylib_c" ]]; then
echo "未找到 sylib: $sylib_c" >&2
exit 1
fi
mkdir -p "$out_dir"
sylib_obj="$out_dir/sylib.o"
aarch64-linux-gnu-gcc -c "$sylib_c" -I sylib -o "$sylib_obj"
mapfile -t inputs < <(find "$test_dir" -type f -name '*.sy' | sort)
if [[ ${#inputs[@]} -eq 0 ]]; then
echo "测试目录下未找到 .sy 文件: $test_dir" >&2
exit 1
fi
failures=0
normalize() {
tr -d '\r' < "$1" | sed -e '${ /^$/d; }' | perl -pe 'chomp if eof'
}
run_case() {
local input=$1
local input_dir base stem rel_path rel_dir case_out_dir asm_file exe
local stdin_file expected_file stdout_file actual_file time_file elapsed status
input_dir=$(dirname "$input")
base=$(basename "$input")
stem=${base%.sy}
rel_path=${input#"$test_dir"/}
rel_dir=$(dirname "$rel_path")
case_out_dir="$out_dir"
if [[ "$rel_dir" != "." ]]; then
case_out_dir="$out_dir/$rel_dir"
fi
mkdir -p "$case_out_dir"
asm_file="$case_out_dir/$stem.s"
exe="$case_out_dir/$stem"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
stdout_file="$case_out_dir/$stem.stdout"
actual_file="$case_out_dir/$stem.actual.out"
time_file="$case_out_dir/$stem.time"
if ! "$compiler" --emit-asm "$input" > "$asm_file" 2>"$case_out_dir/$stem.err"; then
echo "$stem: 编译失败"
cat "$case_out_dir/$stem.err" >&2
return 1
fi
if ! aarch64-linux-gnu-gcc "$asm_file" "$sylib_obj" -o "$exe" 2>"$case_out_dir/$stem.link.err"; then
echo "$stem: 链接失败"
cat "$case_out_dir/$stem.link.err" >&2
return 1
fi
set +e
if [[ -f "$stdin_file" ]]; then
/usr/bin/time -f "%e" -o "$time_file" \
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file"
else
/usr/bin/time -f "%e" -o "$time_file" \
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" > "$stdout_file"
fi
status=$?
set -e
elapsed=$(tail -1 "$time_file")
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff <(normalize "$expected_file") <(normalize "$actual_file") >/dev/null 2>&1; then
printf "%s: PASS (%.3fs)\n" "$stem" "$elapsed"
printf '%s\t%s\n' "$elapsed" "$stem" >> "$out_dir/elapsed.log"
return 0
else
printf "%s: FAIL (退出码: %d, 耗时: %.3fs)\n" "$stem" "$status" "$elapsed"
diff -u --strip-trailing-cr "$expected_file" "$actual_file" >&2 || true
return 1
fi
else
printf "%s: SKIP (无预期输出, %.3fs, 退出码: %d)\n" "$stem" "$elapsed" "$status"
printf '%s\t%s\n' "$elapsed" "$stem" >> "$out_dir/elapsed.log"
return 0
fi
}
rm -f "$out_dir/elapsed.log"
for input in "${inputs[@]}"; do
if ! run_case "$input"; then
((failures+=1))
fi
done
total=${#inputs[@]}
passed=$((total - failures))
if [[ -f "$out_dir/elapsed.log" && -s "$out_dir/elapsed.log" ]]; then
total_elapsed=$(awk '{s+=$1} END {printf "%.3f", s}' "$out_dir/elapsed.log")
else
total_elapsed="0.000"
fi
echo "总计: $total, 通过: $passed, 失败: $failures"
echo "通过用例总耗时: $total_elapsed s"
if (( failures > 0 )); then
exit 1
fi

@ -1,15 +1,16 @@
#!/usr/bin/env bash
# ./scripts/verify_ir.sh test/test_case/functional/simple_add.sy test/test_result/function/ir --run
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
echo "usage: $0 input.sy [output_dir] [--run]" >&2
exit 1
fi
input=$1
out_dir="test/test_result/ir"
out_dir="$REPO_ROOT/test/test_result/ir"
run_exec=false
input_dir=$(dirname "$input")
@ -27,13 +28,19 @@ while [[ $# -gt 0 ]]; do
done
if [[ ! -f "$input" ]]; then
echo "输入文件不存在: $input" >&2
echo "input file not found: $input" >&2
exit 1
fi
compiler="./build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建(如: mkdir -p build && cd build && cmake .. && make -j" >&2
compiler=""
for candidate in "$REPO_ROOT/build_lab2/bin/compiler" "$REPO_ROOT/build/bin/compiler"; do
if [[ -x "$candidate" ]]; then
compiler="$candidate"
break
fi
done
if [[ -z "$compiler" ]]; then
echo "compiler not found; try: cmake -S . -B build_lab2 && cmake --build build_lab2 -j" >&2
exit 1
fi
@ -44,34 +51,59 @@ out_file="$out_dir/$stem.ll"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
"$compiler" --emit-ir "$input" > "$out_file"
echo "IR 已生成: $out_file"
echo "IR generated: $out_file"
if [[ "$run_exec" == true ]]; then
if ! command -v llc >/dev/null 2>&1; then
echo "未找到 llc无法运行 IR。请安装 LLVM。" >&2
echo "llc not found" >&2
exit 1
fi
if ! command -v clang >/dev/null 2>&1; then
echo "未找到 clang无法链接可执行文件。请安装 LLVM/Clang。" >&2
echo "clang not found" >&2
exit 1
fi
obj="$out_dir/$stem.o"
exe="$out_dir/$stem"
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
llc -filetype=obj "$out_file" -o "$obj"
clang "$obj" -o "$exe"
echo "运行 $exe ..."
llc -opaque-pointers -filetype=obj "$out_file" -o "$obj"
clang "$obj" "$REPO_ROOT/sylib/sylib.c" -o "$exe"
# Optional timeout to prevent hanging test cases.
# Override with RUN_TIMEOUT_SEC/PERF_TIMEOUT_SEC env vars.
timeout_sec="${RUN_TIMEOUT_SEC:-60}"
if [[ "$input" == *"/performance/"* || "$input" == *"/h_performance/"* ]]; then
timeout_sec="${PERF_TIMEOUT_SEC:-300}"
fi
set +e
if [[ -f "$stdin_file" ]]; then
"$exe" < "$stdin_file" > "$stdout_file"
if command -v timeout >/dev/null 2>&1; then
if [[ -f "$stdin_file" ]]; then
timeout "$timeout_sec" "$exe" < "$stdin_file" > "$stdout_file"
else
timeout "$timeout_sec" "$exe" > "$stdout_file"
fi
else
"$exe" > "$stdout_file"
if [[ -f "$stdin_file" ]]; then
"$exe" < "$stdin_file" > "$stdout_file"
else
"$exe" > "$stdout_file"
fi
fi
status=$?
set -e
if [[ $status -eq 124 ]]; then
echo "timeout after ${timeout_sec}s: $exe" >&2
fi
cat "$stdout_file"
echo "退出码: $status"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
echo "exit code: $status"
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
@ -81,14 +113,14 @@ if [[ "$run_exec" == true ]]; then
} > "$actual_file"
if [[ -f "$expected_file" ]]; then
if diff -u "$expected_file" "$actual_file"; then
echo "输出匹配: $expected_file"
if diff -u <(awk '{ sub(/\r$/, ""); print }' "$expected_file") <(awk '{ sub(/\r$/, ""); print }' "$actual_file"); then
echo "matched: $expected_file"
else
echo "输出不匹配: $expected_file" >&2
echo "实际输出已保存: $actual_file" >&2
echo "mismatch: $expected_file" >&2
echo "actual saved to: $actual_file" >&2
exit 1
fi
else
echo "未找到预期输出文件,跳过比对: $expected_file"
echo "expected output not found, skipped diff: $expected_file"
fi
fi

@ -0,0 +1,591 @@
// Generated from /root/sysy2026/nudt-compiler-cpp/src/antlr4/SysY.g4 by ANTLR 4.13.1
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.tree.ErrorNode;
import org.antlr.v4.runtime.tree.TerminalNode;
/**
* This class provides an empty implementation of {@link SysYListener},
* which can be extended to create a listener which only needs to handle a subset
* of the available methods.
*/
@SuppressWarnings("CheckReturnValue")
public class SysYBaseListener implements SysYListener {
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterCompUnit(SysYParser.CompUnitContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitCompUnit(SysYParser.CompUnitContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterDecl(SysYParser.DeclContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitDecl(SysYParser.DeclContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterConstDecl(SysYParser.ConstDeclContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitConstDecl(SysYParser.ConstDeclContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBType(SysYParser.BTypeContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBType(SysYParser.BTypeContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterConstDef(SysYParser.ConstDefContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitConstDef(SysYParser.ConstDefContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterConstInitVal(SysYParser.ConstInitValContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitConstInitVal(SysYParser.ConstInitValContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterVarDecl(SysYParser.VarDeclContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitVarDecl(SysYParser.VarDeclContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterVarDef(SysYParser.VarDefContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitVarDef(SysYParser.VarDefContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterInitVal(SysYParser.InitValContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitInitVal(SysYParser.InitValContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterFuncDef(SysYParser.FuncDefContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitFuncDef(SysYParser.FuncDefContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterFuncType(SysYParser.FuncTypeContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitFuncType(SysYParser.FuncTypeContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterFuncFParams(SysYParser.FuncFParamsContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitFuncFParams(SysYParser.FuncFParamsContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterFuncFParam(SysYParser.FuncFParamContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitFuncFParam(SysYParser.FuncFParamContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBlock(SysYParser.BlockContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBlock(SysYParser.BlockContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBlockItem(SysYParser.BlockItemContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBlockItem(SysYParser.BlockItemContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterAssignStmt(SysYParser.AssignStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitAssignStmt(SysYParser.AssignStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterExpStmt(SysYParser.ExpStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitExpStmt(SysYParser.ExpStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBlockStmt(SysYParser.BlockStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBlockStmt(SysYParser.BlockStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterIfStmt(SysYParser.IfStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitIfStmt(SysYParser.IfStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterWhileStmt(SysYParser.WhileStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitWhileStmt(SysYParser.WhileStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBreakStmt(SysYParser.BreakStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBreakStmt(SysYParser.BreakStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterContinueStmt(SysYParser.ContinueStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitContinueStmt(SysYParser.ContinueStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterReturnStmt(SysYParser.ReturnStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitReturnStmt(SysYParser.ReturnStmtContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterExp(SysYParser.ExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitExp(SysYParser.ExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterCond(SysYParser.CondContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitCond(SysYParser.CondContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterLVal(SysYParser.LValContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitLVal(SysYParser.LValContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterPrimaryExp(SysYParser.PrimaryExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitPrimaryExp(SysYParser.PrimaryExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterNumber(SysYParser.NumberContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitNumber(SysYParser.NumberContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterPrimaryUnaryExp(SysYParser.PrimaryUnaryExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitPrimaryUnaryExp(SysYParser.PrimaryUnaryExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterCallUnaryExp(SysYParser.CallUnaryExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitCallUnaryExp(SysYParser.CallUnaryExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterOpUnaryExp(SysYParser.OpUnaryExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitOpUnaryExp(SysYParser.OpUnaryExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterUnaryOp(SysYParser.UnaryOpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitUnaryOp(SysYParser.UnaryOpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterFuncRParams(SysYParser.FuncRParamsContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitFuncRParams(SysYParser.FuncRParamsContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBinaryMulExp(SysYParser.BinaryMulExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBinaryMulExp(SysYParser.BinaryMulExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterUnaryMulExp(SysYParser.UnaryMulExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitUnaryMulExp(SysYParser.UnaryMulExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBinaryAddExp(SysYParser.BinaryAddExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBinaryAddExp(SysYParser.BinaryAddExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterMulAddExp(SysYParser.MulAddExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitMulAddExp(SysYParser.MulAddExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterAddRelExp(SysYParser.AddRelExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitAddRelExp(SysYParser.AddRelExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBinaryRelExp(SysYParser.BinaryRelExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBinaryRelExp(SysYParser.BinaryRelExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBinaryEqExp(SysYParser.BinaryEqExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBinaryEqExp(SysYParser.BinaryEqExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterRelEqExp(SysYParser.RelEqExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitRelEqExp(SysYParser.RelEqExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterEqLAndExp(SysYParser.EqLAndExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitEqLAndExp(SysYParser.EqLAndExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBinaryLAndExp(SysYParser.BinaryLAndExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBinaryLAndExp(SysYParser.BinaryLAndExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterAndLOrExp(SysYParser.AndLOrExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitAndLOrExp(SysYParser.AndLOrExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterBinaryLOrExp(SysYParser.BinaryLOrExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitBinaryLOrExp(SysYParser.BinaryLOrExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterConstExp(SysYParser.ConstExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitConstExp(SysYParser.ConstExpContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterEveryRule(ParserRuleContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitEveryRule(ParserRuleContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void visitTerminal(TerminalNode node) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void visitErrorNode(ErrorNode node) { }
}

@ -0,0 +1,358 @@
// Generated from /root/sysy2026/nudt-compiler-cpp/src/antlr4/SysY.g4 by ANTLR 4.13.1
import org.antlr.v4.runtime.Lexer;
import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.TokenStream;
import org.antlr.v4.runtime.*;
import org.antlr.v4.runtime.atn.*;
import org.antlr.v4.runtime.dfa.DFA;
import org.antlr.v4.runtime.misc.*;
@SuppressWarnings({"all", "warnings", "unchecked", "unused", "cast", "CheckReturnValue", "this-escape"})
public class SysYLexer extends Lexer {
static { RuntimeMetaData.checkVersion("4.13.1", RuntimeMetaData.VERSION); }
protected static final DFA[] _decisionToDFA;
protected static final PredictionContextCache _sharedContextCache =
new PredictionContextCache();
public static final int
CONST=1, INT=2, FLOAT=3, VOID=4, IF=5, ELSE=6, WHILE=7, BREAK=8, CONTINUE=9,
RETURN=10, ADD=11, SUB=12, MUL=13, DIV=14, MOD=15, ASSIGN=16, EQ=17, NE=18,
LT=19, LE=20, GT=21, GE=22, NOT=23, AND=24, OR=25, LPAREN=26, RPAREN=27,
LBRACK=28, RBRACK=29, LBRACE=30, RBRACE=31, COMMA=32, SEMI=33, Ident=34,
IntConst=35, FloatConst=36, WS=37, LINE_COMMENT=38, BLOCK_COMMENT=39;
public static String[] channelNames = {
"DEFAULT_TOKEN_CHANNEL", "HIDDEN"
};
public static String[] modeNames = {
"DEFAULT_MODE"
};
private static String[] makeRuleNames() {
return new String[] {
"CONST", "INT", "FLOAT", "VOID", "IF", "ELSE", "WHILE", "BREAK", "CONTINUE",
"RETURN", "ADD", "SUB", "MUL", "DIV", "MOD", "ASSIGN", "EQ", "NE", "LT",
"LE", "GT", "GE", "NOT", "AND", "OR", "LPAREN", "RPAREN", "LBRACK", "RBRACK",
"LBRACE", "RBRACE", "COMMA", "SEMI", "Ident", "Digit", "NonzeroDigit",
"OctDigit", "HexDigit", "DecInteger", "OctInteger", "HexInteger", "DecFraction",
"DecExponent", "DecFloat", "HexFraction", "BinExponent", "HexFloat",
"IntConst", "FloatConst", "WS", "LINE_COMMENT", "BLOCK_COMMENT"
};
}
public static final String[] ruleNames = makeRuleNames();
private static String[] makeLiteralNames() {
return new String[] {
null, "'const'", "'int'", "'float'", "'void'", "'if'", "'else'", "'while'",
"'break'", "'continue'", "'return'", "'+'", "'-'", "'*'", "'/'", "'%'",
"'='", "'=='", "'!='", "'<'", "'<='", "'>'", "'>='", "'!'", "'&&'", "'||'",
"'('", "')'", "'['", "']'", "'{'", "'}'", "','", "';'"
};
}
private static final String[] _LITERAL_NAMES = makeLiteralNames();
private static String[] makeSymbolicNames() {
return new String[] {
null, "CONST", "INT", "FLOAT", "VOID", "IF", "ELSE", "WHILE", "BREAK",
"CONTINUE", "RETURN", "ADD", "SUB", "MUL", "DIV", "MOD", "ASSIGN", "EQ",
"NE", "LT", "LE", "GT", "GE", "NOT", "AND", "OR", "LPAREN", "RPAREN",
"LBRACK", "RBRACK", "LBRACE", "RBRACE", "COMMA", "SEMI", "Ident", "IntConst",
"FloatConst", "WS", "LINE_COMMENT", "BLOCK_COMMENT"
};
}
private static final String[] _SYMBOLIC_NAMES = makeSymbolicNames();
public static final Vocabulary VOCABULARY = new VocabularyImpl(_LITERAL_NAMES, _SYMBOLIC_NAMES);
/**
* @deprecated Use {@link #VOCABULARY} instead.
*/
@Deprecated
public static final String[] tokenNames;
static {
tokenNames = new String[_SYMBOLIC_NAMES.length];
for (int i = 0; i < tokenNames.length; i++) {
tokenNames[i] = VOCABULARY.getLiteralName(i);
if (tokenNames[i] == null) {
tokenNames[i] = VOCABULARY.getSymbolicName(i);
}
if (tokenNames[i] == null) {
tokenNames[i] = "<INVALID>";
}
}
}
@Override
@Deprecated
public String[] getTokenNames() {
return tokenNames;
}
@Override
public Vocabulary getVocabulary() {
return VOCABULARY;
}
public SysYLexer(CharStream input) {
super(input);
_interp = new LexerATNSimulator(this,_ATN,_decisionToDFA,_sharedContextCache);
}
@Override
public String getGrammarFileName() { return "SysY.g4"; }
@Override
public String[] getRuleNames() { return ruleNames; }
@Override
public String getSerializedATN() { return _serializedATN; }
@Override
public String[] getChannelNames() { return channelNames; }
@Override
public String[] getModeNames() { return modeNames; }
@Override
public ATN getATN() { return _ATN; }
public static final String _serializedATN =
"\u0004\u0000\'\u0171\u0006\uffff\uffff\u0002\u0000\u0007\u0000\u0002\u0001"+
"\u0007\u0001\u0002\u0002\u0007\u0002\u0002\u0003\u0007\u0003\u0002\u0004"+
"\u0007\u0004\u0002\u0005\u0007\u0005\u0002\u0006\u0007\u0006\u0002\u0007"+
"\u0007\u0007\u0002\b\u0007\b\u0002\t\u0007\t\u0002\n\u0007\n\u0002\u000b"+
"\u0007\u000b\u0002\f\u0007\f\u0002\r\u0007\r\u0002\u000e\u0007\u000e\u0002"+
"\u000f\u0007\u000f\u0002\u0010\u0007\u0010\u0002\u0011\u0007\u0011\u0002"+
"\u0012\u0007\u0012\u0002\u0013\u0007\u0013\u0002\u0014\u0007\u0014\u0002"+
"\u0015\u0007\u0015\u0002\u0016\u0007\u0016\u0002\u0017\u0007\u0017\u0002"+
"\u0018\u0007\u0018\u0002\u0019\u0007\u0019\u0002\u001a\u0007\u001a\u0002"+
"\u001b\u0007\u001b\u0002\u001c\u0007\u001c\u0002\u001d\u0007\u001d\u0002"+
"\u001e\u0007\u001e\u0002\u001f\u0007\u001f\u0002 \u0007 \u0002!\u0007"+
"!\u0002\"\u0007\"\u0002#\u0007#\u0002$\u0007$\u0002%\u0007%\u0002&\u0007"+
"&\u0002\'\u0007\'\u0002(\u0007(\u0002)\u0007)\u0002*\u0007*\u0002+\u0007"+
"+\u0002,\u0007,\u0002-\u0007-\u0002.\u0007.\u0002/\u0007/\u00020\u0007"+
"0\u00021\u00071\u00022\u00072\u00023\u00073\u0001\u0000\u0001\u0000\u0001"+
"\u0000\u0001\u0000\u0001\u0000\u0001\u0000\u0001\u0001\u0001\u0001\u0001"+
"\u0001\u0001\u0001\u0001\u0002\u0001\u0002\u0001\u0002\u0001\u0002\u0001"+
"\u0002\u0001\u0002\u0001\u0003\u0001\u0003\u0001\u0003\u0001\u0003\u0001"+
"\u0003\u0001\u0004\u0001\u0004\u0001\u0004\u0001\u0005\u0001\u0005\u0001"+
"\u0005\u0001\u0005\u0001\u0005\u0001\u0006\u0001\u0006\u0001\u0006\u0001"+
"\u0006\u0001\u0006\u0001\u0006\u0001\u0007\u0001\u0007\u0001\u0007\u0001"+
"\u0007\u0001\u0007\u0001\u0007\u0001\b\u0001\b\u0001\b\u0001\b\u0001\b"+
"\u0001\b\u0001\b\u0001\b\u0001\b\u0001\t\u0001\t\u0001\t\u0001\t\u0001"+
"\t\u0001\t\u0001\t\u0001\n\u0001\n\u0001\u000b\u0001\u000b\u0001\f\u0001"+
"\f\u0001\r\u0001\r\u0001\u000e\u0001\u000e\u0001\u000f\u0001\u000f\u0001"+
"\u0010\u0001\u0010\u0001\u0010\u0001\u0011\u0001\u0011\u0001\u0011\u0001"+
"\u0012\u0001\u0012\u0001\u0013\u0001\u0013\u0001\u0013\u0001\u0014\u0001"+
"\u0014\u0001\u0015\u0001\u0015\u0001\u0015\u0001\u0016\u0001\u0016\u0001"+
"\u0017\u0001\u0017\u0001\u0017\u0001\u0018\u0001\u0018\u0001\u0018\u0001"+
"\u0019\u0001\u0019\u0001\u001a\u0001\u001a\u0001\u001b\u0001\u001b\u0001"+
"\u001c\u0001\u001c\u0001\u001d\u0001\u001d\u0001\u001e\u0001\u001e\u0001"+
"\u001f\u0001\u001f\u0001 \u0001 \u0001!\u0001!\u0005!\u00d9\b!\n!\f!\u00dc"+
"\t!\u0001\"\u0001\"\u0001#\u0001#\u0001$\u0001$\u0001%\u0001%\u0001&\u0001"+
"&\u0005&\u00e8\b&\n&\f&\u00eb\t&\u0001\'\u0001\'\u0005\'\u00ef\b\'\n\'"+
"\f\'\u00f2\t\'\u0001(\u0001(\u0001(\u0004(\u00f7\b(\u000b(\f(\u00f8\u0001"+
")\u0004)\u00fc\b)\u000b)\f)\u00fd\u0001)\u0001)\u0005)\u0102\b)\n)\f)"+
"\u0105\t)\u0001)\u0001)\u0004)\u0109\b)\u000b)\f)\u010a\u0003)\u010d\b"+
")\u0001*\u0001*\u0003*\u0111\b*\u0001*\u0004*\u0114\b*\u000b*\f*\u0115"+
"\u0001+\u0001+\u0003+\u011a\b+\u0001+\u0001+\u0001+\u0003+\u011f\b+\u0001"+
",\u0005,\u0122\b,\n,\f,\u0125\t,\u0001,\u0001,\u0004,\u0129\b,\u000b,"+
"\f,\u012a\u0001,\u0004,\u012e\b,\u000b,\f,\u012f\u0001,\u0001,\u0003,"+
"\u0134\b,\u0001-\u0001-\u0003-\u0138\b-\u0001-\u0004-\u013b\b-\u000b-"+
"\f-\u013c\u0001.\u0001.\u0001.\u0001.\u0001.\u0001.\u0001.\u0001.\u0003"+
".\u0147\b.\u0001/\u0001/\u0001/\u0003/\u014c\b/\u00010\u00010\u00030\u0150"+
"\b0\u00011\u00041\u0153\b1\u000b1\f1\u0154\u00011\u00011\u00012\u0001"+
"2\u00012\u00012\u00052\u015d\b2\n2\f2\u0160\t2\u00012\u00012\u00013\u0001"+
"3\u00013\u00013\u00053\u0168\b3\n3\f3\u016b\t3\u00013\u00013\u00013\u0001"+
"3\u00013\u0001\u0169\u00004\u0001\u0001\u0003\u0002\u0005\u0003\u0007"+
"\u0004\t\u0005\u000b\u0006\r\u0007\u000f\b\u0011\t\u0013\n\u0015\u000b"+
"\u0017\f\u0019\r\u001b\u000e\u001d\u000f\u001f\u0010!\u0011#\u0012%\u0013"+
"\'\u0014)\u0015+\u0016-\u0017/\u00181\u00193\u001a5\u001b7\u001c9\u001d"+
";\u001e=\u001f? A!C\"E\u0000G\u0000I\u0000K\u0000M\u0000O\u0000Q\u0000"+
"S\u0000U\u0000W\u0000Y\u0000[\u0000]\u0000_#a$c%e&g\'\u0001\u0000\f\u0003"+
"\u0000AZ__az\u0004\u000009AZ__az\u0001\u000009\u0001\u000019\u0001\u0000"+
"07\u0003\u000009AFaf\u0002\u0000XXxx\u0002\u0000EEee\u0002\u0000++--\u0002"+
"\u0000PPpp\u0003\u0000\t\n\r\r \u0002\u0000\n\n\r\r\u017c\u0000\u0001"+
"\u0001\u0000\u0000\u0000\u0000\u0003\u0001\u0000\u0000\u0000\u0000\u0005"+
"\u0001\u0000\u0000\u0000\u0000\u0007\u0001\u0000\u0000\u0000\u0000\t\u0001"+
"\u0000\u0000\u0000\u0000\u000b\u0001\u0000\u0000\u0000\u0000\r\u0001\u0000"+
"\u0000\u0000\u0000\u000f\u0001\u0000\u0000\u0000\u0000\u0011\u0001\u0000"+
"\u0000\u0000\u0000\u0013\u0001\u0000\u0000\u0000\u0000\u0015\u0001\u0000"+
"\u0000\u0000\u0000\u0017\u0001\u0000\u0000\u0000\u0000\u0019\u0001\u0000"+
"\u0000\u0000\u0000\u001b\u0001\u0000\u0000\u0000\u0000\u001d\u0001\u0000"+
"\u0000\u0000\u0000\u001f\u0001\u0000\u0000\u0000\u0000!\u0001\u0000\u0000"+
"\u0000\u0000#\u0001\u0000\u0000\u0000\u0000%\u0001\u0000\u0000\u0000\u0000"+
"\'\u0001\u0000\u0000\u0000\u0000)\u0001\u0000\u0000\u0000\u0000+\u0001"+
"\u0000\u0000\u0000\u0000-\u0001\u0000\u0000\u0000\u0000/\u0001\u0000\u0000"+
"\u0000\u00001\u0001\u0000\u0000\u0000\u00003\u0001\u0000\u0000\u0000\u0000"+
"5\u0001\u0000\u0000\u0000\u00007\u0001\u0000\u0000\u0000\u00009\u0001"+
"\u0000\u0000\u0000\u0000;\u0001\u0000\u0000\u0000\u0000=\u0001\u0000\u0000"+
"\u0000\u0000?\u0001\u0000\u0000\u0000\u0000A\u0001\u0000\u0000\u0000\u0000"+
"C\u0001\u0000\u0000\u0000\u0000_\u0001\u0000\u0000\u0000\u0000a\u0001"+
"\u0000\u0000\u0000\u0000c\u0001\u0000\u0000\u0000\u0000e\u0001\u0000\u0000"+
"\u0000\u0000g\u0001\u0000\u0000\u0000\u0001i\u0001\u0000\u0000\u0000\u0003"+
"o\u0001\u0000\u0000\u0000\u0005s\u0001\u0000\u0000\u0000\u0007y\u0001"+
"\u0000\u0000\u0000\t~\u0001\u0000\u0000\u0000\u000b\u0081\u0001\u0000"+
"\u0000\u0000\r\u0086\u0001\u0000\u0000\u0000\u000f\u008c\u0001\u0000\u0000"+
"\u0000\u0011\u0092\u0001\u0000\u0000\u0000\u0013\u009b\u0001\u0000\u0000"+
"\u0000\u0015\u00a2\u0001\u0000\u0000\u0000\u0017\u00a4\u0001\u0000\u0000"+
"\u0000\u0019\u00a6\u0001\u0000\u0000\u0000\u001b\u00a8\u0001\u0000\u0000"+
"\u0000\u001d\u00aa\u0001\u0000\u0000\u0000\u001f\u00ac\u0001\u0000\u0000"+
"\u0000!\u00ae\u0001\u0000\u0000\u0000#\u00b1\u0001\u0000\u0000\u0000%"+
"\u00b4\u0001\u0000\u0000\u0000\'\u00b6\u0001\u0000\u0000\u0000)\u00b9"+
"\u0001\u0000\u0000\u0000+\u00bb\u0001\u0000\u0000\u0000-\u00be\u0001\u0000"+
"\u0000\u0000/\u00c0\u0001\u0000\u0000\u00001\u00c3\u0001\u0000\u0000\u0000"+
"3\u00c6\u0001\u0000\u0000\u00005\u00c8\u0001\u0000\u0000\u00007\u00ca"+
"\u0001\u0000\u0000\u00009\u00cc\u0001\u0000\u0000\u0000;\u00ce\u0001\u0000"+
"\u0000\u0000=\u00d0\u0001\u0000\u0000\u0000?\u00d2\u0001\u0000\u0000\u0000"+
"A\u00d4\u0001\u0000\u0000\u0000C\u00d6\u0001\u0000\u0000\u0000E\u00dd"+
"\u0001\u0000\u0000\u0000G\u00df\u0001\u0000\u0000\u0000I\u00e1\u0001\u0000"+
"\u0000\u0000K\u00e3\u0001\u0000\u0000\u0000M\u00e5\u0001\u0000\u0000\u0000"+
"O\u00ec\u0001\u0000\u0000\u0000Q\u00f3\u0001\u0000\u0000\u0000S\u010c"+
"\u0001\u0000\u0000\u0000U\u010e\u0001\u0000\u0000\u0000W\u011e\u0001\u0000"+
"\u0000\u0000Y\u0133\u0001\u0000\u0000\u0000[\u0135\u0001\u0000\u0000\u0000"+
"]\u0146\u0001\u0000\u0000\u0000_\u014b\u0001\u0000\u0000\u0000a\u014f"+
"\u0001\u0000\u0000\u0000c\u0152\u0001\u0000\u0000\u0000e\u0158\u0001\u0000"+
"\u0000\u0000g\u0163\u0001\u0000\u0000\u0000ij\u0005c\u0000\u0000jk\u0005"+
"o\u0000\u0000kl\u0005n\u0000\u0000lm\u0005s\u0000\u0000mn\u0005t\u0000"+
"\u0000n\u0002\u0001\u0000\u0000\u0000op\u0005i\u0000\u0000pq\u0005n\u0000"+
"\u0000qr\u0005t\u0000\u0000r\u0004\u0001\u0000\u0000\u0000st\u0005f\u0000"+
"\u0000tu\u0005l\u0000\u0000uv\u0005o\u0000\u0000vw\u0005a\u0000\u0000"+
"wx\u0005t\u0000\u0000x\u0006\u0001\u0000\u0000\u0000yz\u0005v\u0000\u0000"+
"z{\u0005o\u0000\u0000{|\u0005i\u0000\u0000|}\u0005d\u0000\u0000}\b\u0001"+
"\u0000\u0000\u0000~\u007f\u0005i\u0000\u0000\u007f\u0080\u0005f\u0000"+
"\u0000\u0080\n\u0001\u0000\u0000\u0000\u0081\u0082\u0005e\u0000\u0000"+
"\u0082\u0083\u0005l\u0000\u0000\u0083\u0084\u0005s\u0000\u0000\u0084\u0085"+
"\u0005e\u0000\u0000\u0085\f\u0001\u0000\u0000\u0000\u0086\u0087\u0005"+
"w\u0000\u0000\u0087\u0088\u0005h\u0000\u0000\u0088\u0089\u0005i\u0000"+
"\u0000\u0089\u008a\u0005l\u0000\u0000\u008a\u008b\u0005e\u0000\u0000\u008b"+
"\u000e\u0001\u0000\u0000\u0000\u008c\u008d\u0005b\u0000\u0000\u008d\u008e"+
"\u0005r\u0000\u0000\u008e\u008f\u0005e\u0000\u0000\u008f\u0090\u0005a"+
"\u0000\u0000\u0090\u0091\u0005k\u0000\u0000\u0091\u0010\u0001\u0000\u0000"+
"\u0000\u0092\u0093\u0005c\u0000\u0000\u0093\u0094\u0005o\u0000\u0000\u0094"+
"\u0095\u0005n\u0000\u0000\u0095\u0096\u0005t\u0000\u0000\u0096\u0097\u0005"+
"i\u0000\u0000\u0097\u0098\u0005n\u0000\u0000\u0098\u0099\u0005u\u0000"+
"\u0000\u0099\u009a\u0005e\u0000\u0000\u009a\u0012\u0001\u0000\u0000\u0000"+
"\u009b\u009c\u0005r\u0000\u0000\u009c\u009d\u0005e\u0000\u0000\u009d\u009e"+
"\u0005t\u0000\u0000\u009e\u009f\u0005u\u0000\u0000\u009f\u00a0\u0005r"+
"\u0000\u0000\u00a0\u00a1\u0005n\u0000\u0000\u00a1\u0014\u0001\u0000\u0000"+
"\u0000\u00a2\u00a3\u0005+\u0000\u0000\u00a3\u0016\u0001\u0000\u0000\u0000"+
"\u00a4\u00a5\u0005-\u0000\u0000\u00a5\u0018\u0001\u0000\u0000\u0000\u00a6"+
"\u00a7\u0005*\u0000\u0000\u00a7\u001a\u0001\u0000\u0000\u0000\u00a8\u00a9"+
"\u0005/\u0000\u0000\u00a9\u001c\u0001\u0000\u0000\u0000\u00aa\u00ab\u0005"+
"%\u0000\u0000\u00ab\u001e\u0001\u0000\u0000\u0000\u00ac\u00ad\u0005=\u0000"+
"\u0000\u00ad \u0001\u0000\u0000\u0000\u00ae\u00af\u0005=\u0000\u0000\u00af"+
"\u00b0\u0005=\u0000\u0000\u00b0\"\u0001\u0000\u0000\u0000\u00b1\u00b2"+
"\u0005!\u0000\u0000\u00b2\u00b3\u0005=\u0000\u0000\u00b3$\u0001\u0000"+
"\u0000\u0000\u00b4\u00b5\u0005<\u0000\u0000\u00b5&\u0001\u0000\u0000\u0000"+
"\u00b6\u00b7\u0005<\u0000\u0000\u00b7\u00b8\u0005=\u0000\u0000\u00b8("+
"\u0001\u0000\u0000\u0000\u00b9\u00ba\u0005>\u0000\u0000\u00ba*\u0001\u0000"+
"\u0000\u0000\u00bb\u00bc\u0005>\u0000\u0000\u00bc\u00bd\u0005=\u0000\u0000"+
"\u00bd,\u0001\u0000\u0000\u0000\u00be\u00bf\u0005!\u0000\u0000\u00bf."+
"\u0001\u0000\u0000\u0000\u00c0\u00c1\u0005&\u0000\u0000\u00c1\u00c2\u0005"+
"&\u0000\u0000\u00c20\u0001\u0000\u0000\u0000\u00c3\u00c4\u0005|\u0000"+
"\u0000\u00c4\u00c5\u0005|\u0000\u0000\u00c52\u0001\u0000\u0000\u0000\u00c6"+
"\u00c7\u0005(\u0000\u0000\u00c74\u0001\u0000\u0000\u0000\u00c8\u00c9\u0005"+
")\u0000\u0000\u00c96\u0001\u0000\u0000\u0000\u00ca\u00cb\u0005[\u0000"+
"\u0000\u00cb8\u0001\u0000\u0000\u0000\u00cc\u00cd\u0005]\u0000\u0000\u00cd"+
":\u0001\u0000\u0000\u0000\u00ce\u00cf\u0005{\u0000\u0000\u00cf<\u0001"+
"\u0000\u0000\u0000\u00d0\u00d1\u0005}\u0000\u0000\u00d1>\u0001\u0000\u0000"+
"\u0000\u00d2\u00d3\u0005,\u0000\u0000\u00d3@\u0001\u0000\u0000\u0000\u00d4"+
"\u00d5\u0005;\u0000\u0000\u00d5B\u0001\u0000\u0000\u0000\u00d6\u00da\u0007"+
"\u0000\u0000\u0000\u00d7\u00d9\u0007\u0001\u0000\u0000\u00d8\u00d7\u0001"+
"\u0000\u0000\u0000\u00d9\u00dc\u0001\u0000\u0000\u0000\u00da\u00d8\u0001"+
"\u0000\u0000\u0000\u00da\u00db\u0001\u0000\u0000\u0000\u00dbD\u0001\u0000"+
"\u0000\u0000\u00dc\u00da\u0001\u0000\u0000\u0000\u00dd\u00de\u0007\u0002"+
"\u0000\u0000\u00deF\u0001\u0000\u0000\u0000\u00df\u00e0\u0007\u0003\u0000"+
"\u0000\u00e0H\u0001\u0000\u0000\u0000\u00e1\u00e2\u0007\u0004\u0000\u0000"+
"\u00e2J\u0001\u0000\u0000\u0000\u00e3\u00e4\u0007\u0005\u0000\u0000\u00e4"+
"L\u0001\u0000\u0000\u0000\u00e5\u00e9\u0003G#\u0000\u00e6\u00e8\u0003"+
"E\"\u0000\u00e7\u00e6\u0001\u0000\u0000\u0000\u00e8\u00eb\u0001\u0000"+
"\u0000\u0000\u00e9\u00e7\u0001\u0000\u0000\u0000\u00e9\u00ea\u0001\u0000"+
"\u0000\u0000\u00eaN\u0001\u0000\u0000\u0000\u00eb\u00e9\u0001\u0000\u0000"+
"\u0000\u00ec\u00f0\u00050\u0000\u0000\u00ed\u00ef\u0003I$\u0000\u00ee"+
"\u00ed\u0001\u0000\u0000\u0000\u00ef\u00f2\u0001\u0000\u0000\u0000\u00f0"+
"\u00ee\u0001\u0000\u0000\u0000\u00f0\u00f1\u0001\u0000\u0000\u0000\u00f1"+
"P\u0001\u0000\u0000\u0000\u00f2\u00f0\u0001\u0000\u0000\u0000\u00f3\u00f4"+
"\u00050\u0000\u0000\u00f4\u00f6\u0007\u0006\u0000\u0000\u00f5\u00f7\u0003"+
"K%\u0000\u00f6\u00f5\u0001\u0000\u0000\u0000\u00f7\u00f8\u0001\u0000\u0000"+
"\u0000\u00f8\u00f6\u0001\u0000\u0000\u0000\u00f8\u00f9\u0001\u0000\u0000"+
"\u0000\u00f9R\u0001\u0000\u0000\u0000\u00fa\u00fc\u0003E\"\u0000\u00fb"+
"\u00fa\u0001\u0000\u0000\u0000\u00fc\u00fd\u0001\u0000\u0000\u0000\u00fd"+
"\u00fb\u0001\u0000\u0000\u0000\u00fd\u00fe\u0001\u0000\u0000\u0000\u00fe"+
"\u00ff\u0001\u0000\u0000\u0000\u00ff\u0103\u0005.\u0000\u0000\u0100\u0102"+
"\u0003E\"\u0000\u0101\u0100\u0001\u0000\u0000\u0000\u0102\u0105\u0001"+
"\u0000\u0000\u0000\u0103\u0101\u0001\u0000\u0000\u0000\u0103\u0104\u0001"+
"\u0000\u0000\u0000\u0104\u010d\u0001\u0000\u0000\u0000\u0105\u0103\u0001"+
"\u0000\u0000\u0000\u0106\u0108\u0005.\u0000\u0000\u0107\u0109\u0003E\""+
"\u0000\u0108\u0107\u0001\u0000\u0000\u0000\u0109\u010a\u0001\u0000\u0000"+
"\u0000\u010a\u0108\u0001\u0000\u0000\u0000\u010a\u010b\u0001\u0000\u0000"+
"\u0000\u010b\u010d\u0001\u0000\u0000\u0000\u010c\u00fb\u0001\u0000\u0000"+
"\u0000\u010c\u0106\u0001\u0000\u0000\u0000\u010dT\u0001\u0000\u0000\u0000"+
"\u010e\u0110\u0007\u0007\u0000\u0000\u010f\u0111\u0007\b\u0000\u0000\u0110"+
"\u010f\u0001\u0000\u0000\u0000\u0110\u0111\u0001\u0000\u0000\u0000\u0111"+
"\u0113\u0001\u0000\u0000\u0000\u0112\u0114\u0003E\"\u0000\u0113\u0112"+
"\u0001\u0000\u0000\u0000\u0114\u0115\u0001\u0000\u0000\u0000\u0115\u0113"+
"\u0001\u0000\u0000\u0000\u0115\u0116\u0001\u0000\u0000\u0000\u0116V\u0001"+
"\u0000\u0000\u0000\u0117\u0119\u0003S)\u0000\u0118\u011a\u0003U*\u0000"+
"\u0119\u0118\u0001\u0000\u0000\u0000\u0119\u011a\u0001\u0000\u0000\u0000"+
"\u011a\u011f\u0001\u0000\u0000\u0000\u011b\u011c\u0003M&\u0000\u011c\u011d"+
"\u0003U*\u0000\u011d\u011f\u0001\u0000\u0000\u0000\u011e\u0117\u0001\u0000"+
"\u0000\u0000\u011e\u011b\u0001\u0000\u0000\u0000\u011fX\u0001\u0000\u0000"+
"\u0000\u0120\u0122\u0003K%\u0000\u0121\u0120\u0001\u0000\u0000\u0000\u0122"+
"\u0125\u0001\u0000\u0000\u0000\u0123\u0121\u0001\u0000\u0000\u0000\u0123"+
"\u0124\u0001\u0000\u0000\u0000\u0124\u0126\u0001\u0000\u0000\u0000\u0125"+
"\u0123\u0001\u0000\u0000\u0000\u0126\u0128\u0005.\u0000\u0000\u0127\u0129"+
"\u0003K%\u0000\u0128\u0127\u0001\u0000\u0000\u0000\u0129\u012a\u0001\u0000"+
"\u0000\u0000\u012a\u0128\u0001\u0000\u0000\u0000\u012a\u012b\u0001\u0000"+
"\u0000\u0000\u012b\u0134\u0001\u0000\u0000\u0000\u012c\u012e\u0003K%\u0000"+
"\u012d\u012c\u0001\u0000\u0000\u0000\u012e\u012f\u0001\u0000\u0000\u0000"+
"\u012f\u012d\u0001\u0000\u0000\u0000\u012f\u0130\u0001\u0000\u0000\u0000"+
"\u0130\u0131\u0001\u0000\u0000\u0000\u0131\u0132\u0005.\u0000\u0000\u0132"+
"\u0134\u0001\u0000\u0000\u0000\u0133\u0123\u0001\u0000\u0000\u0000\u0133"+
"\u012d\u0001\u0000\u0000\u0000\u0134Z\u0001\u0000\u0000\u0000\u0135\u0137"+
"\u0007\t\u0000\u0000\u0136\u0138\u0007\b\u0000\u0000\u0137\u0136\u0001"+
"\u0000\u0000\u0000\u0137\u0138\u0001\u0000\u0000\u0000\u0138\u013a\u0001"+
"\u0000\u0000\u0000\u0139\u013b\u0003E\"\u0000\u013a\u0139\u0001\u0000"+
"\u0000\u0000\u013b\u013c\u0001\u0000\u0000\u0000\u013c\u013a\u0001\u0000"+
"\u0000\u0000\u013c\u013d\u0001\u0000\u0000\u0000\u013d\\\u0001\u0000\u0000"+
"\u0000\u013e\u013f\u00050\u0000\u0000\u013f\u0140\u0007\u0006\u0000\u0000"+
"\u0140\u0141\u0003Y,\u0000\u0141\u0142\u0003[-\u0000\u0142\u0147\u0001"+
"\u0000\u0000\u0000\u0143\u0144\u0003Q(\u0000\u0144\u0145\u0003[-\u0000"+
"\u0145\u0147\u0001\u0000\u0000\u0000\u0146\u013e\u0001\u0000\u0000\u0000"+
"\u0146\u0143\u0001\u0000\u0000\u0000\u0147^\u0001\u0000\u0000\u0000\u0148"+
"\u014c\u0003M&\u0000\u0149\u014c\u0003O\'\u0000\u014a\u014c\u0003Q(\u0000"+
"\u014b\u0148\u0001\u0000\u0000\u0000\u014b\u0149\u0001\u0000\u0000\u0000"+
"\u014b\u014a\u0001\u0000\u0000\u0000\u014c`\u0001\u0000\u0000\u0000\u014d"+
"\u0150\u0003W+\u0000\u014e\u0150\u0003].\u0000\u014f\u014d\u0001\u0000"+
"\u0000\u0000\u014f\u014e\u0001\u0000\u0000\u0000\u0150b\u0001\u0000\u0000"+
"\u0000\u0151\u0153\u0007\n\u0000\u0000\u0152\u0151\u0001\u0000\u0000\u0000"+
"\u0153\u0154\u0001\u0000\u0000\u0000\u0154\u0152\u0001\u0000\u0000\u0000"+
"\u0154\u0155\u0001\u0000\u0000\u0000\u0155\u0156\u0001\u0000\u0000\u0000"+
"\u0156\u0157\u00061\u0000\u0000\u0157d\u0001\u0000\u0000\u0000\u0158\u0159"+
"\u0005/\u0000\u0000\u0159\u015a\u0005/\u0000\u0000\u015a\u015e\u0001\u0000"+
"\u0000\u0000\u015b\u015d\b\u000b\u0000\u0000\u015c\u015b\u0001\u0000\u0000"+
"\u0000\u015d\u0160\u0001\u0000\u0000\u0000\u015e\u015c\u0001\u0000\u0000"+
"\u0000\u015e\u015f\u0001\u0000\u0000\u0000\u015f\u0161\u0001\u0000\u0000"+
"\u0000\u0160\u015e\u0001\u0000\u0000\u0000\u0161\u0162\u00062\u0000\u0000"+
"\u0162f\u0001\u0000\u0000\u0000\u0163\u0164\u0005/\u0000\u0000\u0164\u0165"+
"\u0005*\u0000\u0000\u0165\u0169\u0001\u0000\u0000\u0000\u0166\u0168\t"+
"\u0000\u0000\u0000\u0167\u0166\u0001\u0000\u0000\u0000\u0168\u016b\u0001"+
"\u0000\u0000\u0000\u0169\u016a\u0001\u0000\u0000\u0000\u0169\u0167\u0001"+
"\u0000\u0000\u0000\u016a\u016c\u0001\u0000\u0000\u0000\u016b\u0169\u0001"+
"\u0000\u0000\u0000\u016c\u016d\u0005*\u0000\u0000\u016d\u016e\u0005/\u0000"+
"\u0000\u016e\u016f\u0001\u0000\u0000\u0000\u016f\u0170\u00063\u0000\u0000"+
"\u0170h\u0001\u0000\u0000\u0000\u0019\u0000\u00da\u00e9\u00f0\u00f8\u00fd"+
"\u0103\u010a\u010c\u0110\u0115\u0119\u011e\u0123\u012a\u012f\u0133\u0137"+
"\u013c\u0146\u014b\u014f\u0154\u015e\u0169\u0001\u0006\u0000\u0000";
public static final ATN _ATN =
new ATNDeserializer().deserialize(_serializedATN.toCharArray());
static {
_decisionToDFA = new DFA[_ATN.getNumberOfDecisions()];
for (int i = 0; i < _ATN.getNumberOfDecisions(); i++) {
_decisionToDFA[i] = new DFA(_ATN.getDecisionState(i), i);
}
}
}

@ -0,0 +1,515 @@
// Generated from /root/sysy2026/nudt-compiler-cpp/src/antlr4/SysY.g4 by ANTLR 4.13.1
import org.antlr.v4.runtime.tree.ParseTreeListener;
/**
* This interface defines a complete listener for a parse tree produced by
* {@link SysYParser}.
*/
public interface SysYListener extends ParseTreeListener {
/**
* Enter a parse tree produced by {@link SysYParser#compUnit}.
* @param ctx the parse tree
*/
void enterCompUnit(SysYParser.CompUnitContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#compUnit}.
* @param ctx the parse tree
*/
void exitCompUnit(SysYParser.CompUnitContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#decl}.
* @param ctx the parse tree
*/
void enterDecl(SysYParser.DeclContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#decl}.
* @param ctx the parse tree
*/
void exitDecl(SysYParser.DeclContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#constDecl}.
* @param ctx the parse tree
*/
void enterConstDecl(SysYParser.ConstDeclContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#constDecl}.
* @param ctx the parse tree
*/
void exitConstDecl(SysYParser.ConstDeclContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#bType}.
* @param ctx the parse tree
*/
void enterBType(SysYParser.BTypeContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#bType}.
* @param ctx the parse tree
*/
void exitBType(SysYParser.BTypeContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#constDef}.
* @param ctx the parse tree
*/
void enterConstDef(SysYParser.ConstDefContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#constDef}.
* @param ctx the parse tree
*/
void exitConstDef(SysYParser.ConstDefContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#constInitVal}.
* @param ctx the parse tree
*/
void enterConstInitVal(SysYParser.ConstInitValContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#constInitVal}.
* @param ctx the parse tree
*/
void exitConstInitVal(SysYParser.ConstInitValContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#varDecl}.
* @param ctx the parse tree
*/
void enterVarDecl(SysYParser.VarDeclContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#varDecl}.
* @param ctx the parse tree
*/
void exitVarDecl(SysYParser.VarDeclContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#varDef}.
* @param ctx the parse tree
*/
void enterVarDef(SysYParser.VarDefContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#varDef}.
* @param ctx the parse tree
*/
void exitVarDef(SysYParser.VarDefContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#initVal}.
* @param ctx the parse tree
*/
void enterInitVal(SysYParser.InitValContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#initVal}.
* @param ctx the parse tree
*/
void exitInitVal(SysYParser.InitValContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#funcDef}.
* @param ctx the parse tree
*/
void enterFuncDef(SysYParser.FuncDefContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#funcDef}.
* @param ctx the parse tree
*/
void exitFuncDef(SysYParser.FuncDefContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#funcType}.
* @param ctx the parse tree
*/
void enterFuncType(SysYParser.FuncTypeContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#funcType}.
* @param ctx the parse tree
*/
void exitFuncType(SysYParser.FuncTypeContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#funcFParams}.
* @param ctx the parse tree
*/
void enterFuncFParams(SysYParser.FuncFParamsContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#funcFParams}.
* @param ctx the parse tree
*/
void exitFuncFParams(SysYParser.FuncFParamsContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#funcFParam}.
* @param ctx the parse tree
*/
void enterFuncFParam(SysYParser.FuncFParamContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#funcFParam}.
* @param ctx the parse tree
*/
void exitFuncFParam(SysYParser.FuncFParamContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#block}.
* @param ctx the parse tree
*/
void enterBlock(SysYParser.BlockContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#block}.
* @param ctx the parse tree
*/
void exitBlock(SysYParser.BlockContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#blockItem}.
* @param ctx the parse tree
*/
void enterBlockItem(SysYParser.BlockItemContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#blockItem}.
* @param ctx the parse tree
*/
void exitBlockItem(SysYParser.BlockItemContext ctx);
/**
* Enter a parse tree produced by the {@code assignStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void enterAssignStmt(SysYParser.AssignStmtContext ctx);
/**
* Exit a parse tree produced by the {@code assignStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void exitAssignStmt(SysYParser.AssignStmtContext ctx);
/**
* Enter a parse tree produced by the {@code expStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void enterExpStmt(SysYParser.ExpStmtContext ctx);
/**
* Exit a parse tree produced by the {@code expStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void exitExpStmt(SysYParser.ExpStmtContext ctx);
/**
* Enter a parse tree produced by the {@code blockStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void enterBlockStmt(SysYParser.BlockStmtContext ctx);
/**
* Exit a parse tree produced by the {@code blockStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void exitBlockStmt(SysYParser.BlockStmtContext ctx);
/**
* Enter a parse tree produced by the {@code ifStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void enterIfStmt(SysYParser.IfStmtContext ctx);
/**
* Exit a parse tree produced by the {@code ifStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void exitIfStmt(SysYParser.IfStmtContext ctx);
/**
* Enter a parse tree produced by the {@code whileStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void enterWhileStmt(SysYParser.WhileStmtContext ctx);
/**
* Exit a parse tree produced by the {@code whileStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void exitWhileStmt(SysYParser.WhileStmtContext ctx);
/**
* Enter a parse tree produced by the {@code breakStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void enterBreakStmt(SysYParser.BreakStmtContext ctx);
/**
* Exit a parse tree produced by the {@code breakStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void exitBreakStmt(SysYParser.BreakStmtContext ctx);
/**
* Enter a parse tree produced by the {@code continueStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void enterContinueStmt(SysYParser.ContinueStmtContext ctx);
/**
* Exit a parse tree produced by the {@code continueStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void exitContinueStmt(SysYParser.ContinueStmtContext ctx);
/**
* Enter a parse tree produced by the {@code returnStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void enterReturnStmt(SysYParser.ReturnStmtContext ctx);
/**
* Exit a parse tree produced by the {@code returnStmt}
* labeled alternative in {@link SysYParser#stmt}.
* @param ctx the parse tree
*/
void exitReturnStmt(SysYParser.ReturnStmtContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#exp}.
* @param ctx the parse tree
*/
void enterExp(SysYParser.ExpContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#exp}.
* @param ctx the parse tree
*/
void exitExp(SysYParser.ExpContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#cond}.
* @param ctx the parse tree
*/
void enterCond(SysYParser.CondContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#cond}.
* @param ctx the parse tree
*/
void exitCond(SysYParser.CondContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#lVal}.
* @param ctx the parse tree
*/
void enterLVal(SysYParser.LValContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#lVal}.
* @param ctx the parse tree
*/
void exitLVal(SysYParser.LValContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#primaryExp}.
* @param ctx the parse tree
*/
void enterPrimaryExp(SysYParser.PrimaryExpContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#primaryExp}.
* @param ctx the parse tree
*/
void exitPrimaryExp(SysYParser.PrimaryExpContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#number}.
* @param ctx the parse tree
*/
void enterNumber(SysYParser.NumberContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#number}.
* @param ctx the parse tree
*/
void exitNumber(SysYParser.NumberContext ctx);
/**
* Enter a parse tree produced by the {@code primaryUnaryExp}
* labeled alternative in {@link SysYParser#unaryExp}.
* @param ctx the parse tree
*/
void enterPrimaryUnaryExp(SysYParser.PrimaryUnaryExpContext ctx);
/**
* Exit a parse tree produced by the {@code primaryUnaryExp}
* labeled alternative in {@link SysYParser#unaryExp}.
* @param ctx the parse tree
*/
void exitPrimaryUnaryExp(SysYParser.PrimaryUnaryExpContext ctx);
/**
* Enter a parse tree produced by the {@code callUnaryExp}
* labeled alternative in {@link SysYParser#unaryExp}.
* @param ctx the parse tree
*/
void enterCallUnaryExp(SysYParser.CallUnaryExpContext ctx);
/**
* Exit a parse tree produced by the {@code callUnaryExp}
* labeled alternative in {@link SysYParser#unaryExp}.
* @param ctx the parse tree
*/
void exitCallUnaryExp(SysYParser.CallUnaryExpContext ctx);
/**
* Enter a parse tree produced by the {@code opUnaryExp}
* labeled alternative in {@link SysYParser#unaryExp}.
* @param ctx the parse tree
*/
void enterOpUnaryExp(SysYParser.OpUnaryExpContext ctx);
/**
* Exit a parse tree produced by the {@code opUnaryExp}
* labeled alternative in {@link SysYParser#unaryExp}.
* @param ctx the parse tree
*/
void exitOpUnaryExp(SysYParser.OpUnaryExpContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#unaryOp}.
* @param ctx the parse tree
*/
void enterUnaryOp(SysYParser.UnaryOpContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#unaryOp}.
* @param ctx the parse tree
*/
void exitUnaryOp(SysYParser.UnaryOpContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#funcRParams}.
* @param ctx the parse tree
*/
void enterFuncRParams(SysYParser.FuncRParamsContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#funcRParams}.
* @param ctx the parse tree
*/
void exitFuncRParams(SysYParser.FuncRParamsContext ctx);
/**
* Enter a parse tree produced by the {@code binaryMulExp}
* labeled alternative in {@link SysYParser#mulExp}.
* @param ctx the parse tree
*/
void enterBinaryMulExp(SysYParser.BinaryMulExpContext ctx);
/**
* Exit a parse tree produced by the {@code binaryMulExp}
* labeled alternative in {@link SysYParser#mulExp}.
* @param ctx the parse tree
*/
void exitBinaryMulExp(SysYParser.BinaryMulExpContext ctx);
/**
* Enter a parse tree produced by the {@code unaryMulExp}
* labeled alternative in {@link SysYParser#mulExp}.
* @param ctx the parse tree
*/
void enterUnaryMulExp(SysYParser.UnaryMulExpContext ctx);
/**
* Exit a parse tree produced by the {@code unaryMulExp}
* labeled alternative in {@link SysYParser#mulExp}.
* @param ctx the parse tree
*/
void exitUnaryMulExp(SysYParser.UnaryMulExpContext ctx);
/**
* Enter a parse tree produced by the {@code binaryAddExp}
* labeled alternative in {@link SysYParser#addExp}.
* @param ctx the parse tree
*/
void enterBinaryAddExp(SysYParser.BinaryAddExpContext ctx);
/**
* Exit a parse tree produced by the {@code binaryAddExp}
* labeled alternative in {@link SysYParser#addExp}.
* @param ctx the parse tree
*/
void exitBinaryAddExp(SysYParser.BinaryAddExpContext ctx);
/**
* Enter a parse tree produced by the {@code mulAddExp}
* labeled alternative in {@link SysYParser#addExp}.
* @param ctx the parse tree
*/
void enterMulAddExp(SysYParser.MulAddExpContext ctx);
/**
* Exit a parse tree produced by the {@code mulAddExp}
* labeled alternative in {@link SysYParser#addExp}.
* @param ctx the parse tree
*/
void exitMulAddExp(SysYParser.MulAddExpContext ctx);
/**
* Enter a parse tree produced by the {@code addRelExp}
* labeled alternative in {@link SysYParser#relExp}.
* @param ctx the parse tree
*/
void enterAddRelExp(SysYParser.AddRelExpContext ctx);
/**
* Exit a parse tree produced by the {@code addRelExp}
* labeled alternative in {@link SysYParser#relExp}.
* @param ctx the parse tree
*/
void exitAddRelExp(SysYParser.AddRelExpContext ctx);
/**
* Enter a parse tree produced by the {@code binaryRelExp}
* labeled alternative in {@link SysYParser#relExp}.
* @param ctx the parse tree
*/
void enterBinaryRelExp(SysYParser.BinaryRelExpContext ctx);
/**
* Exit a parse tree produced by the {@code binaryRelExp}
* labeled alternative in {@link SysYParser#relExp}.
* @param ctx the parse tree
*/
void exitBinaryRelExp(SysYParser.BinaryRelExpContext ctx);
/**
* Enter a parse tree produced by the {@code binaryEqExp}
* labeled alternative in {@link SysYParser#eqExp}.
* @param ctx the parse tree
*/
void enterBinaryEqExp(SysYParser.BinaryEqExpContext ctx);
/**
* Exit a parse tree produced by the {@code binaryEqExp}
* labeled alternative in {@link SysYParser#eqExp}.
* @param ctx the parse tree
*/
void exitBinaryEqExp(SysYParser.BinaryEqExpContext ctx);
/**
* Enter a parse tree produced by the {@code relEqExp}
* labeled alternative in {@link SysYParser#eqExp}.
* @param ctx the parse tree
*/
void enterRelEqExp(SysYParser.RelEqExpContext ctx);
/**
* Exit a parse tree produced by the {@code relEqExp}
* labeled alternative in {@link SysYParser#eqExp}.
* @param ctx the parse tree
*/
void exitRelEqExp(SysYParser.RelEqExpContext ctx);
/**
* Enter a parse tree produced by the {@code eqLAndExp}
* labeled alternative in {@link SysYParser#lAndExp}.
* @param ctx the parse tree
*/
void enterEqLAndExp(SysYParser.EqLAndExpContext ctx);
/**
* Exit a parse tree produced by the {@code eqLAndExp}
* labeled alternative in {@link SysYParser#lAndExp}.
* @param ctx the parse tree
*/
void exitEqLAndExp(SysYParser.EqLAndExpContext ctx);
/**
* Enter a parse tree produced by the {@code binaryLAndExp}
* labeled alternative in {@link SysYParser#lAndExp}.
* @param ctx the parse tree
*/
void enterBinaryLAndExp(SysYParser.BinaryLAndExpContext ctx);
/**
* Exit a parse tree produced by the {@code binaryLAndExp}
* labeled alternative in {@link SysYParser#lAndExp}.
* @param ctx the parse tree
*/
void exitBinaryLAndExp(SysYParser.BinaryLAndExpContext ctx);
/**
* Enter a parse tree produced by the {@code andLOrExp}
* labeled alternative in {@link SysYParser#lOrExp}.
* @param ctx the parse tree
*/
void enterAndLOrExp(SysYParser.AndLOrExpContext ctx);
/**
* Exit a parse tree produced by the {@code andLOrExp}
* labeled alternative in {@link SysYParser#lOrExp}.
* @param ctx the parse tree
*/
void exitAndLOrExp(SysYParser.AndLOrExpContext ctx);
/**
* Enter a parse tree produced by the {@code binaryLOrExp}
* labeled alternative in {@link SysYParser#lOrExp}.
* @param ctx the parse tree
*/
void enterBinaryLOrExp(SysYParser.BinaryLOrExpContext ctx);
/**
* Exit a parse tree produced by the {@code binaryLOrExp}
* labeled alternative in {@link SysYParser#lOrExp}.
* @param ctx the parse tree
*/
void exitBinaryLOrExp(SysYParser.BinaryLOrExpContext ctx);
/**
* Enter a parse tree produced by {@link SysYParser#constExp}.
* @param ctx the parse tree
*/
void enterConstExp(SysYParser.ConstExpContext ctx);
/**
* Exit a parse tree produced by {@link SysYParser#constExp}.
* @param ctx the parse tree
*/
void exitConstExp(SysYParser.ConstExpContext ctx);
}

File diff suppressed because it is too large Load Diff

@ -1,98 +1,178 @@
// SysY 子集语法:支持形如
// int main() { int a = 1; int b = 2; return a + b; }
// 的最小返回表达式编译。
// 后续需要自行添加
grammar SysY;
/*===-------------------------------------------===*/
/* Lexer rules */
/*===-------------------------------------------===*/
INT: 'int';
RETURN: 'return';
ASSIGN: '=';
ADD: '+';
LPAREN: '(';
RPAREN: ')';
LBRACE: '{';
RBRACE: '}';
SEMICOLON: ';';
ID: [a-zA-Z_][a-zA-Z_0-9]*;
ILITERAL: [0-9]+;
WS: [ \t\r\n] -> skip;
LINECOMMENT: '//' ~[\r\n]* -> skip;
BLOCKCOMMENT: '/*' .*? '*/' -> skip;
/*===-------------------------------------------===*/
/* Syntax rules */
/*===-------------------------------------------===*/
compUnit
: funcDef EOF
;
decl
: btype varDef SEMICOLON
;
btype
: INT
;
varDef
: lValue (ASSIGN initValue)?
;
initValue
: exp
;
funcDef
: funcType ID LPAREN RPAREN blockStmt
;
funcType
: INT
;
blockStmt
: LBRACE blockItem* RBRACE
;
blockItem
: decl
| stmt
;
stmt
: returnStmt
;
returnStmt
: RETURN exp SEMICOLON
;
exp
: LPAREN exp RPAREN # parenExp
| var # varExp
| number # numberExp
| exp ADD exp # additiveExp
;
var
: ID
;
lValue
: ID
;
number
: ILITERAL
;
grammar SysY;
////Grammer
module: compUnit EOF;
compUnit: (decl | funcDef)+;
decl: constDecl | varDecl;
constDecl: CONST bType constDef (COMMA constDef)* SEMI;
bType: INT | FLOAT;
constDef: Ident (LBRACK constExp RBRACK)* ASSIGN constInitVal;
constInitVal:
constExp
| LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE;
varDecl: bType varDef (COMMA varDef)* SEMI;
varDef:
Ident (LBRACK constExp RBRACK)*
| Ident (LBRACK constExp RBRACK)* ASSIGN initVal;
initVal: exp | LBRACE (initVal (COMMA initVal)*)? RBRACE;
funcDef: funcType Ident LPAREN funcFParams? RPAREN block;
funcType: VOID | INT | FLOAT;
funcFParams: funcFParam (COMMA funcFParam)*;
funcFParam:
bType Ident
| bType Ident LBRACK RBRACK (LBRACK exp RBRACK)*;
block: LBRACE blockItem* RBRACE;
blockItem: decl | stmt;
stmt:
lVal ASSIGN exp SEMI
| exp? SEMI
| block
| IF LPAREN cond RPAREN stmt (ELSE stmt)?
| WHILE LPAREN cond RPAREN stmt
| BREAK SEMI
| CONTINUE SEMI
| RETURN exp? SEMI;
exp: addExp;
cond: lOrExp;
lVal: Ident (LBRACK exp RBRACK)*;
primaryExp: LPAREN exp RPAREN | lVal | number;
number: IntConst | FloatConst;
unaryExp:
primaryExp
| Ident LPAREN funcRParams? RPAREN
| unaryOp unaryExp;
unaryOp: ADD | SUB | NOT;
funcRParams: exp (COMMA exp)*;
mulExp:
unaryExp
| mulExp op = (MUL | DIV | MOD) unaryExp;
addExp:
mulExp
| addExp op = (ADD | SUB) mulExp;
relExp:
addExp
| relExp op = (LT | GT | LE | GE) addExp;
eqExp:
relExp
| eqExp op = (EQ | NE) relExp;
lAndExp: eqExp | lAndExp AND eqExp ;
lOrExp: lAndExp | lOrExp OR lAndExp ;
constExp: addExp;
////Lexer
//keywords
CONST: 'const';
INT: 'int';
FLOAT: 'float';
VOID: 'void';
IF: 'if';
ELSE: 'else';
WHILE: 'while';
BREAK: 'break';
CONTINUE: 'continue';
RETURN: 'return';
//operators
ADD: '+';
SUB: '-';
MUL: '*';
DIV: '/';
MOD: '%';
ASSIGN: '=';
EQ: '==';
NE: '!=';
LT: '<';
LE: '<=';
GT: '>';
GE: '>=';
NOT: '!';
AND: '&&';
OR: '||';
//括号
LPAREN: '(';
RPAREN: ')';
LBRACK: '[';
RBRACK: ']';
LBRACE: '{';
RBRACE: '}';
COMMA: ',';
SEMI: ';';
//标识符
Ident: [a-zA-Z_] [a-zA-Z_0-9]*;
//数字常量片段
// 十进制数字
fragment Digit: [0-9];
// 非零十进制数字
fragment NonzeroDigit: [1-9];
// 八进制数字
fragment OctDigit: [0-7];
// 十六进制数字
fragment HexDigit: [0-9a-fA-F];
// 十进制整数:非零开头,后接若干十进制数字
fragment DecInteger: NonzeroDigit Digit*;
// 八进制整数:以 0 开头
fragment OctInteger: '0' OctDigit*;
// 十六进制整数:以 0x 或 0X 开头
fragment HexInteger: '0' [xX] HexDigit+;
// 十进制小数部分
fragment DecFraction: Digit+ '.' Digit* | '.' Digit+;
// 十进制指数部分
fragment DecExponent: [eE] [+\-]? Digit+;
// 十进制浮点数
fragment DecFloat:
DecFraction DecExponent?
| DecInteger DecExponent;
// 十六进制小数部分
fragment HexFraction: HexDigit* '.' HexDigit+ | HexDigit+ '.';
// 十六进制浮点数的二进制指数部分
fragment BinExponent: [pP] [+\-]? Digit+;
// 十六进制浮点数
fragment HexFloat:
'0' [xX] HexFraction BinExponent
| HexInteger BinExponent;
//整型常量
IntConst: DecInteger | OctInteger | HexInteger;
//浮点常量
FloatConst: DecFloat | HexFloat;
//空白符规则
WS: [ \t\r\n]+ -> skip;
// 单行注释
LINE_COMMENT: '//' ~[\r\n]* -> skip;
// 跨行注释
BLOCK_COMMENT: '/*' .*? '*/' -> skip;

@ -1,17 +1,34 @@
find_package(Java REQUIRED COMPONENTS Runtime)
set(SYSY_GRAMMAR "${PROJECT_SOURCE_DIR}/src/antlr4/SysY.g4")
set(SYSY_ANTLR_JAR "${PROJECT_SOURCE_DIR}/third_party/antlr-4.13.2-complete.jar")
set(SYSY_ANTLR_OUTPUTS
"${ANTLR4_GENERATED_DIR}/SysYLexer.cpp"
"${ANTLR4_GENERATED_DIR}/SysYLexer.h"
"${ANTLR4_GENERATED_DIR}/SysYParser.cpp"
"${ANTLR4_GENERATED_DIR}/SysYParser.h"
"${ANTLR4_GENERATED_DIR}/SysYVisitor.cpp"
"${ANTLR4_GENERATED_DIR}/SysYVisitor.h"
"${ANTLR4_GENERATED_DIR}/SysYBaseVisitor.cpp"
"${ANTLR4_GENERATED_DIR}/SysYBaseVisitor.h"
)
add_custom_command(
OUTPUT ${SYSY_ANTLR_OUTPUTS}
COMMAND ${CMAKE_COMMAND} -E make_directory "${ANTLR4_GENERATED_DIR}"
COMMAND ${Java_JAVA_EXECUTABLE} -jar "${SYSY_ANTLR_JAR}" -Dlanguage=Cpp -visitor -no-listener -o "${ANTLR4_GENERATED_DIR}" -Xexact-output-dir "${SYSY_GRAMMAR}"
DEPENDS "${SYSY_GRAMMAR}" "${SYSY_ANTLR_JAR}"
COMMENT "Generating SysY parser with ANTLR4"
VERBATIM
)
add_library(frontend STATIC
AntlrDriver.cpp
SyntaxTreePrinter.cpp
${SYSY_ANTLR_OUTPUTS}
)
target_link_libraries(frontend PUBLIC
build_options
${ANTLR4_RUNTIME_TARGET}
)
# Lexer/Parser
file(GLOB_RECURSE ANTLR4_GENERATED_SOURCES CONFIGURE_DEPENDS
"${ANTLR4_GENERATED_DIR}/*.cpp"
)
if(ANTLR4_GENERATED_SOURCES)
target_sources(frontend PRIVATE ${ANTLR4_GENERATED_SOURCES})
endif()
)

@ -1,45 +1,62 @@
// IR 基本块:
// - 保存指令序列
// - 为后续 CFG 分析预留前驱/后继接口
//
// 当前仍是最小实现:
// - BasicBlock 已纳入 Value 体系,但类型先用 void 占位;
// - 指令追加与 terminator 约束主要在头文件中的 Append 模板里处理;
// - 前驱/后继容器已经预留,但当前项目里还没有分支指令与自动维护逻辑。
#include "ir/IR.h"
#include <utility>
#include <algorithm>
#include <stdexcept>
namespace ir {
// 当前 BasicBlock 还没有专门的 label type因此先用 void 作为占位类型。
BasicBlock::BasicBlock(std::string name)
: Value(Type::GetVoidType(), std::move(name)) {}
Function* BasicBlock::GetParent() const { return parent_; }
void BasicBlock::SetParent(Function* parent) { parent_ = parent; }
BasicBlock::BasicBlock(const std::string& name)
: Value(Type::GetLabelType(), name) {}
BasicBlock::BasicBlock(Function* parent, const std::string& name)
: Value(Type::GetLabelType(), name), parent_(parent) {}
bool BasicBlock::HasTerminator() const {
return !instructions_.empty() && instructions_.back()->IsTerminator();
}
// 按插入顺序返回块内指令序列。
const std::vector<std::unique_ptr<Instruction>>& BasicBlock::GetInstructions()
const {
return instructions_;
void BasicBlock::EraseInstruction(Instruction* inst) {
if (!inst) {
return;
}
if (inst->IsTerminator()) {
throw std::runtime_error("cannot erase terminator instruction");
}
auto it = std::find_if(instructions_.begin(), instructions_.end(),
[&](const std::unique_ptr<Instruction>& current) {
return current.get() == inst;
});
if (it == instructions_.end()) {
return;
}
(*it)->ClearAllOperands();
instructions_.erase(it);
}
void BasicBlock::AddPredecessor(BasicBlock* pred) {
if (pred &&
std::find(predecessors_.begin(), predecessors_.end(), pred) ==
predecessors_.end()) {
predecessors_.push_back(pred);
}
}
void BasicBlock::AddSuccessor(BasicBlock* succ) {
if (succ && std::find(successors_.begin(), successors_.end(), succ) ==
successors_.end()) {
successors_.push_back(succ);
}
}
// 前驱/后继接口先保留给后续 CFG 扩展使用。
// 当前最小 IR 中还没有 branch 指令,因此这些列表通常为空。
const std::vector<BasicBlock*>& BasicBlock::GetPredecessors() const {
return predecessors_;
void BasicBlock::RemovePredecessor(BasicBlock* pred) {
predecessors_.erase(
std::remove(predecessors_.begin(), predecessors_.end(), pred),
predecessors_.end());
}
const std::vector<BasicBlock*>& BasicBlock::GetSuccessors() const {
return successors_;
void BasicBlock::RemoveSuccessor(BasicBlock* succ) {
successors_.erase(std::remove(successors_.begin(), successors_.end(), succ),
successors_.end());
}
} // namespace ir
} // namespace ir

@ -1,5 +1,4 @@
// 管理基础类型、整型常量池和临时名生成。
#include "ir/IR.h"
#include "ir/IR.h"
#include <sstream>
@ -9,16 +8,34 @@ Context::~Context() = default;
ConstantInt* Context::GetConstInt(int v) {
auto it = const_ints_.find(v);
if (it != const_ints_.end()) return it->second.get();
auto inserted =
const_ints_.emplace(v, std::make_unique<ConstantInt>(Type::GetInt32Type(), v)).first;
return inserted->second.get();
if (it != const_ints_.end()) {
return it->second.get();
}
auto inserted = const_ints_.emplace(
v, std::make_unique<ConstantInt>(Type::GetInt32Type(), v));
return inserted.first->second.get();
}
ConstantI1* Context::GetConstBool(bool v) {
auto it = const_bools_.find(v);
if (it != const_bools_.end()) {
return it->second.get();
}
auto inserted = const_bools_.emplace(
v, std::make_unique<ConstantI1>(Type::GetInt1Type(), v));
return inserted.first->second.get();
}
std::string Context::NextTemp() {
std::ostringstream oss;
oss << "%" << ++temp_index_;
oss << "%t" << ++temp_index_;
return oss.str();
}
std::string Context::NextBlockName(const std::string& prefix) {
std::ostringstream oss;
oss << prefix << "." << ++block_index_;
return oss.str();
}
} // namespace ir
} // namespace ir

@ -1,17 +1,44 @@
// IR Function
// - 保存参数列表、基本块列表
// - 记录函数属性/元信息(按需要扩展)
#include "ir/IR.h"
namespace ir {
Function::Function(std::string name, std::shared_ptr<Type> ret_type)
: Value(std::move(ret_type), std::move(name)) {
entry_ = CreateBlock("entry");
Argument::Argument(std::shared_ptr<Type> type, std::string name, size_t index)
: Value(std::move(type), std::move(name)), index_(index) {}
Function::Function(std::string name, std::shared_ptr<Type> ret_type,
const std::vector<std::shared_ptr<Type>>& param_types,
const std::vector<std::string>& param_names,
bool is_external)
: Value(Type::GetPointerType(), std::move(name)),
return_type_(std::move(ret_type)),
param_types_(param_types),
is_external_(is_external) {
for (size_t i = 0; i < param_types_.size(); ++i) {
std::string arg_name = i < param_names.size() && !param_names[i].empty()
? param_names[i]
: "%arg" + std::to_string(i);
arguments_.push_back(
std::make_unique<Argument>(param_types_[i], std::move(arg_name), i));
}
}
Argument* Function::GetArgument(size_t index) const {
return index < arguments_.size() ? arguments_[index].get() : nullptr;
}
BasicBlock* Function::EnsureEntryBlock() {
if (!entry_) {
entry_ = CreateBlock("entry");
}
return entry_;
}
BasicBlock* Function::CreateBlock(const std::string& name) {
auto block = std::make_unique<BasicBlock>(name);
auto block = std::make_unique<BasicBlock>(this, name);
return AddBlock(std::move(block));
}
BasicBlock* Function::AddBlock(std::unique_ptr<BasicBlock> block) {
auto* ptr = block.get();
ptr->SetParent(this);
blocks_.push_back(std::move(block));
@ -21,12 +48,4 @@ BasicBlock* Function::CreateBlock(const std::string& name) {
return ptr;
}
BasicBlock* Function::GetEntry() { return entry_; }
const BasicBlock* Function::GetEntry() const { return entry_; }
const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const {
return blocks_;
}
} // namespace ir
} // namespace ir

@ -1,11 +1,12 @@
// GlobalValue 占位实现:
// - 具体的全局初始化器、打印和链接语义需要自行补全
#include "ir/IR.h"
namespace ir {
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)) {}
GlobalValue::GlobalValue(std::shared_ptr<Type> object_type,
const std::string& name, bool is_const, Value* init)
: User(Type::GetPointerType(object_type), name),
object_type_(std::move(object_type)),
is_const_(is_const),
init_(init) {}
} // namespace ir
} // namespace ir

@ -1,89 +1,213 @@
// IR 构建工具:
// - 管理插入点(当前基本块/位置)
// - 提供创建各类指令的便捷接口,降低 IRGen 复杂度
#include "ir/IR.h"
#include <stdexcept>
#include "utils/Log.h"
namespace ir {
IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb)
: ctx_(ctx), insert_block_(bb) {}
namespace {
BasicBlock* RequireInsertBlock(BasicBlock* block) {
if (!block) {
throw std::runtime_error("IRBuilder has no insert block");
}
return block;
}
bool IsFloatBinaryOp(Opcode op) {
return op == Opcode::FAdd || op == Opcode::FSub || op == Opcode::FMul ||
op == Opcode::FDiv || op == Opcode::FRem || op == Opcode::FCmpEQ ||
op == Opcode::FCmpNE || op == Opcode::FCmpLT || op == Opcode::FCmpGT ||
op == Opcode::FCmpLE || op == Opcode::FCmpGE;
}
bool IsCompareOp(Opcode op) {
return (op >= Opcode::ICmpEQ && op <= Opcode::ICmpGE) ||
(op >= Opcode::FCmpEQ && op <= Opcode::FCmpGE);
}
std::shared_ptr<Type> ResultTypeForBinary(Opcode op, Value* lhs) {
if (IsCompareOp(op)) {
return Type::GetInt1Type();
}
if (IsFloatBinaryOp(op)) {
return Type::GetFloatType();
}
return lhs->GetType();
}
} // namespace
IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb) : ctx_(ctx), insert_block_(bb) {}
void IRBuilder::SetInsertPoint(BasicBlock* bb) { insert_block_ = bb; }
BasicBlock* IRBuilder::GetInsertBlock() const { return insert_block_; }
ConstantInt* IRBuilder::CreateConstInt(int v) { return ctx_.GetConstInt(v); }
ConstantFloat* IRBuilder::CreateConstFloat(float v) {
return new ConstantFloat(Type::GetFloatType(), v);
}
ConstantI1* IRBuilder::CreateConstBool(bool v) { return ctx_.GetConstBool(v); }
ConstantInt* IRBuilder::CreateConstInt(int v) {
// 常量不需要挂在基本块里,由 Context 负责去重与生命周期。
return ctx_.GetConstInt(v);
ConstantArrayValue* IRBuilder::CreateConstArray(std::shared_ptr<Type> array_type,
const std::vector<Value*>& elements,
const std::vector<size_t>& dims,
const std::string& name) {
return new ConstantArrayValue(std::move(array_type), elements, dims, name);
}
BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!lhs) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateBinary 缺少 lhs"));
}
if (!rhs) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateBinary 缺少 rhs"));
}
return insert_block_->Append<BinaryInst>(op, lhs->GetType(), lhs, rhs, name);
auto* block = RequireInsertBlock(insert_block_);
return block->Append<BinaryInst>(op, ResultTypeForBinary(op, lhs), lhs, rhs,
nullptr, name);
}
BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs,
const std::string& name) {
BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Add, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateSub(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Sub, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateMul(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Mul, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateDiv(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Div, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateRem(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Rem, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateAnd(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::And, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateOr(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Or, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateXor(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Xor, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateShl(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::Shl, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateAShr(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::AShr, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateLShr(Value* lhs, Value* rhs, const std::string& name) {
return CreateBinary(Opcode::LShr, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateICmp(Opcode op, Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(op, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFCmp(Opcode op, Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(op, lhs, rhs, name);
}
AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<AllocaInst>(Type::GetPtrInt32Type(), name);
UnaryInst* IRBuilder::CreateNeg(Value* operand, const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnaryInst>(Opcode::Neg, operand->GetType(), operand, nullptr,
name);
}
UnaryInst* IRBuilder::CreateNot(Value* operand, const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnaryInst>(Opcode::Not, operand->GetType(), operand, nullptr,
name);
}
UnaryInst* IRBuilder::CreateFNeg(Value* operand, const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnaryInst>(Opcode::FNeg, operand->GetType(), operand,
nullptr, name);
}
UnaryInst* IRBuilder::CreateFtoI(Value* operand, const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnaryInst>(Opcode::FtoI, Type::GetInt32Type(), operand,
nullptr, name);
}
UnaryInst* IRBuilder::CreateIToF(Value* operand, const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnaryInst>(Opcode::IToF, Type::GetFloatType(), operand,
nullptr, name);
}
LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!ptr) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr"));
}
return insert_block_->Append<LoadInst>(Type::GetInt32Type(), ptr, name);
AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr<Type> allocated_type,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<AllocaInst>(std::move(allocated_type), nullptr, name);
}
LoadInst* IRBuilder::CreateLoad(Value* ptr, std::shared_ptr<Type> value_type,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<LoadInst>(std::move(value_type), ptr, nullptr, name);
}
StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!val) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateStore 缺少 val"));
}
if (!ptr) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateStore 缺少 ptr"));
}
return insert_block_->Append<StoreInst>(Type::GetVoidType(), val, ptr);
auto* block = RequireInsertBlock(insert_block_);
return block->Append<StoreInst>(val, ptr, nullptr);
}
ReturnInst* IRBuilder::CreateRet(Value* v) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!v) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateRet 缺少返回值"));
}
return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v);
UncondBrInst* IRBuilder::CreateBr(BasicBlock* dest) {
auto* block = RequireInsertBlock(insert_block_);
auto* inst = block->Append<UncondBrInst>(dest, nullptr);
block->AddSuccessor(dest);
dest->AddPredecessor(block);
return inst;
}
CondBrInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* then_bb,
BasicBlock* else_bb) {
auto* block = RequireInsertBlock(insert_block_);
auto* inst = block->Append<CondBrInst>(cond, then_bb, else_bb, nullptr);
block->AddSuccessor(then_bb);
block->AddSuccessor(else_bb);
then_bb->AddPredecessor(block);
else_bb->AddPredecessor(block);
return inst;
}
ReturnInst* IRBuilder::CreateRet(Value* val) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<ReturnInst>(val, nullptr);
}
UnreachableInst* IRBuilder::CreateUnreachable() {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<UnreachableInst>(nullptr);
}
CallInst* IRBuilder::CreateCall(Function* callee, const std::vector<Value*>& args,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
std::string real_name = callee->GetReturnType()->IsVoid() ? std::string() : name;
return block->Append<CallInst>(callee, args, nullptr, real_name);
}
GetElementPtrInst* IRBuilder::CreateGEP(Value* ptr,
std::shared_ptr<Type> source_type,
const std::vector<Value*>& indices,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<GetElementPtrInst>(std::move(source_type), ptr, indices,
nullptr, name);
}
PhiInst* IRBuilder::CreatePhi(std::shared_ptr<Type> type,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<PhiInst>(std::move(type), nullptr, name);
}
ZextInst* IRBuilder::CreateZext(Value* val, std::shared_ptr<Type> target_type,
const std::string& name) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<ZextInst>(val, std::move(target_type), nullptr, name);
}
MemsetInst* IRBuilder::CreateMemset(Value* dst, Value* val, Value* len,
Value* is_volatile) {
auto* block = RequireInsertBlock(insert_block_);
return block->Append<MemsetInst>(dst, val, len, is_volatile, nullptr);
}
} // namespace ir
} // namespace ir

@ -1,107 +1,484 @@
// IR 文本输出:
// - 将 IR 打印为 .ll 风格的文本
// - 支撑调试与测试对比diff
#include "ir/IR.h"
#include "ir/IR.h"
#include <cstdint>
#include <cstring>
#include <iomanip>
#include <ostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include "utils/Log.h"
#include <vector>
namespace ir {
namespace {
std::string TypeToString(const Type& ty) {
std::ostringstream oss;
ty.Print(oss);
return oss.str();
}
std::string FloatToString(float value) {
double promoted = static_cast<double>(value);
std::uint64_t bits = 0;
std::memcpy(&bits, &promoted, sizeof(bits));
std::ostringstream oss;
oss << "0x" << std::uppercase << std::hex << std::setw(16)
<< std::setfill('0') << bits;
return oss.str();
}
std::string ValueRef(const Value* value) {
if (!value) {
return "<null>";
}
if (auto* ci = dyncast<ConstantInt>(value)) {
return std::to_string(ci->GetValue());
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
return FloatToString(cf->GetValue());
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return cb->GetValue() ? "1" : "0";
}
if (isa<Function>(value) || isa<GlobalValue>(value)) {
return "@" + value->GetName();
}
return value->GetName();
}
std::string BlockRef(const BasicBlock* block) {
if (!block) {
return "%<null>";
}
return "%" + block->GetName();
}
bool IsZeroScalarConstant(const Value* value) {
if (!value) {
return true;
}
if (auto* ci = dyncast<ConstantInt>(value)) {
return ci->GetValue() == 0;
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
return cf->GetValue() == 0.0f;
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return !cb->GetValue();
}
return false;
}
static const char* TypeToString(const Type& ty) {
switch (ty.GetKind()) {
case Type::Kind::Void:
return "void";
case Type::Kind::Int32:
return "i32";
case Type::Kind::PtrInt32:
return "i32*";
size_t CountScalarElements(const Type& type) {
if (!type.IsArray()) {
return 1;
}
return type.GetNumElements() * CountScalarElements(*type.GetElementType());
}
bool IsZeroArrayRange(const std::vector<Value*>& elements, const Type& type,
size_t offset) {
const auto count = CountScalarElements(type);
for (size_t i = 0; i < count; ++i) {
if (offset + i < elements.size() &&
!IsZeroScalarConstant(elements[offset + i])) {
return false;
}
}
throw std::runtime_error(FormatError("ir", "未知类型"));
return true;
}
static const char* OpcodeToString(Opcode op) {
switch (op) {
void PrintConstantForType(std::ostream& os, const Type& type, Value* value);
void PrintArrayConstant(std::ostream& os, const Type& type,
const std::vector<Value*>& elements, size_t& offset) {
if (IsZeroArrayRange(elements, type, offset)) {
os << "zeroinitializer";
offset += CountScalarElements(type);
return;
}
const auto elem_type = type.GetElementType();
os << "[";
for (size_t i = 0; i < type.GetNumElements(); ++i) {
if (i > 0) {
os << ", ";
}
os << TypeToString(*elem_type) << " ";
if (elem_type->IsArray()) {
PrintArrayConstant(os, *elem_type, elements, offset);
} else {
Value* elem = offset < elements.size() ? elements[offset] : nullptr;
PrintConstantForType(os, *elem_type, elem);
++offset;
}
}
os << "]";
}
void PrintConstantForType(std::ostream& os, const Type& type, Value* value) {
if (type.IsArray()) {
auto* array_value = dyncast<ConstantArrayValue>(value);
size_t offset = 0;
if (array_value) {
PrintArrayConstant(os, type, array_value->GetElements(), offset);
} else {
os << "zeroinitializer";
}
return;
}
if (!value) {
if (type.IsFloat()) {
os << FloatToString(0.0f);
} else {
os << "0";
}
return;
}
if (auto* ci = dyncast<ConstantInt>(value)) {
os << ci->GetValue();
return;
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
os << FloatToString(cf->GetValue());
return;
}
if (auto* cb = dyncast<ConstantI1>(value)) {
os << (cb->GetValue() ? "1" : "0");
return;
}
throw std::runtime_error("global initializer must be constant");
}
const char* BinaryOpcodeMnemonic(Opcode opcode) {
switch (opcode) {
case Opcode::Add:
return "add";
case Opcode::Sub:
return "sub";
case Opcode::Mul:
return "mul";
case Opcode::Alloca:
return "alloca";
case Opcode::Load:
return "load";
case Opcode::Store:
return "store";
case Opcode::Ret:
return "ret";
}
return "?";
case Opcode::Div:
return "sdiv";
case Opcode::Rem:
return "srem";
case Opcode::FAdd:
return "fadd";
case Opcode::FSub:
return "fsub";
case Opcode::FMul:
return "fmul";
case Opcode::FDiv:
return "fdiv";
case Opcode::FRem:
return "frem";
case Opcode::And:
return "and";
case Opcode::Or:
return "or";
case Opcode::Xor:
return "xor";
case Opcode::Shl:
return "shl";
case Opcode::AShr:
return "ashr";
case Opcode::LShr:
return "lshr";
case Opcode::ICmpEQ:
return "icmp eq";
case Opcode::ICmpNE:
return "icmp ne";
case Opcode::ICmpLT:
return "icmp slt";
case Opcode::ICmpGT:
return "icmp sgt";
case Opcode::ICmpLE:
return "icmp sle";
case Opcode::ICmpGE:
return "icmp sge";
case Opcode::FCmpEQ:
return "fcmp oeq";
case Opcode::FCmpNE:
return "fcmp one";
case Opcode::FCmpLT:
return "fcmp olt";
case Opcode::FCmpGT:
return "fcmp ogt";
case Opcode::FCmpLE:
return "fcmp ole";
case Opcode::FCmpGE:
return "fcmp oge";
default:
throw std::runtime_error("unsupported binary opcode");
}
}
static std::string ValueToString(const Value* v) {
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
return std::to_string(ci->GetValue());
bool NeedsMemsetDeclaration(const Module& module) {
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsExternal()) {
continue;
}
for (const auto& bb : func->GetBlocks()) {
for (const auto& inst : bb->GetInstructions()) {
if (inst->GetOpcode() == Opcode::Memset) {
return true;
}
}
}
}
return false;
}
void PrintInstruction(const Instruction& inst, std::ostream& os) {
switch (inst.GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE: {
auto& bin = static_cast<const BinaryInst&>(inst);
os << " " << bin.GetName() << " = " << BinaryOpcodeMnemonic(bin.GetOpcode())
<< " " << TypeToString(*bin.GetLhs()->GetType()) << " "
<< ValueRef(bin.GetLhs()) << ", " << ValueRef(bin.GetRhs()) << "\n";
return;
}
case Opcode::Neg: {
auto& un = static_cast<const UnaryInst&>(inst);
os << " " << un.GetName() << " = sub " << TypeToString(*un.GetType())
<< " 0, " << ValueRef(un.GetOprd()) << "\n";
return;
}
case Opcode::Not: {
auto& un = static_cast<const UnaryInst&>(inst);
os << " " << un.GetName() << " = xor " << TypeToString(*un.GetType())
<< " " << ValueRef(un.GetOprd()) << ", 1\n";
return;
}
case Opcode::FNeg: {
auto& un = static_cast<const UnaryInst&>(inst);
os << " " << un.GetName() << " = fneg " << TypeToString(*un.GetType())
<< " " << ValueRef(un.GetOprd()) << "\n";
return;
}
case Opcode::FtoI: {
auto& un = static_cast<const UnaryInst&>(inst);
os << " " << un.GetName() << " = fptosi "
<< TypeToString(*un.GetOprd()->GetType()) << " " << ValueRef(un.GetOprd())
<< " to " << TypeToString(*un.GetType()) << "\n";
return;
}
case Opcode::IToF: {
auto& un = static_cast<const UnaryInst&>(inst);
os << " " << un.GetName() << " = sitofp "
<< TypeToString(*un.GetOprd()->GetType()) << " " << ValueRef(un.GetOprd())
<< " to " << TypeToString(*un.GetType()) << "\n";
return;
}
case Opcode::Alloca: {
auto& alloca_inst = static_cast<const AllocaInst&>(inst);
os << " " << alloca_inst.GetName() << " = alloca "
<< TypeToString(*alloca_inst.GetAllocatedType()) << "\n";
return;
}
case Opcode::Load: {
auto& load = static_cast<const LoadInst&>(inst);
os << " " << load.GetName() << " = load "
<< TypeToString(*load.GetType()) << ", ptr " << ValueRef(load.GetPtr())
<< "\n";
return;
}
case Opcode::Store: {
auto& store = static_cast<const StoreInst&>(inst);
os << " store " << TypeToString(*store.GetValue()->GetType()) << " "
<< ValueRef(store.GetValue()) << ", ptr " << ValueRef(store.GetPtr())
<< "\n";
return;
}
case Opcode::Br: {
auto& br = static_cast<const UncondBrInst&>(inst);
os << " br label " << BlockRef(br.GetDest()) << "\n";
return;
}
case Opcode::CondBr: {
auto& br = static_cast<const CondBrInst&>(inst);
os << " br i1 " << ValueRef(br.GetCondition()) << ", label "
<< BlockRef(br.GetThenBlock()) << ", label "
<< BlockRef(br.GetElseBlock()) << "\n";
return;
}
case Opcode::Return: {
auto& ret = static_cast<const ReturnInst&>(inst);
if (!ret.HasReturnValue()) {
os << " ret void\n";
} else {
os << " ret " << TypeToString(*ret.GetReturnValue()->GetType()) << " "
<< ValueRef(ret.GetReturnValue()) << "\n";
}
return;
}
case Opcode::Unreachable:
os << " unreachable\n";
return;
case Opcode::Call: {
auto& call = static_cast<const CallInst&>(inst);
if (!call.GetType()->IsVoid()) {
os << " " << call.GetName() << " = ";
} else {
os << " ";
}
os << "call " << TypeToString(*call.GetCallee()->GetReturnType()) << " @"
<< call.GetCallee()->GetName() << "(";
const auto args = call.GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) {
os << ", ";
}
os << TypeToString(*args[i]->GetType()) << " " << ValueRef(args[i]);
}
os << ")\n";
return;
}
case Opcode::GetElementPtr: {
auto& gep = static_cast<const GetElementPtrInst&>(inst);
os << " " << gep.GetName() << " = getelementptr "
<< TypeToString(*gep.GetSourceType()) << ", ptr "
<< ValueRef(gep.GetPointer());
for (size_t i = 0; i < gep.GetNumIndices(); ++i) {
auto* index = gep.GetIndex(i);
os << ", " << TypeToString(*index->GetType()) << " " << ValueRef(index);
}
os << "\n";
return;
}
case Opcode::Phi: {
auto& phi = static_cast<const PhiInst&>(inst);
os << " " << phi.GetName() << " = phi " << TypeToString(*phi.GetType())
<< " ";
for (int i = 0; i < phi.GetNumIncomings(); ++i) {
if (i > 0) {
os << ", ";
}
os << "[ " << ValueRef(phi.GetIncomingValue(i)) << ", "
<< BlockRef(phi.GetIncomingBlock(i)) << " ]";
}
os << "\n";
return;
}
case Opcode::Zext: {
auto& zext = static_cast<const ZextInst&>(inst);
os << " " << zext.GetName() << " = zext "
<< TypeToString(*zext.GetValue()->GetType()) << " "
<< ValueRef(zext.GetValue()) << " to " << TypeToString(*zext.GetType())
<< "\n";
return;
}
case Opcode::Memset: {
auto& memset = static_cast<const MemsetInst&>(inst);
os << " call void @llvm.memset.p0.i32(ptr " << ValueRef(memset.GetDest())
<< ", i8 " << ValueRef(memset.GetValue()) << ", i32 "
<< ValueRef(memset.GetLength()) << ", i1 "
<< ValueRef(memset.GetIsVolatile()) << ")\n";
return;
}
}
return v ? v->GetName() : "<null>";
throw std::runtime_error("unsupported instruction in printer");
}
} // namespace
void IRPrinter::Print(const Module& module, std::ostream& os) {
if (NeedsMemsetDeclaration(module)) {
os << "declare void @llvm.memset.p0.i32(ptr, i8, i32, i1)\n\n";
}
for (const auto& global : module.GetGlobalValues()) {
os << "@" << global->GetName() << " = "
<< (global->IsConstant() ? "constant " : "global ")
<< TypeToString(*global->GetObjectType()) << " ";
PrintConstantForType(os, *global->GetObjectType(), global->GetInitializer());
os << "\n";
}
if (!module.GetGlobalValues().empty()) {
os << "\n";
}
for (const auto& func : module.GetFunctions()) {
os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName()
<< "() {\n";
for (const auto& bb : func->GetBlocks()) {
if (!bb) {
continue;
if (!func->IsExternal()) {
continue;
}
os << "declare " << TypeToString(*func->GetReturnType()) << " @"
<< func->GetName() << "(";
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) {
os << ", ";
}
os << TypeToString(*args[i]->GetType());
}
os << ")\n";
}
bool printed_decl = false;
for (const auto& func : module.GetFunctions()) {
if (func->IsExternal()) {
printed_decl = true;
}
}
if (printed_decl) {
os << "\n";
}
for (const auto& func : module.GetFunctions()) {
if (func->IsExternal()) {
continue;
}
os << "define " << TypeToString(*func->GetReturnType()) << " @"
<< func->GetName() << "(";
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) {
os << ", ";
}
os << TypeToString(*args[i]->GetType()) << " " << args[i]->GetName();
}
os << ") {\n";
for (const auto& bb : func->GetBlocks()) {
os << bb->GetName() << ":\n";
for (const auto& instPtr : bb->GetInstructions()) {
const auto* inst = instPtr.get();
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul: {
auto* bin = static_cast<const BinaryInst*>(inst);
os << " " << bin->GetName() << " = "
<< OpcodeToString(bin->GetOpcode()) << " "
<< TypeToString(*bin->GetLhs()->GetType()) << " "
<< ValueToString(bin->GetLhs()) << ", "
<< ValueToString(bin->GetRhs()) << "\n";
break;
}
case Opcode::Alloca: {
auto* alloca = static_cast<const AllocaInst*>(inst);
os << " " << alloca->GetName() << " = alloca i32\n";
break;
}
case Opcode::Load: {
auto* load = static_cast<const LoadInst*>(inst);
os << " " << load->GetName() << " = load i32, i32* "
<< ValueToString(load->GetPtr()) << "\n";
break;
}
case Opcode::Store: {
auto* store = static_cast<const StoreInst*>(inst);
os << " store i32 " << ValueToString(store->GetValue())
<< ", i32* " << ValueToString(store->GetPtr()) << "\n";
break;
}
case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst);
os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " "
<< ValueToString(ret->GetValue()) << "\n";
break;
}
}
for (const auto& inst : bb->GetInstructions()) {
PrintInstruction(*inst, os);
}
}
os << "}\n";
os << "}\n\n";
}
}
} // namespace ir
} // namespace ir

@ -1,151 +1,263 @@
// IR 指令体系:
// - 二元运算/比较、load/store、call、br/condbr、ret、phi、alloca 等
// - 指令操作数与结果类型管理,支持打印与优化
#include "ir/IR.h"
#include "ir/IR.h"
#include <stdexcept>
#include "utils/Log.h"
namespace ir {
User::User(std::shared_ptr<Type> ty, std::string name)
: Value(std::move(ty), std::move(name)) {}
size_t User::GetNumOperands() const { return operands_.size(); }
Value* User::GetOperand(size_t index) const {
if (index >= operands_.size()) {
throw std::out_of_range("User operand index out of range");
throw std::out_of_range("operand index out of range");
}
return operands_[index];
return operands_[index].GetValue();
}
void User::SetOperand(size_t index, Value* value) {
if (index >= operands_.size()) {
throw std::out_of_range("User operand index out of range");
}
if (!value) {
throw std::runtime_error(FormatError("ir", "User operand 不能为空"));
throw std::out_of_range("operand index out of range");
}
auto* old = operands_[index];
if (old == value) {
auto* old_value = operands_[index].GetValue();
if (old_value == value) {
return;
}
if (old) {
old->RemoveUse(this, index);
if (old_value) {
old_value->RemoveUse(this, index);
}
operands_[index].SetValue(value);
if (value) {
value->AddUse(this, index);
}
operands_[index] = value;
value->AddUse(this, index);
}
void User::AddOperand(Value* value) {
if (!value) {
throw std::runtime_error(FormatError("ir", "User operand 不能为空"));
throw std::runtime_error("operand cannot be null");
}
size_t operand_index = operands_.size();
operands_.push_back(value);
value->AddUse(this, operand_index);
operands_.emplace_back(value, this, operands_.size());
value->AddUse(this, operands_.size() - 1);
}
Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)), opcode_(op) {}
Opcode Instruction::GetOpcode() const { return opcode_; }
bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret; }
BasicBlock* Instruction::GetParent() const { return parent_; }
void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; }
BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
Value* rhs, std::string name)
: Instruction(op, std::move(ty), std::move(name)) {
if (op != Opcode::Add) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add"));
void User::AddOperands(const std::vector<Value*>& values) {
for (auto* value : values) {
AddOperand(value);
}
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数"));
}
void User::RemoveOperand(size_t index) {
if (index >= operands_.size()) {
throw std::out_of_range("operand index out of range");
}
if (!type_ || !lhs->GetType() || !rhs->GetType()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息"));
if (auto* value = operands_[index].GetValue()) {
value->RemoveUse(this, index);
}
if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind() ||
type_->GetKind() != lhs->GetType()->GetKind()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配"));
operands_.erase(operands_.begin() + static_cast<long long>(index));
for (size_t i = index; i < operands_.size(); ++i) {
if (auto* value = operands_[i].GetValue()) {
value->RemoveUse(this, i + 1);
value->AddUse(this, i);
}
operands_[i].SetOperandIndex(i);
}
if (!type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32"));
}
void User::ClearAllOperands() {
for (size_t i = 0; i < operands_.size(); ++i) {
if (auto* value = operands_[i].GetValue()) {
value->RemoveUse(this, i);
}
}
AddOperand(lhs);
AddOperand(rhs);
operands_.clear();
}
Value* BinaryInst::GetLhs() const { return GetOperand(0); }
Instruction::Instruction(Opcode opcode, std::shared_ptr<Type> ty,
BasicBlock* parent, const std::string& name)
: User(std::move(ty), name), opcode_(opcode), parent_(parent) {}
Value* BinaryInst::GetRhs() const { return GetOperand(1); }
bool Instruction::IsTerminator() const {
return opcode_ == Opcode::Br || opcode_ == Opcode::CondBr ||
opcode_ == Opcode::Return || opcode_ == Opcode::Unreachable;
}
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
: Instruction(Opcode::Ret, std::move(void_ty), "") {
if (!val) {
throw std::runtime_error(FormatError("ir", "ReturnInst 缺少返回值"));
static bool IsBinaryOpcode(Opcode opcode) {
switch (opcode) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
return true;
default:
return false;
}
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void"));
}
AddOperand(val);
}
Value* ReturnInst::GetValue() const { return GetOperand(0); }
bool BinaryInst::classof(const Value* value) {
return value && Instruction::classof(value) &&
IsBinaryOpcode(static_cast<const Instruction*>(value)->GetOpcode());
}
AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name)
: Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) {
if (!type_ || !type_->IsPtrInt32()) {
throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*"));
}
BinaryInst::BinaryInst(Opcode opcode, std::shared_ptr<Type> ty, Value* lhs,
Value* rhs, BasicBlock* parent,
const std::string& name)
: Instruction(opcode, std::move(ty), parent, name) {
AddOperand(lhs);
AddOperand(rhs);
}
LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name)
: Instruction(Opcode::Load, std::move(val_ty), std::move(name)) {
if (!ptr) {
throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr"));
bool UnaryInst::classof(const Value* value) {
if (!value || !Instruction::classof(value)) {
return false;
}
if (!type_ || !type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32"));
switch (static_cast<const Instruction*>(value)->GetOpcode()) {
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
return true;
default:
return false;
}
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
throw std::runtime_error(
FormatError("ir", "LoadInst 当前只支持从 i32* 加载"));
}
UnaryInst::UnaryInst(Opcode opcode, std::shared_ptr<Type> ty, Value* operand,
BasicBlock* parent, const std::string& name)
: Instruction(opcode, std::move(ty), parent, name) {
AddOperand(operand);
}
ReturnInst::ReturnInst(Value* value, BasicBlock* parent)
: Instruction(Opcode::Return, Type::GetVoidType(), parent, "") {
if (value) {
AddOperand(value);
}
}
AllocaInst::AllocaInst(std::shared_ptr<Type> allocated_type,
BasicBlock* parent, const std::string& name)
: Instruction(Opcode::Alloca, Type::GetPointerType(allocated_type), parent,
name),
allocated_type_(std::move(allocated_type)) {}
LoadInst::LoadInst(std::shared_ptr<Type> value_type, Value* ptr,
BasicBlock* parent, const std::string& name)
: Instruction(Opcode::Load, std::move(value_type), parent, name) {
AddOperand(ptr);
}
Value* LoadInst::GetPtr() const { return GetOperand(0); }
StoreInst::StoreInst(Value* value, Value* ptr, BasicBlock* parent)
: Instruction(Opcode::Store, Type::GetVoidType(), parent, "") {
AddOperand(value);
AddOperand(ptr);
}
StoreInst::StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr)
: Instruction(Opcode::Store, std::move(void_ty), "") {
if (!val) {
throw std::runtime_error(FormatError("ir", "StoreInst 缺少 value"));
}
if (!ptr) {
throw std::runtime_error(FormatError("ir", "StoreInst 缺少 ptr"));
}
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void"));
}
if (!val->GetType() || !val->GetType()->IsInt32()) {
throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32"));
}
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
throw std::runtime_error(
FormatError("ir", "StoreInst 当前只支持写入 i32*"));
UncondBrInst::UncondBrInst(BasicBlock* dest, BasicBlock* parent)
: Instruction(Opcode::Br, Type::GetVoidType(), parent, "") {
AddOperand(dest);
}
BasicBlock* UncondBrInst::GetDest() const {
return dyncast<BasicBlock>(GetOperand(0));
}
CondBrInst::CondBrInst(Value* cond, BasicBlock* then_block,
BasicBlock* else_block, BasicBlock* parent)
: Instruction(Opcode::CondBr, Type::GetVoidType(), parent, "") {
AddOperand(cond);
AddOperand(then_block);
AddOperand(else_block);
}
BasicBlock* CondBrInst::GetThenBlock() const {
return dyncast<BasicBlock>(GetOperand(1));
}
BasicBlock* CondBrInst::GetElseBlock() const {
return dyncast<BasicBlock>(GetOperand(2));
}
UnreachableInst::UnreachableInst(BasicBlock* parent)
: Instruction(Opcode::Unreachable, Type::GetVoidType(), parent, "") {}
CallInst::CallInst(Function* callee, const std::vector<Value*>& args,
BasicBlock* parent, const std::string& name)
: Instruction(Opcode::Call, callee->GetReturnType(), parent, name) {
AddOperand(callee);
AddOperands(args);
}
Function* CallInst::GetCallee() const { return dyncast<Function>(GetOperand(0)); }
std::vector<Value*> CallInst::GetArguments() const {
std::vector<Value*> args;
for (size_t i = 1; i < GetNumOperands(); ++i) {
args.push_back(GetOperand(i));
}
AddOperand(val);
return args;
}
GetElementPtrInst::GetElementPtrInst(std::shared_ptr<Type> source_type,
Value* ptr,
const std::vector<Value*>& indices,
BasicBlock* parent,
const std::string& name)
: Instruction(Opcode::GetElementPtr, Type::GetPointerType(), parent, name),
source_type_(std::move(source_type)) {
AddOperand(ptr);
AddOperands(indices);
}
Value* StoreInst::GetValue() const { return GetOperand(0); }
PhiInst::PhiInst(std::shared_ptr<Type> type, BasicBlock* parent,
const std::string& name)
: Instruction(Opcode::Phi, std::move(type), parent, name) {}
Value* StoreInst::GetPtr() const { return GetOperand(1); }
void PhiInst::AddIncoming(Value* value, BasicBlock* block) {
AddOperand(value);
AddOperand(block);
}
BasicBlock* PhiInst::GetIncomingBlock(int index) const {
return dyncast<BasicBlock>(GetOperand(static_cast<size_t>(2 * index + 1)));
}
ZextInst::ZextInst(Value* value, std::shared_ptr<Type> target_type,
BasicBlock* parent, const std::string& name)
: Instruction(Opcode::Zext, std::move(target_type), parent, name) {
AddOperand(value);
}
MemsetInst::MemsetInst(Value* dst, Value* value, Value* len,
Value* is_volatile, BasicBlock* parent)
: Instruction(Opcode::Memset, Type::GetVoidType(), parent, "") {
AddOperand(dst);
AddOperand(value);
AddOperand(len);
AddOperand(is_volatile);
}
} // namespace ir
} // namespace ir

@ -1,21 +1,45 @@
// 保存函数列表并提供模块级上下文访问。
#include "ir/IR.h"
namespace ir {
Context& Module::GetContext() { return context_; }
Function* Module::CreateFunction(
const std::string& name, std::shared_ptr<Type> ret_type,
const std::vector<std::shared_ptr<Type>>& param_types,
const std::vector<std::string>& param_names, bool is_external) {
if (auto* existing = GetFunction(name)) {
existing->SetExternal(existing->IsExternal() && is_external);
return existing;
}
auto func = std::make_unique<Function>(name, std::move(ret_type), param_types,
param_names, is_external);
auto* ptr = func.get();
functions_.push_back(std::move(func));
function_map_[name] = ptr;
return ptr;
}
const Context& Module::GetContext() const { return context_; }
Function* Module::GetFunction(const std::string& name) const {
auto it = function_map_.find(name);
return it == function_map_.end() ? nullptr : it->second;
}
Function* Module::CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type) {
functions_.push_back(std::make_unique<Function>(name, std::move(ret_type)));
return functions_.back().get();
GlobalValue* Module::CreateGlobalValue(const std::string& name,
std::shared_ptr<Type> object_type,
bool is_const, Value* init) {
if (auto* existing = GetGlobalValue(name)) {
return existing;
}
auto global =
std::make_unique<GlobalValue>(std::move(object_type), name, is_const, init);
auto* ptr = global.get();
globals_.push_back(std::move(global));
global_map_[name] = ptr;
return ptr;
}
const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const {
return functions_;
GlobalValue* Module::GetGlobalValue(const std::string& name) const {
auto it = global_map_.find(name);
return it == global_map_.end() ? nullptr : it->second;
}
} // namespace ir
} // namespace ir

@ -1,31 +1,111 @@
// 当前仅支持 void、i32 和 i32*。
#include "ir/IR.h"
#include <ostream>
#include <stdexcept>
namespace ir {
Type::Type(Kind k) : kind_(k) {}
Type::Type(Kind kind) : kind_(kind) {}
Type::Type(Kind kind, std::shared_ptr<Type> element_type, size_t num_elements)
: kind_(kind),
element_type_(std::move(element_type)),
num_elements_(num_elements) {}
const std::shared_ptr<Type>& Type::GetVoidType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Void);
static const auto type = std::make_shared<Type>(Kind::Void);
return type;
}
const std::shared_ptr<Type>& Type::GetInt1Type() {
static const auto type = std::make_shared<Type>(Kind::Int1);
return type;
}
const std::shared_ptr<Type>& Type::GetInt32Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int32);
static const auto type = std::make_shared<Type>(Kind::Int32);
return type;
}
const std::shared_ptr<Type>& Type::GetPtrInt32Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::PtrInt32);
const std::shared_ptr<Type>& Type::GetFloatType() {
static const auto type = std::make_shared<Type>(Kind::Float);
return type;
}
const std::shared_ptr<Type>& Type::GetLabelType() {
static const auto type = std::make_shared<Type>(Kind::Label);
return type;
}
Type::Kind Type::GetKind() const { return kind_; }
const std::shared_ptr<Type>& Type::GetBoolType() { return GetInt1Type(); }
bool Type::IsVoid() const { return kind_ == Kind::Void; }
std::shared_ptr<Type> Type::GetPointerType(std::shared_ptr<Type> pointee) {
return std::make_shared<Type>(Kind::Pointer, std::move(pointee));
}
bool Type::IsInt32() const { return kind_ == Kind::Int32; }
const std::shared_ptr<Type>& Type::GetPtrInt32Type() {
static const auto type = std::make_shared<Type>(Kind::Pointer);
return type;
}
bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; }
std::shared_ptr<Type> Type::GetArrayType(std::shared_ptr<Type> element_type,
size_t num_elements) {
return std::make_shared<Type>(Kind::Array, std::move(element_type), num_elements);
}
int Type::GetSize() const {
switch (kind_) {
case Kind::Void:
case Kind::Label:
case Kind::Function:
return 0;
case Kind::Int1:
return 1;
case Kind::Int32:
case Kind::Float:
return 4;
case Kind::Pointer:
return 8;
case Kind::Array:
return static_cast<int>(num_elements_) *
(element_type_ ? element_type_->GetSize() : 0);
}
throw std::runtime_error("unknown IR type kind");
}
void Type::Print(std::ostream& os) const {
switch (kind_) {
case Kind::Void:
os << "void";
return;
case Kind::Int1:
os << "i1";
return;
case Kind::Int32:
os << "i32";
return;
case Kind::Float:
os << "float";
return;
case Kind::Label:
os << "label";
return;
case Kind::Function:
os << "fn";
return;
case Kind::Pointer:
os << "ptr";
return;
case Kind::Array:
os << "[" << num_elements_ << " x ";
if (element_type_) {
element_type_->Print(os);
} else {
os << "void";
}
os << "]";
return;
}
}
} // namespace ir
} // namespace ir

@ -1,46 +1,19 @@
// SSA 值体系抽象:
// - 常量、参数、指令结果等统一为 Value
// - 提供类型信息与使用/被使用关系(按需要实现)
#include "ir/IR.h"
#include <algorithm>
#include <ostream>
#include <stdexcept>
namespace ir {
Value::Value(std::shared_ptr<Type> ty, std::string name)
: type_(std::move(ty)), name_(std::move(name)) {}
const std::shared_ptr<Type>& Value::GetType() const { return type_; }
const std::string& Value::GetName() const { return name_; }
void Value::SetName(std::string n) { name_ = std::move(n); }
bool Value::IsVoid() const { return type_ && type_->IsVoid(); }
bool Value::IsInt32() const { return type_ && type_->IsInt32(); }
bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); }
bool Value::IsConstant() const {
return dynamic_cast<const ConstantValue*>(this) != nullptr;
}
bool Value::IsInstruction() const {
return dynamic_cast<const Instruction*>(this) != nullptr;
}
bool Value::IsUser() const {
return dynamic_cast<const User*>(this) != nullptr;
}
bool Value::IsFunction() const {
return dynamic_cast<const Function*>(this) != nullptr;
}
void Value::AddUse(User* user, size_t operand_index) {
if (!user) return;
uses_.push_back(Use(this, user, operand_index));
if (!user) {
return;
}
uses_.emplace_back(this, user, operand_index);
}
void Value::RemoveUse(User* user, size_t operand_index) {
@ -53,31 +26,41 @@ void Value::RemoveUse(User* user, size_t operand_index) {
uses_.end());
}
const std::vector<Use>& Value::GetUses() const { return uses_; }
void Value::ReplaceAllUsesWith(Value* new_value) {
if (!new_value) {
throw std::runtime_error("ReplaceAllUsesWith 缺少 new_value");
throw std::runtime_error("ReplaceAllUsesWith requires a new value");
}
if (new_value == this) {
return;
}
auto uses = uses_;
for (const auto& use : uses) {
auto* user = use.GetUser();
if (!user) continue;
size_t operand_index = use.GetOperandIndex();
if (user->GetOperand(operand_index) == this) {
user->SetOperand(operand_index, new_value);
if (auto* user = use.GetUser()) {
user->SetOperand(use.GetOperandIndex(), new_value);
}
}
}
void Value::Print(std::ostream& os) const { os << name_; }
ConstantValue::ConstantValue(std::shared_ptr<Type> ty, std::string name)
: Value(std::move(ty), std::move(name)) {}
ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int v)
: ConstantValue(std::move(ty), ""), value_(v) {}
ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int value)
: ConstantValue(std::move(ty), ""), value_(value) {}
ConstantFloat::ConstantFloat(std::shared_ptr<Type> ty, float value)
: ConstantValue(std::move(ty), ""), value_(value) {}
ConstantI1::ConstantI1(std::shared_ptr<Type> ty, bool value)
: ConstantValue(std::move(ty), ""), value_(value) {}
ConstantArrayValue::ConstantArrayValue(std::shared_ptr<Type> array_type,
const std::vector<Value*>& elements,
const std::vector<size_t>& dims,
const std::string& name)
: Value(std::move(array_type), name), elements_(elements), dims_(dims) {}
void ConstantArrayValue::Print(std::ostream& os) const { os << name_; }
} // namespace ir
} // namespace ir

@ -1,4 +1,167 @@
// 支配树分析:
// - 构建/查询 Dominator Tree 及相关关系
// - 为 mem2reg、CFG 优化与循环分析提供基础能力
#include "ir/Analysis.h"
#include <algorithm>
#include <functional>
namespace ir {
namespace {
std::vector<BasicBlock*> BuildReversePostOrder(Function& function) {
std::vector<BasicBlock*> post_order;
auto* entry = function.GetEntryBlock();
if (!entry) {
return post_order;
}
std::unordered_set<BasicBlock*> visited;
std::function<void(BasicBlock*)> dfs = [&](BasicBlock* block) {
if (!block || !visited.insert(block).second) {
return;
}
for (auto* succ : block->GetSuccessors()) {
dfs(succ);
}
post_order.push_back(block);
};
dfs(entry);
std::reverse(post_order.begin(), post_order.end());
return post_order;
}
} // namespace
DominatorTree::DominatorTree(Function& function) : function_(&function) {
Recalculate();
}
void DominatorTree::Recalculate() {
reverse_post_order_ = BuildReversePostOrder(*function_);
block_index_.clear();
dominates_.clear();
immediate_dominator_.clear();
dom_children_.clear();
const auto num_blocks = reverse_post_order_.size();
for (std::size_t i = 0; i < num_blocks; ++i) {
block_index_.emplace(reverse_post_order_[i], i);
}
if (num_blocks == 0) {
return;
}
dominates_.assign(num_blocks, std::vector<std::uint8_t>(num_blocks, 1));
dominates_[0].assign(num_blocks, 0);
dominates_[0][0] = 1;
bool changed = true;
while (changed) {
changed = false;
for (std::size_t i = 1; i < num_blocks; ++i) {
auto* block = reverse_post_order_[i];
std::vector<std::uint8_t> next(num_blocks, 1);
bool has_reachable_pred = false;
for (auto* pred : block->GetPredecessors()) {
auto pred_it = block_index_.find(pred);
if (pred_it == block_index_.end()) {
continue;
}
has_reachable_pred = true;
const auto& pred_dom = dominates_[pred_it->second];
for (std::size_t bit = 0; bit < num_blocks; ++bit) {
next[bit] &= pred_dom[bit];
}
}
if (!has_reachable_pred) {
next.assign(num_blocks, 0);
}
next[i] = 1;
if (next != dominates_[i]) {
dominates_[i] = std::move(next);
changed = true;
}
}
}
std::vector<std::size_t> dom_depth(num_blocks, 0);
for (std::size_t i = 0; i < num_blocks; ++i) {
for (std::size_t candidate = 0; candidate < num_blocks; ++candidate) {
if (dominates_[i][candidate]) {
++dom_depth[i];
}
}
}
for (std::size_t i = 1; i < num_blocks; ++i) {
auto* block = reverse_post_order_[i];
BasicBlock* idom = nullptr;
std::size_t best_depth = 0;
for (std::size_t candidate = 0; candidate < num_blocks; ++candidate) {
if (candidate == i || !dominates_[i][candidate]) {
continue;
}
auto* candidate_block = reverse_post_order_[candidate];
if (idom == nullptr || dom_depth[candidate] > best_depth) {
idom = candidate_block;
best_depth = dom_depth[candidate];
}
}
immediate_dominator_.emplace(block, idom);
if (idom) {
dom_children_[idom].push_back(block);
}
}
}
bool DominatorTree::IsReachable(BasicBlock* block) const {
return block != nullptr && block_index_.find(block) != block_index_.end();
}
bool DominatorTree::Dominates(BasicBlock* dom, BasicBlock* node) const {
if (!dom || !node) {
return false;
}
const auto dom_it = block_index_.find(dom);
const auto node_it = block_index_.find(node);
if (dom_it == block_index_.end() || node_it == block_index_.end()) {
return false;
}
return dominates_[node_it->second][dom_it->second] != 0;
}
bool DominatorTree::Dominates(Instruction* dom, Instruction* user) const {
if (!dom || !user) {
return false;
}
if (dom == user) {
return true;
}
auto* dom_block = dom->GetParent();
auto* user_block = user->GetParent();
if (dom_block != user_block) {
return Dominates(dom_block, user_block);
}
for (const auto& inst_ptr : dom_block->GetInstructions()) {
if (inst_ptr.get() == dom) {
return true;
}
if (inst_ptr.get() == user) {
return false;
}
}
return false;
}
BasicBlock* DominatorTree::GetIDom(BasicBlock* block) const {
auto it = immediate_dominator_.find(block);
return it == immediate_dominator_.end() ? nullptr : it->second;
}
const std::vector<BasicBlock*>& DominatorTree::GetChildren(BasicBlock* block) const {
static const std::vector<BasicBlock*> kEmpty;
auto it = dom_children_.find(block);
return it == dom_children_.end() ? kEmpty : it->second;
}
} // namespace ir

@ -1,4 +1,214 @@
// 循环分析:
// - 识别循环结构与层级关系
// - 为后续优化(可选)提供循环信息
#include "ir/Analysis.h"
#include <algorithm>
#include <functional>
namespace ir {
namespace {
std::vector<BasicBlock*> CollectNaturalLoopBlocks(BasicBlock* header,
BasicBlock* latch) {
std::vector<BasicBlock*> stack{latch};
std::unordered_set<BasicBlock*> loop_blocks{header, latch};
while (!stack.empty()) {
auto* block = stack.back();
stack.pop_back();
for (auto* pred : block->GetPredecessors()) {
if (!pred || !loop_blocks.insert(pred).second) {
continue;
}
stack.push_back(pred);
}
}
return {loop_blocks.begin(), loop_blocks.end()};
}
} // namespace
bool Loop::Contains(BasicBlock* block) const {
return block != nullptr && blocks.find(block) != blocks.end();
}
bool Loop::Contains(const Loop* other) const {
if (!other) {
return false;
}
for (auto* block : other->blocks) {
if (!Contains(block)) {
return false;
}
}
return true;
}
bool Loop::IsInnermost() const { return subloops.empty(); }
LoopInfo::LoopInfo(Function& function, const DominatorTree& dom_tree)
: function_(&function), dom_tree_(&dom_tree) {
Recalculate();
}
void LoopInfo::Recalculate() {
loops_.clear();
top_level_loops_.clear();
block_to_loop_.clear();
std::unordered_map<BasicBlock*, Loop*> loops_by_header;
for (auto* block : dom_tree_->GetReversePostOrder()) {
for (auto* succ : block->GetSuccessors()) {
if (!dom_tree_->Dominates(succ, block)) {
continue;
}
Loop* loop = nullptr;
auto it = loops_by_header.find(succ);
if (it == loops_by_header.end()) {
auto new_loop = std::make_unique<Loop>();
new_loop->header = succ;
loop = new_loop.get();
loops_.push_back(std::move(new_loop));
loops_by_header.emplace(succ, loop);
} else {
loop = it->second;
}
if (std::find(loop->latches.begin(), loop->latches.end(), block) ==
loop->latches.end()) {
loop->latches.push_back(block);
}
for (auto* natural_block : CollectNaturalLoopBlocks(succ, block)) {
loop->blocks.insert(natural_block);
}
}
}
std::unordered_map<BasicBlock*, std::size_t> function_order;
for (std::size_t i = 0; i < function_->GetBlocks().size(); ++i) {
function_order.emplace(function_->GetBlocks()[i].get(), i);
}
for (const auto& loop_ptr : loops_) {
auto& loop = *loop_ptr;
loop.block_list.clear();
loop.exiting_blocks.clear();
loop.exit_blocks.clear();
loop.subloops.clear();
loop.parent = nullptr;
for (const auto& block_ptr : function_->GetBlocks()) {
if (loop.Contains(block_ptr.get())) {
loop.block_list.push_back(block_ptr.get());
}
}
std::sort(loop.latches.begin(), loop.latches.end(),
[&](BasicBlock* lhs, BasicBlock* rhs) {
return function_order[lhs] < function_order[rhs];
});
std::vector<BasicBlock*> outside_preds;
for (auto* pred : loop.header->GetPredecessors()) {
if (!loop.Contains(pred)) {
outside_preds.push_back(pred);
}
}
if (outside_preds.size() == 1 &&
outside_preds.front()->GetSuccessors().size() == 1) {
loop.preheader = outside_preds.front();
} else {
loop.preheader = nullptr;
}
std::unordered_set<BasicBlock*> exiting_seen;
std::unordered_set<BasicBlock*> exit_seen;
for (auto* block : loop.block_list) {
for (auto* succ : block->GetSuccessors()) {
if (loop.Contains(succ)) {
continue;
}
if (exiting_seen.insert(block).second) {
loop.exiting_blocks.push_back(block);
}
if (exit_seen.insert(succ).second) {
loop.exit_blocks.push_back(succ);
}
}
}
std::sort(loop.exiting_blocks.begin(), loop.exiting_blocks.end(),
[&](BasicBlock* lhs, BasicBlock* rhs) {
return function_order[lhs] < function_order[rhs];
});
std::sort(loop.exit_blocks.begin(), loop.exit_blocks.end(),
[&](BasicBlock* lhs, BasicBlock* rhs) {
return function_order[lhs] < function_order[rhs];
});
}
for (const auto& loop_ptr : loops_) {
auto* loop = loop_ptr.get();
Loop* parent = nullptr;
for (const auto& candidate_ptr : loops_) {
auto* candidate = candidate_ptr.get();
if (candidate == loop || !candidate->Contains(loop)) {
continue;
}
if (!parent || candidate->blocks.size() < parent->blocks.size()) {
parent = candidate;
}
}
loop->parent = parent;
if (parent) {
parent->subloops.push_back(loop);
} else {
top_level_loops_.push_back(loop);
}
}
auto loop_order = [&](Loop* lhs, Loop* rhs) {
return function_order[lhs->header] < function_order[rhs->header];
};
std::sort(top_level_loops_.begin(), top_level_loops_.end(), loop_order);
for (const auto& loop_ptr : loops_) {
std::sort(loop_ptr->subloops.begin(), loop_ptr->subloops.end(), loop_order);
}
for (const auto& block_ptr : function_->GetBlocks()) {
Loop* innermost = nullptr;
for (const auto& loop_ptr : loops_) {
auto* loop = loop_ptr.get();
if (!loop->Contains(block_ptr.get())) {
continue;
}
if (!innermost || loop->blocks.size() < innermost->blocks.size()) {
innermost = loop;
}
}
if (innermost) {
block_to_loop_.emplace(block_ptr.get(), innermost);
}
}
}
std::vector<Loop*> LoopInfo::GetTopLevelLoops() const { return top_level_loops_; }
std::vector<Loop*> LoopInfo::GetLoopsInPostOrder() const {
std::vector<Loop*> ordered;
std::function<void(Loop*)> dfs = [&](Loop* loop) {
for (auto* subloop : loop->subloops) {
dfs(subloop);
}
ordered.push_back(loop);
};
for (auto* loop : top_level_loops_) {
dfs(loop);
}
return ordered;
}
Loop* LoopInfo::GetLoopFor(BasicBlock* block) const {
auto it = block_to_loop_.find(block);
return it == block_to_loop_.end() ? nullptr : it->second;
}
} // namespace ir

@ -0,0 +1,137 @@
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "LoopPassUtils.h"
#include <algorithm>
#include <cstdint>
#include <memory>
#include <vector>
namespace ir {
namespace {
bool IsPowerOfTwoPositive(int value) {
return value > 0 && (value & (value - 1)) == 0;
}
std::size_t FindInstructionIndex(BasicBlock* block, Instruction* inst) {
if (!block || !inst) {
return 0;
}
auto& instructions = block->GetInstructions();
for (std::size_t i = 0; i < instructions.size(); ++i) {
if (instructions[i].get() == inst) {
return i;
}
}
return instructions.size();
}
bool IsZero(Value* value) {
if (auto* ci = dyncast<ConstantInt>(value)) {
return ci->GetValue() == 0;
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return !cb->GetValue();
}
return false;
}
Value* OtherCompareOperand(BinaryInst* cmp, Value* value) {
if (!cmp || cmp->GetNumOperands() != 2) {
return nullptr;
}
if (cmp->GetLhs() == value) {
return cmp->GetRhs();
}
if (cmp->GetRhs() == value) {
return cmp->GetLhs();
}
return nullptr;
}
bool SimplifyPowerOfTwoRemTests(Function& function) {
bool changed = false;
std::vector<Instruction*> dead_rems;
for (const auto& block_ptr : function.GetBlocks()) {
auto* block = block_ptr.get();
if (!block) {
continue;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* rem = dyncast<BinaryInst>(inst_ptr.get());
if (!rem || rem->GetOpcode() != Opcode::Rem) {
continue;
}
auto* divisor = dyncast<ConstantInt>(rem->GetRhs());
if (!divisor || !IsPowerOfTwoPositive(divisor->GetValue())) {
continue;
}
const int mask_value = divisor->GetValue() - 1;
if (mask_value == 0) {
rem->ReplaceAllUsesWith(looputils::ConstInt(0));
dead_rems.push_back(rem);
changed = true;
continue;
}
std::vector<BinaryInst*> compare_uses;
bool all_uses_are_zero_tests = !rem->GetUses().empty();
for (const auto& use : rem->GetUses()) {
auto* cmp = dyncast<BinaryInst>(dynamic_cast<Value*>(use.GetUser()));
if (!cmp || (cmp->GetOpcode() != Opcode::ICmpEQ &&
cmp->GetOpcode() != Opcode::ICmpNE) ||
!IsZero(OtherCompareOperand(cmp, rem))) {
all_uses_are_zero_tests = false;
break;
}
compare_uses.push_back(cmp);
}
if (!all_uses_are_zero_tests || compare_uses.empty()) {
continue;
}
const auto insert_index = FindInstructionIndex(block, rem) + 1;
auto* masked = block->Insert<BinaryInst>(
insert_index, Opcode::And, Type::GetInt32Type(), rem->GetLhs(),
looputils::ConstInt(mask_value), nullptr,
looputils::NextSyntheticName(function, "pow2.mask."));
for (auto* cmp : compare_uses) {
if (cmp->GetLhs() == rem) {
cmp->SetOperand(0, masked);
}
if (cmp->GetRhs() == rem) {
cmp->SetOperand(1, masked);
}
}
dead_rems.push_back(rem);
changed = true;
}
}
for (auto* rem : dead_rems) {
if (rem->GetUses().empty() && rem->GetParent()) {
rem->GetParent()->EraseInstruction(rem);
}
}
return changed;
}
} // namespace
bool RunArithmeticSimplify(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (!function || function->IsExternal()) {
continue;
}
changed |= SimplifyPowerOfTwoRemTests(*function);
}
return changed;
}
} // namespace ir

@ -1,4 +1,107 @@
// CFG 简化:
// - 删除不可达块、合并空块、简化分支等
// - 改善 IR 结构,便于后续优化与后端生成
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <vector>
namespace ir {
namespace {
bool TryGetConstBranchTarget(CondBrInst* br, BasicBlock*& target, BasicBlock*& removed) {
if (!br) {
return false;
}
auto* then_block = br->GetThenBlock();
auto* else_block = br->GetElseBlock();
if (then_block == else_block) {
target = then_block;
removed = nullptr;
return true;
}
if (auto* cond = dyncast<ConstantI1>(br->GetCondition())) {
target = cond->GetValue() ? then_block : else_block;
removed = cond->GetValue() ? else_block : then_block;
return true;
}
return false;
}
bool SimplifyBlockTerminator(BasicBlock* block) {
if (!block || block->GetInstructions().empty()) {
return false;
}
auto* term = block->GetInstructions().back().get();
auto* condbr = dyncast<CondBrInst>(term);
if (!condbr) {
return false;
}
BasicBlock* target = nullptr;
BasicBlock* removed = nullptr;
if (!TryGetConstBranchTarget(condbr, target, removed)) {
return false;
}
if (removed) {
passutils::RemoveIncomingFromSuccessor(removed, block);
removed->RemovePredecessor(block);
block->RemoveSuccessor(removed);
}
passutils::ReplaceTerminatorWithBr(block, target);
return true;
}
bool SimplifyPhiNodes(Function& function) {
bool changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
bool local_changed = true;
while (local_changed) {
local_changed = false;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
if (!passutils::SimplifyPhiInst(phi)) {
continue;
}
local_changed = true;
changed = true;
break;
}
}
}
return changed;
}
bool RunCFGSimplifyOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
bool changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
changed |= SimplifyBlockTerminator(block_ptr.get());
}
changed |= passutils::RemoveUnreachableBlocks(function);
changed |= SimplifyPhiNodes(function);
return changed;
}
} // namespace
bool RunCFGSimplify(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunCFGSimplifyOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -3,9 +3,20 @@ add_library(ir_passes STATIC
Mem2Reg.cpp
ConstFold.cpp
ConstProp.cpp
TailRecursionElim.cpp
ArithmeticSimplify.cpp
Inline.cpp
CSE.cpp
GVN.cpp
LoadStoreElim.cpp
DCE.cpp
CFGSimplify.cpp
LICM.cpp
LoopMemoryPromotion.cpp
LoopUnswitch.cpp
LoopStrengthReduction.cpp
LoopUnroll.cpp
LoopFission.cpp
)
target_link_libraries(ir_passes PUBLIC

@ -1,4 +1,164 @@
// 公共子表达式消除CSE
// - 识别并复用重复计算的等价表达式
// - 典型放置在 ConstFold 之后、DCE 之前
// - 当前为 Lab4 的框架占位,具体算法由实验实现
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <algorithm>
#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
struct ExprKey {
Opcode opcode = Opcode::Add;
struct OperandKey {
int kind = 0;
std::intptr_t value = 0;
bool operator==(const OperandKey& rhs) const {
return kind == rhs.kind && value == rhs.value;
}
};
std::vector<OperandKey> operands;
bool operator==(const ExprKey& rhs) const {
return opcode == rhs.opcode && operands == rhs.operands;
}
};
struct ExprKeyHash {
std::size_t operator()(const ExprKey& key) const {
std::size_t h = static_cast<std::size_t>(key.opcode);
for (auto operand : key.operands) {
h ^= std::hash<int>{}(operand.kind) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<std::intptr_t>{}(operand.value) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
ExprKey::OperandKey BuildOperandKey(Value* value) {
if (auto* ci = dyncast<ConstantInt>(value)) {
return {1, ci->GetValue()};
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return {2, cb->GetValue() ? 1 : 0};
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
return {3, static_cast<std::intptr_t>(passutils::FloatBits(cf->GetValue()))};
}
return {0, reinterpret_cast<std::intptr_t>(value)};
}
bool IsSupportedCSEInstruction(Instruction* inst) {
if (!inst || inst->IsVoid()) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
case Opcode::Zext:
return true;
default:
return false;
}
}
ExprKey BuildExprKey(Instruction* inst) {
ExprKey key;
key.opcode = inst->GetOpcode();
key.operands.reserve(inst->GetNumOperands());
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
key.operands.push_back(BuildOperandKey(inst->GetOperand(i)));
}
if (inst->GetNumOperands() == 2 && passutils::IsCommutativeOpcode(inst->GetOpcode()) &&
(key.operands[1].kind < key.operands[0].kind ||
(key.operands[1].kind == key.operands[0].kind &&
key.operands[1].value < key.operands[0].value))) {
std::swap(key.operands[0], key.operands[1]);
}
return key;
}
bool RunCSEOnFunction(Function& function) {
if (function.IsExternal()) {
return false;
}
bool changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
std::unordered_map<ExprKey, Value*, ExprKeyHash> available_exprs;
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!IsSupportedCSEInstruction(inst)) {
continue;
}
const auto key = BuildExprKey(inst);
auto it = available_exprs.find(key);
if (it == available_exprs.end()) {
available_exprs.emplace(key, inst);
continue;
}
inst->ReplaceAllUsesWith(it->second);
to_remove.push_back(inst);
changed = true;
}
for (auto* inst : to_remove) {
block_ptr->EraseInstruction(inst);
}
}
return changed;
}
} // namespace
bool RunCSE(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunCSEOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,469 @@
// IR 常量折叠:
// - 折叠可判定的常量表达式
// - 简化常量控制流分支(按实现范围裁剪)
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <cmath>
#include <cstdint>
#include <limits>
#include <vector>
namespace ir {
namespace {
Value* GetInt32Const(Context& ctx, std::int32_t value) {
return ctx.GetConstInt(static_cast<int>(value));
}
Value* GetBoolConst(Context& ctx, bool value) { return ctx.GetConstBool(value); }
Value* GetFloatConst(float value) {
return new ConstantFloat(Type::GetFloatType(), value);
}
bool TryGetInt32(Value* value, std::int32_t& out) {
if (auto* ci = dyncast<ConstantInt>(value)) {
out = static_cast<std::int32_t>(ci->GetValue());
return true;
}
return false;
}
bool TryGetBool(Value* value, bool& out) {
if (auto* cb = dyncast<ConstantI1>(value)) {
out = cb->GetValue();
return true;
}
return false;
}
bool TryGetFloat(Value* value, float& out) {
if (auto* cf = dyncast<ConstantFloat>(value)) {
out = cf->GetValue();
return true;
}
return false;
}
bool IsZeroValue(Value* value) {
std::int32_t i32 = 0;
bool i1 = false;
float f32 = 0.0f;
return (TryGetInt32(value, i32) && i32 == 0) || (TryGetBool(value, i1) && !i1) ||
(TryGetFloat(value, f32) && passutils::FloatBits(f32) == 0);
}
bool IsOneValue(Value* value) {
std::int32_t i32 = 0;
bool i1 = false;
float f32 = 0.0f;
return (TryGetInt32(value, i32) && i32 == 1) || (TryGetBool(value, i1) && i1) ||
(TryGetFloat(value, f32) &&
passutils::FloatBits(f32) == passutils::FloatBits(1.0f));
}
bool IsAllOnesInt(Value* value) {
std::int32_t i32 = 0;
return TryGetInt32(value, i32) && i32 == -1;
}
std::int32_t WrapInt32(std::uint32_t value) {
return static_cast<std::int32_t>(value);
}
Value* FoldBinary(Context& ctx, BinaryInst* inst) {
const auto opcode = inst->GetOpcode();
auto* lhs = inst->GetLhs();
auto* rhs = inst->GetRhs();
std::int32_t lhs_i32 = 0;
std::int32_t rhs_i32 = 0;
bool lhs_i1 = false;
bool rhs_i1 = false;
float lhs_f32 = 0.0f;
float rhs_f32 = 0.0f;
const bool has_lhs_i32 = TryGetInt32(lhs, lhs_i32);
const bool has_rhs_i32 = TryGetInt32(rhs, rhs_i32);
const bool has_lhs_i1 = TryGetBool(lhs, lhs_i1);
const bool has_rhs_i1 = TryGetBool(rhs, rhs_i1);
const bool has_lhs_f32 = TryGetFloat(lhs, lhs_f32);
const bool has_rhs_f32 = TryGetFloat(rhs, rhs_f32);
if (has_lhs_i32 && has_rhs_i32) {
switch (opcode) {
case Opcode::Add:
return GetInt32Const(
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) +
static_cast<std::uint32_t>(rhs_i32)));
case Opcode::Sub:
return GetInt32Const(
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) -
static_cast<std::uint32_t>(rhs_i32)));
case Opcode::Mul:
return GetInt32Const(
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) *
static_cast<std::uint32_t>(rhs_i32)));
case Opcode::Div:
if (rhs_i32 == 0 ||
(lhs_i32 == std::numeric_limits<std::int32_t>::min() && rhs_i32 == -1)) {
return nullptr;
}
return GetInt32Const(ctx, lhs_i32 / rhs_i32);
case Opcode::Rem:
if (rhs_i32 == 0 ||
(lhs_i32 == std::numeric_limits<std::int32_t>::min() && rhs_i32 == -1)) {
return nullptr;
}
return GetInt32Const(ctx, lhs_i32 % rhs_i32);
case Opcode::And:
return GetInt32Const(ctx, lhs_i32 & rhs_i32);
case Opcode::Or:
return GetInt32Const(ctx, lhs_i32 | rhs_i32);
case Opcode::Xor:
return GetInt32Const(ctx, lhs_i32 ^ rhs_i32);
case Opcode::Shl:
if (rhs_i32 < 0 || rhs_i32 >= 32) {
return nullptr;
}
return GetInt32Const(ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32)
<< rhs_i32));
case Opcode::AShr:
if (rhs_i32 < 0 || rhs_i32 >= 32) {
return nullptr;
}
return GetInt32Const(ctx, lhs_i32 >> rhs_i32);
case Opcode::LShr:
if (rhs_i32 < 0 || rhs_i32 >= 32) {
return nullptr;
}
return GetInt32Const(
ctx,
WrapInt32(static_cast<std::uint32_t>(lhs_i32) >> rhs_i32));
case Opcode::ICmpEQ:
return GetBoolConst(ctx, lhs_i32 == rhs_i32);
case Opcode::ICmpNE:
return GetBoolConst(ctx, lhs_i32 != rhs_i32);
case Opcode::ICmpLT:
return GetBoolConst(ctx, lhs_i32 < rhs_i32);
case Opcode::ICmpGT:
return GetBoolConst(ctx, lhs_i32 > rhs_i32);
case Opcode::ICmpLE:
return GetBoolConst(ctx, lhs_i32 <= rhs_i32);
case Opcode::ICmpGE:
return GetBoolConst(ctx, lhs_i32 >= rhs_i32);
default:
break;
}
}
if (has_lhs_i1 && has_rhs_i1) {
switch (opcode) {
case Opcode::And:
return GetBoolConst(ctx, lhs_i1 && rhs_i1);
case Opcode::Or:
return GetBoolConst(ctx, lhs_i1 || rhs_i1);
case Opcode::Xor:
return GetBoolConst(ctx, lhs_i1 != rhs_i1);
case Opcode::ICmpEQ:
return GetBoolConst(ctx, lhs_i1 == rhs_i1);
case Opcode::ICmpNE:
return GetBoolConst(ctx, lhs_i1 != rhs_i1);
case Opcode::ICmpLT:
return GetBoolConst(ctx, static_cast<int>(lhs_i1) < static_cast<int>(rhs_i1));
case Opcode::ICmpGT:
return GetBoolConst(ctx, static_cast<int>(lhs_i1) > static_cast<int>(rhs_i1));
case Opcode::ICmpLE:
return GetBoolConst(ctx, static_cast<int>(lhs_i1) <= static_cast<int>(rhs_i1));
case Opcode::ICmpGE:
return GetBoolConst(ctx, static_cast<int>(lhs_i1) >= static_cast<int>(rhs_i1));
default:
break;
}
}
if (has_lhs_f32 && has_rhs_f32) {
switch (opcode) {
case Opcode::FAdd:
return GetFloatConst(lhs_f32 + rhs_f32);
case Opcode::FSub:
return GetFloatConst(lhs_f32 - rhs_f32);
case Opcode::FMul:
return GetFloatConst(lhs_f32 * rhs_f32);
case Opcode::FDiv:
return GetFloatConst(lhs_f32 / rhs_f32);
case Opcode::FRem:
return GetFloatConst(std::fmod(lhs_f32, rhs_f32));
case Opcode::FCmpEQ:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 == rhs_f32);
case Opcode::FCmpNE:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 != rhs_f32);
case Opcode::FCmpLT:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 < rhs_f32);
case Opcode::FCmpGT:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 > rhs_f32);
case Opcode::FCmpLE:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 <= rhs_f32);
case Opcode::FCmpGE:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 >= rhs_f32);
default:
break;
}
}
switch (opcode) {
case Opcode::Add:
if (IsZeroValue(rhs)) {
return lhs;
}
if (IsZeroValue(lhs)) {
return rhs;
}
break;
case Opcode::Sub:
if (IsZeroValue(rhs)) {
return lhs;
}
break;
case Opcode::Mul:
if (IsOneValue(rhs)) {
return lhs;
}
if (IsOneValue(lhs)) {
return rhs;
}
if (IsZeroValue(lhs) || IsZeroValue(rhs)) {
return GetInt32Const(ctx, 0);
}
break;
case Opcode::Div:
if (IsOneValue(rhs)) {
return lhs;
}
if (IsZeroValue(lhs) && !IsZeroValue(rhs)) {
return GetInt32Const(ctx, 0);
}
break;
case Opcode::Rem:
if ((has_rhs_i32 && (rhs_i32 == 1 || rhs_i32 == -1)) ||
(has_rhs_i1 && rhs_i1)) {
return GetInt32Const(ctx, 0);
}
if (IsZeroValue(lhs) && !IsZeroValue(rhs)) {
return GetInt32Const(ctx, 0);
}
break;
case Opcode::And:
if (IsZeroValue(lhs) || IsZeroValue(rhs)) {
return inst->GetType()->IsInt1() ? GetBoolConst(ctx, false)
: GetInt32Const(ctx, 0);
}
if (has_lhs_i1 && lhs_i1) {
return rhs;
}
if (has_rhs_i1 && rhs_i1) {
return lhs;
}
if (IsAllOnesInt(lhs)) {
return rhs;
}
if (IsAllOnesInt(rhs)) {
return lhs;
}
if (passutils::AreEquivalentValues(lhs, rhs)) {
return lhs;
}
break;
case Opcode::Or:
if (IsZeroValue(lhs)) {
return rhs;
}
if (IsZeroValue(rhs)) {
return lhs;
}
if (has_lhs_i1 && lhs_i1) {
return GetBoolConst(ctx, true);
}
if (has_rhs_i1 && rhs_i1) {
return GetBoolConst(ctx, true);
}
if (passutils::AreEquivalentValues(lhs, rhs)) {
return lhs;
}
break;
case Opcode::Xor:
if (IsZeroValue(lhs)) {
return rhs;
}
if (IsZeroValue(rhs)) {
return lhs;
}
if (passutils::AreEquivalentValues(lhs, rhs)) {
return inst->GetType()->IsInt1() ? GetBoolConst(ctx, false)
: GetInt32Const(ctx, 0);
}
break;
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
if (IsZeroValue(rhs)) {
return lhs;
}
if (IsZeroValue(lhs)) {
return GetInt32Const(ctx, 0);
}
break;
case Opcode::FAdd:
if (IsZeroValue(lhs)) {
return rhs;
}
if (IsZeroValue(rhs)) {
return lhs;
}
break;
case Opcode::FSub:
if (IsZeroValue(rhs)) {
return lhs;
}
break;
case Opcode::FMul:
if (IsOneValue(lhs)) {
return rhs;
}
if (IsOneValue(rhs)) {
return lhs;
}
break;
case Opcode::FDiv:
if (IsOneValue(rhs)) {
return lhs;
}
break;
default:
break;
}
return nullptr;
}
Value* FoldUnary(Context& ctx, UnaryInst* inst) {
auto* operand = inst->GetOprd();
std::int32_t i32 = 0;
bool i1 = false;
float f32 = 0.0f;
switch (inst->GetOpcode()) {
case Opcode::Neg:
if (TryGetInt32(operand, i32)) {
return GetInt32Const(ctx, WrapInt32(0u - static_cast<std::uint32_t>(i32)));
}
break;
case Opcode::Not:
if (TryGetBool(operand, i1)) {
return GetBoolConst(ctx, !i1);
}
if (TryGetInt32(operand, i32)) {
return GetInt32Const(ctx, i32 ^ 1);
}
break;
case Opcode::FNeg:
if (TryGetFloat(operand, f32)) {
return GetFloatConst(-f32);
}
break;
case Opcode::FtoI:
if (TryGetFloat(operand, f32)) {
return GetInt32Const(ctx, static_cast<std::int32_t>(f32));
}
break;
case Opcode::IToF:
if (TryGetInt32(operand, i32)) {
return GetFloatConst(static_cast<float>(i32));
}
if (TryGetBool(operand, i1)) {
return GetFloatConst(i1 ? 1.0f : 0.0f);
}
break;
default:
break;
}
return nullptr;
}
Value* FoldZext(Context& ctx, ZextInst* inst) {
auto* value = inst->GetValue();
bool i1 = false;
std::int32_t i32 = 0;
if (inst->GetType()->IsInt1()) {
if (TryGetBool(value, i1)) {
return GetBoolConst(ctx, i1);
}
if (TryGetInt32(value, i32)) {
return GetBoolConst(ctx, i32 != 0);
}
}
if (inst->GetType()->IsInt32()) {
if (TryGetBool(value, i1)) {
return GetInt32Const(ctx, i1 ? 1 : 0);
}
if (TryGetInt32(value, i32)) {
return GetInt32Const(ctx, i32);
}
}
return nullptr;
}
bool FoldFunction(Function& function, Context& ctx) {
if (function.IsExternal()) {
return false;
}
bool changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
Value* replacement = nullptr;
if (auto* binary = dyncast<BinaryInst>(inst)) {
replacement = FoldBinary(ctx, binary);
} else if (auto* unary = dyncast<UnaryInst>(inst)) {
replacement = FoldUnary(ctx, unary);
} else if (auto* zext = dyncast<ZextInst>(inst)) {
replacement = FoldZext(ctx, zext);
}
if (!replacement || replacement == inst) {
continue;
}
inst->ReplaceAllUsesWith(replacement);
to_remove.push_back(inst);
changed = true;
}
for (auto* inst : to_remove) {
block_ptr->EraseInstruction(inst);
}
}
return changed;
}
} // namespace
bool RunConstFold(Module& module) {
bool changed = false;
auto& ctx = module.GetContext();
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= FoldFunction(*function, ctx);
}
}
return changed;
}
} // namespace ir

@ -1,5 +1,550 @@
// 常量传播Constant Propagation
// - 沿 use-def 关系传播已知常量
// - 将可替换的 SSA 值改写为常量,暴露更多折叠机会
// - 常与 ConstFold、DCE、CFGSimplify 迭代配合使用
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <cmath>
#include <cstdint>
#include <limits>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
enum class LatticeKind { Unknown, Constant, Overdefined };
struct ConstantValue {
enum class Kind { Int32, Bool, Float };
Kind kind = Kind::Int32;
std::int32_t int32_value = 0;
bool bool_value = false;
float float_value = 0.0f;
};
struct LatticeValue {
LatticeKind kind = LatticeKind::Unknown;
ConstantValue constant;
};
bool EqualConstants(const ConstantValue& lhs, const ConstantValue& rhs) {
if (lhs.kind != rhs.kind) {
return false;
}
switch (lhs.kind) {
case ConstantValue::Kind::Int32:
return lhs.int32_value == rhs.int32_value;
case ConstantValue::Kind::Bool:
return lhs.bool_value == rhs.bool_value;
case ConstantValue::Kind::Float:
return passutils::FloatBits(lhs.float_value) ==
passutils::FloatBits(rhs.float_value);
}
return false;
}
Value* MaterializeConstant(Context& ctx, const ConstantValue& constant) {
switch (constant.kind) {
case ConstantValue::Kind::Int32:
return ctx.GetConstInt(static_cast<int>(constant.int32_value));
case ConstantValue::Kind::Bool:
return ctx.GetConstBool(constant.bool_value);
case ConstantValue::Kind::Float:
return new ConstantFloat(Type::GetFloatType(), constant.float_value);
}
return nullptr;
}
bool TryGetConstantValue(Value* value, ConstantValue& out) {
if (auto* ci = dyncast<ConstantInt>(value)) {
out.kind = ConstantValue::Kind::Int32;
out.int32_value = static_cast<std::int32_t>(ci->GetValue());
return true;
}
if (auto* cb = dyncast<ConstantI1>(value)) {
out.kind = ConstantValue::Kind::Bool;
out.bool_value = cb->GetValue();
return true;
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
out.kind = ConstantValue::Kind::Float;
out.float_value = cf->GetValue();
return true;
}
return false;
}
LatticeValue ConstantLattice(const ConstantValue& constant) {
LatticeValue value;
value.kind = LatticeKind::Constant;
value.constant = constant;
return value;
}
LatticeValue OverdefinedLattice() {
LatticeValue value;
value.kind = LatticeKind::Overdefined;
return value;
}
LatticeValue GetValueState(
Value* value, const std::unordered_map<Value*, LatticeValue>& states) {
ConstantValue constant;
if (TryGetConstantValue(value, constant)) {
return ConstantLattice(constant);
}
auto it = states.find(value);
if (it != states.end()) {
return it->second;
}
return OverdefinedLattice();
}
LatticeValue Meet(LatticeValue lhs, const LatticeValue& rhs) {
if (lhs.kind == LatticeKind::Overdefined || rhs.kind == LatticeKind::Overdefined) {
return OverdefinedLattice();
}
if (lhs.kind == LatticeKind::Unknown) {
return rhs;
}
if (rhs.kind == LatticeKind::Unknown) {
return lhs;
}
if (EqualConstants(lhs.constant, rhs.constant)) {
return lhs;
}
return OverdefinedLattice();
}
bool EvaluateUnary(Opcode opcode, const ConstantValue& operand,
ConstantValue& result) {
switch (opcode) {
case Opcode::Neg:
if (operand.kind != ConstantValue::Kind::Int32) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(
0u - static_cast<std::uint32_t>(operand.int32_value));
return true;
case Opcode::Not:
if (operand.kind == ConstantValue::Kind::Bool) {
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !operand.bool_value;
return true;
}
if (operand.kind == ConstantValue::Kind::Int32) {
result.kind = ConstantValue::Kind::Int32;
result.int32_value = operand.int32_value ^ 1;
return true;
}
return false;
case Opcode::FNeg:
if (operand.kind != ConstantValue::Kind::Float) {
return false;
}
result.kind = ConstantValue::Kind::Float;
result.float_value = -operand.float_value;
return true;
case Opcode::FtoI:
if (operand.kind != ConstantValue::Kind::Float) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(operand.float_value);
return true;
case Opcode::IToF:
if (operand.kind == ConstantValue::Kind::Int32) {
result.kind = ConstantValue::Kind::Float;
result.float_value = static_cast<float>(operand.int32_value);
return true;
}
if (operand.kind == ConstantValue::Kind::Bool) {
result.kind = ConstantValue::Kind::Float;
result.float_value = operand.bool_value ? 1.0f : 0.0f;
return true;
}
return false;
default:
return false;
}
}
bool EvaluateBinary(Opcode opcode, const ConstantValue& lhs,
const ConstantValue& rhs, ConstantValue& result) {
if (lhs.kind == ConstantValue::Kind::Int32 &&
rhs.kind == ConstantValue::Kind::Int32) {
const auto left = lhs.int32_value;
const auto right = rhs.int32_value;
switch (opcode) {
case Opcode::Add:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(
static_cast<std::uint32_t>(left) + static_cast<std::uint32_t>(right));
return true;
case Opcode::Sub:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(
static_cast<std::uint32_t>(left) - static_cast<std::uint32_t>(right));
return true;
case Opcode::Mul:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(
static_cast<std::uint32_t>(left) * static_cast<std::uint32_t>(right));
return true;
case Opcode::Div:
if (right == 0 ||
(left == std::numeric_limits<std::int32_t>::min() && right == -1)) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left / right;
return true;
case Opcode::Rem:
if (right == 0 ||
(left == std::numeric_limits<std::int32_t>::min() && right == -1)) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left % right;
return true;
case Opcode::And:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left & right;
return true;
case Opcode::Or:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left | right;
return true;
case Opcode::Xor:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left ^ right;
return true;
case Opcode::Shl:
if (right < 0 || right >= 32) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value =
static_cast<std::int32_t>(static_cast<std::uint32_t>(left) << right);
return true;
case Opcode::AShr:
if (right < 0 || right >= 32) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left >> right;
return true;
case Opcode::LShr:
if (right < 0 || right >= 32) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(
static_cast<std::uint32_t>(left) >> right);
return true;
case Opcode::ICmpEQ:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left == right;
return true;
case Opcode::ICmpNE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left != right;
return true;
case Opcode::ICmpLT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left < right;
return true;
case Opcode::ICmpGT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left > right;
return true;
case Opcode::ICmpLE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left <= right;
return true;
case Opcode::ICmpGE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left >= right;
return true;
default:
break;
}
}
if (lhs.kind == ConstantValue::Kind::Bool && rhs.kind == ConstantValue::Kind::Bool) {
const auto left = lhs.bool_value;
const auto right = rhs.bool_value;
switch (opcode) {
case Opcode::And:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left && right;
return true;
case Opcode::Or:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left || right;
return true;
case Opcode::Xor:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left != right;
return true;
case Opcode::ICmpEQ:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left == right;
return true;
case Opcode::ICmpNE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left != right;
return true;
case Opcode::ICmpLT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = static_cast<int>(left) < static_cast<int>(right);
return true;
case Opcode::ICmpGT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = static_cast<int>(left) > static_cast<int>(right);
return true;
case Opcode::ICmpLE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = static_cast<int>(left) <= static_cast<int>(right);
return true;
case Opcode::ICmpGE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = static_cast<int>(left) >= static_cast<int>(right);
return true;
default:
break;
}
}
if (lhs.kind == ConstantValue::Kind::Float &&
rhs.kind == ConstantValue::Kind::Float) {
const auto left = lhs.float_value;
const auto right = rhs.float_value;
switch (opcode) {
case Opcode::FAdd:
result.kind = ConstantValue::Kind::Float;
result.float_value = left + right;
return true;
case Opcode::FSub:
result.kind = ConstantValue::Kind::Float;
result.float_value = left - right;
return true;
case Opcode::FMul:
result.kind = ConstantValue::Kind::Float;
result.float_value = left * right;
return true;
case Opcode::FDiv:
result.kind = ConstantValue::Kind::Float;
result.float_value = left / right;
return true;
case Opcode::FRem:
result.kind = ConstantValue::Kind::Float;
result.float_value = std::fmod(left, right);
return true;
case Opcode::FCmpEQ:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left == right;
return true;
case Opcode::FCmpNE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left != right;
return true;
case Opcode::FCmpLT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left < right;
return true;
case Opcode::FCmpGT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left > right;
return true;
case Opcode::FCmpLE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left <= right;
return true;
case Opcode::FCmpGE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left >= right;
return true;
default:
break;
}
}
return false;
}
LatticeValue EvaluateInstruction(
Instruction* inst, const std::unordered_map<Value*, LatticeValue>& states) {
if (!inst || inst->IsVoid()) {
return OverdefinedLattice();
}
if (auto* phi = dyncast<PhiInst>(inst)) {
LatticeValue merged;
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
merged = Meet(merged, GetValueState(phi->GetIncomingValue(i), states));
if (merged.kind == LatticeKind::Overdefined) {
break;
}
}
return merged;
}
if (auto* binary = dyncast<BinaryInst>(inst)) {
const auto lhs = GetValueState(binary->GetLhs(), states);
const auto rhs = GetValueState(binary->GetRhs(), states);
if (lhs.kind == LatticeKind::Overdefined || rhs.kind == LatticeKind::Overdefined) {
return OverdefinedLattice();
}
if (lhs.kind != LatticeKind::Constant || rhs.kind != LatticeKind::Constant) {
return {};
}
ConstantValue folded;
if (!EvaluateBinary(binary->GetOpcode(), lhs.constant, rhs.constant, folded)) {
return OverdefinedLattice();
}
return ConstantLattice(folded);
}
if (auto* unary = dyncast<UnaryInst>(inst)) {
const auto operand = GetValueState(unary->GetOprd(), states);
if (operand.kind == LatticeKind::Overdefined) {
return OverdefinedLattice();
}
if (operand.kind != LatticeKind::Constant) {
return {};
}
ConstantValue folded;
if (!EvaluateUnary(unary->GetOpcode(), operand.constant, folded)) {
return OverdefinedLattice();
}
return ConstantLattice(folded);
}
if (auto* zext = dyncast<ZextInst>(inst)) {
const auto operand = GetValueState(zext->GetValue(), states);
if (operand.kind == LatticeKind::Overdefined) {
return OverdefinedLattice();
}
if (operand.kind != LatticeKind::Constant) {
return {};
}
ConstantValue folded;
if (zext->GetType()->IsInt1()) {
folded.kind = ConstantValue::Kind::Bool;
if (operand.constant.kind == ConstantValue::Kind::Bool) {
folded.bool_value = operand.constant.bool_value;
return ConstantLattice(folded);
}
if (operand.constant.kind == ConstantValue::Kind::Int32) {
folded.bool_value = operand.constant.int32_value != 0;
return ConstantLattice(folded);
}
return OverdefinedLattice();
}
if (zext->GetType()->IsInt32()) {
folded.kind = ConstantValue::Kind::Int32;
if (operand.constant.kind == ConstantValue::Kind::Bool) {
folded.int32_value = operand.constant.bool_value ? 1 : 0;
return ConstantLattice(folded);
}
if (operand.constant.kind == ConstantValue::Kind::Int32) {
folded.int32_value = operand.constant.int32_value;
return ConstantLattice(folded);
}
}
return OverdefinedLattice();
}
return OverdefinedLattice();
}
bool RewriteFunction(Function& function, Context& ctx) {
if (function.IsExternal()) {
return false;
}
std::unordered_map<Value*, LatticeValue> states;
for (const auto& block_ptr : function.GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!inst->IsVoid()) {
states[inst] = {};
}
}
}
bool changed = true;
while (changed) {
changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsVoid()) {
continue;
}
const auto evaluated = EvaluateInstruction(inst, states);
if (evaluated.kind != states[inst].kind ||
(evaluated.kind == LatticeKind::Constant &&
!EqualConstants(evaluated.constant, states[inst].constant))) {
states[inst] = evaluated;
changed = true;
}
}
}
}
bool rewritten = false;
for (const auto& block_ptr : function.GetBlocks()) {
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
auto* operand = inst->GetOperand(i);
if (isa<BasicBlock>(operand) || isa<Function>(operand) || operand->IsConstant()) {
continue;
}
const auto state = GetValueState(operand, states);
if (state.kind != LatticeKind::Constant) {
continue;
}
inst->SetOperand(i, MaterializeConstant(ctx, state.constant));
rewritten = true;
}
if (inst->IsVoid()) {
continue;
}
const auto state = states[inst];
if (state.kind != LatticeKind::Constant) {
continue;
}
inst->ReplaceAllUsesWith(MaterializeConstant(ctx, state.constant));
to_remove.push_back(inst);
rewritten = true;
}
for (auto* inst : to_remove) {
block_ptr->EraseInstruction(inst);
}
}
return rewritten;
}
} // namespace
bool RunConstProp(Module& module) {
bool changed = false;
auto& ctx = module.GetContext();
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RewriteFunction(*function, ctx);
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,55 @@
// 死代码删除DCE
// - 删除无用指令与无用基本块
// - 通常与 CFG 简化配合使用
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <vector>
namespace ir {
namespace {
bool RunDCEOnFunction(Function& function) {
if (function.IsExternal()) {
return false;
}
bool changed = false;
bool local_changed = true;
while (local_changed) {
local_changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!passutils::IsTriviallyDead(inst)) {
continue;
}
to_remove.push_back(inst);
}
if (to_remove.empty()) {
continue;
}
for (auto* inst : to_remove) {
block_ptr->EraseInstruction(inst);
}
local_changed = true;
changed = true;
}
}
return changed;
}
} // namespace
bool RunDCE(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunDCEOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,219 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "MemoryUtils.h"
#include "PassUtils.h"
#include <algorithm>
#include <cstdint>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
struct ExprKey {
Opcode opcode = Opcode::Add;
std::uintptr_t result_type = 0;
std::uintptr_t aux_type = 0;
struct OperandKey {
int kind = 0;
std::intptr_t value = 0;
bool operator==(const OperandKey& rhs) const {
return kind == rhs.kind && value == rhs.value;
}
};
std::vector<OperandKey> operands;
bool operator==(const ExprKey& rhs) const {
return opcode == rhs.opcode && result_type == rhs.result_type &&
aux_type == rhs.aux_type && operands == rhs.operands;
}
};
struct ExprKeyHash {
std::size_t operator()(const ExprKey& key) const {
std::size_t h = static_cast<std::size_t>(key.opcode);
h ^= std::hash<std::uintptr_t>{}(key.result_type) + 0x9e3779b9 + (h << 6) +
(h >> 2);
h ^= std::hash<std::uintptr_t>{}(key.aux_type) + 0x9e3779b9 + (h << 6) +
(h >> 2);
for (auto operand : key.operands) {
h ^= std::hash<int>{}(operand.kind) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<std::intptr_t>{}(operand.value) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
ExprKey::OperandKey BuildOperandKey(Value* value) {
if (auto* ci = dyncast<ConstantInt>(value)) {
return {1, ci->GetValue()};
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return {2, cb->GetValue() ? 1 : 0};
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
return {3, static_cast<std::intptr_t>(passutils::FloatBits(cf->GetValue()))};
}
return {0, reinterpret_cast<std::intptr_t>(value)};
}
struct ScopedExpr {
ExprKey key;
Value* previous = nullptr;
bool had_previous = false;
};
bool IsSupportedGVNInstruction(Instruction* inst) {
if (!inst || inst->IsVoid()) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
case Opcode::GetElementPtr:
case Opcode::Zext:
return true;
case Opcode::Call:
return memutils::IsPureCall(dyncast<CallInst>(inst));
default:
return false;
}
}
ExprKey BuildExprKey(Instruction* inst) {
ExprKey key;
key.opcode = inst->GetOpcode();
key.result_type =
reinterpret_cast<std::uintptr_t>(inst->GetType().get());
if (auto* gep = dyncast<GetElementPtrInst>(inst)) {
key.aux_type = reinterpret_cast<std::uintptr_t>(gep->GetSourceType().get());
}
key.operands.reserve(inst->GetNumOperands());
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
key.operands.push_back(BuildOperandKey(inst->GetOperand(i)));
}
if (inst->GetNumOperands() == 2 &&
passutils::IsCommutativeOpcode(inst->GetOpcode()) &&
(key.operands[1].kind < key.operands[0].kind ||
(key.operands[1].kind == key.operands[0].kind &&
key.operands[1].value < key.operands[0].value))) {
std::swap(key.operands[0], key.operands[1]);
}
return key;
}
bool RunGVNInDomSubtree(
BasicBlock* block, const DominatorTree& dom_tree,
std::unordered_map<ExprKey, Value*, ExprKeyHash>& available) {
if (!block) {
return false;
}
bool changed = false;
std::vector<ScopedExpr> scope;
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!IsSupportedGVNInstruction(inst)) {
continue;
}
const auto key = BuildExprKey(inst);
auto it = available.find(key);
if (it != available.end()) {
inst->ReplaceAllUsesWith(it->second);
to_remove.push_back(inst);
changed = true;
continue;
}
ScopedExpr scoped{key, nullptr, false};
auto existing = available.find(key);
if (existing != available.end()) {
scoped.previous = existing->second;
scoped.had_previous = true;
existing->second = inst;
} else {
available.emplace(key, inst);
}
scope.push_back(std::move(scoped));
}
for (auto* inst : to_remove) {
block->EraseInstruction(inst);
}
for (auto* child : dom_tree.GetChildren(block)) {
changed |= RunGVNInDomSubtree(child, dom_tree, available);
}
for (auto it = scope.rbegin(); it != scope.rend(); ++it) {
if (it->had_previous) {
available[it->key] = it->previous;
} else {
available.erase(it->key);
}
}
return changed;
}
bool RunGVNOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
DominatorTree dom_tree(function);
std::unordered_map<ExprKey, Value*, ExprKeyHash> available;
return RunGVNInDomSubtree(function.GetEntryBlock(), dom_tree, available);
}
} // namespace
bool RunGVN(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunGVNOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,756 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopPassUtils.h"
#include "MathIdiomUtils.h"
#include <algorithm>
#include <cstdint>
#include <unordered_set>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
struct InlineCandidateInfo {
bool valid = false;
int cost = 0;
bool has_nested_call = false;
bool has_control_flow = false;
};
bool IsInlineableInstruction(const Instruction* inst) {
if (!inst) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
case Opcode::Load:
case Opcode::Store:
case Opcode::GetElementPtr:
case Opcode::Phi:
case Opcode::Zext:
case Opcode::Memset:
case Opcode::Call:
case Opcode::Return:
case Opcode::Br:
case Opcode::CondBr:
return true;
default:
return false;
}
}
int EstimateInstructionCost(const Instruction* inst) {
if (!inst) {
return 0;
}
switch (inst->GetOpcode()) {
case Opcode::Phi:
case Opcode::Return:
return 0;
case Opcode::Load:
case Opcode::Store:
case Opcode::Memset:
return 3;
case Opcode::Call:
return 8;
case Opcode::GetElementPtr:
return 2;
default:
return 1;
}
}
InlineCandidateInfo AnalyzeInlineCandidate(const Function& function) {
InlineCandidateInfo info;
if (function.IsExternal() || function.IsRecursive()) {
return info;
}
if (function.GetBlocks().empty() || function.GetBlocks().size() > 16) {
return info;
}
bool saw_return = false;
for (const auto& block : function.GetBlocks()) {
if (!block || block->GetInstructions().empty()) {
return info;
}
for (std::size_t i = 0; i < block->GetInstructions().size(); ++i) {
auto* inst = block->GetInstructions()[i].get();
if (!IsInlineableInstruction(inst) || dyncast<AllocaInst>(inst) ||
dyncast<UnreachableInst>(inst)) {
return {};
}
if (dyncast<ReturnInst>(inst)) {
if (i + 1 != block->GetInstructions().size()) {
return {};
}
saw_return = true;
continue;
}
if ((dyncast<UncondBrInst>(inst) || dyncast<CondBrInst>(inst)) &&
i + 1 != block->GetInstructions().size()) {
return {};
}
if (dyncast<CondBrInst>(inst) || dyncast<UncondBrInst>(inst)) {
info.has_control_flow = true;
}
if (dyncast<CallInst>(inst)) {
info.has_nested_call = true;
}
info.cost += EstimateInstructionCost(inst);
}
}
if (!saw_return) {
return {};
}
info.valid = true;
return info;
}
std::unordered_map<Function*, int> CountDirectCalls(Module& module) {
std::unordered_map<Function*, int> counts;
for (const auto& function_ptr : module.GetFunctions()) {
if (!function_ptr) {
continue;
}
for (const auto& block_ptr : function_ptr->GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
if (auto* call = dyncast<CallInst>(inst_ptr.get())) {
if (auto* callee = call->GetCallee()) {
++counts[callee];
}
}
}
}
}
return counts;
}
bool ShouldInlineCallSite(const Function& caller, const CallInst& call,
const InlineCandidateInfo& callee_info, int call_count) {
auto* callee = call.GetCallee();
if (!callee || callee == &caller || !callee_info.valid) {
return false;
}
if (mathidiom::IsToleranceNewtonSqrtShape(*callee)) {
return false;
}
if (mathidiom::IsPow2DigitExtractShape(*callee)) {
return false;
}
if (callee_info.has_control_flow && callee_info.has_nested_call) {
return false;
}
int budget = callee->CanDiscardUnusedCall() ? 96 : 72;
if (call_count <= 1) {
budget += 48;
}
if (callee_info.has_nested_call) {
budget -= 8;
}
if (callee_info.has_control_flow) {
budget -= 12;
}
if (callee->MayWriteMemory()) {
budget -= 4;
}
return callee_info.cost <= budget;
}
Instruction* CloneInstructionAt(Function& function, Instruction* inst, BasicBlock* dest,
std::size_t insert_index,
std::unordered_map<Value*, Value*>& remap) {
if (!inst || !dest) {
return nullptr;
}
const auto name = inst->IsVoid() ? std::string()
: looputils::NextSyntheticName(function, "inline.");
auto remap_operand = [&](Value* value) { return looputils::RemapValue(remap, value); };
auto remember = [&](Instruction* clone) {
if (clone && !inst->IsVoid()) {
remap[inst] = clone;
}
return clone;
};
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE: {
auto* bin = static_cast<BinaryInst*>(inst);
return remember(dest->Insert<BinaryInst>(insert_index, inst->GetOpcode(), inst->GetType(),
remap_operand(bin->GetLhs()),
remap_operand(bin->GetRhs()), nullptr, name));
}
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF: {
auto* un = static_cast<UnaryInst*>(inst);
return remember(dest->Insert<UnaryInst>(insert_index, inst->GetOpcode(), inst->GetType(),
remap_operand(un->GetOprd()), nullptr, name));
}
case Opcode::Load: {
auto* load = static_cast<LoadInst*>(inst);
return remember(dest->Insert<LoadInst>(insert_index, inst->GetType(),
remap_operand(load->GetPtr()), nullptr, name));
}
case Opcode::Store: {
auto* store = static_cast<StoreInst*>(inst);
return dest->Insert<StoreInst>(insert_index, remap_operand(store->GetValue()),
remap_operand(store->GetPtr()), nullptr);
}
case Opcode::Memset: {
auto* memset = static_cast<MemsetInst*>(inst);
return dest->Insert<MemsetInst>(insert_index, remap_operand(memset->GetDest()),
remap_operand(memset->GetValue()),
remap_operand(memset->GetLength()),
remap_operand(memset->GetIsVolatile()), nullptr);
}
case Opcode::GetElementPtr: {
auto* gep = static_cast<GetElementPtrInst*>(inst);
std::vector<Value*> indices;
indices.reserve(gep->GetNumIndices());
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
indices.push_back(remap_operand(gep->GetIndex(i)));
}
return remember(dest->Insert<GetElementPtrInst>(
insert_index, gep->GetSourceType(), remap_operand(gep->GetPointer()), indices, nullptr,
name));
}
case Opcode::Zext: {
auto* zext = static_cast<ZextInst*>(inst);
return remember(dest->Insert<ZextInst>(insert_index, remap_operand(zext->GetValue()),
inst->GetType(), nullptr, name));
}
case Opcode::Call: {
auto* call = static_cast<CallInst*>(inst);
std::vector<Value*> args;
const auto original_args = call->GetArguments();
args.reserve(original_args.size());
for (auto* arg : original_args) {
args.push_back(remap_operand(arg));
}
return remember(
dest->Insert<CallInst>(insert_index, call->GetCallee(), args, nullptr, name));
}
case Opcode::Return:
case Opcode::Alloca:
case Opcode::Phi:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Unreachable:
break;
}
return nullptr;
}
bool InlineCallSite(Function& caller, CallInst* call) {
if (!call) {
return false;
}
auto* callee = call->GetCallee();
if (!callee || callee->GetBlocks().size() != 1) {
return false;
}
const auto& callee_args = callee->GetArguments();
const auto call_args = call->GetArguments();
if (callee_args.size() != call_args.size()) {
return false;
}
auto* block = call->GetParent();
if (!block) {
return false;
}
auto& instructions = block->GetInstructions();
auto call_it = std::find_if(instructions.begin(), instructions.end(),
[&](const std::unique_ptr<Instruction>& current) {
return current.get() == call;
});
if (call_it == instructions.end()) {
return false;
}
std::size_t insert_index = static_cast<std::size_t>(call_it - instructions.begin());
std::unordered_map<Value*, Value*> remap;
for (std::size_t i = 0; i < call_args.size(); ++i) {
remap[callee_args[i].get()] = call_args[i];
}
Value* return_value = nullptr;
for (const auto& inst_ptr : callee->GetBlocks().front()->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* ret = dyncast<ReturnInst>(inst)) {
if (ret->HasReturnValue()) {
return_value = looputils::RemapValue(remap, ret->GetReturnValue());
}
break;
}
if (!CloneInstructionAt(caller, inst, block, insert_index, remap)) {
return false;
}
++insert_index;
}
if (!call->GetType()->IsVoid()) {
if (!return_value) {
return false;
}
call->ReplaceAllUsesWith(return_value);
}
block->EraseInstruction(call);
return true;
}
void ReplaceIncomingBlock(BasicBlock* block, BasicBlock* old_pred, BasicBlock* new_pred) {
if (!block || !old_pred || !new_pred) {
return;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
const int index = looputils::GetPhiIncomingIndex(phi, old_pred);
if (index >= 0) {
phi->SetOperand(static_cast<std::size_t>(2 * index + 1), new_pred);
}
}
}
std::vector<BasicBlock*> CollectReachableBlocks(Function& function) {
std::vector<BasicBlock*> order;
auto* entry = function.GetEntryBlock();
if (!entry) {
return order;
}
std::vector<BasicBlock*> stack{entry};
std::unordered_set<BasicBlock*> visited;
while (!stack.empty()) {
auto* block = stack.back();
stack.pop_back();
if (!block || !visited.insert(block).second) {
continue;
}
order.push_back(block);
const auto& succs = block->GetSuccessors();
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
if (*it) {
stack.push_back(*it);
}
}
}
return order;
}
BasicBlock* SplitBlockAfterCall(Function& caller, BasicBlock* block, CallInst* call) {
if (!block || !call) {
return nullptr;
}
auto& instructions = block->GetInstructions();
auto call_it = std::find_if(instructions.begin(), instructions.end(),
[&](const std::unique_ptr<Instruction>& current) {
return current.get() == call;
});
if (call_it == instructions.end() || std::next(call_it) == instructions.end()) {
return nullptr;
}
auto* continuation =
caller.CreateBlock(looputils::NextSyntheticBlockName(caller, "inline.cont"));
auto& continuation_insts = continuation->GetInstructions();
for (auto it = std::next(call_it); it != instructions.end(); ++it) {
(*it)->SetParent(continuation);
continuation_insts.push_back(std::move(*it));
}
instructions.erase(std::next(call_it), instructions.end());
auto old_succs = block->GetSuccessors();
for (auto* succ : old_succs) {
block->RemoveSuccessor(succ);
succ->RemovePredecessor(block);
succ->AddPredecessor(continuation);
continuation->AddSuccessor(succ);
ReplaceIncomingBlock(succ, block, continuation);
}
return continuation;
}
bool CanInlineCFGCallSite(Function& caller, CallInst* call,
std::vector<BasicBlock*>& callee_blocks) {
auto* callee = call ? call->GetCallee() : nullptr;
if (!call || !callee || callee->GetBlocks().size() <= 1 ||
callee == &caller) {
return false;
}
if (mathidiom::IsToleranceNewtonSqrtShape(*callee)) {
return false;
}
const auto& callee_args = callee->GetArguments();
const auto call_args = call->GetArguments();
if (callee_args.size() != call_args.size()) {
return false;
}
callee_blocks = CollectReachableBlocks(*callee);
if (callee_blocks.empty()) {
return false;
}
std::unordered_set<BasicBlock*> reachable(callee_blocks.begin(), callee_blocks.end());
for (auto* block : callee_blocks) {
if (!block || block->GetInstructions().empty()) {
return false;
}
bool seen_non_phi = false;
for (std::size_t i = 0; i < block->GetInstructions().size(); ++i) {
auto* inst = block->GetInstructions()[i].get();
if (dyncast<AllocaInst>(inst) || dyncast<UnreachableInst>(inst) ||
!IsInlineableInstruction(inst)) {
return false;
}
if (dyncast<PhiInst>(inst)) {
if (seen_non_phi) {
return false;
}
continue;
}
seen_non_phi = true;
if (auto* br = dyncast<UncondBrInst>(inst)) {
if (i + 1 != block->GetInstructions().size() ||
reachable.count(br->GetDest()) == 0) {
return false;
}
continue;
}
if (auto* condbr = dyncast<CondBrInst>(inst)) {
if (i + 1 != block->GetInstructions().size() ||
reachable.count(condbr->GetThenBlock()) == 0 ||
reachable.count(condbr->GetElseBlock()) == 0) {
return false;
}
continue;
}
if (dyncast<ReturnInst>(inst)) {
if (i + 1 != block->GetInstructions().size()) {
return false;
}
continue;
}
if (inst->IsTerminator() || !looputils::IsCloneableInstruction(inst)) {
return false;
}
}
}
return true;
}
bool InlineCFGCallSite(Function& caller, CallInst* call) {
auto* callee = call ? call->GetCallee() : nullptr;
if (!call || !callee || callee->GetBlocks().size() <= 1) {
return false;
}
const auto& callee_args = callee->GetArguments();
const auto call_args = call->GetArguments();
if (callee_args.size() != call_args.size()) {
return false;
}
std::vector<BasicBlock*> callee_blocks;
if (!CanInlineCFGCallSite(caller, call, callee_blocks)) {
return false;
}
auto* call_block = call->GetParent();
auto* continuation = SplitBlockAfterCall(caller, call_block, call);
if (!call_block || !continuation) {
return false;
}
std::unordered_map<Value*, Value*> remap;
for (std::size_t i = 0; i < call_args.size(); ++i) {
remap[callee_args[i].get()] = call_args[i];
}
std::unordered_map<BasicBlock*, BasicBlock*> block_map;
for (auto* block : callee_blocks) {
block_map[block] =
caller.CreateBlock(looputils::NextSyntheticBlockName(caller, "inline.bb"));
}
for (auto* block : callee_blocks) {
auto* clone = block_map.at(block);
for (const auto& inst_ptr : block->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
auto* cloned_phi = clone->Append<PhiInst>(
phi->GetType(), nullptr,
looputils::NextSyntheticName(caller, "inline.phi."));
remap[phi] = cloned_phi;
}
}
for (auto* block : callee_blocks) {
auto* clone = block_map.at(block);
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dyncast<PhiInst>(inst)) {
continue;
}
if (inst->IsTerminator()) {
continue;
}
if (!CloneInstructionAt(caller, inst, clone,
looputils::GetTerminatorIndex(clone), remap)) {
return false;
}
}
}
for (auto* block : callee_blocks) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
auto* cloned_phi = static_cast<PhiInst*>(remap.at(phi));
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
auto* incoming_block = phi->GetIncomingBlock(i);
auto block_it = block_map.find(incoming_block);
if (block_it == block_map.end()) {
return false;
}
cloned_phi->AddIncoming(looputils::RemapValue(remap, phi->GetIncomingValue(i)),
block_it->second);
}
}
}
std::vector<std::pair<BasicBlock*, Value*>> return_edges;
for (auto* block : callee_blocks) {
auto* clone = block_map.at(block);
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dyncast<PhiInst>(inst) || !inst->IsTerminator()) {
continue;
}
if (auto* ret = dyncast<ReturnInst>(inst)) {
clone->Append<UncondBrInst>(continuation, nullptr);
clone->AddSuccessor(continuation);
continuation->AddPredecessor(clone);
return_edges.emplace_back(
clone, ret->HasReturnValue() ? looputils::RemapValue(remap, ret->GetReturnValue())
: nullptr);
continue;
}
if (auto* br = dyncast<UncondBrInst>(inst)) {
auto* target = block_map.at(br->GetDest());
clone->Append<UncondBrInst>(target, nullptr);
clone->AddSuccessor(target);
target->AddPredecessor(clone);
continue;
}
if (auto* condbr = dyncast<CondBrInst>(inst)) {
auto* then_block = block_map.at(condbr->GetThenBlock());
auto* else_block = block_map.at(condbr->GetElseBlock());
clone->Append<CondBrInst>(looputils::RemapValue(remap, condbr->GetCondition()),
then_block, else_block, nullptr);
clone->AddSuccessor(then_block);
clone->AddSuccessor(else_block);
then_block->AddPredecessor(clone);
else_block->AddPredecessor(clone);
continue;
}
return false;
}
}
call_block->Append<UncondBrInst>(block_map.at(callee->GetEntryBlock()), nullptr);
call_block->AddSuccessor(block_map.at(callee->GetEntryBlock()));
block_map.at(callee->GetEntryBlock())->AddPredecessor(call_block);
if (!call->GetType()->IsVoid()) {
Value* return_value = nullptr;
if (return_edges.size() == 1) {
return_value = return_edges.front().second;
} else {
auto* phi = continuation->Insert<PhiInst>(
looputils::GetFirstNonPhiIndex(continuation), call->GetType(), nullptr,
looputils::NextSyntheticName(caller, "inline.ret."));
for (const auto& [pred, value] : return_edges) {
if (!value) {
return false;
}
phi->AddIncoming(value, pred);
}
return_value = phi;
}
if (!return_value) {
return false;
}
call->ReplaceAllUsesWith(return_value);
}
call_block->EraseInstruction(call);
return true;
}
bool RunFunctionInliningOnFunction(
Function& function,
const std::unordered_map<Function*, InlineCandidateInfo>& callee_info,
const std::unordered_map<Function*, int>& call_counts) {
if (function.IsExternal()) {
return false;
}
bool changed = false;
std::vector<BasicBlock*> block_snapshot;
block_snapshot.reserve(function.GetBlocks().size());
for (const auto& block_ptr : function.GetBlocks()) {
if (block_ptr) {
block_snapshot.push_back(block_ptr.get());
}
}
for (auto* block : block_snapshot) {
if (!block) {
continue;
}
std::vector<CallInst*> calls;
for (const auto& inst_ptr : block->GetInstructions()) {
if (auto* call = dyncast<CallInst>(inst_ptr.get())) {
calls.push_back(call);
}
}
for (auto* call : calls) {
auto* callee = call->GetCallee();
if (!callee) {
continue;
}
auto info_it = callee_info.find(callee);
if (info_it == callee_info.end()) {
continue;
}
const int call_count =
call_counts.count(callee) != 0 ? call_counts.at(callee) : 0;
if (!ShouldInlineCallSite(function, *call, info_it->second, call_count)) {
continue;
}
if (callee->GetBlocks().size() == 1) {
changed |= InlineCallSite(function, call);
} else {
changed |= InlineCFGCallSite(function, call);
}
}
}
return changed;
}
} // namespace
bool RunFunctionInlining(Module& module) {
std::unordered_map<Function*, InlineCandidateInfo> callee_info;
for (const auto& function_ptr : module.GetFunctions()) {
if (function_ptr) {
callee_info.emplace(function_ptr.get(), AnalyzeInlineCandidate(*function_ptr));
}
}
const auto call_counts = CountDirectCalls(module);
bool changed = false;
for (const auto& function_ptr : module.GetFunctions()) {
if (function_ptr) {
changed |= RunFunctionInliningOnFunction(*function_ptr, callee_info, call_counts);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,236 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopMemoryUtils.h"
#include "LoopPassUtils.h"
#include "MemoryUtils.h"
#include <cstdint>
#include <unordered_set>
#include <unordered_map>
#include <utility>
#include <vector>
namespace ir {
namespace {
struct HoistedLoadKey {
memutils::AddressKey address;
std::uintptr_t type_id = 0;
bool operator==(const HoistedLoadKey& rhs) const {
return type_id == rhs.type_id && address == rhs.address;
}
};
struct HoistedLoadKeyHash {
std::size_t operator()(const HoistedLoadKey& key) const {
std::size_t h = memutils::AddressKeyHash{}(key.address);
h ^= std::hash<std::uintptr_t>{}(key.type_id) + 0x9e3779b9 + (h << 6) + (h >> 2);
return h;
}
};
bool IsHoistableInstruction(const Instruction* inst) {
if (!inst || inst->IsTerminator() || inst->IsVoid()) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
case Opcode::GetElementPtr:
case Opcode::Zext:
case Opcode::Load:
return true;
default:
return false;
}
}
bool IsLoopInvariantInstruction(
const Loop& loop, Instruction* inst,
const std::unordered_set<Instruction*>& invariant_insts,
PhiInst* iv, int iv_stride,
const std::vector<loopmem::MemoryAccessInfo>& accesses,
const memutils::EscapeSummary& escapes) {
if (!IsHoistableInstruction(inst)) {
return false;
}
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
auto* operand = inst->GetOperand(i);
auto* operand_inst = dyncast<Instruction>(operand);
if (!operand_inst) {
continue;
}
if (!loop.Contains(operand_inst->GetParent())) {
continue;
}
if (invariant_insts.find(operand_inst) == invariant_insts.end()) {
return false;
}
}
if (auto* load = dyncast<LoadInst>(inst)) {
return loopmem::IsSafeInvariantLoadToHoist(loop, load, iv, iv_stride, accesses, &escapes);
}
return true;
}
bool HoistLoopInvariants(Function& function, const Loop& loop,
BasicBlock* preheader) {
if (!preheader) {
return false;
}
loopmem::SimpleInductionVar induction_var;
PhiInst* iv = nullptr;
int iv_stride = 1;
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
if (loopmem::MatchSimpleInductionVariable(loop, preheader, phi, induction_var)) {
iv = induction_var.phi;
iv_stride = induction_var.stride;
break;
}
}
const auto escapes = memutils::AnalyzeEscapes(function);
const auto accesses = loopmem::CollectMemoryAccesses(loop, iv, &escapes);
std::unordered_set<Instruction*> invariant_insts;
std::vector<Instruction*> hoist_list;
bool progress = true;
while (progress) {
progress = false;
for (const auto& block_ptr : function.GetBlocks()) {
auto* block = block_ptr.get();
if (!loop.Contains(block) || block == preheader) {
continue;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (invariant_insts.find(inst) != invariant_insts.end()) {
continue;
}
if (!IsLoopInvariantInstruction(loop, inst, invariant_insts, iv, iv_stride,
accesses, escapes)) {
continue;
}
invariant_insts.insert(inst);
hoist_list.push_back(inst);
progress = true;
}
}
}
bool changed = false;
std::unordered_map<HoistedLoadKey, LoadInst*, HoistedLoadKeyHash> hoisted_loads;
for (auto* inst : hoist_list) {
if (auto* load = dyncast<LoadInst>(inst)) {
auto ptr = loopmem::AnalyzePointer(load->GetPtr(), iv, loop,
load->GetType()->GetSize(), &escapes);
if (ptr.exact_key_valid) {
HoistedLoadKey key{ptr.exact_key,
reinterpret_cast<std::uintptr_t>(load->GetType().get())};
auto it = hoisted_loads.find(key);
if (it != hoisted_loads.end()) {
load->ReplaceAllUsesWith(it->second);
load->GetParent()->EraseInstruction(load);
changed = true;
continue;
}
auto* moved = dyncast<LoadInst>(
looputils::MoveInstructionBeforeTerminator(load, preheader));
if (moved) {
hoisted_loads.emplace(std::move(key), moved);
changed = true;
}
continue;
}
}
if (looputils::MoveInstructionBeforeTerminator(inst, preheader)) {
changed = true;
}
}
return changed;
}
bool RunLICMOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
auto* old_preheader = loop->preheader;
auto* preheader = looputils::EnsurePreheader(function, *loop);
bool loop_changed = preheader != old_preheader;
loop_changed |= HoistLoopInvariants(function, *loop, preheader);
if (!loop_changed) {
continue;
}
changed = true;
local_changed = true;
break;
}
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLICM(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLICMOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,323 @@
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "MemoryUtils.h"
#include "PassUtils.h"
#include <unordered_map>
#include <utility>
#include <vector>
namespace ir {
namespace {
struct AvailableValue {
Value* value = nullptr;
bool operator==(const AvailableValue& rhs) const {
return passutils::AreEquivalentValues(value, rhs.value) || value == rhs.value;
}
};
using MemoryState =
std::unordered_map<memutils::AddressKey, AvailableValue,
memutils::AddressKeyHash>;
bool SameAvailableValue(const AvailableValue& lhs, const AvailableValue& rhs) {
return lhs == rhs;
}
bool SameMemoryState(const MemoryState& lhs, const MemoryState& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (const auto& [key, value] : lhs) {
auto it = rhs.find(key);
if (it == rhs.end() || !SameAvailableValue(value, it->second)) {
return false;
}
}
return true;
}
MemoryState MeetMemoryStates(const std::vector<MemoryState*>& predecessors) {
if (predecessors.empty()) {
return {};
}
MemoryState in = *predecessors.front();
for (auto it = in.begin(); it != in.end();) {
bool keep = true;
for (std::size_t i = 1; i < predecessors.size(); ++i) {
auto pred_it = predecessors[i]->find(it->first);
if (pred_it == predecessors[i]->end() ||
!SameAvailableValue(it->second, pred_it->second)) {
keep = false;
break;
}
}
if (!keep) {
it = in.erase(it);
continue;
}
++it;
}
return in;
}
void InvalidateAliasStates(MemoryState& state,
const memutils::AddressKey& key) {
for (auto it = state.begin(); it != state.end();) {
if (memutils::MayAliasConservatively(it->first, key)) {
it = state.erase(it);
continue;
}
++it;
}
}
void InvalidateStatesForCall(MemoryState& state, Function* callee) {
for (auto it = state.begin(); it != state.end();) {
if (memutils::CallMayWriteRoot(callee, it->first.kind)) {
it = state.erase(it);
continue;
}
++it;
}
}
void SimulateInstruction(const memutils::EscapeSummary& escapes, Instruction* inst,
MemoryState& state) {
if (!inst) {
return;
}
if (auto* load = dyncast<LoadInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(load->GetPtr(), &escapes, key)) {
state.clear();
return;
}
if (state.find(key) == state.end()) {
state[key] = {load};
}
return;
}
if (auto* store = dyncast<StoreInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) {
state.clear();
return;
}
InvalidateAliasStates(state, key);
state[key] = {store->GetValue()};
return;
}
if (auto* call = dyncast<CallInst>(inst)) {
InvalidateStatesForCall(state, call->GetCallee());
return;
}
if (auto* memset = dyncast<MemsetInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key)) {
state.clear();
return;
}
InvalidateAliasStates(state, key);
return;
}
}
MemoryState SimulateBlock(const memutils::EscapeSummary& escapes, BasicBlock* block,
const MemoryState& in_state) {
MemoryState state = in_state;
for (const auto& inst_ptr : block->GetInstructions()) {
SimulateInstruction(escapes, inst_ptr.get(), state);
}
return state;
}
bool MarkLoadObserved(
const memutils::AddressKey& key,
std::unordered_map<memutils::AddressKey, StoreInst*, memutils::AddressKeyHash>&
pending_stores) {
bool changed = false;
for (auto it = pending_stores.begin(); it != pending_stores.end();) {
if (memutils::MayAliasConservatively(it->first, key)) {
it = pending_stores.erase(it);
changed = true;
continue;
}
++it;
}
return changed;
}
void InvalidatePendingForCall(
std::unordered_map<memutils::AddressKey, StoreInst*, memutils::AddressKeyHash>&
pending_stores,
Function* callee) {
for (auto it = pending_stores.begin(); it != pending_stores.end();) {
if (memutils::CallMayReadRoot(callee, it->first.kind) ||
memutils::CallMayWriteRoot(callee, it->first.kind)) {
it = pending_stores.erase(it);
continue;
}
++it;
}
}
bool OptimizeBlock(
const memutils::EscapeSummary& escapes, BasicBlock* block,
const MemoryState& in_state) {
bool changed = false;
MemoryState state = in_state;
std::unordered_map<memutils::AddressKey, StoreInst*, memutils::AddressKeyHash>
pending_stores;
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* load = dyncast<LoadInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(load->GetPtr(), &escapes, key)) {
state.clear();
pending_stores.clear();
continue;
}
MarkLoadObserved(key, pending_stores);
auto it = state.find(key);
if (it != state.end() && it->second.value != load) {
load->ReplaceAllUsesWith(it->second.value);
to_remove.push_back(load);
changed = true;
continue;
}
if (state.find(key) == state.end()) {
state[key] = {load};
}
continue;
}
if (auto* store = dyncast<StoreInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) {
state.clear();
pending_stores.clear();
continue;
}
for (auto it = pending_stores.begin(); it != pending_stores.end();) {
if (!memutils::MayAliasConservatively(it->first, key)) {
++it;
continue;
}
if (it->first == key) {
to_remove.push_back(it->second);
changed = true;
}
it = pending_stores.erase(it);
}
pending_stores.emplace(key, store);
InvalidateAliasStates(state, key);
state[key] = {store->GetValue()};
continue;
}
if (auto* call = dyncast<CallInst>(inst)) {
InvalidateStatesForCall(state, call->GetCallee());
InvalidatePendingForCall(pending_stores, call->GetCallee());
continue;
}
if (auto* memset = dyncast<MemsetInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key)) {
state.clear();
pending_stores.clear();
continue;
}
InvalidateAliasStates(state, key);
MarkLoadObserved(key, pending_stores);
continue;
}
}
for (auto* inst : to_remove) {
if (inst->GetParent() == block) {
block->EraseInstruction(inst);
}
}
return changed;
}
bool RunLoadStoreElimOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
const auto escapes = memutils::AnalyzeEscapes(function);
const auto reachable_blocks = passutils::CollectReachableBlocks(function);
if (reachable_blocks.empty()) {
return false;
}
std::unordered_map<BasicBlock*, MemoryState> in_states;
std::unordered_map<BasicBlock*, MemoryState> out_states;
bool dataflow_changed = true;
while (dataflow_changed) {
dataflow_changed = false;
for (auto* block : reachable_blocks) {
MemoryState in_state;
if (block != function.GetEntryBlock()) {
std::vector<MemoryState*> predecessors;
for (auto* pred : block->GetPredecessors()) {
auto it = out_states.find(pred);
if (it != out_states.end()) {
predecessors.push_back(&it->second);
}
}
in_state = MeetMemoryStates(predecessors);
}
auto out_state = SimulateBlock(escapes, block, in_state);
auto in_it = in_states.find(block);
if (in_it == in_states.end() || !SameMemoryState(in_it->second, in_state)) {
in_states[block] = in_state;
dataflow_changed = true;
}
auto out_it = out_states.find(block);
if (out_it == out_states.end() || !SameMemoryState(out_it->second, out_state)) {
out_states[block] = std::move(out_state);
dataflow_changed = true;
}
}
}
bool changed = false;
for (auto* block : reachable_blocks) {
changed |= OptimizeBlock(escapes, block, in_states[block]);
}
return changed;
}
} // namespace
bool RunLoadStoreElim(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoadStoreElimOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,326 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopMemoryUtils.h"
#include "LoopPassUtils.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
struct FissionLoopInfo {
Loop* loop = nullptr;
BasicBlock* preheader = nullptr;
BasicBlock* header = nullptr;
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
CondBrInst* branch = nullptr;
BinaryInst* compare = nullptr;
Opcode compare_opcode = Opcode::ICmpLT;
Value* bound = nullptr;
loopmem::SimpleInductionVar induction_var;
PhiInst* iv = nullptr;
BinaryInst* step_inst = nullptr;
};
bool HasSyntheticLoopTag(const std::string& name) {
return name.find("unroll.") != std::string::npos ||
name.find("fission.") != std::string::npos;
}
bool IsAlreadyTransformedLoop(const Loop& loop, BasicBlock* body) {
if (!loop.preheader || !loop.header || !body) {
return true;
}
return HasSyntheticLoopTag(loop.preheader->GetName()) ||
HasSyntheticLoopTag(loop.header->GetName()) ||
HasSyntheticLoopTag(body->GetName());
}
Opcode SwapCompareOpcode(Opcode opcode) {
switch (opcode) {
case Opcode::ICmpLT:
return Opcode::ICmpGT;
case Opcode::ICmpLE:
return Opcode::ICmpGE;
case Opcode::ICmpGT:
return Opcode::ICmpLT;
case Opcode::ICmpGE:
return Opcode::ICmpLE;
default:
return opcode;
}
}
bool MatchFissionLoop(Loop& loop, FissionLoopInfo& info) {
if (!loop.preheader || !loop.header || !loop.IsInnermost()) {
return false;
}
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) {
return false;
}
if (IsAlreadyTransformedLoop(loop, body)) {
return false;
}
std::vector<PhiInst*> phis;
loopmem::SimpleInductionVar induction_var;
bool found_iv = false;
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
phis.push_back(phi);
if (!found_iv &&
loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) {
found_iv = true;
}
}
if (!found_iv || phis.size() != 1) {
return false;
}
auto* branch = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
auto* compare = branch ? dyncast<BinaryInst>(branch->GetCondition()) : nullptr;
if (!branch || branch->GetThenBlock() != body || !compare) {
return false;
}
Opcode compare_opcode = compare->GetOpcode();
Value* bound = nullptr;
if (compare->GetLhs() == induction_var.phi &&
looputils::IsLoopInvariantValue(loop, compare->GetRhs())) {
bound = compare->GetRhs();
} else if (compare->GetRhs() == induction_var.phi &&
looputils::IsLoopInvariantValue(loop, compare->GetLhs())) {
bound = compare->GetLhs();
compare_opcode = SwapCompareOpcode(compare_opcode);
} else {
return false;
}
auto* step_inst = dyncast<BinaryInst>(induction_var.latch_value);
if (!step_inst || step_inst->GetParent() != body) {
return false;
}
for (const auto& inst_ptr : body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator() || inst == step_inst) {
continue;
}
if (!looputils::IsCloneableInstruction(inst) || dyncast<CallInst>(inst) ||
dyncast<MemsetInst>(inst) || dyncast<AllocaInst>(inst)) {
return false;
}
}
info.loop = &loop;
info.preheader = loop.preheader;
info.header = loop.header;
info.body = body;
info.exit = exit;
info.branch = branch;
info.compare = compare;
info.compare_opcode = compare_opcode;
info.bound = bound;
info.induction_var = induction_var;
info.iv = induction_var.phi;
info.step_inst = step_inst;
return true;
}
bool ContainsInterestingPayload(const std::vector<Instruction*>& group) {
bool has_memory = false;
for (auto* inst : group) {
if (dyncast<LoadInst>(inst) || dyncast<StoreInst>(inst)) {
has_memory = true;
}
}
return has_memory;
}
Value* RemapExitValue(Value* value, PhiInst* old_iv, PhiInst* new_iv) {
if (value == old_iv) {
return new_iv;
}
return value;
}
bool BuildSecondLoop(Function& function, const FissionLoopInfo& info,
const std::vector<Instruction*>& second_group) {
auto* second_header =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.header"));
auto* second_body =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.body"));
const int preheader_index = looputils::GetPhiIncomingIndex(info.iv, info.preheader);
if (preheader_index < 0) {
return false;
}
auto* second_iv = second_header->Append<PhiInst>(
info.iv->GetType(), nullptr,
looputils::NextSyntheticName(function, "fission.iv."));
second_iv->AddIncoming(info.iv->GetIncomingValue(preheader_index), info.header);
auto* second_cmp = second_header->Append<BinaryInst>(
info.compare_opcode, Type::GetBoolType(), second_iv, info.bound, nullptr,
looputils::NextSyntheticName(function, "fission.cmp."));
second_header->Append<CondBrInst>(second_cmp, second_body, info.exit, nullptr);
second_header->AddPredecessor(info.header);
second_header->AddSuccessor(second_body);
second_header->AddSuccessor(info.exit);
std::unordered_map<Value*, Value*> remap;
remap[info.iv] = second_iv;
std::unordered_set<Instruction*> selected(second_group.begin(), second_group.end());
selected.insert(info.step_inst);
for (const auto& inst_ptr : info.body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator() || selected.find(inst) == selected.end()) {
continue;
}
looputils::CloneInstruction(function, inst, second_body, remap, "fission.");
}
auto* cloned_step_value = looputils::RemapValue(remap, info.step_inst);
if (!cloned_step_value) {
return false;
}
second_iv->AddIncoming(cloned_step_value, second_body);
second_body->Append<UncondBrInst>(second_header, nullptr);
second_body->AddPredecessor(second_header);
second_body->AddSuccessor(second_header);
second_header->AddPredecessor(second_body);
if (!looputils::RedirectSuccessorEdge(info.header, info.exit, second_header)) {
return false;
}
info.exit->RemovePredecessor(info.header);
info.exit->AddPredecessor(second_header);
for (const auto& inst_ptr : info.exit->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
const int incoming = looputils::GetPhiIncomingIndex(phi, info.header);
if (incoming < 0) {
continue;
}
phi->SetOperand(static_cast<std::size_t>(2 * incoming),
RemapExitValue(phi->GetIncomingValue(incoming), info.iv, second_iv));
phi->SetOperand(static_cast<std::size_t>(2 * incoming + 1), second_header);
}
return true;
}
bool RunLoopFissionOnFunction(Function& function) {
if (function.IsExternal() || !function.GetEntryBlock()) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
FissionLoopInfo info;
if (!MatchFissionLoop(*loop, info)) {
continue;
}
const auto accesses = loopmem::CollectMemoryAccesses(*loop, info.iv);
std::vector<Instruction*> payload;
for (const auto& inst_ptr : info.body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator() || inst == info.step_inst) {
continue;
}
payload.push_back(inst);
}
if (payload.size() < 2) {
continue;
}
int chosen_cut = -1;
std::vector<Instruction*> first_group;
std::vector<Instruction*> second_group;
for (std::size_t cut = 1; cut < payload.size(); ++cut) {
std::vector<Instruction*> first(payload.begin(), payload.begin() + static_cast<long long>(cut));
std::vector<Instruction*> second(payload.begin() + static_cast<long long>(cut),
payload.end());
if (!ContainsInterestingPayload(first) || !ContainsInterestingPayload(second)) {
continue;
}
std::unordered_set<Instruction*> first_set(first.begin(), first.end());
std::unordered_set<Instruction*> second_set(second.begin(), second.end());
if (loopmem::HasScalarDependenceAcrossCut(first, second_set) ||
loopmem::HasMemoryDependenceAcrossCut(accesses, first_set, second_set,
info.induction_var.stride)) {
continue;
}
chosen_cut = static_cast<int>(cut);
first_group = std::move(first);
second_group = std::move(second);
break;
}
if (chosen_cut < 0) {
continue;
}
std::unordered_set<Instruction*> keep(first_group.begin(), first_group.end());
keep.insert(info.step_inst);
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : info.body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator() || keep.find(inst) != keep.end()) {
continue;
}
to_remove.push_back(inst);
}
if (!BuildSecondLoop(function, info, second_group)) {
continue;
}
for (auto* inst : to_remove) {
info.body->EraseInstruction(inst);
}
changed = true;
local_changed = true;
break;
}
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLoopFission(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoopFissionOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,855 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopMemoryUtils.h"
#include "LoopPassUtils.h"
#include "MemoryUtils.h"
#include <algorithm>
#include <cstdint>
#include <memory>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace ir {
namespace {
struct DominatorInfo {
std::vector<BasicBlock*> blocks;
std::unordered_map<BasicBlock*, size_t> index;
std::vector<std::vector<bool>> dominates;
std::vector<BasicBlock*> idom;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dom_tree_children;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dominance_frontier;
};
enum class SeedStateKind { Unavailable, Available, Conflict };
struct SeedState {
SeedStateKind kind = SeedStateKind::Unavailable;
StoreInst* store = nullptr;
bool operator==(const SeedState& rhs) const {
return kind == rhs.kind && store == rhs.store;
}
bool operator!=(const SeedState& rhs) const { return !(*this == rhs); }
};
struct CandidateKey {
memutils::AddressKey address;
std::uintptr_t type_id = 0;
bool operator==(const CandidateKey& rhs) const {
return type_id == rhs.type_id && address == rhs.address;
}
};
struct CandidateKeyHash {
std::size_t operator()(const CandidateKey& key) const {
std::size_t h = memutils::AddressKeyHash{}(key.address);
h ^= std::hash<std::uintptr_t>{}(key.type_id) + 0x9e3779b9 + (h << 6) + (h >> 2);
return h;
}
};
struct PromotionCandidate {
CandidateKey key;
std::shared_ptr<Type> value_type;
loopmem::PointerInfo pointer_info;
std::vector<LoadInst*> loads;
std::vector<StoreInst*> stores;
StoreInst* seed_store = nullptr;
Value* canonical_ptr = nullptr;
Value* initial_value = nullptr;
std::unordered_set<BasicBlock*> def_blocks;
std::unordered_map<BasicBlock*, PhiInst*> phis;
int EstimatedBenefit() const {
return static_cast<int>(loads.size()) + 2 * static_cast<int>(stores.size()) - 1;
}
};
bool IsScalarPromotableType(const std::shared_ptr<Type>& type) {
return type && (type->IsInt1() || type->IsInt32() || type->IsFloat());
}
int CountFunctionInstructions(const Function& function) {
int count = 0;
for (const auto& block_ptr : function.GetBlocks()) {
if (!block_ptr) {
continue;
}
count += static_cast<int>(block_ptr->GetInstructions().size());
}
return count;
}
int CountLoopInstructions(const Loop& loop) {
int count = 0;
for (auto* block : loop.block_list) {
if (!block) {
continue;
}
count += static_cast<int>(block->GetInstructions().size());
}
return count;
}
bool ShouldAnalyzeFunction(const Function& function) {
constexpr int kMaxFunctionInstructions = 2000;
return CountFunctionInstructions(function) <= kMaxFunctionInstructions;
}
bool ShouldAnalyzeLoop(const Loop& loop) {
constexpr int kMaxLoopBlocks = 8;
constexpr int kMaxLoopInstructions = 96;
return static_cast<int>(loop.block_list.size()) <= kMaxLoopBlocks &&
CountLoopInstructions(loop) <= kMaxLoopInstructions;
}
bool DominatesBlock(const DominatorInfo& info, BasicBlock* dom, BasicBlock* block) {
if (!dom || !block) {
return false;
}
auto dom_it = info.index.find(dom);
auto block_it = info.index.find(block);
if (dom_it == info.index.end() || block_it == info.index.end()) {
return false;
}
return info.dominates[block_it->second][dom_it->second];
}
std::vector<BasicBlock*> CollectReachableBlocks(Function& function) {
std::vector<BasicBlock*> order;
auto* entry = function.GetEntryBlock();
if (!entry) {
return order;
}
std::unordered_set<BasicBlock*> visited;
std::vector<BasicBlock*> stack{entry};
while (!stack.empty()) {
auto* block = stack.back();
stack.pop_back();
if (!block || !visited.insert(block).second) {
continue;
}
order.push_back(block);
const auto& succs = block->GetSuccessors();
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
if (*it) {
stack.push_back(*it);
}
}
}
return order;
}
std::vector<bool> IntersectDominators(const std::vector<std::vector<bool>>& doms,
const std::vector<size_t>& pred_indices,
size_t self_index) {
std::vector<bool> result(doms.size(), true);
if (pred_indices.empty()) {
std::fill(result.begin(), result.end(), false);
result[self_index] = true;
return result;
}
result = doms[pred_indices.front()];
for (size_t i = 1; i < pred_indices.size(); ++i) {
const auto& pred_dom = doms[pred_indices[i]];
for (size_t j = 0; j < result.size(); ++j) {
result[j] = result[j] && pred_dom[j];
}
}
result[self_index] = true;
return result;
}
DominatorInfo BuildDominatorInfo(Function& function) {
DominatorInfo info;
info.blocks = CollectReachableBlocks(function);
info.idom.resize(info.blocks.size(), nullptr);
info.dominates.assign(info.blocks.size(),
std::vector<bool>(info.blocks.size(), true));
if (info.blocks.empty()) {
return info;
}
for (size_t i = 0; i < info.blocks.size(); ++i) {
info.index[info.blocks[i]] = i;
}
for (size_t i = 0; i < info.blocks.size(); ++i) {
std::fill(info.dominates[i].begin(), info.dominates[i].end(), i != 0);
info.dominates[i][i] = true;
}
bool changed = true;
while (changed) {
changed = false;
for (size_t i = 1; i < info.blocks.size(); ++i) {
std::vector<size_t> pred_indices;
for (auto* pred : info.blocks[i]->GetPredecessors()) {
auto it = info.index.find(pred);
if (it != info.index.end()) {
pred_indices.push_back(it->second);
}
}
auto new_dom = IntersectDominators(info.dominates, pred_indices, i);
if (new_dom != info.dominates[i]) {
info.dominates[i] = std::move(new_dom);
changed = true;
}
}
}
for (size_t i = 1; i < info.blocks.size(); ++i) {
BasicBlock* candidate_idom = nullptr;
for (size_t j = 0; j < info.blocks.size(); ++j) {
if (i == j || !info.dominates[i][j]) {
continue;
}
bool is_immediate = true;
for (size_t k = 0; k < info.blocks.size(); ++k) {
if (k == i || k == j || !info.dominates[i][k]) {
continue;
}
if (info.dominates[k][j]) {
is_immediate = false;
break;
}
}
if (is_immediate) {
candidate_idom = info.blocks[j];
break;
}
}
info.idom[i] = candidate_idom;
if (candidate_idom) {
info.dom_tree_children[candidate_idom].push_back(info.blocks[i]);
}
}
for (auto* block : info.blocks) {
info.dominance_frontier[block] = {};
}
for (auto* block : info.blocks) {
std::vector<BasicBlock*> reachable_preds;
for (auto* pred : block->GetPredecessors()) {
if (info.index.find(pred) != info.index.end()) {
reachable_preds.push_back(pred);
}
}
if (reachable_preds.size() < 2) {
continue;
}
auto* idom_block = info.idom[info.index[block]];
for (auto* pred : reachable_preds) {
auto* runner = pred;
while (runner && runner != idom_block) {
auto& frontier = info.dominance_frontier[runner];
if (std::find(frontier.begin(), frontier.end(), block) == frontier.end()) {
frontier.push_back(block);
}
auto idom_it = info.index.find(runner);
if (idom_it == info.index.end()) {
break;
}
runner = info.idom[idom_it->second];
}
}
}
return info;
}
SeedState MergeSeedState(const SeedState& lhs, const SeedState& rhs) {
if (lhs == rhs) {
return lhs;
}
return {SeedStateKind::Conflict, nullptr};
}
SeedState TransferSeedState(const SeedState& in, BasicBlock* block,
const PromotionCandidate& candidate,
const memutils::EscapeSummary& escapes) {
SeedState state = in;
if (!block) {
return state;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* call = dyncast<CallInst>(inst)) {
if (memutils::CallMayWriteRoot(call->GetCallee(), candidate.pointer_info.root_kind)) {
state = {SeedStateKind::Unavailable, nullptr};
}
continue;
}
if (auto* memset = dyncast<MemsetInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key) ||
memutils::MayAliasConservatively(key, candidate.key.address)) {
state = {SeedStateKind::Unavailable, nullptr};
}
continue;
}
auto* store = dyncast<StoreInst>(inst);
if (!store) {
continue;
}
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) {
state = {SeedStateKind::Unavailable, nullptr};
continue;
}
if (!memutils::MayAliasConservatively(key, candidate.key.address)) {
continue;
}
if (key == candidate.key.address && store->GetValue()->GetType() == candidate.value_type) {
state = {SeedStateKind::Available, store};
} else {
state = {SeedStateKind::Unavailable, nullptr};
}
}
return state;
}
StoreInst* FindSeedStoreInPreheader(const Loop& loop,
const PromotionCandidate& candidate,
const memutils::EscapeSummary& escapes) {
auto* preheader = loop.preheader;
if (!preheader) {
return nullptr;
}
StoreInst* seed = nullptr;
for (const auto& inst_ptr : preheader->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* call = dyncast<CallInst>(inst)) {
if (memutils::CallMayWriteRoot(call->GetCallee(), candidate.pointer_info.root_kind)) {
seed = nullptr;
}
continue;
}
if (auto* memset = dyncast<MemsetInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key) ||
memutils::MayAliasConservatively(key, candidate.key.address)) {
seed = nullptr;
}
continue;
}
auto* store = dyncast<StoreInst>(inst);
if (!store) {
continue;
}
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) {
seed = nullptr;
continue;
}
if (!memutils::MayAliasConservatively(key, candidate.key.address)) {
continue;
}
if (key == candidate.key.address && store->GetValue()->GetType() == candidate.value_type) {
seed = store;
} else {
seed = nullptr;
}
}
return seed;
}
StoreInst* FindReachingSeedStoreAtLoopEntry(Function& function, const Loop& loop,
const PromotionCandidate& candidate,
const memutils::EscapeSummary& escapes) {
auto* preheader = loop.preheader;
if (!preheader) {
return nullptr;
}
const auto blocks = CollectReachableBlocks(function);
std::unordered_map<BasicBlock*, SeedState> in_state;
std::unordered_map<BasicBlock*, SeedState> out_state;
for (auto* block : blocks) {
in_state[block] = {SeedStateKind::Unavailable, nullptr};
out_state[block] = {SeedStateKind::Unavailable, nullptr};
}
bool changed = true;
while (changed) {
changed = false;
for (auto* block : blocks) {
SeedState merged{SeedStateKind::Unavailable, nullptr};
bool first_pred = true;
for (auto* pred : block->GetPredecessors()) {
auto it = out_state.find(pred);
if (it == out_state.end()) {
continue;
}
if (first_pred) {
merged = it->second;
first_pred = false;
} else {
merged = MergeSeedState(merged, it->second);
}
}
if (block == function.GetEntryBlock() && first_pred) {
merged = {SeedStateKind::Unavailable, nullptr};
}
SeedState next_out = TransferSeedState(merged, block, candidate, escapes);
if (merged != in_state[block] || next_out != out_state[block]) {
in_state[block] = merged;
out_state[block] = next_out;
changed = true;
}
}
}
const auto it = out_state.find(preheader);
if (it == out_state.end() || it->second.kind != SeedStateKind::Available) {
return nullptr;
}
return it->second.store;
}
bool ExitBlocksArePromotable(const Loop& loop) {
for (auto* exit : loop.exit_blocks) {
if (!exit) {
return false;
}
for (auto* pred : exit->GetPredecessors()) {
if (!loop.Contains(pred)) {
return false;
}
}
}
return !loop.exit_blocks.empty();
}
bool IsSafeToPromoteCandidate(const Loop& loop, const PromotionCandidate& candidate,
const std::vector<loopmem::MemoryAccessInfo>& accesses,
int iv_stride, const DominatorInfo& dom_info) {
if (!candidate.seed_store || !candidate.canonical_ptr || !candidate.initial_value) {
return false;
}
if (loop.parent != nullptr && candidate.seed_store->GetParent() != loop.preheader) {
return false;
}
if (!DominatesBlock(dom_info, candidate.seed_store->GetParent(), loop.preheader)) {
return false;
}
auto* ptr_inst = dyncast<Instruction>(candidate.canonical_ptr);
if (ptr_inst &&
(loop.Contains(ptr_inst->GetParent()) ||
!DominatesBlock(dom_info, ptr_inst->GetParent(), loop.preheader))) {
return false;
}
if (!ExitBlocksArePromotable(loop)) {
return false;
}
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* call = dyncast<CallInst>(inst_ptr.get());
if (!call) {
continue;
}
if (memutils::CallMayReadRoot(call->GetCallee(), candidate.pointer_info.root_kind) ||
memutils::CallMayWriteRoot(call->GetCallee(), candidate.pointer_info.root_kind)) {
return false;
}
}
}
for (const auto& access : accesses) {
if (!loopmem::MayAliasSameIteration(candidate.pointer_info, access.ptr) &&
!loopmem::HasCrossIterationDependence(candidate.pointer_info, access.ptr, iv_stride)) {
continue;
}
if (!access.ptr.exact_key_valid || !(access.ptr.exact_key == candidate.key.address)) {
return false;
}
if (isa<MemsetInst>(access.inst)) {
return false;
}
if (auto* load = dyncast<LoadInst>(access.inst)) {
if (load->GetType() != candidate.value_type) {
return false;
}
continue;
}
if (auto* store = dyncast<StoreInst>(access.inst)) {
if (store->GetValue()->GetType() != candidate.value_type) {
return false;
}
continue;
}
return false;
}
return true;
}
void InsertPhiNodes(const Loop& loop, PromotionCandidate& candidate,
const DominatorInfo& dom_info, Function& function) {
std::queue<BasicBlock*> worklist;
std::unordered_set<BasicBlock*> queued;
for (auto* block : candidate.def_blocks) {
worklist.push(block);
queued.insert(block);
}
while (!worklist.empty()) {
auto* block = worklist.front();
worklist.pop();
auto frontier_it = dom_info.dominance_frontier.find(block);
if (frontier_it == dom_info.dominance_frontier.end()) {
continue;
}
for (auto* frontier_block : frontier_it->second) {
if (!loop.Contains(frontier_block)) {
continue;
}
if (candidate.phis.find(frontier_block) != candidate.phis.end()) {
continue;
}
auto* phi = frontier_block->Insert<PhiInst>(
looputils::GetFirstNonPhiIndex(frontier_block), candidate.value_type, nullptr,
looputils::NextSyntheticName(function, "lmp.phi."));
candidate.phis[frontier_block] = phi;
if (candidate.def_blocks.insert(frontier_block).second && queued.insert(frontier_block).second) {
worklist.push(frontier_block);
}
}
}
}
void RenameCandidateInLoop(
BasicBlock* block, const Loop& loop, PromotionCandidate& candidate,
const DominatorInfo& dom_info, std::vector<Value*>& stack,
std::unordered_map<BasicBlock*, Value*>& block_out) {
if (!block || !loop.Contains(block)) {
return;
}
size_t pushed = 0;
PhiInst* block_phi = nullptr;
auto phi_it = candidate.phis.find(block);
if (phi_it != candidate.phis.end()) {
block_phi = phi_it->second;
stack.push_back(block_phi);
++pushed;
}
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst == block_phi) {
continue;
}
if (auto* load = dyncast<LoadInst>(inst)) {
auto it = std::find(candidate.loads.begin(), candidate.loads.end(), load);
if (it == candidate.loads.end()) {
continue;
}
load->ReplaceAllUsesWith(stack.back());
to_remove.push_back(load);
continue;
}
auto* store = dyncast<StoreInst>(inst);
if (!store) {
continue;
}
auto it = std::find(candidate.stores.begin(), candidate.stores.end(), store);
if (it == candidate.stores.end()) {
continue;
}
stack.push_back(store->GetValue());
++pushed;
to_remove.push_back(store);
}
block_out[block] = stack.back();
for (auto* succ : block->GetSuccessors()) {
if (!loop.Contains(succ)) {
continue;
}
auto succ_phi_it = candidate.phis.find(succ);
if (succ_phi_it == candidate.phis.end()) {
continue;
}
succ_phi_it->second->AddIncoming(stack.back(), block);
}
auto child_it = dom_info.dom_tree_children.find(block);
if (child_it != dom_info.dom_tree_children.end()) {
for (auto* child : child_it->second) {
RenameCandidateInLoop(child, loop, candidate, dom_info, stack, block_out);
}
}
for (auto* inst : to_remove) {
if (inst->GetParent() == block) {
block->EraseInstruction(inst);
}
}
while (pushed > 0) {
stack.pop_back();
--pushed;
}
}
void InsertExitStores(Function& function, const Loop& loop, PromotionCandidate& candidate,
const std::unordered_map<BasicBlock*, Value*>& block_out) {
std::unordered_set<BasicBlock*> seen;
for (auto* exit : loop.exit_blocks) {
if (!exit || !seen.insert(exit).second) {
continue;
}
std::vector<BasicBlock*> preds;
preds.reserve(exit->GetPredecessors().size());
for (auto* pred : exit->GetPredecessors()) {
if (loop.Contains(pred)) {
preds.push_back(pred);
}
}
if (preds.empty()) {
continue;
}
Value* final_value = nullptr;
auto insert_index = looputils::GetFirstNonPhiIndex(exit);
if (preds.size() == 1) {
auto it = block_out.find(preds.front());
if (it == block_out.end()) {
continue;
}
final_value = it->second;
} else {
auto* phi = exit->Insert<PhiInst>(insert_index, candidate.value_type, nullptr,
looputils::NextSyntheticName(function, "lmp.exit."));
++insert_index;
for (auto* pred : preds) {
auto it = block_out.find(pred);
if (it != block_out.end()) {
phi->AddIncoming(it->second, pred);
}
}
final_value = phi;
}
exit->Insert<StoreInst>(insert_index, final_value, candidate.canonical_ptr, nullptr);
}
}
bool PromoteCandidate(Function& function, const Loop& loop, PromotionCandidate& candidate,
const DominatorInfo& dom_info) {
if (!candidate.seed_store || !candidate.initial_value) {
return false;
}
InsertPhiNodes(loop, candidate, dom_info, function);
auto header_phi_it = candidate.phis.find(loop.header);
if (header_phi_it != candidate.phis.end()) {
header_phi_it->second->AddIncoming(candidate.initial_value, loop.preheader);
}
std::vector<Value*> stack{candidate.initial_value};
std::unordered_map<BasicBlock*, Value*> block_out;
RenameCandidateInLoop(loop.header, loop, candidate, dom_info, stack, block_out);
InsertExitStores(function, loop, candidate, block_out);
return true;
}
std::vector<PromotionCandidate> CollectCandidates(
const Loop& loop, const std::vector<loopmem::MemoryAccessInfo>& accesses,
const memutils::EscapeSummary& escapes, int iv_stride, Function& function,
const DominatorInfo& dom_info) {
constexpr std::size_t kMaxLoopAccesses = 64;
if (accesses.size() > kMaxLoopAccesses) {
return {};
}
std::unordered_map<CandidateKey, PromotionCandidate, CandidateKeyHash> groups;
for (const auto& access : accesses) {
if (!access.ptr.exact_key_valid || !access.ptr.invariant_address) {
continue;
}
if (!access.is_read && !access.is_write) {
continue;
}
std::shared_ptr<Type> value_type;
if (auto* load = dyncast<LoadInst>(access.inst)) {
value_type = load->GetType();
} else if (auto* store = dyncast<StoreInst>(access.inst)) {
value_type = store->GetValue()->GetType();
} else {
continue;
}
if (!IsScalarPromotableType(value_type)) {
continue;
}
CandidateKey key{access.ptr.exact_key,
reinterpret_cast<std::uintptr_t>(value_type.get())};
auto& candidate = groups[key];
candidate.key = key;
candidate.value_type = value_type;
candidate.pointer_info = access.ptr;
if (auto* load = dyncast<LoadInst>(access.inst)) {
candidate.loads.push_back(load);
} else if (auto* store = dyncast<StoreInst>(access.inst)) {
candidate.stores.push_back(store);
candidate.def_blocks.insert(store->GetParent());
}
}
std::vector<PromotionCandidate> candidates;
candidates.reserve(groups.size());
for (auto& [key, candidate] : groups) {
if (candidate.stores.empty()) {
continue;
}
candidate.seed_store = FindSeedStoreInPreheader(loop, candidate, escapes);
if (!candidate.seed_store && loop.parent == nullptr) {
candidate.seed_store =
FindReachingSeedStoreAtLoopEntry(function, loop, candidate, escapes);
}
if (!candidate.seed_store) {
continue;
}
candidate.initial_value = candidate.seed_store->GetValue();
candidate.canonical_ptr = candidate.seed_store->GetPtr();
if (!IsSafeToPromoteCandidate(loop, candidate, accesses, iv_stride, dom_info)) {
continue;
}
candidates.push_back(std::move(candidate));
}
std::sort(candidates.begin(), candidates.end(),
[](const PromotionCandidate& lhs, const PromotionCandidate& rhs) {
return lhs.EstimatedBenefit() > rhs.EstimatedBenefit();
});
return candidates;
}
bool PromoteLoopMemory(Function& function, const Loop& loop,
const DominatorInfo& dom_info) {
if (!loop.preheader || !loop.header || !loop.IsInnermost() ||
!ShouldAnalyzeLoop(loop)) {
return false;
}
loopmem::SimpleInductionVar induction_var;
PhiInst* iv = nullptr;
int iv_stride = 1;
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
if (loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) {
iv = induction_var.phi;
iv_stride = induction_var.stride;
break;
}
}
bool changed = false;
while (true) {
const auto escapes = memutils::AnalyzeEscapes(function);
const auto accesses = loopmem::CollectMemoryAccesses(loop, iv, &escapes);
auto candidates = CollectCandidates(loop, accesses, escapes, iv_stride, function, dom_info);
if (candidates.empty()) {
break;
}
if (!PromoteCandidate(function, loop, candidates.front(), dom_info)) {
break;
}
changed = true;
}
return changed;
}
bool RunLoopMemoryPromotionOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
if (!ShouldAnalyzeFunction(function)) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool cfg_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
auto* old_preheader = loop->preheader;
auto* preheader = looputils::EnsurePreheader(function, *loop);
if (preheader != old_preheader) {
changed = true;
cfg_changed = true;
break;
}
}
if (cfg_changed) {
continue;
}
auto dom_info = BuildDominatorInfo(function);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
local_changed |= PromoteLoopMemory(function, *loop, dom_info);
}
changed |= local_changed;
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLoopMemoryPromotion(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoopMemoryPromotionOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,506 @@
#pragma once
#include "LoopPassUtils.h"
#include "MemoryUtils.h"
#include <algorithm>
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <vector>
namespace ir::loopmem {
struct SimpleInductionVar {
PhiInst* phi = nullptr;
Value* start = nullptr;
Value* latch_value = nullptr;
BasicBlock* latch = nullptr;
int stride = 0;
};
inline bool MatchSimpleInductionVariable(const Loop& loop, BasicBlock* preheader,
PhiInst* phi, SimpleInductionVar& info) {
if (!phi || !preheader || phi->GetParent() != loop.header ||
!phi->GetType()->IsInt32() || phi->GetNumIncomings() != 2 ||
loop.latches.size() != 1) {
return false;
}
auto* latch = loop.latches.front();
const int preheader_index = looputils::GetPhiIncomingIndex(phi, preheader);
const int latch_index = looputils::GetPhiIncomingIndex(phi, latch);
if (preheader_index < 0 || latch_index < 0) {
return false;
}
auto* step_inst = dyncast<BinaryInst>(phi->GetIncomingValue(latch_index));
if (!step_inst || step_inst->GetParent() != latch) {
return false;
}
int stride = 0;
if (step_inst->GetOpcode() == Opcode::Add) {
if (step_inst->GetLhs() == phi) {
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
if (!delta) {
return false;
}
stride = delta->GetValue();
} else if (step_inst->GetRhs() == phi) {
auto* delta = dyncast<ConstantInt>(step_inst->GetLhs());
if (!delta) {
return false;
}
stride = delta->GetValue();
} else {
return false;
}
} else if (step_inst->GetOpcode() == Opcode::Sub) {
if (step_inst->GetLhs() != phi) {
return false;
}
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
if (!delta) {
return false;
}
stride = -delta->GetValue();
} else {
return false;
}
if (stride == 0) {
return false;
}
info.phi = phi;
info.start = phi->GetIncomingValue(preheader_index);
info.latch_value = phi->GetIncomingValue(latch_index);
info.latch = latch;
info.stride = stride;
return true;
}
inline bool GetCanonicalLoopBlocks(const Loop& loop, BasicBlock*& body,
BasicBlock*& exit) {
body = nullptr;
exit = nullptr;
if (!loop.header || loop.latches.size() != 1 || loop.block_list.size() != 2) {
return false;
}
auto* condbr = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
if (!condbr) {
return false;
}
auto* then_block = condbr->GetThenBlock();
auto* else_block = condbr->GetElseBlock();
const bool then_in_loop = loop.Contains(then_block);
const bool else_in_loop = loop.Contains(else_block);
if (then_in_loop == else_in_loop) {
return false;
}
body = then_in_loop ? then_block : else_block;
exit = then_in_loop ? else_block : then_block;
if (!body || !exit || body != loop.latches.front() ||
body->GetSuccessors().size() != 1 || body->GetSuccessors().front() != loop.header) {
return false;
}
return true;
}
struct AffineExpr {
bool valid = false;
Value* var = nullptr;
std::int64_t coeff = 0;
std::int64_t constant = 0;
};
inline AffineExpr MakeConst(std::int64_t value) {
return {true, nullptr, 0, value};
}
inline AffineExpr Scale(const AffineExpr& expr, std::int64_t factor) {
if (!expr.valid) {
return {};
}
return {true, expr.var, expr.coeff * factor, expr.constant * factor};
}
inline AffineExpr Combine(const AffineExpr& lhs, const AffineExpr& rhs, int sign) {
if (!lhs.valid || !rhs.valid) {
return {};
}
if (lhs.var != nullptr && rhs.var != nullptr && lhs.var != rhs.var) {
return {};
}
AffineExpr out;
out.valid = true;
out.var = lhs.var ? lhs.var : rhs.var;
out.coeff = lhs.coeff + sign * rhs.coeff;
out.constant = lhs.constant + sign * rhs.constant;
return out;
}
inline AffineExpr AnalyzeAffine(Value* value, PhiInst* iv, const Loop& loop) {
if (!value) {
return {};
}
if (auto* ci = dyncast<ConstantInt>(value)) {
return MakeConst(ci->GetValue());
}
if (value == iv) {
return {true, iv, 1, 0};
}
if (looputils::IsLoopInvariantValue(loop, value)) {
return {};
}
if (auto* zext = dyncast<ZextInst>(value)) {
return AnalyzeAffine(zext->GetValue(), iv, loop);
}
auto* inst = dyncast<Instruction>(value);
if (!inst) {
return {};
}
switch (inst->GetOpcode()) {
case Opcode::Add:
return Combine(AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetLhs(), iv, loop),
AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetRhs(), iv, loop), +1);
case Opcode::Sub:
return Combine(AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetLhs(), iv, loop),
AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetRhs(), iv, loop), -1);
case Opcode::Mul: {
auto* bin = static_cast<BinaryInst*>(inst);
auto lhs = AnalyzeAffine(bin->GetLhs(), iv, loop);
auto rhs = AnalyzeAffine(bin->GetRhs(), iv, loop);
if (lhs.valid && lhs.var == nullptr && rhs.valid) {
return Scale(rhs, lhs.constant);
}
if (rhs.valid && rhs.var == nullptr && lhs.valid) {
return Scale(lhs, rhs.constant);
}
return {};
}
case Opcode::Neg:
return Scale(AnalyzeAffine(static_cast<UnaryInst*>(inst)->GetOprd(), iv, loop), -1);
default:
return {};
}
}
struct PointerInfo {
Value* base = nullptr;
AffineExpr byte_offset;
bool invariant_address = false;
bool distinct_root = false;
bool argument_root = false;
bool readonly_root = false;
bool exact_key_valid = false;
memutils::PointerRootKind root_kind = memutils::PointerRootKind::Unknown;
memutils::AddressKey exact_key;
int access_size = 0;
};
inline Value* StripPointerBase(Value* pointer) {
auto* value = pointer;
while (auto* gep = dyncast<GetElementPtrInst>(value)) {
value = gep->GetPointer();
}
return value;
}
inline std::shared_ptr<Type> AdvanceGEPType(std::shared_ptr<Type> current) {
if (current && current->IsArray()) {
return current->GetElementType();
}
return current;
}
inline PointerInfo AnalyzePointer(Value* pointer, PhiInst* iv, const Loop& loop,
int access_size,
const memutils::EscapeSummary* escapes = nullptr) {
PointerInfo info;
info.access_size = access_size;
info.base = StripPointerBase(pointer);
info.root_kind = memutils::ClassifyRoot(info.base, escapes);
info.argument_root = info.root_kind == memutils::PointerRootKind::Param;
info.readonly_root = info.root_kind == memutils::PointerRootKind::ReadonlyGlobal;
info.distinct_root = info.root_kind == memutils::PointerRootKind::Local ||
info.root_kind == memutils::PointerRootKind::Global ||
info.root_kind == memutils::PointerRootKind::ReadonlyGlobal;
info.exact_key_valid =
escapes != nullptr && memutils::BuildExactAddressKey(pointer, escapes, info.exact_key);
info.invariant_address = looputils::IsLoopInvariantValue(loop, pointer);
if (!dyncast<GetElementPtrInst>(pointer)) {
info.byte_offset = MakeConst(0);
return info;
}
auto* gep = static_cast<GetElementPtrInst*>(pointer);
std::shared_ptr<Type> current = gep->GetSourceType();
AffineExpr total = MakeConst(0);
bool all_indices_loop_invariant = looputils::IsLoopInvariantValue(loop, gep->GetPointer());
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
auto* index = gep->GetIndex(i);
all_indices_loop_invariant &= looputils::IsLoopInvariantValue(loop, index);
const std::int64_t stride = current ? current->GetSize() : 0;
auto term = AnalyzeAffine(index, iv, loop);
if (!term.valid) {
total = {};
} else if (total.valid) {
total = Combine(total, Scale(term, stride), +1);
}
current = AdvanceGEPType(current);
}
info.invariant_address = all_indices_loop_invariant;
info.byte_offset = total;
return info;
}
struct MemoryAccessInfo {
Instruction* inst = nullptr;
Value* pointer = nullptr;
PointerInfo ptr;
bool is_read = false;
bool is_write = false;
};
inline std::vector<MemoryAccessInfo> CollectMemoryAccesses(const Loop& loop,
PhiInst* iv,
const memutils::EscapeSummary* escapes =
nullptr) {
std::vector<MemoryAccessInfo> accesses;
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* load = dyncast<LoadInst>(inst)) {
accesses.push_back(
{inst, load->GetPtr(),
AnalyzePointer(load->GetPtr(), iv, loop, load->GetType()->GetSize(), escapes),
true,
false});
} else if (auto* store = dyncast<StoreInst>(inst)) {
accesses.push_back({inst, store->GetPtr(),
AnalyzePointer(store->GetPtr(), iv, loop,
store->GetValue()->GetType()->GetSize(), escapes),
false, true});
} else if (auto* memset = dyncast<MemsetInst>(inst)) {
accesses.push_back(
{inst, memset->GetDest(),
AnalyzePointer(memset->GetDest(), iv, loop, 1, escapes), false, true});
}
}
}
return accesses;
}
inline bool SameAffineAddress(const PointerInfo& lhs, const PointerInfo& rhs) {
return lhs.base == rhs.base && lhs.byte_offset.valid && rhs.byte_offset.valid &&
lhs.byte_offset.var == rhs.byte_offset.var &&
lhs.byte_offset.coeff == rhs.byte_offset.coeff &&
lhs.byte_offset.constant == rhs.byte_offset.constant;
}
inline bool MayAliasSameIteration(const PointerInfo& lhs, const PointerInfo& rhs) {
if (lhs.exact_key_valid && rhs.exact_key_valid) {
return memutils::MayAliasConservatively(lhs.exact_key, rhs.exact_key);
}
if (!lhs.base || !rhs.base) {
return true;
}
if (lhs.base != rhs.base) {
if (lhs.distinct_root && rhs.distinct_root && !lhs.argument_root && !rhs.argument_root) {
return false;
}
return true;
}
if (!lhs.byte_offset.valid || !rhs.byte_offset.valid) {
return true;
}
if (lhs.byte_offset.var != rhs.byte_offset.var) {
return true;
}
if (lhs.byte_offset.coeff != rhs.byte_offset.coeff) {
return true;
}
const auto diff = std::llabs(lhs.byte_offset.constant - rhs.byte_offset.constant);
const auto overlap = std::min(lhs.access_size, rhs.access_size);
return diff < overlap;
}
inline bool HasCrossIterationDependence(const PointerInfo& lhs, const PointerInfo& rhs,
int iv_stride) {
if (lhs.exact_key_valid && rhs.exact_key_valid &&
!memutils::MayAliasConservatively(lhs.exact_key, rhs.exact_key)) {
return false;
}
if (!lhs.base || !rhs.base) {
return true;
}
if (lhs.base != rhs.base) {
if (lhs.distinct_root && rhs.distinct_root && !lhs.argument_root && !rhs.argument_root) {
return false;
}
return true;
}
if (!lhs.byte_offset.valid || !rhs.byte_offset.valid) {
return true;
}
if (lhs.byte_offset.var != rhs.byte_offset.var) {
return true;
}
const auto lhs_step = lhs.byte_offset.coeff * iv_stride;
const auto rhs_step = rhs.byte_offset.coeff * iv_stride;
if (lhs_step == 0 && rhs_step == 0) {
return MayAliasSameIteration(lhs, rhs);
}
if (lhs_step == rhs_step && lhs_step != 0) {
const auto diff = rhs.byte_offset.constant - lhs.byte_offset.constant;
return diff != 0 && diff % std::llabs(lhs_step) == 0;
}
return true;
}
inline bool CallMayWritePointer(Function* callee, const PointerInfo& ptr) {
if (ptr.readonly_root) {
return false;
}
return memutils::CallMayWriteRoot(callee, ptr.root_kind);
}
inline bool IsSafeInvariantLoadToHoist(const Loop& loop, LoadInst* load, PhiInst* iv,
int iv_stride,
const std::vector<MemoryAccessInfo>& accesses,
const memutils::EscapeSummary* escapes = nullptr) {
if (!load) {
return false;
}
auto ptr = AnalyzePointer(load->GetPtr(), iv, loop, load->GetType()->GetSize(), escapes);
if (!ptr.invariant_address) {
return false;
}
if (ptr.readonly_root) {
return true;
}
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst == load) {
continue;
}
if (auto* call = dyncast<CallInst>(inst)) {
if (CallMayWritePointer(call->GetCallee(), ptr)) {
return false;
}
}
}
}
for (const auto& access : accesses) {
if (access.inst == load || !access.is_write) {
continue;
}
if (MayAliasSameIteration(ptr, access.ptr)) {
return false;
}
if (HasCrossIterationDependence(ptr, access.ptr, iv_stride)) {
return false;
}
}
return true;
}
inline bool HasScalarDependenceAcrossCut(const std::vector<Instruction*>& first_group,
const std::unordered_set<Instruction*>& second_set) {
for (auto* inst : first_group) {
if (!inst || inst->IsVoid()) {
continue;
}
for (const auto& use : inst->GetUses()) {
auto* user = dyncast<Instruction>(use.GetUser());
if (user && second_set.find(user) != second_set.end()) {
return true;
}
}
}
return false;
}
inline bool HasMemoryDependenceAcrossCut(const std::vector<MemoryAccessInfo>& accesses,
const std::unordered_set<Instruction*>& first_set,
const std::unordered_set<Instruction*>& second_set,
int iv_stride) {
for (const auto& lhs : accesses) {
if (first_set.find(lhs.inst) == first_set.end()) {
continue;
}
for (const auto& rhs : accesses) {
if (second_set.find(rhs.inst) == second_set.end()) {
continue;
}
if (!lhs.is_write && !rhs.is_write) {
continue;
}
if (MayAliasSameIteration(lhs.ptr, rhs.ptr) ||
HasCrossIterationDependence(lhs.ptr, rhs.ptr, iv_stride)) {
return true;
}
}
}
return false;
}
inline bool IsLoopParallelizable(const Loop& loop, PhiInst* iv, int iv_stride,
const std::vector<MemoryAccessInfo>& accesses) {
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
if (phi != iv) {
return false;
}
}
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* call = dyncast<CallInst>(inst)) {
auto* callee = call->GetCallee();
if (callee == nullptr || callee->HasObservableSideEffects() || callee->IsRecursive()) {
return false;
}
for (const auto& access : accesses) {
if (CallMayWritePointer(callee, access.ptr)) {
return false;
}
}
continue;
}
if (dyncast<MemsetInst>(inst)) {
return false;
}
}
}
for (std::size_t i = 0; i < accesses.size(); ++i) {
for (std::size_t j = i + 1; j < accesses.size(); ++j) {
if (!accesses[i].is_write && !accesses[j].is_write) {
continue;
}
if (HasCrossIterationDependence(accesses[i].ptr, accesses[j].ptr, iv_stride)) {
return false;
}
}
}
return true;
}
} // namespace ir::loopmem

@ -0,0 +1,440 @@
#pragma once
#include "ir/Analysis.h"
#include "ir/IR.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir::looputils {
inline Instruction* GetTerminator(BasicBlock* block) {
if (!block || block->GetInstructions().empty()) {
return nullptr;
}
auto* inst = block->GetInstructions().back().get();
return inst && inst->IsTerminator() ? inst : nullptr;
}
inline std::size_t GetTerminatorIndex(BasicBlock* block) {
if (!block) {
return 0;
}
const auto size = block->GetInstructions().size();
if (!block->HasTerminator()) {
return size;
}
return size == 0 ? 0 : size - 1;
}
inline std::size_t GetFirstNonPhiIndex(BasicBlock* block) {
if (!block) {
return 0;
}
std::size_t index = 0;
for (const auto& inst_ptr : block->GetInstructions()) {
if (!dyncast<PhiInst>(inst_ptr.get())) {
break;
}
++index;
}
return index;
}
inline std::string NextSyntheticName(Function& function, const std::string& prefix) {
static std::unordered_map<Function*, int> counters;
const int id = ++counters[&function];
return "%" + prefix + std::to_string(id);
}
inline std::string NextSyntheticBlockName(Function& function,
const std::string& prefix) {
static std::unordered_map<Function*, int> counters;
const int id = ++counters[&function];
return prefix + "." + std::to_string(id);
}
inline ConstantInt* ConstInt(int value) {
return new ConstantInt(Type::GetInt32Type(), value);
}
inline int GetPhiIncomingIndex(PhiInst* phi, BasicBlock* block) {
if (!phi || !block) {
return -1;
}
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
if (phi->GetIncomingBlock(i) == block) {
return i;
}
}
return -1;
}
inline bool ReplacePhiIncoming(PhiInst* phi, BasicBlock* old_block,
Value* new_value, BasicBlock* new_block) {
if (!phi || !old_block || !new_value || !new_block) {
return false;
}
const int index = GetPhiIncomingIndex(phi, old_block);
if (index < 0) {
return false;
}
phi->SetOperand(static_cast<std::size_t>(2 * index), new_value);
phi->SetOperand(static_cast<std::size_t>(2 * index + 1), new_block);
return true;
}
inline bool RedirectSuccessorEdge(BasicBlock* pred, BasicBlock* old_succ,
BasicBlock* new_succ) {
auto* terminator = GetTerminator(pred);
if (!terminator || !old_succ || !new_succ) {
return false;
}
if (auto* br = dyncast<UncondBrInst>(terminator)) {
if (br->GetDest() != old_succ) {
return false;
}
br->SetOperand(0, new_succ);
} else if (auto* condbr = dyncast<CondBrInst>(terminator)) {
bool changed = false;
if (condbr->GetThenBlock() == old_succ) {
condbr->SetOperand(1, new_succ);
changed = true;
}
if (condbr->GetElseBlock() == old_succ) {
condbr->SetOperand(2, new_succ);
changed = true;
}
if (!changed) {
return false;
}
} else {
return false;
}
pred->RemoveSuccessor(old_succ);
pred->AddSuccessor(new_succ);
return true;
}
inline Instruction* MoveInstructionBeforeTerminator(Instruction* inst,
BasicBlock* dest) {
if (!inst || !dest) {
return nullptr;
}
auto* src = inst->GetParent();
if (!src || src == dest) {
return inst;
}
auto& src_insts = src->GetInstructions();
auto src_it = std::find_if(src_insts.begin(), src_insts.end(),
[&](const std::unique_ptr<Instruction>& current) {
return current.get() == inst;
});
if (src_it == src_insts.end()) {
return nullptr;
}
auto moved = std::move(*src_it);
src_insts.erase(src_it);
moved->SetParent(dest);
auto& dest_insts = dest->GetInstructions();
auto insert_it = dest_insts.begin() +
static_cast<long long>(GetTerminatorIndex(dest));
auto* ptr = moved.get();
dest_insts.insert(insert_it, std::move(moved));
return ptr;
}
inline bool IsLoopInvariantValue(const Loop& loop, Value* value) {
auto* inst = dyncast<Instruction>(value);
return inst == nullptr || !loop.Contains(inst->GetParent());
}
inline Value* RemapValue(const std::unordered_map<Value*, Value*>& remap,
Value* value) {
auto it = remap.find(value);
return it == remap.end() ? value : it->second;
}
inline bool IsCloneableInstruction(const Instruction* inst) {
if (!inst || inst->IsTerminator() || inst->GetOpcode() == Opcode::Phi) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
case Opcode::Alloca:
case Opcode::Load:
case Opcode::Store:
case Opcode::Memset:
case Opcode::GetElementPtr:
case Opcode::Zext:
case Opcode::Call:
return true;
default:
return false;
}
}
inline Instruction* CloneInstruction(Function& function, Instruction* inst,
BasicBlock* dest,
std::unordered_map<Value*, Value*>& remap,
const std::string& prefix) {
if (!inst || !dest || !IsCloneableInstruction(inst)) {
return nullptr;
}
const auto insert_index = GetTerminatorIndex(dest);
const auto name = inst->IsVoid() ? std::string()
: NextSyntheticName(function, prefix);
auto remap_operand = [&](Value* value) { return RemapValue(remap, value); };
auto remember = [&](Instruction* clone) {
if (clone && !inst->IsVoid()) {
remap[inst] = clone;
}
return clone;
};
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE: {
auto* bin = static_cast<BinaryInst*>(inst);
return remember(dest->Insert<BinaryInst>(
insert_index, inst->GetOpcode(), inst->GetType(),
remap_operand(bin->GetLhs()), remap_operand(bin->GetRhs()), nullptr,
name));
}
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF: {
auto* un = static_cast<UnaryInst*>(inst);
return remember(dest->Insert<UnaryInst>(insert_index, inst->GetOpcode(),
inst->GetType(),
remap_operand(un->GetOprd()),
nullptr, name));
}
case Opcode::Alloca: {
auto* alloca = static_cast<AllocaInst*>(inst);
return remember(dest->Insert<AllocaInst>(insert_index,
alloca->GetAllocatedType(),
nullptr, name));
}
case Opcode::Load: {
auto* load = static_cast<LoadInst*>(inst);
return remember(dest->Insert<LoadInst>(insert_index, inst->GetType(),
remap_operand(load->GetPtr()),
nullptr, name));
}
case Opcode::Store: {
auto* store = static_cast<StoreInst*>(inst);
return dest->Insert<StoreInst>(insert_index,
remap_operand(store->GetValue()),
remap_operand(store->GetPtr()), nullptr);
}
case Opcode::Memset: {
auto* memset = static_cast<MemsetInst*>(inst);
return dest->Insert<MemsetInst>(insert_index,
remap_operand(memset->GetDest()),
remap_operand(memset->GetValue()),
remap_operand(memset->GetLength()),
remap_operand(memset->GetIsVolatile()),
nullptr);
}
case Opcode::GetElementPtr: {
auto* gep = static_cast<GetElementPtrInst*>(inst);
std::vector<Value*> indices;
indices.reserve(gep->GetNumIndices());
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
indices.push_back(remap_operand(gep->GetIndex(i)));
}
return remember(dest->Insert<GetElementPtrInst>(
insert_index, gep->GetSourceType(), remap_operand(gep->GetPointer()),
indices, nullptr, name));
}
case Opcode::Zext: {
auto* zext = static_cast<ZextInst*>(inst);
return remember(dest->Insert<ZextInst>(insert_index,
remap_operand(zext->GetValue()),
inst->GetType(), nullptr, name));
}
case Opcode::Call: {
auto* call = static_cast<CallInst*>(inst);
std::vector<Value*> args;
const auto original_args = call->GetArguments();
args.reserve(original_args.size());
for (auto* arg : original_args) {
args.push_back(remap_operand(arg));
}
return remember(dest->Insert<CallInst>(insert_index, call->GetCallee(),
args, nullptr, name));
}
case Opcode::Phi:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Return:
case Opcode::Unreachable:
break;
}
return nullptr;
}
inline BasicBlock* EnsurePreheader(Function& function, Loop& loop) {
if (loop.preheader) {
return loop.preheader;
}
auto* header = loop.header;
if (!header) {
return nullptr;
}
std::vector<BasicBlock*> outside_preds;
for (auto* pred : header->GetPredecessors()) {
if (!loop.Contains(pred)) {
outside_preds.push_back(pred);
}
}
if (outside_preds.empty()) {
return nullptr;
}
if (outside_preds.size() == 1 &&
outside_preds.front()->GetSuccessors().size() == 1) {
loop.preheader = outside_preds.front();
return loop.preheader;
}
auto* preheader = function.CreateBlock(
NextSyntheticBlockName(function, header->GetName() + ".preheader"));
std::size_t phi_insert_index = 0;
for (const auto& inst_ptr : header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
std::vector<int> outside_incomings;
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
if (!loop.Contains(phi->GetIncomingBlock(i))) {
outside_incomings.push_back(i);
}
}
if (outside_incomings.empty()) {
continue;
}
Value* merged_value = nullptr;
if (outside_incomings.size() == 1) {
merged_value = phi->GetIncomingValue(outside_incomings.front());
} else {
auto new_phi = std::make_unique<PhiInst>(
phi->GetType(), nullptr,
NextSyntheticName(function, "preheader.phi."));
auto* new_phi_ptr = new_phi.get();
new_phi_ptr->SetParent(preheader);
auto& preheader_insts = preheader->GetInstructions();
preheader_insts.insert(preheader_insts.begin() +
static_cast<long long>(phi_insert_index),
std::move(new_phi));
++phi_insert_index;
for (int incoming_index : outside_incomings) {
new_phi_ptr->AddIncoming(phi->GetIncomingValue(incoming_index),
phi->GetIncomingBlock(incoming_index));
}
merged_value = new_phi_ptr;
}
for (auto it = outside_incomings.rbegin(); it != outside_incomings.rend();
++it) {
phi->RemoveOperand(static_cast<std::size_t>(2 * *it + 1));
phi->RemoveOperand(static_cast<std::size_t>(2 * *it));
}
phi->AddIncoming(merged_value, preheader);
}
preheader->Append<UncondBrInst>(header, nullptr);
preheader->AddSuccessor(header);
header->AddPredecessor(preheader);
for (auto* pred : outside_preds) {
if (RedirectSuccessorEdge(pred, header, preheader)) {
preheader->AddPredecessor(pred);
header->RemovePredecessor(pred);
}
}
loop.preheader = preheader;
return preheader;
}
} // namespace ir::looputils

@ -0,0 +1,295 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopPassUtils.h"
#include <cstdlib>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
struct InductionVarInfo {
PhiInst* phi = nullptr;
Value* start = nullptr;
BasicBlock* latch = nullptr;
int stride = 0;
};
Value* BuildMulValue(Function& function, BasicBlock* block, Value* lhs, Value* rhs,
const std::string& prefix) {
if (auto* lhs_const = dyncast<ConstantInt>(lhs)) {
if (lhs_const->GetValue() == 0) {
return looputils::ConstInt(0);
}
if (lhs_const->GetValue() == 1) {
return rhs;
}
}
if (auto* rhs_const = dyncast<ConstantInt>(rhs)) {
if (rhs_const->GetValue() == 0) {
return looputils::ConstInt(0);
}
if (rhs_const->GetValue() == 1) {
return lhs;
}
}
if (auto* lhs_const = dyncast<ConstantInt>(lhs)) {
if (auto* rhs_const = dyncast<ConstantInt>(rhs)) {
return looputils::ConstInt(lhs_const->GetValue() * rhs_const->GetValue());
}
}
return block->Insert<BinaryInst>(looputils::GetTerminatorIndex(block), Opcode::Mul,
Type::GetInt32Type(), lhs, rhs, nullptr,
looputils::NextSyntheticName(function, prefix));
}
Value* BuildScaledValue(Function& function, BasicBlock* block, Value* base,
int factor, const std::string& prefix) {
if (factor == 0) {
return looputils::ConstInt(0);
}
if (factor == 1) {
return base;
}
if (auto* base_const = dyncast<ConstantInt>(base)) {
return looputils::ConstInt(base_const->GetValue() * factor);
}
if (factor == -1) {
return block->Insert<UnaryInst>(looputils::GetTerminatorIndex(block), Opcode::Neg,
base->GetType(), base, nullptr,
looputils::NextSyntheticName(function, prefix));
}
return BuildMulValue(function, block, base, looputils::ConstInt(factor), prefix);
}
bool MatchSimpleInductionVariable(const Loop& loop, BasicBlock* preheader,
PhiInst* phi, InductionVarInfo& info) {
if (!phi || !phi->GetType() || !phi->GetType()->IsInt32() ||
phi->GetParent() != loop.header || phi->GetNumIncomings() != 2 ||
loop.latches.size() != 1) {
return false;
}
auto* latch = loop.latches.front();
int preheader_index = -1;
int latch_index = -1;
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
if (phi->GetIncomingBlock(i) == preheader) {
preheader_index = i;
} else if (phi->GetIncomingBlock(i) == latch) {
latch_index = i;
}
}
if (preheader_index < 0 || latch_index < 0) {
return false;
}
auto* step_inst = dyncast<BinaryInst>(phi->GetIncomingValue(latch_index));
if (!step_inst || step_inst->GetParent() != latch) {
return false;
}
int stride = 0;
if (step_inst->GetOpcode() == Opcode::Add) {
if (step_inst->GetLhs() == phi) {
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
if (!delta) {
return false;
}
stride = delta->GetValue();
} else if (step_inst->GetRhs() == phi) {
auto* delta = dyncast<ConstantInt>(step_inst->GetLhs());
if (!delta) {
return false;
}
stride = delta->GetValue();
} else {
return false;
}
} else if (step_inst->GetOpcode() == Opcode::Sub) {
if (step_inst->GetLhs() != phi) {
return false;
}
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
if (!delta) {
return false;
}
stride = -delta->GetValue();
} else {
return false;
}
if (stride == 0) {
return false;
}
info.phi = phi;
info.start = phi->GetIncomingValue(preheader_index);
info.latch = latch;
info.stride = stride;
return true;
}
bool IsMulCandidate(const Loop& loop, Instruction* inst, PhiInst* phi, Value*& factor) {
auto* mul = dyncast<BinaryInst>(inst);
if (!mul || mul->GetOpcode() != Opcode::Mul || !mul->GetType()->IsInt32()) {
return false;
}
if (mul->GetLhs() == phi && looputils::IsLoopInvariantValue(loop, mul->GetRhs())) {
factor = mul->GetRhs();
return true;
}
if (mul->GetRhs() == phi && looputils::IsLoopInvariantValue(loop, mul->GetLhs())) {
factor = mul->GetLhs();
return true;
}
return false;
}
Value* CreateReducedPhi(Function& function, BasicBlock* header, BasicBlock* preheader,
const InductionVarInfo& iv, Value* factor) {
auto* reduced_phi = header->Insert<PhiInst>(
looputils::GetFirstNonPhiIndex(header), Type::GetInt32Type(), nullptr,
looputils::NextSyntheticName(function, "lsr.phi."));
Value* init = BuildMulValue(function, preheader, iv.start, factor, "lsr.init.");
reduced_phi->AddIncoming(init, preheader);
Value* step = BuildScaledValue(function, preheader, factor, std::abs(iv.stride),
"lsr.step.");
Instruction* next = nullptr;
if (iv.stride > 0) {
next = iv.latch->Insert<BinaryInst>(
looputils::GetTerminatorIndex(iv.latch), Opcode::Add, Type::GetInt32Type(),
reduced_phi, step, nullptr,
looputils::NextSyntheticName(function, "lsr.next."));
} else {
next = iv.latch->Insert<BinaryInst>(
looputils::GetTerminatorIndex(iv.latch), Opcode::Sub, Type::GetInt32Type(),
reduced_phi, step, nullptr,
looputils::NextSyntheticName(function, "lsr.next."));
}
reduced_phi->AddIncoming(next, iv.latch);
return reduced_phi;
}
bool ReduceLoopMultiplications(Function& function, const Loop& loop,
BasicBlock* preheader) {
if (!preheader || loop.latches.size() != 1) {
return false;
}
std::vector<InductionVarInfo> induction_vars;
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
InductionVarInfo info;
if (MatchSimpleInductionVariable(loop, preheader, phi, info)) {
induction_vars.push_back(info);
}
}
if (induction_vars.empty()) {
return false;
}
bool changed = false;
std::vector<Instruction*> to_remove;
for (const auto& iv : induction_vars) {
std::vector<std::pair<Instruction*, Value*>> candidates;
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst == iv.phi) {
continue;
}
Value* factor = nullptr;
if (IsMulCandidate(loop, inst, iv.phi, factor)) {
candidates.push_back({inst, factor});
}
}
}
if (candidates.empty()) {
continue;
}
std::unordered_map<Value*, Value*> reduced_cache;
for (const auto& candidate : candidates) {
auto* inst = candidate.first;
auto* factor = candidate.second;
auto cache_it = reduced_cache.find(factor);
Value* replacement = nullptr;
if (cache_it != reduced_cache.end()) {
replacement = cache_it->second;
} else {
replacement = CreateReducedPhi(function, loop.header, preheader, iv, factor);
reduced_cache.emplace(factor, replacement);
}
inst->ReplaceAllUsesWith(replacement);
to_remove.push_back(inst);
changed = true;
}
}
for (auto* inst : to_remove) {
if (inst && inst->GetParent()) {
inst->GetParent()->EraseInstruction(inst);
}
}
return changed;
}
bool RunLoopStrengthReductionOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
auto* old_preheader = loop->preheader;
auto* preheader = looputils::EnsurePreheader(function, *loop);
bool loop_changed = preheader != old_preheader;
loop_changed |= ReduceLoopMultiplications(function, *loop, preheader);
if (!loop_changed) {
continue;
}
changed = true;
local_changed = true;
break;
}
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLoopStrengthReduction(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoopStrengthReductionOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,400 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopMemoryUtils.h"
#include "LoopPassUtils.h"
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
struct CountedLoopInfo {
Loop* loop = nullptr;
BasicBlock* preheader = nullptr;
BasicBlock* header = nullptr;
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
CondBrInst* branch = nullptr;
BinaryInst* compare = nullptr;
Opcode compare_opcode = Opcode::ICmpLT;
Value* bound = nullptr;
loopmem::SimpleInductionVar induction_var;
std::vector<PhiInst*> phis;
};
bool HasSyntheticLoopTag(const std::string& name) {
return name.find("unroll.") != std::string::npos;
}
bool IsAlreadyTransformedLoop(const Loop& loop, BasicBlock* body) {
if (!loop.preheader || !loop.header || !body) {
return true;
}
if (HasSyntheticLoopTag(loop.preheader->GetName()) ||
HasSyntheticLoopTag(loop.header->GetName()) ||
HasSyntheticLoopTag(body->GetName())) {
return true;
}
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
const int incoming = looputils::GetPhiIncomingIndex(phi, loop.preheader);
if (incoming < 0) {
continue;
}
auto* incoming_phi = dyncast<PhiInst>(phi->GetIncomingValue(incoming));
if (incoming_phi && incoming_phi->GetParent() &&
HasSyntheticLoopTag(incoming_phi->GetParent()->GetName())) {
return true;
}
}
return false;
}
bool IsSupportedCompareOpcode(Opcode opcode) {
switch (opcode) {
case Opcode::ICmpLT:
case Opcode::ICmpLE:
case Opcode::ICmpGT:
case Opcode::ICmpGE:
return true;
default:
return false;
}
}
Opcode SwapCompareOpcode(Opcode opcode) {
switch (opcode) {
case Opcode::ICmpLT:
return Opcode::ICmpGT;
case Opcode::ICmpLE:
return Opcode::ICmpGE;
case Opcode::ICmpGT:
return Opcode::ICmpLT;
case Opcode::ICmpGE:
return Opcode::ICmpLE;
default:
return opcode;
}
}
int CountPayloadInstructions(BasicBlock* block) {
int count = 0;
if (!block) {
return 0;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator()) {
break;
}
++count;
}
return count;
}
int ChooseUnrollFactor(BasicBlock* body) {
const int inst_count = CountPayloadInstructions(body);
int mem_ops = 0;
for (const auto& inst_ptr : body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator()) {
break;
}
if (dyncast<LoadInst>(inst) || dyncast<StoreInst>(inst)) {
++mem_ops;
}
}
if (inst_count >= 2 && inst_count <= 6 && mem_ops <= 2) {
return 4;
}
if (inst_count >= 2 && inst_count <= 18) {
return 2;
}
return 1;
}
bool HasUnsafeLoopCarriedMemoryDependence(
const std::vector<loopmem::MemoryAccessInfo>& accesses, int iv_stride) {
for (std::size_t i = 0; i < accesses.size(); ++i) {
if (accesses[i].is_write &&
loopmem::HasCrossIterationDependence(accesses[i].ptr, accesses[i].ptr,
iv_stride)) {
return true;
}
for (std::size_t j = i + 1; j < accesses.size(); ++j) {
if (!accesses[i].is_write && !accesses[j].is_write) {
continue;
}
if (loopmem::HasCrossIterationDependence(accesses[i].ptr, accesses[j].ptr,
iv_stride)) {
return true;
}
}
}
return false;
}
bool MatchCountedLoop(Loop& loop, CountedLoopInfo& info) {
if (!loop.preheader || !loop.header || !loop.IsInnermost()) {
return false;
}
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) {
return false;
}
if (IsAlreadyTransformedLoop(loop, body)) {
return false;
}
auto* branch = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
if (!branch || branch->GetThenBlock() != body) {
return false;
}
auto* compare = dyncast<BinaryInst>(branch->GetCondition());
if (!compare || !compare->GetType()->IsBool() ||
!IsSupportedCompareOpcode(compare->GetOpcode())) {
return false;
}
bool found_iv = false;
loopmem::SimpleInductionVar induction_var;
std::vector<PhiInst*> phis;
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
phis.push_back(phi);
if (!found_iv &&
loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) {
found_iv = true;
}
}
if (!found_iv) {
return false;
}
Opcode compare_opcode = compare->GetOpcode();
Value* bound = nullptr;
if (compare->GetLhs() == induction_var.phi &&
looputils::IsLoopInvariantValue(loop, compare->GetRhs())) {
bound = compare->GetRhs();
} else if (compare->GetRhs() == induction_var.phi &&
looputils::IsLoopInvariantValue(loop, compare->GetLhs())) {
bound = compare->GetLhs();
compare_opcode = SwapCompareOpcode(compare_opcode);
} else {
return false;
}
if (!bound) {
return false;
}
if ((induction_var.stride > 0 &&
!(compare_opcode == Opcode::ICmpLT || compare_opcode == Opcode::ICmpLE)) ||
(induction_var.stride < 0 &&
!(compare_opcode == Opcode::ICmpGT || compare_opcode == Opcode::ICmpGE))) {
return false;
}
const auto accesses = loopmem::CollectMemoryAccesses(loop, induction_var.phi);
if (HasUnsafeLoopCarriedMemoryDependence(accesses, induction_var.stride)) {
return false;
}
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dyncast<PhiInst>(inst) || inst == compare || inst->IsTerminator()) {
continue;
}
return false;
}
for (const auto& inst_ptr : body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator()) {
continue;
}
if (!looputils::IsCloneableInstruction(inst) || dyncast<CallInst>(inst) ||
dyncast<MemsetInst>(inst) || dyncast<AllocaInst>(inst)) {
return false;
}
}
info.loop = &loop;
info.preheader = loop.preheader;
info.header = loop.header;
info.body = body;
info.exit = exit;
info.branch = branch;
info.compare = compare;
info.compare_opcode = compare_opcode;
info.bound = bound;
info.induction_var = induction_var;
info.phis = std::move(phis);
return true;
}
Value* BuildAdjustedBound(Function& function, BasicBlock* preheader, Value* bound,
int stride, int factor) {
const int delta = std::abs(stride) * (factor - 1);
if (delta == 0) {
return bound;
}
if (auto* ci = dyncast<ConstantInt>(bound)) {
return looputils::ConstInt(stride > 0 ? ci->GetValue() - delta : ci->GetValue() + delta);
}
return preheader->Insert<BinaryInst>(
looputils::GetTerminatorIndex(preheader),
stride > 0 ? Opcode::Sub : Opcode::Add, Type::GetInt32Type(), bound,
looputils::ConstInt(delta), nullptr,
looputils::NextSyntheticName(function, "unroll.bound."));
}
bool RunLoopUnrollOnFunction(Function& function) {
if (function.IsExternal() || !function.GetEntryBlock()) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
CountedLoopInfo info;
if (!MatchCountedLoop(*loop, info)) {
continue;
}
const int factor = ChooseUnrollFactor(info.body);
if (factor <= 1) {
continue;
}
auto* unrolled_header =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.header"));
auto* unrolled_body =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.body"));
auto* unrolled_exit =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.exit"));
std::unordered_map<Value*, Value*> remap;
std::unordered_map<PhiInst*, PhiInst*> unrolled_phis;
std::unordered_map<PhiInst*, PhiInst*> exit_phis;
std::unordered_map<PhiInst*, Value*> current_phi_values;
std::unordered_map<PhiInst*, Value*> latch_values;
for (auto* phi : info.phis) {
auto* cloned_phi = unrolled_header->Append<PhiInst>(
phi->GetType(), nullptr,
looputils::NextSyntheticName(function, "unroll.phi."));
const int preheader_index = looputils::GetPhiIncomingIndex(phi, info.preheader);
const int latch_index = looputils::GetPhiIncomingIndex(phi, info.body);
if (preheader_index < 0 || latch_index < 0) {
continue;
}
cloned_phi->AddIncoming(phi->GetIncomingValue(preheader_index), info.preheader);
remap[phi] = cloned_phi;
unrolled_phis.emplace(phi, cloned_phi);
current_phi_values.emplace(phi, cloned_phi);
latch_values.emplace(phi, phi->GetIncomingValue(latch_index));
}
auto* adjusted_bound = BuildAdjustedBound(function, info.preheader, info.bound,
info.induction_var.stride, factor);
auto* unrolled_cond = unrolled_header->Append<BinaryInst>(
info.compare_opcode, Type::GetBoolType(), unrolled_phis[info.induction_var.phi],
adjusted_bound, nullptr,
looputils::NextSyntheticName(function, "unroll.cmp."));
unrolled_header->Append<CondBrInst>(unrolled_cond, unrolled_body, unrolled_exit, nullptr);
unrolled_header->AddPredecessor(info.preheader);
unrolled_header->AddSuccessor(unrolled_body);
unrolled_header->AddSuccessor(unrolled_exit);
for (int lane = 0; lane < factor; ++lane) {
for (const auto& inst_ptr : info.body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator()) {
continue;
}
looputils::CloneInstruction(function, inst, unrolled_body, remap,
"unroll." + std::to_string(lane) + ".");
}
std::unordered_map<PhiInst*, Value*> next_phi_values;
for (const auto& entry : latch_values) {
next_phi_values.emplace(entry.first,
looputils::RemapValue(remap, entry.second));
}
for (const auto& entry : next_phi_values) {
remap[entry.first] = entry.second;
current_phi_values[entry.first] = entry.second;
}
}
for (const auto& entry : unrolled_phis) {
entry.second->AddIncoming(current_phi_values[entry.first], unrolled_body);
}
unrolled_body->Append<UncondBrInst>(unrolled_header, nullptr);
unrolled_body->AddPredecessor(unrolled_header);
unrolled_body->AddSuccessor(unrolled_header);
unrolled_header->AddPredecessor(unrolled_body);
for (const auto& entry : unrolled_phis) {
auto* exit_phi = unrolled_exit->Append<PhiInst>(
entry.first->GetType(), nullptr,
looputils::NextSyntheticName(function, "unroll.exit."));
exit_phi->AddIncoming(entry.second, unrolled_header);
exit_phis.emplace(entry.first, exit_phi);
}
unrolled_exit->Append<UncondBrInst>(info.header, nullptr);
unrolled_exit->AddPredecessor(unrolled_header);
unrolled_exit->AddSuccessor(info.header);
if (!looputils::RedirectSuccessorEdge(info.preheader, info.header, unrolled_header)) {
continue;
}
info.header->RemovePredecessor(info.preheader);
info.header->AddPredecessor(unrolled_exit);
for (auto* phi : info.phis) {
looputils::ReplacePhiIncoming(phi, info.preheader, exit_phis[phi], unrolled_exit);
}
changed = true;
local_changed = true;
break;
}
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLoopUnroll(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoopUnrollOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,313 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopPassUtils.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
struct UnswitchInfo {
Loop* loop = nullptr;
BasicBlock* preheader = nullptr;
BasicBlock* guard_block = nullptr;
CondBrInst* guard = nullptr;
Value* condition = nullptr;
std::vector<BasicBlock*> order;
};
bool HasSyntheticUnswitchTag(const std::string& name) {
return name.find("unswitch.") != std::string::npos;
}
bool IsSafeLoopBlockForUnswitch(BasicBlock* block) {
if (!block) {
return false;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dyncast<CallInst>(inst) || dyncast<MemsetInst>(inst) ||
dyncast<AllocaInst>(inst) || dyncast<UnreachableInst>(inst)) {
return false;
}
}
return true;
}
void CollectLoopDFS(BasicBlock* block, const Loop& loop,
std::unordered_set<BasicBlock*>& visited,
std::vector<BasicBlock*>& postorder) {
if (!block || !loop.Contains(block) || !visited.insert(block).second) {
return;
}
for (auto* succ : block->GetSuccessors()) {
if (loop.Contains(succ)) {
CollectLoopDFS(succ, loop, visited, postorder);
}
}
postorder.push_back(block);
}
std::vector<BasicBlock*> CollectLoopRPO(const Loop& loop) {
std::vector<BasicBlock*> postorder;
std::unordered_set<BasicBlock*> visited;
CollectLoopDFS(loop.header, loop, visited, postorder);
return std::vector<BasicBlock*>(postorder.rbegin(), postorder.rend());
}
void RemovePhiIncomingFromPred(BasicBlock* block, BasicBlock* pred) {
if (!block || !pred) {
return;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
const int index = looputils::GetPhiIncomingIndex(phi, pred);
if (index < 0) {
continue;
}
phi->RemoveOperand(static_cast<std::size_t>(2 * index + 1));
phi->RemoveOperand(static_cast<std::size_t>(2 * index));
}
}
void ReplaceTerminatorWithBr(BasicBlock* block, BasicBlock* dest) {
auto& instructions = block->GetInstructions();
if (instructions.empty() || !instructions.back()->IsTerminator()) {
return;
}
instructions.back()->ClearAllOperands();
instructions.pop_back();
block->Append<UncondBrInst>(dest, nullptr);
}
void ReplaceTerminatorWithCondBr(BasicBlock* block, Value* cond,
BasicBlock* then_block, BasicBlock* else_block) {
auto& instructions = block->GetInstructions();
if (instructions.empty() || !instructions.back()->IsTerminator()) {
return;
}
instructions.back()->ClearAllOperands();
instructions.pop_back();
block->Append<CondBrInst>(cond, then_block, else_block, nullptr);
}
bool MatchLoopUnswitch(Loop& loop, UnswitchInfo& info) {
if (!loop.preheader || !loop.IsInnermost() || loop.block_list.size() > 6) {
return false;
}
if (HasSyntheticUnswitchTag(loop.preheader->GetName()) ||
HasSyntheticUnswitchTag(loop.header->GetName())) {
return false;
}
int instruction_count = 0;
for (auto* block : loop.block_list) {
if (!IsSafeLoopBlockForUnswitch(block) || HasSyntheticUnswitchTag(block->GetName())) {
return false;
}
instruction_count += static_cast<int>(block->GetInstructions().size());
}
if (instruction_count > 48) {
return false;
}
for (auto* block : loop.block_list) {
auto* condbr = dyncast<CondBrInst>(looputils::GetTerminator(block));
if (!condbr) {
continue;
}
auto* cond_inst = dyncast<Instruction>(condbr->GetCondition());
if (cond_inst && loop.Contains(cond_inst->GetParent())) {
continue;
}
info.loop = &loop;
info.preheader = loop.preheader;
info.guard_block = block;
info.guard = condbr;
info.condition = condbr->GetCondition();
info.order = CollectLoopRPO(loop);
return true;
}
return false;
}
bool CloneLoopForUnswitch(Function& function, const UnswitchInfo& info,
std::unordered_map<BasicBlock*, BasicBlock*>& block_map,
std::unordered_map<Value*, Value*>& value_map) {
for (auto* block : info.order) {
block_map[block] = function.CreateBlock(
looputils::NextSyntheticBlockName(function, "unswitch.loop"));
}
for (auto* block : info.order) {
auto* clone = block_map.at(block);
for (const auto& inst_ptr : block->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
auto* cloned_phi = clone->Append<PhiInst>(
phi->GetType(), nullptr, looputils::NextSyntheticName(function, "unswitch.phi."));
value_map[phi] = cloned_phi;
}
}
for (auto* block : info.order) {
auto* clone = block_map.at(block);
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dyncast<PhiInst>(inst) || inst->IsTerminator()) {
continue;
}
if (!looputils::CloneInstruction(function, inst, clone, value_map, "unswitch.")) {
return false;
}
}
}
for (auto* block : info.order) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
auto* cloned_phi = static_cast<PhiInst*>(value_map.at(phi));
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
auto* incoming_block = phi->GetIncomingBlock(i);
auto value_it = value_map.find(phi->GetIncomingValue(i));
Value* incoming_value =
value_it == value_map.end() ? phi->GetIncomingValue(i) : value_it->second;
auto block_it = block_map.find(incoming_block);
cloned_phi->AddIncoming(incoming_value,
block_it == block_map.end() ? incoming_block
: block_it->second);
}
}
}
for (auto* block : info.order) {
auto* clone = block_map.at(block);
auto* terminator = looputils::GetTerminator(block);
if (auto* br = dyncast<UncondBrInst>(terminator)) {
auto target_it = block_map.find(br->GetDest());
auto* target = target_it == block_map.end() ? br->GetDest() : target_it->second;
clone->Append<UncondBrInst>(target, nullptr);
clone->AddSuccessor(target);
target->AddPredecessor(clone);
continue;
}
auto* condbr = dyncast<CondBrInst>(terminator);
if (!condbr) {
return false;
}
auto cond_it = value_map.find(condbr->GetCondition());
Value* cond = cond_it == value_map.end() ? condbr->GetCondition() : cond_it->second;
auto then_it = block_map.find(condbr->GetThenBlock());
auto else_it = block_map.find(condbr->GetElseBlock());
auto* then_block = then_it == block_map.end() ? condbr->GetThenBlock() : then_it->second;
auto* else_block = else_it == block_map.end() ? condbr->GetElseBlock() : else_it->second;
clone->Append<CondBrInst>(cond, then_block, else_block, nullptr);
clone->AddSuccessor(then_block);
clone->AddSuccessor(else_block);
then_block->AddPredecessor(clone);
else_block->AddPredecessor(clone);
}
return true;
}
bool RunLoopUnswitchOnFunction(Function& function) {
if (function.IsExternal() || !function.GetEntryBlock()) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
UnswitchInfo info;
if (!MatchLoopUnswitch(*loop, info)) {
continue;
}
std::unordered_map<BasicBlock*, BasicBlock*> block_map;
std::unordered_map<Value*, Value*> value_map;
if (!CloneLoopForUnswitch(function, info, block_map, value_map)) {
continue;
}
auto* then_target = info.guard->GetThenBlock();
auto* else_target = info.guard->GetElseBlock();
auto* cloned_guard = block_map.at(info.guard_block);
auto* cloned_then =
block_map.count(then_target) ? block_map.at(then_target) : then_target;
auto* cloned_else =
block_map.count(else_target) ? block_map.at(else_target) : else_target;
RemovePhiIncomingFromPred(else_target, info.guard_block);
if (then_target != else_target) {
RemovePhiIncomingFromPred(cloned_then, cloned_guard);
}
info.guard_block->RemoveSuccessor(else_target);
else_target->RemovePredecessor(info.guard_block);
ReplaceTerminatorWithBr(info.guard_block, then_target);
info.guard_block->AddSuccessor(then_target);
then_target->AddPredecessor(info.guard_block);
cloned_guard->RemoveSuccessor(cloned_then);
cloned_then->RemovePredecessor(cloned_guard);
ReplaceTerminatorWithBr(cloned_guard, cloned_else);
cloned_guard->AddSuccessor(cloned_else);
cloned_else->AddPredecessor(cloned_guard);
auto* cloned_header = block_map.at(loop->header);
auto* old_preheader_term = looputils::GetTerminator(info.preheader);
if (!old_preheader_term || !dyncast<UncondBrInst>(old_preheader_term)) {
continue;
}
info.preheader->RemoveSuccessor(loop->header);
loop->header->RemovePredecessor(info.preheader);
ReplaceTerminatorWithCondBr(info.preheader, info.condition, loop->header, cloned_header);
info.preheader->AddSuccessor(loop->header);
info.preheader->AddSuccessor(cloned_header);
loop->header->AddPredecessor(info.preheader);
cloned_header->AddPredecessor(info.preheader);
changed = true;
local_changed = true;
break;
}
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLoopUnswitch(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoopUnswitchOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,375 @@
#pragma once
#include "ir/IR.h"
#include <cstdint>
#include <cstddef>
#include <unordered_map>
#include <unordered_set>
namespace ir {
namespace mathidiom {
inline bool IsFloatConstant(Value* value, float expected) {
auto* constant = dyncast<ConstantFloat>(value);
return constant != nullptr && constant->GetValue() == expected;
}
inline bool IsFloatValue(Value* value, float expected) {
if (IsFloatConstant(value, expected)) {
return true;
}
auto* unary = dyncast<UnaryInst>(value);
if (unary == nullptr || unary->GetOpcode() != Opcode::IToF) {
return false;
}
auto* constant = dyncast<ConstantInt>(unary->GetOprd());
return constant != nullptr &&
static_cast<float>(constant->GetValue()) == expected;
}
inline Function* ParentFunction(const Instruction* inst) {
auto* block = inst == nullptr ? nullptr : inst->GetParent();
return block == nullptr ? nullptr : block->GetParent();
}
inline bool IsGlobalOnlyUsedByFunction(const GlobalValue* global,
const Function& function) {
if (global == nullptr) {
return false;
}
for (const auto& use : global->GetUses()) {
auto* inst = dyncast<Instruction>(use.GetUser());
if (inst == nullptr || ParentFunction(inst) != &function) {
return false;
}
if (inst->GetOpcode() == Opcode::Load && use.GetOperandIndex() == 0) {
continue;
}
if (inst->GetOpcode() == Opcode::Store && use.GetOperandIndex() == 1) {
continue;
}
return false;
}
return true;
}
inline bool HasBackedgeLikeBranch(const Function& function) {
std::unordered_map<const BasicBlock*, std::size_t> index;
const auto& blocks = function.GetBlocks();
for (std::size_t i = 0; i < blocks.size(); ++i) {
index[blocks[i].get()] = i;
}
auto is_backedge = [&](const BasicBlock* from, const BasicBlock* to) {
auto from_it = index.find(from);
auto to_it = index.find(to);
return from_it != index.end() && to_it != index.end() &&
to_it->second <= from_it->second;
};
for (std::size_t i = 0; i < blocks.size(); ++i) {
const auto& instructions = blocks[i]->GetInstructions();
if (instructions.empty()) {
continue;
}
auto* terminator = instructions.back().get();
if (auto* br = dyncast<UncondBrInst>(terminator)) {
if (is_backedge(blocks[i].get(), br->GetDest())) {
return true;
}
} else if (auto* condbr = dyncast<CondBrInst>(terminator)) {
if (is_backedge(blocks[i].get(), condbr->GetThenBlock()) ||
is_backedge(blocks[i].get(), condbr->GetElseBlock())) {
return true;
}
}
}
return false;
}
inline bool IsPowerOfTwoPositive(int value) {
return value > 0 && (value & (value - 1)) == 0;
}
inline int Log2Exact(int value) {
int shift = 0;
while (value > 1) {
value >>= 1;
++shift;
}
return shift;
}
inline bool DependsOnValueImpl(Value* value, Value* needle, int depth,
std::unordered_set<Value*>& visiting) {
if (value == needle) {
return true;
}
if (value == nullptr || depth <= 0 || !visiting.insert(value).second) {
return false;
}
auto* inst = dyncast<Instruction>(value);
if (inst == nullptr) {
return false;
}
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
if (DependsOnValueImpl(inst->GetOperand(i), needle, depth - 1, visiting)) {
return true;
}
}
return false;
}
inline bool DependsOnValue(Value* value, Value* needle, int depth = 12) {
std::unordered_set<Value*> visiting;
return DependsOnValueImpl(value, needle, depth, visiting);
}
// Recognize the radix-digit helper:
// while (i < pos) num = num / C;
// return num % C;
// for power-of-two C >= 4. Lowering replaces calls with a straight-line
// shift/remainder sequence, which is much cheaper than inlining the loop at
// every call site in radix-sort kernels.
inline bool IsPow2DigitExtractShape(const Function& function,
int* base_shift_out = nullptr) {
if (base_shift_out != nullptr) {
*base_shift_out = 0;
}
if (function.IsExternal() || function.GetReturnType() == nullptr ||
!function.GetReturnType()->IsInt32() || function.GetArguments().size() != 2 ||
!function.GetArgument(0)->GetType()->IsInt32() ||
!function.GetArgument(1)->GetType()->IsInt32() ||
!HasBackedgeLikeBranch(function)) {
return false;
}
auto* num_arg = function.GetArgument(0);
auto* pos_arg = function.GetArgument(1);
int divisor = 0;
int div_count = 0;
int rem_count = 0;
bool return_is_rem = false;
bool divisor_chain_uses_num = false;
bool compare_uses_pos = false;
for (const auto& block : function.GetBlocks()) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dyncast<CallInst>(inst) || dyncast<LoadInst>(inst) ||
dyncast<StoreInst>(inst) || dyncast<AllocaInst>(inst) ||
dyncast<GetElementPtrInst>(inst) || dyncast<MemsetInst>(inst) ||
dyncast<UnreachableInst>(inst)) {
return false;
}
if (auto* ret = dyncast<ReturnInst>(inst)) {
auto* returned = ret->HasReturnValue() ? ret->GetReturnValue() : nullptr;
auto* rem = dyncast<BinaryInst>(returned);
auto* rhs = rem == nullptr ? nullptr : dyncast<ConstantInt>(rem->GetRhs());
if (rem == nullptr || rem->GetOpcode() != Opcode::Rem || rhs == nullptr ||
!IsPowerOfTwoPositive(rhs->GetValue()) || rhs->GetValue() < 4) {
return false;
}
if (divisor == 0) {
divisor = rhs->GetValue();
} else if (divisor != rhs->GetValue()) {
return false;
}
return_is_rem = true;
continue;
}
auto* bin = dyncast<BinaryInst>(inst);
if (!bin) {
continue;
}
if (bin->GetOpcode() == Opcode::Div || bin->GetOpcode() == Opcode::Rem) {
auto* rhs = dyncast<ConstantInt>(bin->GetRhs());
if (rhs == nullptr || !IsPowerOfTwoPositive(rhs->GetValue()) ||
rhs->GetValue() < 4) {
return false;
}
if (divisor == 0) {
divisor = rhs->GetValue();
} else if (divisor != rhs->GetValue()) {
return false;
}
if (bin->GetOpcode() == Opcode::Div) {
++div_count;
} else {
++rem_count;
}
divisor_chain_uses_num |= DependsOnValue(bin->GetLhs(), num_arg);
}
switch (bin->GetOpcode()) {
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
compare_uses_pos |= DependsOnValue(bin->GetLhs(), pos_arg) ||
DependsOnValue(bin->GetRhs(), pos_arg);
break;
default:
break;
}
}
}
if (divisor == 0 || div_count == 0 || rem_count == 0 || !return_is_rem ||
!divisor_chain_uses_num || !compare_uses_pos) {
return false;
}
if (base_shift_out != nullptr) {
*base_shift_out = Log2Exact(divisor);
}
return true;
}
// Recognize the common tolerance-driven Newton iteration for sqrt:
// while (abs(t - x / t) > eps) t = (t + x / t) / 2;
// The matcher is intentionally structural: it does not inspect source names or
// filenames. Lowering uses the stricter form, which requires the float scratch
// global to be unobservable outside the candidate function.
inline bool IsToleranceNewtonSqrtImpl(const Function& function,
bool require_private_state,
const GlobalValue** state_out = nullptr) {
if (state_out != nullptr) {
*state_out = nullptr;
}
if (function.IsExternal() || function.GetReturnType() == nullptr ||
!function.GetReturnType()->IsFloat() || function.GetArguments().size() != 1 ||
!function.GetArguments()[0]->GetType()->IsFloat() ||
function.GetBlocks().size() < 3 || function.GetBlocks().size() > 8 ||
!HasBackedgeLikeBranch(function)) {
return false;
}
auto* input = function.GetArguments()[0].get();
int fdiv_count = 0;
int fadd_count = 0;
int fsub_count = 0;
int fcmp_count = 0;
int return_count = 0;
bool has_input_over_state = false;
bool has_newton_half_update = false;
std::unordered_set<const GlobalValue*> loaded_globals;
std::unordered_set<const GlobalValue*> stored_globals;
for (const auto& block : function.GetBlocks()) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
switch (inst->GetOpcode()) {
case Opcode::FDiv: {
++fdiv_count;
auto* binary = static_cast<BinaryInst*>(inst);
if (binary->GetLhs() == input) {
has_input_over_state = true;
}
if (IsFloatValue(binary->GetRhs(), 2.0f) &&
dyncast<Instruction>(binary->GetLhs()) != nullptr &&
static_cast<Instruction*>(binary->GetLhs())->GetOpcode() == Opcode::FAdd) {
has_newton_half_update = true;
}
break;
}
case Opcode::FAdd:
++fadd_count;
break;
case Opcode::FSub:
++fsub_count;
break;
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
++fcmp_count;
break;
case Opcode::Load: {
auto* load = static_cast<LoadInst*>(inst);
auto* global = dyncast<GlobalValue>(load->GetPtr());
if (global == nullptr || !load->GetType()->IsFloat() ||
!global->GetObjectType()->IsFloat()) {
return false;
}
loaded_globals.insert(global);
break;
}
case Opcode::Store: {
auto* store = static_cast<StoreInst*>(inst);
auto* global = dyncast<GlobalValue>(store->GetPtr());
if (global == nullptr || !store->GetValue()->GetType()->IsFloat() ||
!global->GetObjectType()->IsFloat()) {
return false;
}
stored_globals.insert(global);
break;
}
case Opcode::Return:
++return_count;
if (!static_cast<ReturnInst*>(inst)->HasReturnValue() ||
!static_cast<ReturnInst*>(inst)->GetReturnValue()->GetType()->IsFloat()) {
return false;
}
break;
case Opcode::Call:
case Opcode::Alloca:
case Opcode::GetElementPtr:
case Opcode::Memset:
case Opcode::Unreachable:
return false;
default:
break;
}
}
}
if (fdiv_count < 2 || fadd_count < 1 || fsub_count < 1 || fcmp_count < 1 ||
return_count != 1 || !has_input_over_state || !has_newton_half_update) {
return false;
}
const GlobalValue* state = nullptr;
for (auto* global : stored_globals) {
if (loaded_globals.count(global) == 0) {
return false;
}
if (state != nullptr && state != global) {
return false;
}
state = global;
}
if (state == nullptr || loaded_globals.size() != 1 || !state->HasInitializer() ||
!IsFloatConstant(state->GetInitializer(), 1.0f)) {
return false;
}
if (require_private_state && !IsGlobalOnlyUsedByFunction(state, function)) {
return false;
}
if (state_out != nullptr) {
*state_out = state;
}
return true;
}
inline bool IsToleranceNewtonSqrtShape(const Function& function) {
return IsToleranceNewtonSqrtImpl(function, false);
}
inline bool IsPrivateToleranceNewtonSqrt(const Function& function,
const GlobalValue** state_out = nullptr) {
return IsToleranceNewtonSqrtImpl(function, true, state_out);
}
} // namespace mathidiom
} // namespace ir

@ -1,4 +1,405 @@
// Mem2RegSSA 构造):
// Mem2RegSSA 构造):
// - 将局部变量的 alloca/load/store 提升为 SSA 形式
// - 插入 PHI 并重写使用,依赖支配树等分析
#include "ir/PassManager.h"
#include "ir/IR.h"
#include <algorithm>
#include <memory>
#include <queue>
#include <stdexcept>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace ir {
namespace {
struct DominatorInfo {
std::vector<BasicBlock*> blocks;
std::unordered_map<BasicBlock*, size_t> index;
std::vector<std::vector<bool>> dominates;
std::vector<BasicBlock*> idom;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dom_tree_children;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dominance_frontier;
};
struct PromotableAlloca {
AllocaInst* alloca = nullptr;
std::shared_ptr<Type> value_type;
std::unordered_set<BasicBlock*> def_blocks;
std::unordered_map<BasicBlock*, PhiInst*> phis;
};
bool IsScalarPromotableType(const std::shared_ptr<Type>& type) {
return type && (type->IsInt1() || type->IsInt32() || type->IsFloat());
}
Value* DefaultValueFor(Context& ctx, const std::shared_ptr<Type>& type) {
if (type->IsInt1()) {
return ctx.GetConstBool(false);
}
if (type->IsInt32()) {
return ctx.GetConstInt(0);
}
if (type->IsFloat()) {
return new ConstantFloat(Type::GetFloatType(), 0.0f);
}
throw std::runtime_error("Mem2Reg encountered unsupported promotable type");
}
std::vector<BasicBlock*> CollectReachableBlocks(Function& function) {
std::vector<BasicBlock*> order;
auto* entry = function.GetEntryBlock();
if (!entry) {
return order;
}
std::unordered_set<BasicBlock*> visited;
std::vector<BasicBlock*> stack{entry};
while (!stack.empty()) {
auto* block = stack.back();
stack.pop_back();
if (!block || !visited.insert(block).second) {
continue;
}
order.push_back(block);
const auto& succs = block->GetSuccessors();
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
if (*it != nullptr) {
stack.push_back(*it);
}
}
}
return order;
}
std::vector<bool> IntersectDominators(const std::vector<std::vector<bool>>& doms,
const std::vector<size_t>& pred_indices,
size_t self_index) {
std::vector<bool> result(doms.size(), true);
if (pred_indices.empty()) {
std::fill(result.begin(), result.end(), false);
result[self_index] = true;
return result;
}
result = doms[pred_indices.front()];
for (size_t i = 1; i < pred_indices.size(); ++i) {
const auto& pred_dom = doms[pred_indices[i]];
for (size_t j = 0; j < result.size(); ++j) {
result[j] = result[j] && pred_dom[j];
}
}
result[self_index] = true;
return result;
}
DominatorInfo BuildDominatorInfo(Function& function) {
DominatorInfo info;
info.blocks = CollectReachableBlocks(function);
info.idom.resize(info.blocks.size(), nullptr);
info.dominates.assign(info.blocks.size(),
std::vector<bool>(info.blocks.size(), true));
if (info.blocks.empty()) {
return info;
}
for (size_t i = 0; i < info.blocks.size(); ++i) {
info.index[info.blocks[i]] = i;
}
for (size_t i = 0; i < info.blocks.size(); ++i) {
std::fill(info.dominates[i].begin(), info.dominates[i].end(), i != 0);
info.dominates[i][i] = true;
}
bool changed = true;
while (changed) {
changed = false;
for (size_t i = 1; i < info.blocks.size(); ++i) {
std::vector<size_t> pred_indices;
for (auto* pred : info.blocks[i]->GetPredecessors()) {
auto it = info.index.find(pred);
if (it != info.index.end()) {
pred_indices.push_back(it->second);
}
}
auto new_dom = IntersectDominators(info.dominates, pred_indices, i);
if (new_dom != info.dominates[i]) {
info.dominates[i] = std::move(new_dom);
changed = true;
}
}
}
for (size_t i = 1; i < info.blocks.size(); ++i) {
BasicBlock* candidate_idom = nullptr;
for (size_t j = 0; j < info.blocks.size(); ++j) {
if (i == j || !info.dominates[i][j]) {
continue;
}
bool is_immediate = true;
for (size_t k = 0; k < info.blocks.size(); ++k) {
if (k == i || k == j || !info.dominates[i][k]) {
continue;
}
if (info.dominates[k][j]) {
is_immediate = false;
break;
}
}
if (is_immediate) {
candidate_idom = info.blocks[j];
break;
}
}
info.idom[i] = candidate_idom;
if (candidate_idom != nullptr) {
info.dom_tree_children[candidate_idom].push_back(info.blocks[i]);
}
}
for (auto* block : info.blocks) {
info.dominance_frontier[block] = {};
}
for (auto* block : info.blocks) {
std::vector<BasicBlock*> reachable_preds;
for (auto* pred : block->GetPredecessors()) {
if (info.index.find(pred) != info.index.end()) {
reachable_preds.push_back(pred);
}
}
if (reachable_preds.size() < 2) {
continue;
}
auto* idom_block = info.idom[info.index[block]];
for (auto* pred : reachable_preds) {
auto* runner = pred;
while (runner != nullptr && runner != idom_block) {
auto& frontier = info.dominance_frontier[runner];
if (std::find(frontier.begin(), frontier.end(), block) == frontier.end()) {
frontier.push_back(block);
}
auto idom_it = info.index.find(runner);
if (idom_it == info.index.end()) {
break;
}
runner = info.idom[idom_it->second];
}
}
}
return info;
}
bool IsPromotableAlloca(AllocaInst& alloca, const DominatorInfo& dom_info) {
if (!IsScalarPromotableType(alloca.GetAllocatedType())) {
return false;
}
for (const auto& use : alloca.GetUses()) {
auto* user = use.GetUser();
auto* inst = dynamic_cast<Instruction*>(user);
if (inst == nullptr || inst->GetParent() == nullptr ||
dom_info.index.find(inst->GetParent()) == dom_info.index.end()) {
return false;
}
if (auto* load = dynamic_cast<LoadInst*>(inst)) {
if (load->GetPtr() != &alloca) {
return false;
}
continue;
}
auto* store = dynamic_cast<StoreInst*>(inst);
if (store == nullptr || store->GetPtr() != &alloca ||
store->GetValue() == &alloca) {
return false;
}
if (store->GetValue()->GetType() != alloca.GetAllocatedType()) {
return false;
}
}
return true;
}
size_t CountLeadingPhiNodes(BasicBlock& block) {
size_t count = 0;
for (const auto& inst : block.GetInstructions()) {
if (!isa<PhiInst>(inst.get())) {
break;
}
++count;
}
return count;
}
void InsertPhiNodes(Context& ctx, PromotableAlloca& slot,
const DominatorInfo& dom_info) {
std::queue<BasicBlock*> worklist;
std::unordered_set<BasicBlock*> queued;
for (auto* block : slot.def_blocks) {
worklist.push(block);
queued.insert(block);
}
while (!worklist.empty()) {
auto* block = worklist.front();
worklist.pop();
auto frontier_it = dom_info.dominance_frontier.find(block);
if (frontier_it == dom_info.dominance_frontier.end()) {
continue;
}
for (auto* frontier_block : frontier_it->second) {
if (slot.phis.find(frontier_block) != slot.phis.end()) {
continue;
}
auto phi_index = CountLeadingPhiNodes(*frontier_block);
auto* phi = frontier_block->Insert<PhiInst>(phi_index, slot.value_type, nullptr,
ctx.NextTemp());
slot.phis[frontier_block] = phi;
if (slot.def_blocks.insert(frontier_block).second) {
worklist.push(frontier_block);
}
}
}
}
void RenamePromotedAlloca(BasicBlock* block, PromotableAlloca& slot,
const DominatorInfo& dom_info,
std::vector<Value*>& stack, Context& ctx) {
if (block == nullptr) {
return;
}
size_t pushed = 0;
PhiInst* block_phi = nullptr;
auto phi_it = slot.phis.find(block);
if (phi_it != slot.phis.end()) {
block_phi = phi_it->second;
stack.push_back(block_phi);
++pushed;
}
std::vector<Instruction*> to_remove;
Instruction* alloca_to_remove = nullptr;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst == block_phi) {
continue;
}
if (auto* load = dynamic_cast<LoadInst*>(inst)) {
if (load->GetPtr() != slot.alloca) {
continue;
}
auto* replacement =
stack.empty() ? DefaultValueFor(ctx, slot.value_type) : stack.back();
load->ReplaceAllUsesWith(replacement);
to_remove.push_back(load);
continue;
}
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
if (store->GetPtr() != slot.alloca) {
continue;
}
stack.push_back(store->GetValue());
++pushed;
to_remove.push_back(store);
continue;
}
if (inst == slot.alloca) {
alloca_to_remove = inst;
}
}
for (auto* succ : block->GetSuccessors()) {
auto succ_phi_it = slot.phis.find(succ);
if (succ_phi_it == slot.phis.end()) {
continue;
}
auto* incoming =
stack.empty() ? DefaultValueFor(ctx, slot.value_type) : stack.back();
succ_phi_it->second->AddIncoming(incoming, block);
}
auto child_it = dom_info.dom_tree_children.find(block);
if (child_it != dom_info.dom_tree_children.end()) {
for (auto* child : child_it->second) {
RenamePromotedAlloca(child, slot, dom_info, stack, ctx);
}
}
for (auto* inst : to_remove) {
block->EraseInstruction(inst);
}
if (alloca_to_remove != nullptr) {
block->EraseInstruction(alloca_to_remove);
}
while (pushed > 0) {
stack.pop_back();
--pushed;
}
}
void PromoteAllocasInFunction(Function& function, Context& ctx) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return;
}
auto dom_info = BuildDominatorInfo(function);
if (dom_info.blocks.empty()) {
return;
}
std::vector<PromotableAlloca> promotable_allocas;
for (const auto& inst_ptr : function.GetEntryBlock()->GetInstructions()) {
auto* alloca = dynamic_cast<AllocaInst*>(inst_ptr.get());
if (alloca == nullptr || !IsPromotableAlloca(*alloca, dom_info)) {
continue;
}
PromotableAlloca slot;
slot.alloca = alloca;
slot.value_type = alloca->GetAllocatedType();
for (const auto& use : alloca->GetUses()) {
auto* inst = dynamic_cast<Instruction*>(use.GetUser());
auto* store = inst == nullptr ? nullptr : dynamic_cast<StoreInst*>(inst);
if (store != nullptr && store->GetPtr() == alloca) {
slot.def_blocks.insert(inst->GetParent());
}
}
promotable_allocas.push_back(std::move(slot));
}
for (auto& slot : promotable_allocas) {
InsertPhiNodes(ctx, slot, dom_info);
std::vector<Value*> stack;
RenamePromotedAlloca(function.GetEntryBlock(), slot, dom_info, stack, ctx);
}
}
} // namespace
void RunMem2Reg(Module& module) {
auto& ctx = module.GetContext();
for (const auto& function : module.GetFunctions()) {
if (function != nullptr) {
PromoteAllocasInFunction(*function, ctx);
}
}
}
} // namespace ir

@ -0,0 +1,261 @@
#pragma once
#include "ir/IR.h"
#include "PassUtils.h"
#include <cstdint>
#include <unordered_set>
#include <vector>
namespace ir::memutils {
enum class PointerRootKind {
Local,
Global,
ReadonlyGlobal,
Param,
Unknown,
};
struct AddressComponent {
bool is_constant = false;
std::int64_t constant = 0;
Value* value = nullptr;
bool operator==(const AddressComponent& rhs) const {
return is_constant == rhs.is_constant && constant == rhs.constant &&
value == rhs.value;
}
};
struct AddressKey {
PointerRootKind kind = PointerRootKind::Unknown;
Value* root = nullptr;
std::vector<AddressComponent> components;
bool operator==(const AddressKey& rhs) const {
return kind == rhs.kind && root == rhs.root && components == rhs.components;
}
};
struct AddressKeyHash {
std::size_t operator()(const AddressKey& key) const {
std::size_t h = static_cast<std::size_t>(key.kind);
h ^= std::hash<Value*>{}(key.root) + 0x9e3779b9 + (h << 6) + (h >> 2);
for (const auto& component : key.components) {
h ^= std::hash<bool>{}(component.is_constant) + 0x9e3779b9 + (h << 6) + (h >> 2);
if (component.is_constant) {
h ^= std::hash<std::int64_t>{}(component.constant) + 0x9e3779b9 + (h << 6) +
(h >> 2);
} else {
h ^= std::hash<Value*>{}(component.value) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
}
return h;
}
};
struct EscapeSummary {
std::unordered_set<Value*> escaped_locals;
bool IsEscaped(Value* value) const {
return value != nullptr && escaped_locals.find(value) != escaped_locals.end();
}
};
inline bool IsNoEscapePointerUse(Value* current, Instruction* user) {
if (!current || !user) {
return false;
}
if (auto* load = dyncast<LoadInst>(user)) {
return load->GetPtr() == current;
}
if (auto* store = dyncast<StoreInst>(user)) {
return store->GetPtr() == current;
}
if (auto* memset = dyncast<MemsetInst>(user)) {
return memset->GetDest() == current;
}
return false;
}
inline bool PointerValueEscapes(Value* current, Value* root,
std::unordered_set<Value*>& visiting) {
if (!current || !root || !visiting.insert(current).second) {
return false;
}
for (const auto& use : current->GetUses()) {
auto* user = dyncast<Instruction>(use.GetUser());
if (!user) {
return true;
}
if (auto* gep = dyncast<GetElementPtrInst>(user)) {
if (gep->GetPointer() == current &&
PointerValueEscapes(gep, root, visiting)) {
return true;
}
continue;
}
if (IsNoEscapePointerUse(current, user)) {
continue;
}
return true;
}
return false;
}
inline EscapeSummary AnalyzeEscapes(Function& function) {
EscapeSummary summary;
for (const auto& block_ptr : function.GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* alloca = dyncast<AllocaInst>(inst_ptr.get());
if (!alloca) {
continue;
}
std::unordered_set<Value*> visiting;
if (PointerValueEscapes(alloca, alloca, visiting)) {
summary.escaped_locals.insert(alloca);
}
}
}
return summary;
}
inline PointerRootKind ClassifyRoot(Value* root, const EscapeSummary* summary) {
if (root == nullptr) {
return PointerRootKind::Unknown;
}
if (auto* global = dyncast<GlobalValue>(root)) {
return global->IsConstant() ? PointerRootKind::ReadonlyGlobal
: PointerRootKind::Global;
}
if (isa<Argument>(root)) {
return PointerRootKind::Param;
}
if (isa<AllocaInst>(root)) {
if (summary != nullptr && summary->IsEscaped(root)) {
return PointerRootKind::Unknown;
}
return PointerRootKind::Local;
}
return PointerRootKind::Unknown;
}
inline Value* StripPointerRoot(Value* pointer) {
auto* current = pointer;
while (auto* gep = dyncast<GetElementPtrInst>(current)) {
current = gep->GetPointer();
}
return current;
}
inline AddressComponent MakeAddressComponent(Value* value) {
if (auto* ci = dyncast<ConstantInt>(value)) {
return {true, ci->GetValue(), nullptr};
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return {true, cb->GetValue() ? 1 : 0, nullptr};
}
return {false, 0, value};
}
inline bool BuildExactAddressKey(Value* pointer, const EscapeSummary* summary,
AddressKey& key) {
if (!pointer) {
return false;
}
if (auto* gep = dyncast<GetElementPtrInst>(pointer)) {
if (!BuildExactAddressKey(gep->GetPointer(), summary, key)) {
return false;
}
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
key.components.push_back(MakeAddressComponent(gep->GetIndex(i)));
}
return true;
}
key.kind = ClassifyRoot(pointer, summary);
key.root = pointer;
key.components.clear();
return true;
}
inline bool HasOnlyConstantComponents(const AddressKey& key) {
for (const auto& component : key.components) {
if (!component.is_constant) {
return false;
}
}
return true;
}
inline bool MayAliasConservatively(const AddressKey& lhs, const AddressKey& rhs) {
if (lhs.kind == PointerRootKind::Unknown || rhs.kind == PointerRootKind::Unknown) {
return true;
}
if (lhs.kind != rhs.kind || lhs.root != rhs.root) {
return false;
}
if (lhs.components == rhs.components) {
return true;
}
if (HasOnlyConstantComponents(lhs) && HasOnlyConstantComponents(rhs)) {
return false;
}
return true;
}
inline bool CallMayReadRoot(Function* callee, PointerRootKind kind) {
if (!callee) {
return true;
}
if (callee->HasUnknownEffects()) {
return true;
}
switch (kind) {
case PointerRootKind::ReadonlyGlobal:
return callee->ReadsGlobalMemory();
case PointerRootKind::Global:
return callee->ReadsGlobalMemory() || callee->WritesGlobalMemory() ||
callee->ReadsParamMemory() || callee->WritesParamMemory();
case PointerRootKind::Param:
return callee->ReadsParamMemory() || callee->WritesParamMemory();
case PointerRootKind::Local:
return callee->ReadsParamMemory() || callee->WritesParamMemory();
case PointerRootKind::Unknown:
return callee->MayReadMemory();
}
return true;
}
inline bool CallMayWriteRoot(Function* callee, PointerRootKind kind) {
if (!callee) {
return true;
}
if (callee->HasUnknownEffects()) {
return true;
}
switch (kind) {
case PointerRootKind::ReadonlyGlobal:
return false;
case PointerRootKind::Global:
return callee->WritesGlobalMemory() || callee->WritesParamMemory();
case PointerRootKind::Param:
return callee->WritesParamMemory();
case PointerRootKind::Local:
return callee->WritesParamMemory();
case PointerRootKind::Unknown:
return callee->MayWriteMemory();
}
return true;
}
inline bool IsPureCall(const CallInst* call) {
auto* callee = call == nullptr ? nullptr : call->GetCallee();
return callee != nullptr && callee->CanDiscardUnusedCall() &&
!callee->MayReadMemory();
}
} // namespace ir::memutils

@ -1 +1,79 @@
// IR Pass 管理骨架。
// IR Pass 管理骨架。
#include "ir/PassManager.h"
#include <cstdlib>
namespace ir {
void RunIRPassPipeline(Module& module) {
const char* disable_mem2reg = std::getenv("NUDTC_DISABLE_MEM2REG");
if (disable_mem2reg != nullptr && disable_mem2reg[0] != '\0' && disable_mem2reg[0] != '0') {
return;
}
const char* disable_loop_mem_promotion =
std::getenv("NUDTC_DISABLE_LOOP_MEM_PROMOTION");
const bool run_loop_mem_promotion =
disable_loop_mem_promotion == nullptr || disable_loop_mem_promotion[0] == '\0' ||
disable_loop_mem_promotion[0] == '0';
const char* disable_inline_cfg = std::getenv("NUDTC_DISABLE_CFG_INLINE");
const bool run_cfg_inline =
disable_inline_cfg == nullptr || disable_inline_cfg[0] == '\0' ||
disable_inline_cfg[0] == '0';
const char* disable_loop_unswitch = std::getenv("NUDTC_DISABLE_LOOP_UNSWITCH");
const bool run_loop_unswitch =
disable_loop_unswitch == nullptr || disable_loop_unswitch[0] == '\0' ||
disable_loop_unswitch[0] == '0';
const char* disable_tail_recursion =
std::getenv("NUDTC_DISABLE_TAIL_RECURSION");
const bool run_tail_recursion =
disable_tail_recursion == nullptr || disable_tail_recursion[0] == '\0' ||
disable_tail_recursion[0] == '0';
RunMem2Reg(module);
if (run_tail_recursion) {
RunTailRecursionElim(module);
}
constexpr int kMaxIterations = 8;
for (int iteration = 0; iteration < kMaxIterations; ++iteration) {
bool changed = false;
if (run_tail_recursion) {
changed |= RunTailRecursionElim(module);
}
if (run_cfg_inline) {
changed |= RunFunctionInlining(module);
}
changed |= RunArithmeticSimplify(module);
changed |= RunConstProp(module);
changed |= RunConstFold(module);
changed |= RunGVN(module);
changed |= RunLoadStoreElim(module);
changed |= RunCSE(module);
changed |= RunDCE(module);
changed |= RunCFGSimplify(module);
changed |= RunLICM(module);
if (run_loop_mem_promotion) {
changed |= RunLoopMemoryPromotion(module);
}
if (run_loop_unswitch) {
changed |= RunLoopUnswitch(module);
}
changed |= RunLoopStrengthReduction(module);
changed |= RunLoopFission(module);
changed |= RunLoopUnroll(module);
changed |= RunArithmeticSimplify(module);
changed |= RunConstProp(module);
changed |= RunConstFold(module);
changed |= RunGVN(module);
changed |= RunLoadStoreElim(module);
changed |= RunCSE(module);
changed |= RunDCE(module);
changed |= RunCFGSimplify(module);
if (!changed) {
break;
}
}
}
} // namespace ir

@ -0,0 +1,234 @@
#pragma once
#include "ir/IR.h"
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <memory>
#include <unordered_set>
#include <vector>
namespace ir::passutils {
inline std::uint32_t FloatBits(float value) {
std::uint32_t bits = 0;
std::memcpy(&bits, &value, sizeof(bits));
return bits;
}
inline bool AreEquivalentValues(Value* lhs, Value* rhs) {
if (lhs == rhs) {
return true;
}
auto* lhs_i32 = dyncast<ConstantInt>(lhs);
auto* rhs_i32 = dyncast<ConstantInt>(rhs);
if (lhs_i32 && rhs_i32) {
return lhs_i32->GetValue() == rhs_i32->GetValue();
}
auto* lhs_i1 = dyncast<ConstantI1>(lhs);
auto* rhs_i1 = dyncast<ConstantI1>(rhs);
if (lhs_i1 && rhs_i1) {
return lhs_i1->GetValue() == rhs_i1->GetValue();
}
auto* lhs_f32 = dyncast<ConstantFloat>(lhs);
auto* rhs_f32 = dyncast<ConstantFloat>(rhs);
if (lhs_f32 && rhs_f32) {
return FloatBits(lhs_f32->GetValue()) == FloatBits(rhs_f32->GetValue());
}
return false;
}
inline std::vector<BasicBlock*> CollectReachableBlocks(Function& function) {
std::vector<BasicBlock*> order;
auto* entry = function.GetEntryBlock();
if (!entry) {
return order;
}
std::unordered_set<BasicBlock*> visited;
std::vector<BasicBlock*> stack{entry};
while (!stack.empty()) {
auto* block = stack.back();
stack.pop_back();
if (!block || !visited.insert(block).second) {
continue;
}
order.push_back(block);
const auto& succs = block->GetSuccessors();
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
if (*it != nullptr) {
stack.push_back(*it);
}
}
}
return order;
}
inline bool IsSideEffectingInstruction(const Instruction* inst) {
if (!inst) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Store:
case Opcode::Memset:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Return:
case Opcode::Unreachable:
return true;
case Opcode::Call: {
auto* call = dyncast<const CallInst>(inst);
auto* callee = call == nullptr ? nullptr : call->GetCallee();
return callee == nullptr || !callee->CanDiscardUnusedCall();
}
default:
return false;
}
}
inline bool IsTriviallyDead(Instruction* inst) {
return inst != nullptr && !IsSideEffectingInstruction(inst) &&
inst->GetUses().empty();
}
inline void RemoveIncomingForBlock(PhiInst* phi, BasicBlock* block) {
if (!phi || !block) {
return;
}
for (int i = phi->GetNumIncomings() - 1; i >= 0; --i) {
if (phi->GetIncomingBlock(i) != block) {
continue;
}
phi->RemoveOperand(static_cast<size_t>(2 * i + 1));
phi->RemoveOperand(static_cast<size_t>(2 * i));
}
}
inline void RemoveIncomingFromSuccessor(BasicBlock* succ, BasicBlock* pred) {
if (!succ || !pred) {
return;
}
for (const auto& inst_ptr : succ->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
RemoveIncomingForBlock(phi, pred);
}
}
inline void ReplaceTerminatorWithBr(BasicBlock* block, BasicBlock* dest) {
auto& instructions = block->GetInstructions();
if (instructions.empty() || !instructions.back()->IsTerminator()) {
return;
}
instructions.back()->ClearAllOperands();
auto branch = std::make_unique<UncondBrInst>(dest, nullptr);
branch->SetParent(block);
instructions.back() = std::move(branch);
}
inline bool SimplifyPhiInst(PhiInst* phi) {
if (!phi) {
return false;
}
Value* unique_value = nullptr;
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
auto* incoming = phi->GetIncomingValue(i);
if (incoming == phi) {
continue;
}
if (unique_value == nullptr) {
unique_value = incoming;
continue;
}
if (!AreEquivalentValues(unique_value, incoming)) {
return false;
}
}
if (unique_value == nullptr) {
return false;
}
auto* parent = phi->GetParent();
phi->ReplaceAllUsesWith(unique_value);
parent->EraseInstruction(phi);
return true;
}
inline void EraseBlock(Function& function, BasicBlock* block) {
if (!block) {
return;
}
auto& blocks = function.GetBlocks();
blocks.erase(std::remove_if(blocks.begin(), blocks.end(),
[&](const std::unique_ptr<BasicBlock>& current) {
return current.get() == block;
}),
blocks.end());
}
inline bool RemoveUnreachableBlocks(Function& function) {
auto reachable = CollectReachableBlocks(function);
std::unordered_set<BasicBlock*> reachable_set(reachable.begin(), reachable.end());
std::vector<BasicBlock*> dead_blocks;
for (const auto& block_ptr : function.GetBlocks()) {
auto* block = block_ptr.get();
if (reachable_set.find(block) == reachable_set.end()) {
dead_blocks.push_back(block);
}
}
if (dead_blocks.empty()) {
return false;
}
for (auto* block : dead_blocks) {
auto preds = block->GetPredecessors();
auto succs = block->GetSuccessors();
for (auto* succ : succs) {
RemoveIncomingFromSuccessor(succ, block);
succ->RemovePredecessor(block);
}
for (auto* pred : preds) {
pred->RemoveSuccessor(block);
}
}
for (auto* block : dead_blocks) {
for (const auto& inst_ptr : block->GetInstructions()) {
inst_ptr->ClearAllOperands();
}
}
for (auto* block : dead_blocks) {
EraseBlock(function, block);
}
return true;
}
inline bool IsCommutativeOpcode(Opcode opcode) {
switch (opcode) {
case Opcode::Add:
case Opcode::Mul:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::FAdd:
case Opcode::FMul:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
return true;
default:
return false;
}
}
} // namespace ir::passutils

@ -0,0 +1,249 @@
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "LoopPassUtils.h"
#include <algorithm>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
struct TailCallSite {
BasicBlock* block = nullptr;
CallInst* call = nullptr;
ReturnInst* ret = nullptr;
};
bool HasEntryPhi(Function& function) {
auto* entry = function.GetEntryBlock();
if (!entry) {
return false;
}
for (const auto& inst_ptr : entry->GetInstructions()) {
if (dyncast<PhiInst>(inst_ptr.get())) {
return true;
}
break;
}
return false;
}
bool IsOnlyUsedByReturn(CallInst* call, ReturnInst* ret) {
if (!call || !ret) {
return false;
}
const auto& uses = call->GetUses();
return uses.size() == 1 && uses.front().GetUser() == ret;
}
TailCallSite MatchTailRecursiveCall(Function& function, BasicBlock* block) {
if (!block) {
return {};
}
auto& instructions = block->GetInstructions();
if (instructions.size() < 2) {
return {};
}
auto* ret = dyncast<ReturnInst>(instructions.back().get());
if (!ret) {
return {};
}
auto* previous = instructions[instructions.size() - 2].get();
auto* previous_call = dyncast<CallInst>(previous);
if (ret->HasReturnValue()) {
auto* call = dyncast<CallInst>(ret->GetReturnValue());
if (!call || call != previous_call || call->GetParent() != block ||
call->GetCallee() != &function || !IsOnlyUsedByReturn(call, ret)) {
return {};
}
return {block, call, ret};
}
if (!previous_call || previous_call->GetCallee() != &function ||
!previous_call->GetType()->IsVoid() || !previous_call->GetUses().empty()) {
return {};
}
return {block, previous_call, ret};
}
std::vector<TailCallSite> CollectTailCallSites(Function& function) {
std::vector<TailCallSite> sites;
for (const auto& block_ptr : function.GetBlocks()) {
auto site = MatchTailRecursiveCall(function, block_ptr.get());
if (site.block && site.call && site.ret) {
sites.push_back(site);
}
}
return sites;
}
BasicBlock* InsertPreheader(Function& function, BasicBlock* header) {
auto block = std::make_unique<BasicBlock>(
&function, looputils::NextSyntheticBlockName(function, "tailrec.entry"));
auto* preheader = block.get();
auto& blocks = function.GetBlocks();
blocks.insert(blocks.begin(), std::move(block));
function.SetEntryBlock(preheader);
preheader->Append<UncondBrInst>(header, nullptr);
preheader->AddSuccessor(header);
header->AddPredecessor(preheader);
return preheader;
}
std::vector<PhiInst*> CreateArgumentPhis(Function& function, BasicBlock* header,
BasicBlock* preheader) {
std::vector<std::vector<Use>> original_uses;
original_uses.reserve(function.GetArguments().size());
for (const auto& arg : function.GetArguments()) {
original_uses.push_back(arg->GetUses());
}
std::vector<PhiInst*> phis;
phis.reserve(function.GetArguments().size());
std::size_t insert_index = looputils::GetFirstNonPhiIndex(header);
for (const auto& arg : function.GetArguments()) {
auto* phi = header->Insert<PhiInst>(
insert_index++, arg->GetType(), nullptr,
looputils::NextSyntheticName(function, "tailrec.arg."));
phi->AddIncoming(arg.get(), preheader);
phis.push_back(phi);
}
for (std::size_t i = 0; i < function.GetArguments().size(); ++i) {
for (const auto& use : original_uses[i]) {
if (auto* user = use.GetUser()) {
user->SetOperand(use.GetOperandIndex(), phis[i]);
}
}
}
return phis;
}
void ReplaceTerminatorWithBranch(BasicBlock* block, BasicBlock* dest) {
auto& instructions = block->GetInstructions();
instructions.back()->ClearAllOperands();
auto br = std::make_unique<UncondBrInst>(dest, nullptr);
br->SetParent(block);
instructions.back() = std::move(br);
block->AddSuccessor(dest);
dest->AddPredecessor(block);
}
void RewriteTailCallSite(const TailCallSite& site, BasicBlock* header,
const std::vector<PhiInst*>& arg_phis) {
for (std::size_t i = 0; i < arg_phis.size(); ++i) {
arg_phis[i]->AddIncoming(site.call->GetOperand(i + 1), site.block);
}
ReplaceTerminatorWithBranch(site.block, header);
site.block->EraseInstruction(site.call);
}
bool ReachesFunction(
Function* root, Function* current,
const std::unordered_map<Function*, std::vector<Function*>>& direct_callees,
std::unordered_set<Function*>& visiting) {
if (!root || !current || current->IsExternal()) {
return false;
}
if (!visiting.insert(current).second) {
return false;
}
auto it = direct_callees.find(current);
if (it == direct_callees.end()) {
return false;
}
for (auto* callee : it->second) {
if (callee == root) {
return true;
}
if (ReachesFunction(root, callee, direct_callees, visiting)) {
return true;
}
}
return false;
}
void RecomputeRecursiveFlags(Module& module) {
std::unordered_map<Function*, std::vector<Function*>> direct_callees;
for (const auto& function_ptr : module.GetFunctions()) {
auto* function = function_ptr.get();
if (!function || function->IsExternal()) {
continue;
}
auto& callees = direct_callees[function];
for (const auto& block_ptr : function->GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* call = dyncast<CallInst>(inst_ptr.get());
auto* callee = call ? call->GetCallee() : nullptr;
if (callee && !callee->IsExternal() &&
std::find(callees.begin(), callees.end(), callee) == callees.end()) {
callees.push_back(callee);
}
}
}
}
for (const auto& function_ptr : module.GetFunctions()) {
auto* function = function_ptr.get();
if (!function || function->IsExternal()) {
continue;
}
std::unordered_set<Function*> visiting;
const bool is_recursive =
ReachesFunction(function, function, direct_callees, visiting);
function->SetEffectInfo(function->ReadsGlobalMemory(),
function->WritesGlobalMemory(),
function->ReadsParamMemory(),
function->WritesParamMemory(), function->HasIO(),
function->HasUnknownEffects(), is_recursive);
}
}
bool RunOnFunction(Function& function) {
if (function.IsExternal() || !function.GetEntryBlock() || HasEntryPhi(function)) {
return false;
}
auto sites = CollectTailCallSites(function);
if (sites.empty()) {
return false;
}
auto* header = function.GetEntryBlock();
auto* preheader = InsertPreheader(function, header);
auto arg_phis = CreateArgumentPhis(function, header, preheader);
for (const auto& site : sites) {
RewriteTailCallSite(site, header, arg_phis);
}
return true;
}
} // namespace
bool RunTailRecursionElim(Module& module) {
bool changed = false;
for (const auto& function_ptr : module.GetFunctions()) {
if (function_ptr) {
changed |= RunOnFunction(*function_ptr);
}
}
if (changed) {
RecomputeRecursiveFlags(module);
}
return changed;
}
} // namespace ir

@ -1,4 +1,4 @@
add_library(irgen STATIC
add_library(irgen STATIC
IRGenDriver.cpp
IRGenFunc.cpp
IRGenStmt.cpp
@ -8,6 +8,8 @@ add_library(irgen STATIC
target_link_libraries(irgen PUBLIC
build_options
frontend
${ANTLR4_RUNTIME_TARGET}
ir
)
sem
)

@ -1,107 +1,670 @@
#include "irgen/IRGen.h"
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
#include <utility>
namespace {
std::string GetLValueName(SysYParser::LValueContext& lvalue) {
if (!lvalue.ID()) {
throw std::runtime_error(FormatError("irgen", "非法左值"));
std::vector<int> ExpandLinearIndex(const std::vector<int>& dims, size_t flat_index) {
std::vector<int> indices(dims.size(), 0);
for (size_t i = dims.size(); i > 0; --i) {
const auto dim_index = i - 1;
indices[dim_index] = static_cast<int>(flat_index % static_cast<size_t>(dims[dim_index]));
flat_index /= static_cast<size_t>(dims[dim_index]);
}
return lvalue.ID()->getText();
return indices;
}
} // namespace
std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句块"));
std::string IRGenImpl::ExpectIdent(const antlr4::ParserRuleContext& ctx,
antlr4::tree::TerminalNode* ident) const {
if (ident == nullptr) {
ThrowError(&ctx, "?????");
}
for (auto* item : ctx->blockItem()) {
if (item) {
if (VisitBlockItemResult(*item) == BlockFlow::Terminated) {
// 当前语法要求 return 为块内最后一条语句;命中后可停止生成。
break;
}
return ident->getText();
}
SemanticType IRGenImpl::ParseBType(SysYParser::BTypeContext* ctx) const {
if (ctx == nullptr) {
ThrowError(ctx, "??????");
}
if (ctx->INT()) {
return SemanticType::Int;
}
if (ctx->FLOAT()) {
return SemanticType::Float;
}
ThrowError(ctx, "????? int/float ????");
}
SemanticType IRGenImpl::ParseFuncType(SysYParser::FuncTypeContext* ctx) const {
if (ctx == nullptr) {
ThrowError(ctx, "????????");
}
if (ctx->VOID()) {
return SemanticType::Void;
}
if (ctx->INT()) {
return SemanticType::Int;
}
if (ctx->FLOAT()) {
return SemanticType::Float;
}
ThrowError(ctx, "????? void/int/float ??????");
}
std::shared_ptr<ir::Type> IRGenImpl::GetIRScalarType(SemanticType type) const {
switch (type) {
case SemanticType::Void:
return ir::Type::GetVoidType();
case SemanticType::Int:
return ir::Type::GetInt32Type();
case SemanticType::Float:
return ir::Type::GetFloatType();
}
throw std::runtime_error("unknown semantic type");
}
std::shared_ptr<ir::Type> IRGenImpl::BuildArrayType(
SemanticType base_type, const std::vector<int>& dims) const {
auto type = GetIRScalarType(base_type);
for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
type = ir::Type::GetArrayType(type, static_cast<size_t>(*it));
}
return type;
}
std::vector<int> IRGenImpl::ParseArrayDims(
const std::vector<SysYParser::ConstExpContext*>& dims_ctx) {
std::vector<int> dims;
dims.reserve(dims_ctx.size());
for (auto* dim_ctx : dims_ctx) {
if (dim_ctx == nullptr || dim_ctx->addExp() == nullptr) {
ThrowError(dim_ctx, "???????????");
}
auto dim = ConvertConst(EvalConstAddExp(*dim_ctx->addExp()), SemanticType::Int);
if (dim.int_value <= 0) {
ThrowError(dim_ctx, "??????????");
}
dims.push_back(dim.int_value);
}
return {};
return dims;
}
IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult(
SysYParser::BlockItemContext& item) {
return std::any_cast<BlockFlow>(item.accept(this));
std::vector<int> IRGenImpl::ParseParamDims(SysYParser::FuncFParamContext& ctx) {
std::vector<int> dims;
for (auto* exp_ctx : ctx.exp()) {
auto dim = ConvertConst(EvalConstExp(*exp_ctx), SemanticType::Int);
if (dim.int_value <= 0) {
ThrowError(exp_ctx, "????????????");
}
dims.push_back(dim.int_value);
}
return dims;
}
std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少块内项"));
void IRGenImpl::PredeclareGlobalDecl(SysYParser::DeclContext& ctx) {
auto declare_one = [&](const std::string& name, SemanticType type, bool is_const,
const std::vector<int>& dims, const antlr4::ParserRuleContext* node) {
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(node, "????????: " + name);
}
SymbolEntry entry;
entry.kind = is_const ? SymbolKind::Constant : SymbolKind::Variable;
entry.type = type;
entry.is_const = is_const;
entry.is_array = !dims.empty();
entry.is_param_array = false;
entry.dims = dims;
entry.ir_value = module_.CreateGlobalValue(
name, dims.empty() ? GetIRScalarType(type) : BuildArrayType(type, dims),
is_const, nullptr);
if (!symbols_.Insert(name, entry)) {
ThrowError(node, "????????: " + name);
}
};
if (ctx.constDecl() != nullptr) {
const auto type = ParseBType(ctx.constDecl()->bType());
for (auto* def : ctx.constDecl()->constDef()) {
const auto name = ExpectIdent(*def, def->Ident());
const auto dims = ParseArrayDims(def->constExp());
declare_one(name, type, true, dims, def);
auto* symbol = symbols_.Lookup(name);
if (symbol != nullptr && dims.empty()) {
symbol->const_scalar = ConvertConst(
EvalConstAddExp(*def->constInitVal()->constExp()->addExp()), type);
}
}
return;
}
if (ctx.varDecl() != nullptr) {
const auto type = ParseBType(ctx.varDecl()->bType());
for (auto* def : ctx.varDecl()->varDef()) {
declare_one(ExpectIdent(*def, def->Ident()), type, false,
ParseArrayDims(def->constExp()), def);
}
return;
}
if (ctx->decl()) {
ctx->decl()->accept(this);
return BlockFlow::Continue;
ThrowError(&ctx, "????");
}
void IRGenImpl::EmitGlobalDecl(SysYParser::DeclContext& ctx) { EmitDecl(ctx, true); }
void IRGenImpl::EmitDecl(SysYParser::DeclContext& ctx, bool is_global) {
if (ctx.constDecl() != nullptr) {
EmitConstDecl(ctx.constDecl(), is_global);
return;
}
if (ctx->stmt()) {
return ctx->stmt()->accept(this);
if (ctx.varDecl() != nullptr) {
EmitVarDecl(ctx.varDecl(), is_global, false);
return;
}
throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明"));
ThrowError(&ctx, "????");
}
// 变量声明的 IR 生成目前也是最小实现:
// - 先检查声明的基础类型,当前仅支持局部 int
// - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。
//
// 和更完整的版本相比,这里还没有:
// - 一个 Decl 中多个变量定义的顺序处理;
// - const、数组、全局变量等不同声明形态
// - 更丰富的类型系统。
std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量声明"));
void IRGenImpl::EmitVarDecl(SysYParser::VarDeclContext* ctx, bool is_global,
bool is_const) {
if (ctx == nullptr) {
ThrowError(ctx, "??????");
}
const auto type = ParseBType(ctx->bType());
for (auto* def : ctx->varDef()) {
if (is_global) {
EmitGlobalVarDef(*def, type);
} else {
EmitLocalVarDef(*def, type, is_const);
}
}
if (!ctx->btype() || !ctx->btype()->INT()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明"));
}
void IRGenImpl::EmitConstDecl(SysYParser::ConstDeclContext* ctx, bool is_global) {
if (ctx == nullptr) {
ThrowError(ctx, "??????");
}
auto* var_def = ctx->varDef();
if (!var_def) {
throw std::runtime_error(FormatError("irgen", "非法变量声明"));
const auto type = ParseBType(ctx->bType());
for (auto* def : ctx->constDef()) {
if (is_global) {
EmitGlobalConstDef(*def, type);
} else {
EmitLocalConstDef(*def, type);
}
}
var_def->accept(this);
return {};
}
void IRGenImpl::EmitGlobalVarDef(SysYParser::VarDefContext& ctx, SemanticType type) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr || !ir::isa<ir::GlobalValue>(symbol->ir_value)) {
ThrowError(&ctx, "??????????: " + name);
}
auto* global = static_cast<ir::GlobalValue*>(symbol->ir_value);
symbol->kind = SymbolKind::Variable;
symbol->type = type;
symbol->is_const = false;
symbol->is_array = !ctx.constExp().empty();
symbol->dims = ParseArrayDims(ctx.constExp());
if (symbol->is_array) {
// Leave uninitialized globals as zeroinitializer instead of materializing
// an explicit all-zero constant array, which can explode memory usage.
if (ctx.initVal() == nullptr) {
global->SetInitializer(nullptr);
return;
}
if (IsExplicitZeroInitVal(ctx.initVal(), type)) {
global->SetInitializer(nullptr);
return;
}
auto flat = FlattenInitVal(ctx.initVal(), type, symbol->dims);
std::vector<ir::Value*> elements;
elements.reserve(flat.size());
for (const auto& value : flat) {
elements.push_back(CreateTypedConstant(value));
}
global->SetInitializer(builder_.CreateConstArray(BuildArrayType(type, symbol->dims),
elements, {}));
} else {
ConstantValue init = ZeroConst(type);
if (ctx.initVal() != nullptr) {
if (ctx.initVal()->exp() == nullptr) {
ThrowError(ctx.initVal(), "???????????????");
}
init = ConvertConst(EvalConstExp(*ctx.initVal()->exp()), type);
}
global->SetInitializer(CreateTypedConstant(init));
}
}
// 当前仍是教学用的最小版本,因此这里只支持:
// - 局部 int 变量;
// - 标量初始化;
// - 一个 VarDef 对应一个槽位。
std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量定义"));
void IRGenImpl::EmitGlobalConstDef(SysYParser::ConstDefContext& ctx,
SemanticType type) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr || !ir::isa<ir::GlobalValue>(symbol->ir_value)) {
ThrowError(&ctx, "??????????: " + name);
}
if (!ctx->lValue()) {
throw std::runtime_error(FormatError("irgen", "变量声明缺少名称"));
auto* global = static_cast<ir::GlobalValue*>(symbol->ir_value);
symbol->kind = SymbolKind::Constant;
symbol->type = type;
symbol->is_const = true;
symbol->is_array = !ctx.constExp().empty();
symbol->dims = ParseArrayDims(ctx.constExp());
global->SetConstant(true);
if (symbol->is_array) {
if (IsExplicitZeroConstInitVal(ctx.constInitVal(), type)) {
symbol->const_array.clear();
symbol->const_array_all_zero = true;
global->SetInitializer(nullptr);
return;
}
symbol->const_array = FlattenConstInitVal(ctx.constInitVal(), type, symbol->dims);
symbol->const_array_all_zero = false;
std::vector<ir::Value*> elements;
elements.reserve(symbol->const_array.size());
for (const auto& value : symbol->const_array) {
elements.push_back(CreateTypedConstant(value));
}
global->SetInitializer(builder_.CreateConstArray(BuildArrayType(type, symbol->dims),
elements, {}));
} else {
auto init = ConvertConst(EvalConstAddExp(*ctx.constInitVal()->constExp()->addExp()), type);
symbol->const_scalar = init;
global->SetInitializer(CreateTypedConstant(init));
}
GetLValueName(*ctx->lValue());
if (storage_map_.find(ctx) != storage_map_.end()) {
throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位"));
}
ir::AllocaInst* IRGenImpl::CreateEntryAlloca(std::shared_ptr<ir::Type> allocated_type,
const std::string& name) {
if (current_function_ == nullptr || current_function_->GetEntryBlock() == nullptr) {
throw std::runtime_error("CreateEntryAlloca requires an active function entry block");
}
auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
storage_map_[ctx] = slot;
ir::Value* init = nullptr;
if (auto* init_value = ctx->initValue()) {
if (!init_value->exp()) {
throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化"));
auto* entry = current_function_->GetEntryBlock();
size_t insert_pos = 0;
for (const auto& inst : entry->GetInstructions()) {
if (!ir::isa<ir::AllocaInst>(inst.get())) {
break;
}
init = EvalExpr(*init_value->exp());
++insert_pos;
}
return entry->Insert<ir::AllocaInst>(insert_pos, std::move(allocated_type), nullptr,
name);
}
void IRGenImpl::EmitLocalVarDef(SysYParser::VarDefContext& ctx, SemanticType type,
bool is_const) {
const auto name = ExpectIdent(ctx, ctx.Ident());
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(&ctx, "????????: " + name);
}
SymbolEntry entry;
entry.kind = is_const ? SymbolKind::Constant : SymbolKind::Variable;
entry.type = type;
entry.is_const = is_const;
entry.is_array = !ctx.constExp().empty();
entry.dims = ParseArrayDims(ctx.constExp());
if (entry.is_array) {
entry.ir_value = CreateEntryAlloca(BuildArrayType(type, entry.dims),
NextTemp());
} else {
init = builder_.CreateConstInt(0);
entry.ir_value = CreateEntryAlloca(GetIRScalarType(type),
NextTemp());
}
if (!symbols_.Insert(name, entry)) {
ThrowError(&ctx, "????????: " + name);
}
auto* symbol = symbols_.Lookup(name);
if (!entry.is_array) {
TypedValue init_value{ZeroIRValue(type), type, false, {}};
if (ctx.initVal() != nullptr) {
if (ctx.initVal()->exp() == nullptr) {
ThrowError(ctx.initVal(), "???????????????");
}
init_value = CastScalar(EmitExp(*ctx.initVal()->exp()), type, ctx.initVal());
}
builder_.CreateStore(init_value.value, symbol->ir_value);
return;
}
ZeroInitializeLocalArray(symbol->ir_value, type, symbol->dims);
if (ctx.initVal() != nullptr) {
auto init_slots = FlattenLocalInitVal(ctx.initVal(), symbol->dims);
StoreLocalArrayElements(symbol->ir_value, type, symbol->dims, init_slots);
}
}
void IRGenImpl::EmitLocalConstDef(SysYParser::ConstDefContext& ctx,
SemanticType type) {
const auto name = ExpectIdent(ctx, ctx.Ident());
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(&ctx, "????????: " + name);
}
SymbolEntry entry;
entry.kind = SymbolKind::Constant;
entry.type = type;
entry.is_const = true;
entry.is_array = !ctx.constExp().empty();
entry.dims = ParseArrayDims(ctx.constExp());
entry.ir_value = CreateEntryAlloca(
entry.is_array ? BuildArrayType(type, entry.dims) : GetIRScalarType(type),
NextTemp());
if (!symbols_.Insert(name, entry)) {
ThrowError(&ctx, "????????: " + name);
}
auto* symbol = symbols_.Lookup(name);
if (!entry.is_array) {
auto init = ConvertConst(EvalConstAddExp(*ctx.constInitVal()->constExp()->addExp()), type);
symbol->const_scalar = init;
builder_.CreateStore(CreateTypedConstant(init), symbol->ir_value);
return;
}
if (IsExplicitZeroConstInitVal(ctx.constInitVal(), type)) {
symbol->const_array.clear();
symbol->const_array_all_zero = true;
ZeroInitializeLocalArray(symbol->ir_value, type, symbol->dims);
return;
}
symbol->const_array = FlattenConstInitVal(ctx.constInitVal(), type, symbol->dims);
symbol->const_array_all_zero = false;
ZeroInitializeLocalArray(symbol->ir_value, type, symbol->dims);
for (size_t i = 0; i < symbol->const_array.size(); ++i) {
if (symbol->const_array[i].type == SemanticType::Int && symbol->const_array[i].int_value == 0) {
continue;
}
if (symbol->const_array[i].type == SemanticType::Float &&
symbol->const_array[i].float_value == 0.0f) {
continue;
}
const auto indices = ExpandLinearIndex(symbol->dims, i);
std::vector<ir::Value*> index_values;
index_values.reserve(indices.size());
for (int index : indices) {
index_values.push_back(builder_.CreateConstInt(index));
}
auto* addr = CreateArrayElementAddr(symbol->ir_value, false, type, symbol->dims,
index_values, &ctx);
builder_.CreateStore(CreateTypedConstant(symbol->const_array[i]), addr);
}
}
std::vector<ConstantValue> IRGenImpl::FlattenConstInitVal(
SysYParser::ConstInitValContext* ctx, SemanticType base_type,
const std::vector<int>& dims) {
std::vector<ConstantValue> out(CountArrayElements(dims), ZeroConst(base_type));
if (ctx != nullptr) {
size_t cursor = 0;
FlattenConstInitValImpl(ctx, base_type, dims, 0, 0, out.size(), cursor, out);
}
return out;
}
std::vector<ConstantValue> IRGenImpl::FlattenInitVal(
SysYParser::InitValContext* ctx, SemanticType base_type,
const std::vector<int>& dims) {
std::vector<ConstantValue> out(CountArrayElements(dims), ZeroConst(base_type));
if (ctx != nullptr) {
size_t cursor = 0;
FlattenInitValImpl(ctx, base_type, dims, 0, 0, out.size(), cursor, out);
}
return out;
}
std::vector<IRGenImpl::InitExprSlot> IRGenImpl::FlattenLocalInitVal(
SysYParser::InitValContext* ctx, const std::vector<int>& dims) {
std::vector<InitExprSlot> out;
if (ctx != nullptr) {
size_t cursor = 0;
FlattenLocalInitValImpl(ctx, dims, 0, 0, CountArrayElements(dims), cursor, out);
}
return out;
}
void IRGenImpl::FlattenConstInitValImpl(SysYParser::ConstInitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims,
size_t depth, size_t object_begin,
size_t object_end, size_t& cursor,
std::vector<ConstantValue>& out) {
if (ctx == nullptr || cursor >= object_end) {
return;
}
if (ctx->constExp() != nullptr) {
out[cursor++] = ConvertConst(EvalConstAddExp(*ctx->constExp()->addExp()), base_type);
return;
}
for (auto* child : ctx->constInitVal()) {
if (cursor >= object_end) {
break;
}
if (child->constExp() == nullptr && depth + 1 < dims.size()) {
cursor = AlignInitializerCursor(dims, depth, object_begin, cursor);
const auto child_begin = cursor;
const auto child_end = std::min(object_end,
child_begin + CountArrayElements(dims, depth + 1));
FlattenConstInitValImpl(child, base_type, dims, depth + 1, child_begin,
child_end, cursor, out);
cursor = child_end;
} else {
FlattenConstInitValImpl(child, base_type, dims, depth + 1, object_begin,
object_end, cursor, out);
}
}
}
void IRGenImpl::FlattenInitValImpl(SysYParser::InitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims, size_t depth,
size_t object_begin, size_t object_end,
size_t& cursor,
std::vector<ConstantValue>& out) {
if (ctx == nullptr || cursor >= object_end) {
return;
}
if (ctx->exp() != nullptr) {
out[cursor++] = ConvertConst(EvalConstExp(*ctx->exp()), base_type);
return;
}
for (auto* child : ctx->initVal()) {
if (cursor >= object_end) {
break;
}
if (child->exp() == nullptr && depth + 1 < dims.size()) {
cursor = AlignInitializerCursor(dims, depth, object_begin, cursor);
const auto child_begin = cursor;
const auto child_end = std::min(object_end,
child_begin + CountArrayElements(dims, depth + 1));
FlattenInitValImpl(child, base_type, dims, depth + 1, child_begin,
child_end, cursor, out);
cursor = child_end;
} else {
FlattenInitValImpl(child, base_type, dims, depth + 1, object_begin,
object_end, cursor, out);
}
}
}
void IRGenImpl::FlattenLocalInitValImpl(SysYParser::InitValContext* ctx,
const std::vector<int>& dims,
size_t depth, size_t object_begin,
size_t object_end, size_t& cursor,
std::vector<InitExprSlot>& out) {
if (ctx == nullptr || cursor >= object_end) {
return;
}
if (ctx->exp() != nullptr) {
out.push_back({cursor++, ctx->exp()});
return;
}
for (auto* child : ctx->initVal()) {
if (cursor >= object_end) {
break;
}
if (child->exp() == nullptr && depth + 1 < dims.size()) {
cursor = AlignInitializerCursor(dims, depth, object_begin, cursor);
const auto child_begin = cursor;
const auto child_end = std::min(object_end,
child_begin + CountArrayElements(dims, depth + 1));
FlattenLocalInitValImpl(child, dims, depth + 1, child_begin, child_end,
cursor, out);
cursor = child_end;
} else {
FlattenLocalInitValImpl(child, dims, depth + 1, object_begin, object_end,
cursor, out);
}
}
}
size_t IRGenImpl::CountArrayElements(const std::vector<int>& dims, size_t start) const {
size_t count = 1;
for (size_t i = start; i < dims.size(); ++i) {
count *= static_cast<size_t>(dims[i]);
}
return count;
}
size_t IRGenImpl::AlignInitializerCursor(const std::vector<int>& dims,
size_t depth, size_t object_begin,
size_t cursor) const {
if (depth + 1 >= dims.size()) {
return cursor;
}
const auto stride = CountArrayElements(dims, depth + 1);
const auto relative = cursor - object_begin;
return object_begin + ((relative + stride - 1) / stride) * stride;
}
size_t IRGenImpl::FlattenIndices(const std::vector<int>& dims,
const std::vector<int>& indices) const {
size_t offset = 0;
for (size_t i = 0; i < dims.size(); ++i) {
offset *= static_cast<size_t>(dims[i]);
offset += static_cast<size_t>(indices[i]);
}
return offset;
}
bool IRGenImpl::IsZeroConstant(const ConstantValue& value) const {
switch (value.type) {
case SemanticType::Int:
return value.int_value == 0;
case SemanticType::Float: {
std::uint32_t bits = 0;
std::memcpy(&bits, &value.float_value, sizeof(bits));
return bits == 0;
}
case SemanticType::Void:
return false;
}
return false;
}
bool IRGenImpl::IsExplicitZeroConstInitVal(SysYParser::ConstInitValContext* ctx,
SemanticType base_type) {
if (ctx == nullptr) {
return true;
}
if (ctx->constExp() != nullptr) {
return IsZeroConstant(
ConvertConst(EvalConstAddExp(*ctx->constExp()->addExp()), base_type));
}
for (auto* child : ctx->constInitVal()) {
if (!IsExplicitZeroConstInitVal(child, base_type)) {
return false;
}
}
return true;
}
bool IRGenImpl::IsExplicitZeroInitVal(SysYParser::InitValContext* ctx,
SemanticType base_type) {
if (ctx == nullptr) {
return true;
}
if (ctx->exp() != nullptr) {
return IsZeroConstant(ConvertConst(EvalConstExp(*ctx->exp()), base_type));
}
for (auto* child : ctx->initVal()) {
if (!IsExplicitZeroInitVal(child, base_type)) {
return false;
}
}
return true;
}
ConstantValue IRGenImpl::ZeroConst(SemanticType type) const {
ConstantValue value;
value.type = type;
value.int_value = 0;
value.float_value = 0.0f;
return value;
}
ir::Value* IRGenImpl::ZeroIRValue(SemanticType type) {
switch (type) {
case SemanticType::Int:
return builder_.CreateConstInt(0);
case SemanticType::Float:
return builder_.CreateConstFloat(0.0f);
case SemanticType::Void:
break;
}
throw std::runtime_error("void type has no zero IR value");
}
ir::Value* IRGenImpl::CreateTypedConstant(const ConstantValue& value) {
switch (value.type) {
case SemanticType::Int:
return builder_.CreateConstInt(value.int_value);
case SemanticType::Float:
return builder_.CreateConstFloat(value.float_value);
case SemanticType::Void:
break;
}
throw std::runtime_error("void type has no constant value");
}
void IRGenImpl::ZeroInitializeLocalArray(ir::Value* addr, SemanticType base_type,
const std::vector<int>& dims) {
const auto elem_count = CountArrayElements(dims);
int bytes = static_cast<int>(elem_count * (base_type == SemanticType::Float ? 4 : 4));
builder_.CreateMemset(addr, builder_.CreateConstInt(0), builder_.CreateConstInt(bytes),
builder_.CreateConstBool(false));
}
void IRGenImpl::StoreLocalArrayElements(ir::Value* addr, SemanticType base_type,
const std::vector<int>& dims,
const std::vector<InitExprSlot>& init_slots) {
for (const auto& slot : init_slots) {
const auto indices = ExpandLinearIndex(dims, slot.index);
std::vector<ir::Value*> index_values;
index_values.reserve(indices.size());
for (int index : indices) {
index_values.push_back(builder_.CreateConstInt(index));
}
auto* elem_addr = CreateArrayElementAddr(addr, false, base_type, dims,
index_values, slot.expr);
auto value = CastScalar(EmitExp(*slot.expr), base_type, slot.expr);
builder_.CreateStore(value.value, elem_addr);
}
builder_.CreateStore(init, slot);
return {};
}

@ -1,15 +1,11 @@
#include "irgen/IRGen.h"
#include "irgen/IRGen.h"
#include <memory>
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemanticContext& sema) {
auto module = std::make_unique<ir::Module>();
IRGenImpl gen(*module, sema);
tree.accept(&gen);
return module;
}
}

@ -1,80 +1,696 @@
#include "irgen/IRGen.h"
#include <cstdlib>
#include <stdexcept>
#include <utility>
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
bool IRGenImpl::IsNumeric(const TypedValue& value) const {
return !value.is_array && value.type != SemanticType::Void;
}
bool IRGenImpl::IsSameDims(const std::vector<int>& lhs,
const std::vector<int>& rhs) const {
if (rhs.empty()) {
return true;
}
if (lhs == rhs) {
return true;
}
if (lhs.size() == rhs.size() + 1) {
return std::equal(lhs.begin() + 1, lhs.end(), rhs.begin());
}
return false;
}
IRGenImpl::TypedValue IRGenImpl::CastScalar(
TypedValue value, SemanticType target_type,
const antlr4::ParserRuleContext* ctx) {
if (value.is_array) {
ThrowError(ctx, "????????????");
}
if (target_type == SemanticType::Void || value.type == SemanticType::Void) {
ThrowError(ctx, "void ?????????");
}
if (target_type == SemanticType::Int) {
if (value.type == SemanticType::Int) {
if (value.value->GetType()->IsInt1()) {
value.value = builder_.CreateZext(value.value, ir::Type::GetInt32Type(), NextTemp());
}
value.type = SemanticType::Int;
return value;
}
value.value = builder_.CreateFtoI(value.value, NextTemp());
value.type = SemanticType::Int;
return value;
}
if (target_type == SemanticType::Float) {
if (value.type == SemanticType::Float) {
return value;
}
if (value.value->GetType()->IsInt1()) {
value.value = builder_.CreateZext(value.value, ir::Type::GetInt32Type(), NextTemp());
}
value.value = builder_.CreateIToF(value.value, NextTemp());
value.type = SemanticType::Float;
return value;
}
ThrowError(ctx, "????????????");
}
ir::Value* IRGenImpl::CastToCondition(TypedValue value,
const antlr4::ParserRuleContext* ctx) {
if (value.is_array) {
ThrowError(ctx, "?????????");
}
if (value.type == SemanticType::Void) {
ThrowError(ctx, "void ???????");
}
if (value.value->GetType()->IsInt1()) {
return value.value;
}
if (value.type == SemanticType::Int) {
return builder_.CreateICmp(ir::Opcode::ICmpNE, value.value, builder_.CreateConstInt(0),
NextTemp());
}
return builder_.CreateFCmp(ir::Opcode::FCmpNE, value.value,
builder_.CreateConstFloat(0.0f), NextTemp());
}
IRGenImpl::TypedValue IRGenImpl::NormalizeLogicalValue(
TypedValue value, const antlr4::ParserRuleContext* ctx) {
auto* cond = CastToCondition(value, ctx);
return {builder_.CreateZext(cond, ir::Type::GetInt32Type(), NextTemp()),
SemanticType::Int, false, {}};
}
ConstantValue IRGenImpl::ParseNumber(SysYParser::NumberContext& ctx) const {
ConstantValue value;
if (ctx.IntConst() != nullptr) {
value.type = SemanticType::Int;
value.int_value = std::stoi(ctx.getText(), nullptr, 0);
value.float_value = static_cast<float>(value.int_value);
return value;
}
if (ctx.FloatConst() != nullptr) {
value.type = SemanticType::Float;
value.float_value = std::strtof(ctx.getText().c_str(), nullptr);
value.int_value = static_cast<int>(value.float_value);
return value;
}
ThrowError(&ctx, "?????????");
}
ConstantValue IRGenImpl::ConvertConst(ConstantValue value,
SemanticType target_type) const {
if (target_type == SemanticType::Void) {
throw std::runtime_error("void is not a valid constant target type");
}
if (value.type == target_type) {
return value;
}
if (target_type == SemanticType::Int) {
value.int_value = static_cast<int>(value.float_value);
value.type = SemanticType::Int;
return value;
}
value.float_value = static_cast<float>(value.int_value);
value.type = SemanticType::Float;
return value;
}
ConstantValue IRGenImpl::EvalConstExp(SysYParser::ExpContext& ctx) {
return EvalConstAddExp(*ctx.addExp());
}
ConstantValue IRGenImpl::EvalConstAddExp(SysYParser::AddExpContext& ctx) {
if (ctx.addExp() == nullptr) {
return EvalConstMulExp(*ctx.mulExp());
}
auto lhs = EvalConstAddExp(*ctx.addExp());
auto rhs = EvalConstMulExp(*ctx.mulExp());
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = ConvertConst(lhs, SemanticType::Float);
rhs = ConvertConst(rhs, SemanticType::Float);
lhs.float_value = ctx.op->getType() == SysYParser::ADD
? lhs.float_value + rhs.float_value
: lhs.float_value - rhs.float_value;
lhs.int_value = static_cast<int>(lhs.float_value);
lhs.type = SemanticType::Float;
return lhs;
}
lhs.int_value = ctx.op->getType() == SysYParser::ADD ? lhs.int_value + rhs.int_value
: lhs.int_value - rhs.int_value;
lhs.float_value = static_cast<float>(lhs.int_value);
lhs.type = SemanticType::Int;
return lhs;
}
ConstantValue IRGenImpl::EvalConstMulExp(SysYParser::MulExpContext& ctx) {
if (ctx.mulExp() == nullptr) {
return EvalConstUnaryExp(*ctx.unaryExp());
}
auto lhs = EvalConstMulExp(*ctx.mulExp());
auto rhs = EvalConstUnaryExp(*ctx.unaryExp());
if (ctx.op->getType() == SysYParser::MOD &&
(lhs.type == SemanticType::Float || rhs.type == SemanticType::Float)) {
ThrowError(&ctx, "?????? % ??");
}
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = ConvertConst(lhs, SemanticType::Float);
rhs = ConvertConst(rhs, SemanticType::Float);
switch (ctx.op->getType()) {
case SysYParser::MUL:
lhs.float_value *= rhs.float_value;
break;
case SysYParser::DIV:
lhs.float_value /= rhs.float_value;
break;
default:
ThrowError(&ctx, "??????????");
}
lhs.int_value = static_cast<int>(lhs.float_value);
lhs.type = SemanticType::Float;
return lhs;
}
switch (ctx.op->getType()) {
case SysYParser::MUL:
lhs.int_value *= rhs.int_value;
break;
case SysYParser::DIV:
lhs.int_value /= rhs.int_value;
break;
case SysYParser::MOD:
lhs.int_value %= rhs.int_value;
break;
default:
ThrowError(&ctx, "????????");
}
lhs.float_value = static_cast<float>(lhs.int_value);
lhs.type = SemanticType::Int;
return lhs;
}
ConstantValue IRGenImpl::EvalConstUnaryExp(SysYParser::UnaryExpContext& ctx) {
if (ctx.primaryExp() != nullptr) {
return EvalConstPrimaryExp(*ctx.primaryExp());
}
if (ctx.Ident() != nullptr) {
ThrowError(&ctx, "?????????????");
}
auto operand = EvalConstUnaryExp(*ctx.unaryExp());
if (ctx.unaryOp()->ADD() != nullptr) {
return operand;
}
if (ctx.unaryOp()->SUB() != nullptr) {
if (operand.type == SemanticType::Float) {
operand.float_value = -operand.float_value;
operand.int_value = static_cast<int>(operand.float_value);
} else {
operand.int_value = -operand.int_value;
operand.float_value = static_cast<float>(operand.int_value);
}
return operand;
}
if (ctx.unaryOp()->NOT() != nullptr) {
const bool truthy = operand.type == SemanticType::Float ? operand.float_value != 0.0f
: operand.int_value != 0;
ConstantValue result;
result.type = SemanticType::Int;
result.int_value = truthy ? 0 : 1;
result.float_value = static_cast<float>(result.int_value);
return result;
}
ThrowError(&ctx, "???????");
}
ConstantValue IRGenImpl::EvalConstPrimaryExp(SysYParser::PrimaryExpContext& ctx) {
if (ctx.exp() != nullptr) {
return EvalConstExp(*ctx.exp());
}
if (ctx.number() != nullptr) {
return ParseNumber(*ctx.number());
}
if (ctx.lVal() != nullptr) {
return EvalConstLVal(*ctx.lVal());
}
ThrowError(&ctx, "???? primaryExp");
}
ConstantValue IRGenImpl::EvalConstLVal(SysYParser::LValContext& ctx) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr) {
ThrowError(&ctx, "?????: " + name);
}
if (!symbol->is_const) {
ThrowError(&ctx, "?????????????: " + name);
}
if (!symbol->is_array) {
if (!ctx.exp().empty()) {
ThrowError(&ctx, "??????????");
}
if (!symbol->const_scalar.has_value()) {
ThrowError(&ctx, "?????: " + name);
}
return *symbol->const_scalar;
}
if (ctx.exp().size() != symbol->dims.size()) {
ThrowError(&ctx, "???????????????????: " + name);
}
std::vector<int> indices;
indices.reserve(ctx.exp().size());
for (auto* exp_ctx : ctx.exp()) {
auto index = ConvertConst(EvalConstExp(*exp_ctx), SemanticType::Int);
indices.push_back(index.int_value);
}
for (size_t i = 0; i < indices.size(); ++i) {
if (indices[i] < 0 || indices[i] >= symbol->dims[i]) {
ThrowError(&ctx, "????????: " + name);
}
}
const auto offset = FlattenIndices(symbol->dims, indices);
if (symbol->const_array_all_zero) {
return ZeroConst(symbol->type);
}
if (offset >= symbol->const_array.size()) {
ThrowError(&ctx, "???????????: " + name);
}
return symbol->const_array[offset];
}
IRGenImpl::TypedValue IRGenImpl::EmitExp(SysYParser::ExpContext& ctx) {
return EmitAddExp(*ctx.addExp());
}
IRGenImpl::TypedValue IRGenImpl::EmitAddExp(SysYParser::AddExpContext& ctx) {
if (ctx.addExp() == nullptr) {
return EmitMulExp(*ctx.mulExp());
}
auto lhs = EmitAddExp(*ctx.addExp());
auto rhs = EmitMulExp(*ctx.mulExp());
if (!IsNumeric(lhs) || !IsNumeric(rhs)) {
ThrowError(&ctx, "???????????");
}
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = CastScalar(lhs, SemanticType::Float, &ctx);
rhs = CastScalar(rhs, SemanticType::Float, &ctx);
auto* value = ctx.op->getType() == SysYParser::ADD
? builder_.CreateBinary(ir::Opcode::FAdd, lhs.value, rhs.value,
NextTemp())
: builder_.CreateBinary(ir::Opcode::FSub, lhs.value, rhs.value,
NextTemp());
return {value, SemanticType::Float, false, {}};
}
// 表达式生成当前也只实现了很小的一个子集。
// 目前支持:
// - 整数字面量
// - 普通局部变量读取
// - 括号表达式
// - 二元加法
//
// 还未支持:
// - 减乘除与一元运算
// - 赋值表达式
// - 函数调用
// - 数组、指针、下标访问
// - 条件与比较表达式
// - ...
ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
return std::any_cast<ir::Value*>(expr.accept(this));
lhs = CastScalar(lhs, SemanticType::Int, &ctx);
rhs = CastScalar(rhs, SemanticType::Int, &ctx);
auto* value = ctx.op->getType() == SysYParser::ADD
? builder_.CreateAdd(lhs.value, rhs.value, NextTemp())
: builder_.CreateSub(lhs.value, rhs.value, NextTemp());
return {value, SemanticType::Int, false, {}};
}
IRGenImpl::TypedValue IRGenImpl::EmitMulExp(SysYParser::MulExpContext& ctx) {
if (ctx.mulExp() == nullptr) {
return EmitUnaryExp(*ctx.unaryExp());
}
auto lhs = EmitMulExp(*ctx.mulExp());
auto rhs = EmitUnaryExp(*ctx.unaryExp());
if (!IsNumeric(lhs) || !IsNumeric(rhs)) {
ThrowError(&ctx, "????????????");
}
if (ctx.op->getType() == SysYParser::MOD &&
(lhs.type == SemanticType::Float || rhs.type == SemanticType::Float)) {
ThrowError(&ctx, "?????? % ??");
}
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = CastScalar(lhs, SemanticType::Float, &ctx);
rhs = CastScalar(rhs, SemanticType::Float, &ctx);
ir::Opcode opcode = ir::Opcode::FMul;
if (ctx.op->getType() == SysYParser::DIV) {
opcode = ir::Opcode::FDiv;
} else if (ctx.op->getType() == SysYParser::MUL) {
opcode = ir::Opcode::FMul;
} else {
ThrowError(&ctx, "?????????");
}
return {builder_.CreateBinary(opcode, lhs.value, rhs.value, NextTemp()),
SemanticType::Float, false, {}};
}
std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "非法括号表达式"));
lhs = CastScalar(lhs, SemanticType::Int, &ctx);
rhs = CastScalar(rhs, SemanticType::Int, &ctx);
ir::Value* value = nullptr;
switch (ctx.op->getType()) {
case SysYParser::MUL:
value = builder_.CreateMul(lhs.value, rhs.value, NextTemp());
break;
case SysYParser::DIV:
value = builder_.CreateDiv(lhs.value, rhs.value, NextTemp());
break;
case SysYParser::MOD:
value = builder_.CreateRem(lhs.value, rhs.value, NextTemp());
break;
default:
ThrowError(&ctx, "???????");
}
return EvalExpr(*ctx->exp());
return {value, SemanticType::Int, false, {}};
}
IRGenImpl::TypedValue IRGenImpl::EmitUnaryExp(SysYParser::UnaryExpContext& ctx) {
if (ctx.primaryExp() != nullptr) {
return EmitPrimaryExp(*ctx.primaryExp());
}
if (ctx.Ident() != nullptr) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr || symbol->kind != SymbolKind::Function || symbol->function == nullptr) {
ThrowError(&ctx, "????????: " + name);
}
const auto& function_type = symbol->function_type;
std::vector<ir::Value*> args;
std::vector<SysYParser::ExpContext*> arg_exprs;
if (ctx.funcRParams() != nullptr) {
arg_exprs = ctx.funcRParams()->exp();
}
if (arg_exprs.size() != function_type.param_types.size()) {
ThrowError(&ctx, "?????????: " + name);
}
for (size_t i = 0; i < arg_exprs.size(); ++i) {
auto arg = EmitExp(*arg_exprs[i]);
if (i < function_type.param_is_array.size() && function_type.param_is_array[i]) {
if (!arg.is_array || !IsSameDims(arg.dims, function_type.param_dims[i])) {
ThrowError(arg_exprs[i], "????????????: " + name);
}
args.push_back(arg.value);
} else {
if (arg.is_array) {
ThrowError(arg_exprs[i], "??????????: " + name);
}
arg = CastScalar(arg, function_type.param_types[i], arg_exprs[i]);
args.push_back(arg.value);
}
}
if (function_type.return_type == SemanticType::Void) {
builder_.CreateCall(symbol->function, args);
return {nullptr, SemanticType::Void, false, {}};
}
return {builder_.CreateCall(symbol->function, args, NextTemp()),
function_type.return_type, false, {}};
}
auto operand = EmitUnaryExp(*ctx.unaryExp());
if (!IsNumeric(operand)) {
ThrowError(&ctx, "???????????");
}
std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) {
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量"));
if (ctx.unaryOp()->ADD() != nullptr) {
return operand;
}
return static_cast<ir::Value*>(
builder_.CreateConstInt(std::stoi(ctx->number()->getText())));
if (ctx.unaryOp()->SUB() != nullptr) {
if (operand.type == SemanticType::Float) {
return {builder_.CreateFNeg(operand.value, NextTemp()), SemanticType::Float,
false, {}};
}
operand = CastScalar(operand, SemanticType::Int, &ctx);
return {builder_.CreateSub(builder_.CreateConstInt(0), operand.value, NextTemp()),
SemanticType::Int, false, {}};
}
if (ctx.unaryOp()->NOT() != nullptr) {
auto* cond = CastToCondition(operand, &ctx);
auto* inverted = builder_.CreateXor(cond, builder_.CreateConstBool(true), NextTemp());
return {builder_.CreateZext(inverted, ir::Type::GetInt32Type(), NextTemp()),
SemanticType::Int, false, {}};
}
ThrowError(&ctx, "???????");
}
// 变量使用的处理流程:
// 1. 先通过语义分析结果把变量使用绑定回声明;
// 2. 再通过 storage_map_ 找到该声明对应的栈槽位;
// 3. 最后生成 load把内存中的值读出来。
//
// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。
std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) {
if (!ctx || !ctx->var() || !ctx->var()->ID()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量"));
IRGenImpl::TypedValue IRGenImpl::EmitPrimaryExp(SysYParser::PrimaryExpContext& ctx) {
if (ctx.exp() != nullptr) {
return EmitExp(*ctx.exp());
}
auto* decl = sema_.ResolveVarUse(ctx->var());
if (!decl) {
throw std::runtime_error(
FormatError("irgen",
"变量使用缺少语义绑定: " + ctx->var()->ID()->getText()));
if (ctx.number() != nullptr) {
auto number = ParseNumber(*ctx.number());
return {CreateTypedConstant(number), number.type, false, {}};
}
auto it = storage_map_.find(decl);
if (it == storage_map_.end()) {
throw std::runtime_error(
FormatError("irgen",
"变量声明缺少存储槽位: " + ctx->var()->ID()->getText()));
if (ctx.lVal() != nullptr) {
return EmitLValValue(*ctx.lVal());
}
return static_cast<ir::Value*>(
builder_.CreateLoad(it->second, module_.GetContext().NextTemp()));
ThrowError(&ctx, "?? primaryExp");
}
IRGenImpl::TypedValue IRGenImpl::EmitRelExp(SysYParser::RelExpContext& ctx) {
if (ctx.relExp() == nullptr) {
return EmitAddExp(*ctx.addExp());
}
auto lhs = EmitRelExp(*ctx.relExp());
auto rhs = EmitAddExp(*ctx.addExp());
if (!IsNumeric(lhs) || !IsNumeric(rhs)) {
ThrowError(&ctx, "???????????");
}
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = CastScalar(lhs, SemanticType::Float, &ctx);
rhs = CastScalar(rhs, SemanticType::Float, &ctx);
ir::Opcode opcode = ir::Opcode::FCmpLT;
switch (ctx.op->getType()) {
case SysYParser::LT:
opcode = ir::Opcode::FCmpLT;
break;
case SysYParser::GT:
opcode = ir::Opcode::FCmpGT;
break;
case SysYParser::LE:
opcode = ir::Opcode::FCmpLE;
break;
case SysYParser::GE:
opcode = ir::Opcode::FCmpGE;
break;
default:
ThrowError(&ctx, "????????");
}
return {builder_.CreateFCmp(opcode, lhs.value, rhs.value, NextTemp()),
SemanticType::Int, false, {}};
}
lhs = CastScalar(lhs, SemanticType::Int, &ctx);
rhs = CastScalar(rhs, SemanticType::Int, &ctx);
ir::Opcode opcode = ir::Opcode::ICmpLT;
switch (ctx.op->getType()) {
case SysYParser::LT:
opcode = ir::Opcode::ICmpLT;
break;
case SysYParser::GT:
opcode = ir::Opcode::ICmpGT;
break;
case SysYParser::LE:
opcode = ir::Opcode::ICmpLE;
break;
case SysYParser::GE:
opcode = ir::Opcode::ICmpGE;
break;
default:
ThrowError(&ctx, "????????");
}
return {builder_.CreateICmp(opcode, lhs.value, rhs.value, NextTemp()),
SemanticType::Int, false, {}};
}
IRGenImpl::TypedValue IRGenImpl::EmitEqExp(SysYParser::EqExpContext& ctx) {
if (ctx.eqExp() == nullptr) {
return EmitRelExp(*ctx.relExp());
}
auto lhs = EmitEqExp(*ctx.eqExp());
auto rhs = EmitRelExp(*ctx.relExp());
if (!IsNumeric(lhs) || !IsNumeric(rhs)) {
ThrowError(&ctx, "???????????");
}
if (lhs.type == SemanticType::Float || rhs.type == SemanticType::Float) {
lhs = CastScalar(lhs, SemanticType::Float, &ctx);
rhs = CastScalar(rhs, SemanticType::Float, &ctx);
const auto opcode = ctx.op->getType() == SysYParser::EQ ? ir::Opcode::FCmpEQ
: ir::Opcode::FCmpNE;
return {builder_.CreateFCmp(opcode, lhs.value, rhs.value, NextTemp()),
SemanticType::Int, false, {}};
}
lhs = CastScalar(lhs, SemanticType::Int, &ctx);
rhs = CastScalar(rhs, SemanticType::Int, &ctx);
const auto opcode = ctx.op->getType() == SysYParser::EQ ? ir::Opcode::ICmpEQ
: ir::Opcode::ICmpNE;
return {builder_.CreateICmp(opcode, lhs.value, rhs.value, NextTemp()),
SemanticType::Int, false, {}};
}
std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) {
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
IRGenImpl::LValueInfo IRGenImpl::ResolveLVal(SysYParser::LValContext& ctx) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr) {
ThrowError(&ctx, "?????????: " + name);
}
ir::Value* lhs = EvalExpr(*ctx->exp(0));
ir::Value* rhs = EvalExpr(*ctx->exp(1));
return static_cast<ir::Value*>(
builder_.CreateBinary(ir::Opcode::Add, lhs, rhs,
module_.GetContext().NextTemp()));
if (symbol->kind == SymbolKind::Function) {
ThrowError(&ctx, "????????: " + name);
}
std::vector<ir::Value*> index_values;
index_values.reserve(ctx.exp().size());
for (auto* exp_ctx : ctx.exp()) {
auto index = CastScalar(EmitExp(*exp_ctx), SemanticType::Int, exp_ctx);
if (index.is_array) {
ThrowError(exp_ctx, "????????");
}
index_values.push_back(index.value);
}
if (!symbol->is_array) {
if (!index_values.empty()) {
ThrowError(&ctx, "??????????: " + name);
}
return {symbol, symbol->ir_value, symbol->type, false, {}, false};
}
std::vector<int> selected_dims;
if (symbol->is_param_array) {
if (index_values.size() > symbol->dims.size() + 1) {
ThrowError(&ctx, "????????: " + name);
}
if (index_values.empty()) {
selected_dims = symbol->dims;
} else if (index_values.size() <= 1) {
selected_dims = symbol->dims;
} else {
selected_dims.assign(symbol->dims.begin() + static_cast<long long>(index_values.size() - 1),
symbol->dims.end());
}
} else {
if (index_values.size() > symbol->dims.size()) {
ThrowError(&ctx, "??????: " + name);
}
selected_dims.assign(symbol->dims.begin() + static_cast<long long>(index_values.size()),
symbol->dims.end());
}
ir::Value* addr = symbol->ir_value;
if (!index_values.empty()) {
addr = CreateArrayElementAddr(symbol->ir_value, symbol->is_param_array, symbol->type,
symbol->dims, index_values, &ctx);
}
const bool root_param_array_no_index = symbol->is_param_array && index_values.empty();
const bool still_array = !selected_dims.empty() || root_param_array_no_index;
return {symbol, addr, symbol->type, still_array, selected_dims,
root_param_array_no_index};
}
ir::Value* IRGenImpl::GenLValAddr(SysYParser::LValContext& ctx) {
auto info = ResolveLVal(ctx);
if (info.is_array) {
ThrowError(&ctx, "?????????????");
}
return info.addr;
}
IRGenImpl::TypedValue IRGenImpl::EmitLValValue(SysYParser::LValContext& ctx) {
auto info = ResolveLVal(ctx);
if (!info.is_array) {
if (info.symbol != nullptr && info.symbol->const_scalar.has_value()) {
return {CreateTypedConstant(*info.symbol->const_scalar), info.type, false, {}};
}
return {builder_.CreateLoad(info.addr, GetIRScalarType(info.type), NextTemp()),
info.type, false, {}};
}
if (info.root_param_array_no_index) {
return {info.addr, info.type, true, info.dims};
}
auto* decayed = builder_.CreateGEP(info.addr, BuildArrayType(info.type, info.dims),
{builder_.CreateConstInt(0), builder_.CreateConstInt(0)},
NextTemp());
std::vector<int> decay_dims;
if (!info.dims.empty()) {
decay_dims.assign(info.dims.begin() + 1, info.dims.end());
}
return {decayed, info.type, true, decay_dims};
}
ir::Value* IRGenImpl::CreateArrayElementAddr(
ir::Value* base_addr, bool is_param_array, SemanticType base_type,
const std::vector<int>& dims, const std::vector<ir::Value*>& indices,
const antlr4::ParserRuleContext* ctx) {
if (base_addr == nullptr) {
ThrowError(ctx, "???????");
}
if (indices.empty()) {
return base_addr;
}
std::vector<ir::Value*> gep_indices;
if (!is_param_array) {
gep_indices.push_back(builder_.CreateConstInt(0));
}
gep_indices.insert(gep_indices.end(), indices.begin(), indices.end());
auto source_type = dims.empty() ? GetIRScalarType(base_type) : BuildArrayType(base_type, dims);
return builder_.CreateGEP(base_addr, source_type, gep_indices, NextTemp());
}
void IRGenImpl::EmitCond(SysYParser::CondContext& ctx, ir::BasicBlock* true_block,
ir::BasicBlock* false_block) {
EmitLOrCond(*ctx.lOrExp(), true_block, false_block);
}
void IRGenImpl::EmitLOrCond(SysYParser::LOrExpContext& ctx,
ir::BasicBlock* true_block,
ir::BasicBlock* false_block) {
if (ctx.lOrExp() == nullptr) {
EmitLAndCond(*ctx.lAndExp(), true_block, false_block);
return;
}
auto* rhs_block = current_function_->CreateBlock(NextBlockName("lor.rhs"));
EmitLOrCond(*ctx.lOrExp(), true_block, rhs_block);
builder_.SetInsertPoint(rhs_block);
EmitLAndCond(*ctx.lAndExp(), true_block, false_block);
}
void IRGenImpl::EmitLAndCond(SysYParser::LAndExpContext& ctx,
ir::BasicBlock* true_block,
ir::BasicBlock* false_block) {
if (ctx.lAndExp() == nullptr) {
auto cond = EmitEqExp(*ctx.eqExp());
builder_.CreateCondBr(CastToCondition(cond, &ctx), true_block, false_block);
return;
}
auto* rhs_block = current_function_->CreateBlock(NextBlockName("land.rhs"));
EmitLAndCond(*ctx.lAndExp(), rhs_block, false_block);
builder_.SetInsertPoint(rhs_block);
auto cond = EmitEqExp(*ctx.eqExp());
builder_.CreateCondBr(CastToCondition(cond, &ctx), true_block, false_block);
}

@ -1,87 +1,268 @@
#include "irgen/IRGen.h"
#include <stdexcept>
#include <utility>
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
namespace {
IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema)
: module_(module), sema_(sema), builder_(module.GetContext(), nullptr) {}
void VerifyFunctionStructure(const ir::Function& func) {
// 当前 IRGen 仍是单入口、顺序生成;这里在生成结束后补一层块终结校验。
for (const auto& bb : func.GetBlocks()) {
if (!bb || !bb->HasTerminator()) {
throw std::runtime_error(
FormatError("irgen", "基本块未正确终结: " +
(bb ? bb->GetName() : std::string("<null>"))));
}
[[noreturn]] void IRGenImpl::ThrowError(
const antlr4::ParserRuleContext* ctx, const std::string& message) const {
if (ctx != nullptr && ctx->getStart() != nullptr) {
throw std::runtime_error(FormatErrorAt("irgen",
static_cast<size_t>(ctx->getStart()->getLine()),
static_cast<size_t>(ctx->getStart()->getCharPositionInLine() + 1),
message));
}
throw std::runtime_error(FormatError("irgen", message));
}
} // namespace
std::string IRGenImpl::NextTemp() { return module_.GetContext().NextTemp(); }
std::string IRGenImpl::NextBlockName(const std::string& prefix) {
return module_.GetContext().NextBlockName(prefix);
}
void IRGenImpl::ApplyFunctionSema(const std::string& name, ir::Function& function) {
const auto* info = sema_.LookupFunction(name);
if (info == nullptr) {
return;
}
function.SetEffectInfo(info->reads_global_memory, info->writes_global_memory,
info->reads_param_memory, info->writes_param_memory,
info->has_io, info->has_unknown_effects,
info->is_recursive);
}
IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema)
: module_(module),
sema_(sema),
func_(nullptr),
builder_(module.GetContext(), nullptr) {}
// 编译单元的 IR 生成当前只实现了最小功能:
// - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容;
// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR
//
// 当前还没有实现:
// - 多个函数定义的遍历与生成;
// - 全局变量、全局常量的 IR 生成。
std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少编译单元"));
if (ctx == nullptr) {
ThrowError(ctx, "??????");
}
auto* func = ctx->funcDef();
if (!func) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
symbols_.Clear();
symbols_.EnterScope();
RegisterBuiltinFunctions();
PredeclareTopLevel(*ctx);
for (auto* child : ctx->children) {
if (auto* decl = dynamic_cast<SysYParser::DeclContext*>(child)) {
EmitGlobalDecl(*decl);
} else if (auto* func = dynamic_cast<SysYParser::FuncDefContext*>(child)) {
EmitFunction(*func);
}
}
func->accept(this);
symbols_.ExitScope();
return {};
}
// 函数 IR 生成当前实现了:
// 1. 获取函数名;
// 2. 检查函数返回类型;
// 3. 在 Module 中创建 Function
// 4. 将 builder 插入点设置到入口基本块;
// 5. 继续生成函数体。
//
// 当前还没有实现:
// - 通用函数返回类型处理;
// - 形参列表遍历与参数类型收集;
// - FunctionType 这样的函数类型对象;
// - Argument/形式参数 IR 对象;
// - 入口块中的参数初始化逻辑。
// ...
// 因此这里目前只支持最小的“无参 int 函数”生成。
std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
void IRGenImpl::RegisterBuiltinFunctions() {
if (builtins_registered_) {
return;
}
if (!ctx->blockStmt()) {
throw std::runtime_error(FormatError("irgen", "函数体为空"));
struct BuiltinSpec {
const char* name;
SemanticType return_type;
std::vector<SemanticType> param_types;
std::vector<bool> param_is_array;
std::vector<std::vector<int>> param_dims;
};
const std::vector<BuiltinSpec> builtins = {
{"getint", SemanticType::Int, {}, {}, {}},
{"getch", SemanticType::Int, {}, {}, {}},
{"getfloat", SemanticType::Float, {}, {}, {}},
{"getarray", SemanticType::Int, {SemanticType::Int}, {true}, {std::vector<int>{}}},
{"getfarray", SemanticType::Int, {SemanticType::Float}, {true}, {std::vector<int>{}}},
{"putint", SemanticType::Void, {SemanticType::Int}, {false}, {std::vector<int>{}}},
{"putch", SemanticType::Void, {SemanticType::Int}, {false}, {std::vector<int>{}}},
{"putfloat", SemanticType::Void, {SemanticType::Float}, {false}, {std::vector<int>{}}},
{"putarray", SemanticType::Void, {SemanticType::Int, SemanticType::Int}, {false, true}, {std::vector<int>{}, std::vector<int>{}}},
{"putfarray", SemanticType::Void, {SemanticType::Int, SemanticType::Float}, {false, true}, {std::vector<int>{}, std::vector<int>{}}},
{"starttime", SemanticType::Void, {}, {}, {}},
{"stoptime", SemanticType::Void, {}, {}, {}},
};
for (const auto& builtin : builtins) {
FunctionTypeInfo function_type;
function_type.return_type = builtin.return_type;
function_type.param_types = builtin.param_types;
function_type.param_is_array = builtin.param_is_array;
function_type.param_dims = builtin.param_dims;
std::vector<std::shared_ptr<ir::Type>> ir_param_types;
std::vector<std::string> ir_param_names;
for (size_t i = 0; i < builtin.param_types.size(); ++i) {
if (i < builtin.param_is_array.size() && builtin.param_is_array[i]) {
ir_param_types.push_back(ir::Type::GetPointerType());
} else {
ir_param_types.push_back(GetIRScalarType(builtin.param_types[i]));
}
ir_param_names.push_back("%arg" + std::to_string(i));
}
auto* function = module_.CreateFunction(
builtin.name, GetIRScalarType(builtin.return_type), ir_param_types,
ir_param_names, true);
ApplyFunctionSema(builtin.name, *function);
SymbolEntry entry;
entry.kind = SymbolKind::Function;
entry.type = builtin.return_type;
entry.function = function;
entry.ir_value = function;
entry.function_type = std::move(function_type);
symbols_.Insert(builtin.name, entry);
}
if (!ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "缺少函数名"));
builtins_registered_ = true;
}
void IRGenImpl::PredeclareTopLevel(SysYParser::CompUnitContext& ctx) {
for (auto* child : ctx.children) {
if (auto* decl = dynamic_cast<SysYParser::DeclContext*>(child)) {
PredeclareGlobalDecl(*decl);
} else if (auto* func = dynamic_cast<SysYParser::FuncDefContext*>(child)) {
PredeclareFunction(*func);
}
}
if (!ctx->funcType() || !ctx->funcType()->INT()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数"));
}
FunctionTypeInfo IRGenImpl::BuildFunctionTypeInfo(
SysYParser::FuncDefContext& ctx) {
FunctionTypeInfo function_type;
function_type.return_type = ParseFuncType(ctx.funcType());
if (auto* params = ctx.funcFParams()) {
for (auto* param : params->funcFParam()) {
const auto type = ParseBType(param->bType());
const auto dims = param->LBRACK().empty() ? std::vector<int>{} : ParseParamDims(*param);
function_type.param_types.push_back(type);
function_type.param_is_array.push_back(!param->LBRACK().empty());
function_type.param_dims.push_back(dims);
}
}
func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type());
builder_.SetInsertPoint(func_->GetEntry());
storage_map_.clear();
return function_type;
}
ctx->blockStmt()->accept(this);
// 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。
VerifyFunctionStructure(*func_);
return {};
std::vector<std::shared_ptr<ir::Type>> IRGenImpl::BuildFunctionIRParamTypes(
const FunctionTypeInfo& function_type) const {
std::vector<std::shared_ptr<ir::Type>> param_types;
for (size_t i = 0; i < function_type.param_types.size(); ++i) {
if (i < function_type.param_is_array.size() && function_type.param_is_array[i]) {
param_types.push_back(ir::Type::GetPointerType());
} else {
param_types.push_back(GetIRScalarType(function_type.param_types[i]));
}
}
return param_types;
}
std::vector<std::string> IRGenImpl::BuildFunctionIRParamNames(
SysYParser::FuncDefContext& ctx) const {
std::vector<std::string> param_names;
if (auto* params = ctx.funcFParams()) {
for (auto* param : params->funcFParam()) {
param_names.push_back("%" + ExpectIdent(*param, param->Ident()));
}
}
return param_names;
}
void IRGenImpl::PredeclareFunction(SysYParser::FuncDefContext& ctx) {
const auto name = ExpectIdent(ctx, ctx.Ident());
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(&ctx, "????????: " + name);
}
auto function_type = BuildFunctionTypeInfo(ctx);
auto* function = module_.CreateFunction(
name, GetIRScalarType(function_type.return_type),
BuildFunctionIRParamTypes(function_type), BuildFunctionIRParamNames(ctx), false);
ApplyFunctionSema(name, *function);
SymbolEntry entry;
entry.kind = SymbolKind::Function;
entry.type = function_type.return_type;
entry.function = function;
entry.ir_value = function;
entry.function_type = std::move(function_type);
symbols_.Insert(name, entry);
}
void IRGenImpl::EmitFunction(SysYParser::FuncDefContext& ctx) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr || symbol->kind != SymbolKind::Function || symbol->function == nullptr) {
ThrowError(&ctx, "????????: " + name);
}
current_function_ = symbol->function;
current_return_type_ = symbol->function_type.return_type;
current_function_->SetExternal(false);
auto* entry_block = current_function_->EnsureEntryBlock();
builder_.SetInsertPoint(entry_block);
symbols_.EnterScope();
BindFunctionParams(ctx, *current_function_);
EmitBlock(*ctx.block(), false);
symbols_.ExitScope();
auto* insert_block = builder_.GetInsertBlock();
if (insert_block != nullptr && !insert_block->HasTerminator()) {
if (current_return_type_ == SemanticType::Void) {
builder_.CreateRet();
} else {
builder_.CreateRet(ZeroIRValue(current_return_type_));
}
}
builder_.SetInsertPoint(nullptr);
current_function_ = nullptr;
current_return_type_ = SemanticType::Void;
loop_stack_.clear();
}
void IRGenImpl::BindFunctionParams(SysYParser::FuncDefContext& ctx, ir::Function& func) {
if (ctx.funcFParams() == nullptr) {
return;
}
const auto& params = ctx.funcFParams()->funcFParam();
for (size_t i = 0; i < params.size(); ++i) {
auto* param = params[i];
const auto name = ExpectIdent(*param, param->Ident());
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(param, "??????: " + name);
}
SymbolEntry entry;
entry.kind = SymbolKind::Variable;
entry.type = ParseBType(param->bType());
entry.is_const = false;
entry.is_array = !param->LBRACK().empty();
entry.is_param_array = entry.is_array;
entry.dims = entry.is_array ? ParseParamDims(*param) : std::vector<int>{};
auto* arg = func.GetArgument(i);
if (arg == nullptr) {
ThrowError(param, "????????: " + name);
}
if (entry.is_array) {
entry.ir_value = arg;
} else {
auto* slot = CreateEntryAlloca(GetIRScalarType(entry.type), NextTemp());
builder_.CreateStore(arg, slot);
entry.ir_value = slot;
}
if (!symbols_.Insert(name, entry)) {
ThrowError(param, "??????: " + name);
}
}
}

@ -1,39 +1,159 @@
#include "irgen/IRGen.h"
#include "irgen/IRGen.h"
#include <stdexcept>
void IRGenImpl::EmitBlock(SysYParser::BlockContext& ctx, bool create_scope) {
if (create_scope) {
symbols_.EnterScope();
}
for (auto* item : ctx.blockItem()) {
if (item != nullptr && EmitBlockItem(*item) == FlowState::Terminated) {
break;
}
}
if (create_scope) {
symbols_.ExitScope();
}
}
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
IRGenImpl::FlowState IRGenImpl::EmitBlockItem(SysYParser::BlockItemContext& ctx) {
if (ctx.decl() != nullptr) {
EmitDecl(*ctx.decl(), false);
return FlowState::Continue;
}
if (ctx.stmt() != nullptr) {
return EmitStmt(*ctx.stmt());
}
ThrowError(&ctx, "??????");
}
// 语句生成当前只实现了最小子集。
// 目前支持:
// - return <exp>;
//
// 还未支持:
// - 赋值语句
// - if / while 等控制流
// - 空语句、块语句嵌套分发之外的更多语句形态
IRGenImpl::FlowState IRGenImpl::EmitStmt(SysYParser::StmtContext& ctx) {
auto branch_terminated = [this]() {
auto* block = builder_.GetInsertBlock();
return block == nullptr || block->HasTerminator();
};
std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句"));
if (ctx.ASSIGN() != nullptr) {
auto lhs = ResolveLVal(*ctx.lVal());
if (lhs.is_array) {
ThrowError(&ctx, "????????");
}
if (lhs.symbol != nullptr && lhs.symbol->is_const) {
ThrowError(&ctx, "??? const ????");
}
auto rhs = CastScalar(EmitExp(*ctx.exp()), lhs.type, ctx.exp());
builder_.CreateStore(rhs.value, lhs.addr);
return FlowState::Continue;
}
if (ctx->returnStmt()) {
return ctx->returnStmt()->accept(this);
if (ctx.RETURN() != nullptr) {
if (current_return_type_ == SemanticType::Void) {
if (ctx.exp() != nullptr) {
ThrowError(&ctx, "void ?????????");
}
builder_.CreateRet();
} else {
if (ctx.exp() == nullptr) {
ThrowError(&ctx, "? void ?????????");
}
auto value = CastScalar(EmitExp(*ctx.exp()), current_return_type_, ctx.exp());
builder_.CreateRet(value.value);
}
return FlowState::Terminated;
}
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型"));
}
if (ctx.block() != nullptr) {
EmitBlock(*ctx.block(), true);
return branch_terminated() ? FlowState::Terminated : FlowState::Continue;
}
if (ctx.IF() != nullptr) {
auto* then_block = current_function_->CreateBlock(NextBlockName("if.then"));
if (ctx.ELSE() == nullptr) {
auto* end_block = current_function_->CreateBlock(NextBlockName("if.end"));
EmitCond(*ctx.cond(), then_block, end_block);
builder_.SetInsertPoint(then_block);
auto then_state = EmitStmt(*ctx.stmt(0));
if (then_state != FlowState::Terminated && !branch_terminated()) {
builder_.CreateBr(end_block);
}
builder_.SetInsertPoint(end_block);
return FlowState::Continue;
}
auto* else_block = current_function_->CreateBlock(NextBlockName("if.else"));
EmitCond(*ctx.cond(), then_block, else_block);
ir::BasicBlock* end_block = nullptr;
std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少 return 语句"));
builder_.SetInsertPoint(then_block);
auto then_state = EmitStmt(*ctx.stmt(0));
const bool then_terminated = then_state == FlowState::Terminated || branch_terminated();
if (!then_terminated) {
if (end_block == nullptr) {
end_block = current_function_->CreateBlock(NextBlockName("if.end"));
}
builder_.CreateBr(end_block);
}
builder_.SetInsertPoint(else_block);
auto else_state = EmitStmt(*ctx.stmt(1));
const bool else_terminated = else_state == FlowState::Terminated || branch_terminated();
if (!else_terminated) {
if (end_block == nullptr) {
end_block = current_function_->CreateBlock(NextBlockName("if.end"));
}
builder_.CreateBr(end_block);
}
if (end_block == nullptr) {
builder_.SetInsertPoint(nullptr);
return FlowState::Terminated;
}
builder_.SetInsertPoint(end_block);
return FlowState::Continue;
}
if (ctx.WHILE() != nullptr) {
auto* cond_block = current_function_->CreateBlock(NextBlockName("while.cond"));
auto* body_block = current_function_->CreateBlock(NextBlockName("while.body"));
auto* end_block = current_function_->CreateBlock(NextBlockName("while.end"));
builder_.CreateBr(cond_block);
builder_.SetInsertPoint(cond_block);
EmitCond(*ctx.cond(), body_block, end_block);
loop_stack_.push_back({cond_block, end_block});
builder_.SetInsertPoint(body_block);
auto body_state = EmitStmt(*ctx.stmt(0));
if (body_state != FlowState::Terminated && !branch_terminated()) {
builder_.CreateBr(cond_block);
}
loop_stack_.pop_back();
builder_.SetInsertPoint(end_block);
return FlowState::Continue;
}
if (!ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "return 缺少表达式"));
if (ctx.BREAK() != nullptr) {
if (loop_stack_.empty()) {
ThrowError(&ctx, "break ????? while ???");
}
builder_.CreateBr(loop_stack_.back().exit_block);
return FlowState::Terminated;
}
ir::Value* v = EvalExpr(*ctx->exp());
builder_.CreateRet(v);
return BlockFlow::Terminated;
}
if (ctx.CONTINUE() != nullptr) {
if (loop_stack_.empty()) {
ThrowError(&ctx, "continue ????? while ???");
}
builder_.CreateBr(loop_stack_.back().cond_block);
return FlowState::Terminated;
}
if (ctx.exp() != nullptr) {
(void)EmitExp(*ctx.exp());
}
return FlowState::Continue;
}

@ -1,68 +1,86 @@
#include <exception>
#include <iostream>
#include <stdexcept>
#include "frontend/AntlrDriver.h"
#include "frontend/SyntaxTreePrinter.h"
#if !COMPILER_PARSE_ONLY
#include "ir/IR.h"
#include "irgen/IRGen.h"
#include "mir/MIR.h"
#include "sem/Sema.h"
#endif
#include "utils/CLI.h"
#include "utils/Log.h"
int main(int argc, char** argv) {
try {
auto opts = ParseCLI(argc, argv);
if (opts.show_help) {
PrintHelp(std::cout);
return 0;
}
auto antlr = ParseFileWithAntlr(opts.input);
bool need_blank_line = false;
if (opts.emit_parse_tree) {
PrintSyntaxTree(antlr.tree, antlr.parser.get(), std::cout);
need_blank_line = true;
}
#if !COMPILER_PARSE_ONLY
auto* comp_unit = dynamic_cast<SysYParser::CompUnitContext*>(antlr.tree);
if (!comp_unit) {
throw std::runtime_error(FormatError("main", "语法树根节点不是 compUnit"));
}
auto sema = RunSema(*comp_unit);
auto module = GenerateIR(*comp_unit, sema);
if (opts.emit_ir) {
ir::IRPrinter printer;
if (need_blank_line) {
std::cout << "\n";
}
printer.Print(*module, std::cout);
need_blank_line = true;
}
#include <exception>
#include <iostream>
#include <stdexcept>
#include "frontend/AntlrDriver.h"
#include "frontend/SyntaxTreePrinter.h"
#if !COMPILER_PARSE_ONLY
#include "ir/IR.h"
#include "ir/PassManager.h"
#include "irgen/IRGen.h"
#include "mir/MIR.h"
#include "sem/Sema.h"
#endif
#include "utils/CLI.h"
#include "utils/Log.h"
int main(int argc, char** argv) {
try {
auto opts = ParseCLI(argc, argv);
if (opts.show_help) {
PrintHelp(std::cout);
return 0;
}
auto antlr = ParseFileWithAntlr(opts.input);
bool need_blank_line = false;
if (opts.emit_parse_tree) {
PrintSyntaxTree(antlr.tree, antlr.parser.get(), std::cout);
need_blank_line = true;
}
#if !COMPILER_PARSE_ONLY
auto* comp_unit = dynamic_cast<SysYParser::CompUnitContext*>(antlr.tree);
if (!comp_unit) {
throw std::runtime_error(FormatError("main", "syntax tree root is not compUnit"));
}
auto sema = RunSema(*comp_unit);
std::unique_ptr<ir::Module> asm_module;
if (opts.emit_asm) {
asm_module = GenerateIR(*comp_unit, sema);
ir::RunIRPassPipeline(*asm_module);
}
if (opts.emit_ir) {
std::unique_ptr<ir::Module> ir_module;
if (opts.emit_asm) {
ir_module = GenerateIR(*comp_unit, sema);
} else {
ir_module = GenerateIR(*comp_unit, sema);
}
ir::RunIRPassPipeline(*ir_module);
if (need_blank_line) {
std::cout << "\n";
}
ir::IRPrinter printer;
printer.Print(*ir_module, std::cout);
need_blank_line = true;
}
if (opts.emit_asm) {
auto machine_func = mir::LowerToMIR(*module);
mir::RunRegAlloc(*machine_func);
mir::RunFrameLowering(*machine_func);
auto machine_module = mir::LowerToMIR(*asm_module);
mir::RunMIRPreRegAllocPassPipeline(*machine_module);
mir::RunRegAlloc(*machine_module);
mir::RunMIRPostRegAllocPassPipeline(*machine_module);
mir::RunFrameLowering(*machine_module);
if (need_blank_line) {
std::cout << "\n";
}
mir::PrintAsm(*machine_func, std::cout);
}
#else
if (opts.emit_ir || opts.emit_asm) {
throw std::runtime_error(
FormatError("main", "当前为 parse-only 构建IR/汇编输出已禁用"));
}
#endif
} catch (const std::exception& ex) {
PrintException(std::cerr, ex);
return 1;
}
return 0;
}
mir::PrintAsm(*machine_module, std::cout);
}
#else
if (opts.emit_ir || opts.emit_asm) {
throw std::runtime_error(
FormatError("main", "IR/asm emission is unavailable in parse-only builds"));
}
#endif
} catch (const std::exception& ex) {
PrintException(std::cerr, ex);
return 1;
}
return 0;
}

@ -0,0 +1,140 @@
#include "mir/MIR.h"
#include <cstdint>
#include <unordered_map>
#include <vector>
namespace mir {
namespace {
bool IsHoistCandidate(const MachineFunction& function, int object_index, int use_count) {
const auto& object = function.GetStackObject(object_index);
if (object.kind != StackObjectKind::Local) {
return false;
}
if (use_count < 2) {
return false;
}
if (object.size >= 4096) {
return true;
}
return object.size >= 256 && use_count >= 4;
}
bool IsPlainFrameLea(const MachineInstr& inst, int object_index) {
if (inst.GetOpcode() != MachineInstr::Opcode::Lea || !inst.HasAddress() ||
inst.GetOperands().empty() || inst.GetOperands()[0].GetKind() != OperandKind::VReg) {
return false;
}
const auto& address = inst.GetAddress();
return address.base_kind == AddrBaseKind::FrameObject &&
address.base_index == object_index && address.const_offset == 0 &&
address.scaled_vregs.empty();
}
std::size_t FindEntryInsertPos(const MachineBasicBlock& block) {
const auto& instructions = block.GetInstructions();
std::size_t pos = 0;
while (pos < instructions.size() &&
instructions[pos].GetOpcode() == MachineInstr::Opcode::Arg) {
++pos;
}
return pos;
}
} // namespace
void RunAddressHoisting(MachineModule& module) {
for (auto& function : module.GetFunctions()) {
if (!function || function->GetBlocks().empty()) {
continue;
}
std::unordered_map<int, int> use_counts;
for (auto& block : function->GetBlocks()) {
for (auto& inst : block->GetInstructions()) {
if (!inst.HasAddress()) {
continue;
}
const auto& address = inst.GetAddress();
if (address.base_kind == AddrBaseKind::FrameObject && address.base_index >= 0) {
++use_counts[address.base_index];
}
}
}
std::unordered_map<int, int> base_vregs;
for (const auto& [object_index, count] : use_counts) {
if (!IsHoistCandidate(*function, object_index, count)) {
continue;
}
base_vregs.emplace(object_index, -1);
}
if (base_vregs.empty()) {
continue;
}
for (auto& block : function->GetBlocks()) {
for (auto& inst : block->GetInstructions()) {
if (!inst.HasAddress()) {
continue;
}
const auto& address = inst.GetAddress();
auto it = base_vregs.find(address.base_index);
if (it == base_vregs.end()) {
continue;
}
if (it->second >= 0) {
continue;
}
if (IsPlainFrameLea(inst, address.base_index)) {
it->second = inst.GetOperands()[0].GetVReg();
}
}
}
auto& entry_block = *function->GetBlocks().front();
auto& entry_insts = entry_block.GetInstructions();
std::size_t insert_pos = FindEntryInsertPos(entry_block);
for (auto& [object_index, base_vreg] : base_vregs) {
if (base_vreg >= 0) {
continue;
}
base_vreg = function->NewVReg(ValueType::Ptr);
MachineInstr lea(MachineInstr::Opcode::Lea, {MachineOperand::VReg(base_vreg)});
AddressExpr address;
address.base_kind = AddrBaseKind::FrameObject;
address.base_index = object_index;
lea.SetAddress(std::move(address));
entry_insts.insert(entry_insts.begin() + static_cast<std::ptrdiff_t>(insert_pos),
std::move(lea));
++insert_pos;
}
for (auto& block : function->GetBlocks()) {
for (auto& inst : block->GetInstructions()) {
if (!inst.HasAddress()) {
continue;
}
auto& address = inst.GetAddress();
auto it = base_vregs.find(address.base_index);
if (it == base_vregs.end()) {
continue;
}
if (IsPlainFrameLea(inst, address.base_index) &&
inst.GetOperands()[0].GetKind() == OperandKind::VReg &&
inst.GetOperands()[0].GetVReg() == it->second) {
continue;
}
if (address.base_kind != AddrBaseKind::FrameObject || address.base_index < 0) {
continue;
}
address.base_kind = AddrBaseKind::VReg;
address.base_index = it->second;
}
}
}
}
} // namespace mir

File diff suppressed because it is too large Load Diff

@ -1,24 +1,25 @@
add_library(mir_core STATIC
MIRContext.cpp
MIRFunction.cpp
MIRBasicBlock.cpp
MIRInstr.cpp
Register.cpp
Lowering.cpp
RegAlloc.cpp
FrameLowering.cpp
AsmPrinter.cpp
)
target_link_libraries(mir_core PUBLIC
build_options
ir
)
add_subdirectory(passes)
add_library(mir INTERFACE)
target_link_libraries(mir INTERFACE
mir_core
mir_passes
)
add_library(mir_core STATIC
MIRContext.cpp
MIRFunction.cpp
MIRBasicBlock.cpp
MIRInstr.cpp
Register.cpp
Lowering.cpp
AddressHoisting.cpp
RegAlloc.cpp
FrameLowering.cpp
AsmPrinter.cpp
)
target_link_libraries(mir_core PUBLIC
build_options
ir
)
add_subdirectory(passes)
add_library(mir INTERFACE)
target_link_libraries(mir INTERFACE
mir_core
mir_passes
)

@ -1,45 +1,40 @@
#include "mir/MIR.h"
#include <stdexcept>
#include <vector>
#include "utils/Log.h"
#include <string>
namespace mir {
namespace {
int AlignTo(int value, int align) {
if (align <= 1) {
return value;
}
return ((value + align - 1) / align) * align;
}
} // namespace
void RunFrameLowering(MachineFunction& function) {
int cursor = 0;
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
if (-cursor < -256) {
throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
void RunFrameLowering(MachineModule& module) {
for (auto& function : module.GetFunctions()) {
for (int reg : function->GetUsedCalleeSavedGPRs()) {
function->CreateStackObject(8, 8, StackObjectKind::SavedGPR,
"save.x" + std::to_string(reg));
}
for (int reg : function->GetUsedCalleeSavedFPRs()) {
function->CreateStackObject(8, 8, StackObjectKind::SavedFPR,
"save.v" + std::to_string(reg));
}
}
cursor = 0;
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
function.GetFrameSlot(slot.index).offset = -cursor;
}
function.SetFrameSize(AlignTo(cursor, 16));
auto& insts = function.GetEntry().GetInstructions();
std::vector<MachineInstr> lowered;
lowered.emplace_back(Opcode::Prologue);
for (const auto& inst : insts) {
if (inst.GetOpcode() == Opcode::Ret) {
lowered.emplace_back(Opcode::Epilogue);
int cursor = 0;
const int object_count = static_cast<int>(function->GetStackObjects().size());
for (int i = 0; i < object_count; ++i) {
auto& object = function->GetStackObject(i);
cursor = AlignTo(cursor, object.align);
cursor += object.size;
object.offset = -cursor;
}
lowered.push_back(inst);
function->SetFrameSize(AlignTo(cursor, 16));
}
insts = std::move(lowered);
}
} // namespace mir
} // namespace mir

File diff suppressed because it is too large Load Diff

@ -7,10 +7,15 @@ namespace mir {
MachineBasicBlock::MachineBasicBlock(std::string name)
: name_(std::move(name)) {}
MachineInstr& MachineBasicBlock::Append(Opcode opcode,
std::initializer_list<Operand> operands) {
instructions_.emplace_back(opcode, std::vector<Operand>(operands));
MachineInstr& MachineBasicBlock::Append(MachineInstr::Opcode opcode,
std::vector<MachineOperand> operands) {
instructions_.emplace_back(opcode, std::move(operands));
return instructions_.back();
}
} // namespace mir
MachineInstr& MachineBasicBlock::Append(MachineInstr instr) {
instructions_.push_back(std::move(instr));
return instructions_.back();
}
} // namespace mir

@ -2,9 +2,10 @@
namespace mir {
MIRContext& DefaultContext() {
static MIRContext ctx;
return ctx;
}
namespace {
MIRContext g_context;
} // namespace
} // namespace mir
MIRContext& DefaultContext() { return g_context; }
} // namespace mir

@ -1,33 +1,106 @@
#include "mir/MIR.h"
#include <algorithm>
#include <stdexcept>
#include <utility>
#include "utils/Log.h"
namespace mir {
MachineFunction::MachineFunction(std::string name)
: name_(std::move(name)), entry_("entry") {}
MachineFunction::MachineFunction(std::string name, ValueType return_type,
std::vector<ValueType> param_types)
: name_(std::move(name)),
return_type_(return_type),
param_types_(std::move(param_types)) {}
MachineBasicBlock* MachineFunction::CreateBlock(const std::string& name) {
auto block = std::make_unique<MachineBasicBlock>(name);
auto* ptr = block.get();
blocks_.push_back(std::move(block));
return ptr;
}
int MachineFunction::NewVReg(ValueType type) {
const int id = static_cast<int>(vregs_.size());
vregs_.push_back({id, type});
allocations_.push_back({});
return id;
}
const VRegInfo& MachineFunction::GetVRegInfo(int id) const {
if (id < 0 || id >= static_cast<int>(vregs_.size())) {
throw std::out_of_range("virtual register index out of range");
}
return vregs_[static_cast<size_t>(id)];
}
VRegInfo& MachineFunction::GetVRegInfo(int id) {
if (id < 0 || id >= static_cast<int>(vregs_.size())) {
throw std::out_of_range("virtual register index out of range");
}
return vregs_[static_cast<size_t>(id)];
}
int MachineFunction::CreateFrameIndex(int size) {
int index = static_cast<int>(frame_slots_.size());
frame_slots_.push_back(FrameSlot{index, size, 0});
int MachineFunction::CreateStackObject(int size, int align, StackObjectKind kind,
const std::string& name) {
const int index = static_cast<int>(stack_objects_.size());
stack_objects_.push_back({index, kind, size, align, 0, name});
return index;
}
FrameSlot& MachineFunction::GetFrameSlot(int index) {
if (index < 0 || index >= static_cast<int>(frame_slots_.size())) {
throw std::runtime_error(FormatError("mir", "非法 FrameIndex"));
StackObject& MachineFunction::GetStackObject(int index) {
if (index < 0 || index >= static_cast<int>(stack_objects_.size())) {
throw std::out_of_range("stack object index out of range");
}
return stack_objects_[static_cast<size_t>(index)];
}
const StackObject& MachineFunction::GetStackObject(int index) const {
if (index < 0 || index >= static_cast<int>(stack_objects_.size())) {
throw std::out_of_range("stack object index out of range");
}
return frame_slots_[index];
return stack_objects_[static_cast<size_t>(index)];
}
const FrameSlot& MachineFunction::GetFrameSlot(int index) const {
if (index < 0 || index >= static_cast<int>(frame_slots_.size())) {
throw std::runtime_error(FormatError("mir", "非法 FrameIndex"));
void MachineFunction::SetAllocation(int vreg, Allocation allocation) {
if (vreg < 0 || vreg >= static_cast<int>(allocations_.size())) {
throw std::out_of_range("allocation index out of range");
}
return frame_slots_[index];
allocations_[static_cast<size_t>(vreg)] = allocation;
}
const Allocation& MachineFunction::GetAllocation(int vreg) const {
if (vreg < 0 || vreg >= static_cast<int>(allocations_.size())) {
throw std::out_of_range("allocation index out of range");
}
return allocations_[static_cast<size_t>(vreg)];
}
Allocation& MachineFunction::GetAllocation(int vreg) {
if (vreg < 0 || vreg >= static_cast<int>(allocations_.size())) {
throw std::out_of_range("allocation index out of range");
}
return allocations_[static_cast<size_t>(vreg)];
}
void MachineFunction::AddUsedCalleeSavedGPR(int reg_index) {
if (std::find(used_callee_saved_gprs_.begin(), used_callee_saved_gprs_.end(),
reg_index) == used_callee_saved_gprs_.end()) {
used_callee_saved_gprs_.push_back(reg_index);
}
}
void MachineFunction::AddUsedCalleeSavedFPR(int reg_index) {
if (std::find(used_callee_saved_fprs_.begin(), used_callee_saved_fprs_.end(),
reg_index) == used_callee_saved_fprs_.end()) {
used_callee_saved_fprs_.push_back(reg_index);
}
}
MachineFunction* MachineModule::AddFunction(
std::unique_ptr<MachineFunction> function) {
auto* ptr = function.get();
functions_.push_back(std::move(function));
return ptr;
}
} // namespace mir
} // namespace mir

@ -1,23 +1,186 @@
#include "mir/MIR.h"
#include <utility>
namespace mir {
Operand::Operand(Kind kind, PhysReg reg, int imm)
: kind_(kind), reg_(reg), imm_(imm) {}
MachineOperand::MachineOperand(OperandKind kind, int vreg, std::int64_t imm,
std::string text)
: kind_(kind), vreg_(vreg), imm_(imm), text_(std::move(text)) {}
MachineOperand MachineOperand::VReg(int reg) {
return MachineOperand(OperandKind::VReg, reg, 0, "");
}
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); }
MachineOperand MachineOperand::Imm(std::int64_t value) {
return MachineOperand(OperandKind::Imm, -1, value, "");
}
Operand Operand::Imm(int value) {
return Operand(Kind::Imm, PhysReg::W0, value);
MachineOperand MachineOperand::Block(std::string name) {
return MachineOperand(OperandKind::Block, -1, 0, std::move(name));
}
Operand Operand::FrameIndex(int index) {
return Operand(Kind::FrameIndex, PhysReg::W0, index);
MachineOperand MachineOperand::Symbol(std::string name) {
return MachineOperand(OperandKind::Symbol, -1, 0, std::move(name));
}
MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands)
MachineInstr::MachineInstr(Opcode opcode, std::vector<MachineOperand> operands)
: opcode_(opcode), operands_(std::move(operands)) {}
} // namespace mir
bool MachineInstr::IsTerminator() const {
return opcode_ == Opcode::Br || opcode_ == Opcode::CondBr ||
opcode_ == Opcode::Ret || opcode_ == Opcode::Unreachable;
}
std::vector<int> MachineInstr::GetDefs() const {
switch (opcode_) {
case Opcode::Arg:
case Opcode::Copy:
case Opcode::Load:
case Opcode::Lea:
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::ModMul:
case Opcode::ModPow:
case Opcode::DigitExtractPow2:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FSqrt:
case Opcode::FNeg:
case Opcode::ICmp:
case Opcode::FCmp:
case Opcode::ZExt:
case Opcode::ItoF:
case Opcode::FtoI:
if (!operands_.empty() && operands_[0].GetKind() == OperandKind::VReg) {
return {operands_[0].GetVReg()};
}
return {};
case Opcode::Call:
if (call_return_type_ != ValueType::Void && !operands_.empty() &&
operands_[0].GetKind() == OperandKind::VReg) {
return {operands_[0].GetVReg()};
}
return {};
case Opcode::Store:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Ret:
case Opcode::Memset:
case Opcode::Unreachable:
return {};
}
return {};
}
std::vector<int> MachineInstr::GetUses() const {
std::vector<int> uses;
auto push_vreg = [&](const MachineOperand& operand) {
if (operand.GetKind() == OperandKind::VReg) {
uses.push_back(operand.GetVReg());
}
};
auto push_addr_uses = [&]() {
if (!has_address_) {
return;
}
if (address_.base_kind == AddrBaseKind::VReg && address_.base_index >= 0) {
uses.push_back(address_.base_index);
}
for (const auto& term : address_.scaled_vregs) {
uses.push_back(term.first);
}
};
switch (opcode_) {
case Opcode::Arg:
case Opcode::Br:
case Opcode::Unreachable:
break;
case Opcode::Copy:
case Opcode::ZExt:
case Opcode::ItoF:
case Opcode::FtoI:
case Opcode::FSqrt:
case Opcode::FNeg:
if (operands_.size() >= 2) {
push_vreg(operands_[1]);
}
break;
case Opcode::Load:
case Opcode::Lea:
push_addr_uses();
break;
case Opcode::Store:
if (!operands_.empty()) {
push_vreg(operands_[0]);
}
push_addr_uses();
break;
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::ModMul:
case Opcode::ModPow:
case Opcode::DigitExtractPow2:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::ICmp:
case Opcode::FCmp:
if (operands_.size() >= 2) {
push_vreg(operands_[1]);
}
if (operands_.size() >= 3) {
push_vreg(operands_[2]);
}
break;
case Opcode::CondBr:
if (!operands_.empty()) {
push_vreg(operands_[0]);
}
break;
case Opcode::Call: {
size_t arg_begin = call_return_type_ == ValueType::Void ? 0 : 1;
for (size_t i = arg_begin; i < operands_.size(); ++i) {
push_vreg(operands_[i]);
}
break;
}
case Opcode::Ret:
if (!operands_.empty()) {
push_vreg(operands_[0]);
}
break;
case Opcode::Memset:
if (!operands_.empty()) {
push_vreg(operands_[0]);
}
if (operands_.size() >= 2) {
push_vreg(operands_[1]);
}
push_addr_uses();
break;
}
return uses;
}
} // namespace mir

@ -1,36 +1,820 @@
#include "mir/MIR.h"
#include <stdexcept>
#include "mir/MIR.h"
#include <algorithm>
#include <cstdint>
#include <limits>
#include <stdexcept>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "utils/Log.h"
namespace mir {
namespace {
struct BlockInfo {
int start_pos = 0;
int end_pos = 0;
std::vector<int> successors;
std::vector<std::uint8_t> use;
std::vector<std::uint8_t> def;
std::vector<std::uint8_t> live_in;
std::vector<std::uint8_t> live_out;
};
struct MoveEdge {
int dst = -1;
int src = -1;
};
bool BelongsToClass(ValueType type, RegClass reg_class) {
if (type == ValueType::Void) {
return false;
}
return IsFPR(type) ? reg_class == RegClass::FPR : reg_class == RegClass::GPR;
}
bool IsCalleeSaved(PhysReg reg) {
if (reg.reg_class == RegClass::GPR) {
return reg.index >= 19 && reg.index <= 28;
}
return reg.index >= 8 && reg.index <= 15;
}
#include "utils/Log.h"
bool IsCallerSaved(PhysReg reg) {
return !IsCalleeSaved(reg);
}
namespace mir {
namespace {
bool IsCheapRematerializableInst(const MachineInstr& inst) {
if (inst.GetDefs().size() != 1) {
return false;
}
if (inst.GetOpcode() == MachineInstr::Opcode::Copy) {
const auto& operands = inst.GetOperands();
return operands.size() >= 2 && operands[1].GetKind() == OperandKind::Imm;
}
if (inst.GetOpcode() != MachineInstr::Opcode::Lea || !inst.HasAddress()) {
return false;
}
const auto& address = inst.GetAddress();
return address.base_kind != AddrBaseKind::VReg && address.scaled_vregs.empty();
}
bool IsAllowedReg(PhysReg reg) {
switch (reg) {
case PhysReg::W0:
case PhysReg::W8:
case PhysReg::W9:
case PhysReg::X29:
case PhysReg::X30:
case PhysReg::SP:
return true;
std::vector<PhysReg> GetAllocatableRegs(RegClass reg_class) {
std::vector<PhysReg> regs;
if (reg_class == RegClass::FPR) {
for (int i = 19; i <= 31; ++i) {
regs.push_back({RegClass::FPR, i});
}
for (int i = 8; i <= 15; ++i) {
regs.push_back({RegClass::FPR, i});
}
return regs;
}
regs.push_back({RegClass::GPR, 8});
for (int i = 13; i <= 15; ++i) {
regs.push_back({RegClass::GPR, i});
}
for (int i = 19; i <= 28; ++i) {
regs.push_back({RegClass::GPR, i});
}
return false;
return regs;
}
int CreateSpillSlot(MachineFunction& function, int vreg) {
const auto type = function.GetVRegInfo(vreg).type;
return function.CreateStackObject(GetValueSize(type), GetValueAlign(type),
StackObjectKind::Spill,
"spill." + std::to_string(vreg));
}
std::vector<BlockInfo> AnalyzeBlocks(const MachineFunction& function) {
const auto& blocks = function.GetBlocks();
const int num_blocks = static_cast<int>(blocks.size());
const int num_vregs = static_cast<int>(function.GetVRegs().size());
std::vector<BlockInfo> infos(static_cast<size_t>(num_blocks));
std::vector<std::pair<std::string, int>> block_name_to_index;
block_name_to_index.reserve(blocks.size());
for (int i = 0; i < num_blocks; ++i) {
block_name_to_index.push_back({blocks[static_cast<size_t>(i)]->GetName(), i});
}
auto find_block_index = [&](const std::string& name) {
auto it = std::find_if(block_name_to_index.begin(), block_name_to_index.end(),
[&](const auto& item) { return item.first == name; });
if (it == block_name_to_index.end()) {
throw std::runtime_error(FormatError("mir", "unknown basic block label: " + name));
}
return it->second;
};
int position = 0;
for (int block_index = 0; block_index < num_blocks; ++block_index) {
auto& info = infos[static_cast<size_t>(block_index)];
info.start_pos = position;
info.use.assign(static_cast<size_t>(num_vregs), 0);
info.def.assign(static_cast<size_t>(num_vregs), 0);
info.live_in.assign(static_cast<size_t>(num_vregs), 0);
info.live_out.assign(static_cast<size_t>(num_vregs), 0);
const auto& instructions = blocks[static_cast<size_t>(block_index)]->GetInstructions();
for (const auto& inst : instructions) {
for (int use : inst.GetUses()) {
if (use >= 0 && use < num_vregs && !info.def[static_cast<size_t>(use)]) {
info.use[static_cast<size_t>(use)] = 1;
}
}
for (int def : inst.GetDefs()) {
if (def >= 0 && def < num_vregs) {
info.def[static_cast<size_t>(def)] = 1;
}
}
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::Br:
if (!inst.GetOperands().empty()) {
info.successors.push_back(find_block_index(inst.GetOperands()[0].GetText()));
}
break;
case MachineInstr::Opcode::CondBr:
if (inst.GetOperands().size() >= 3) {
info.successors.push_back(find_block_index(inst.GetOperands()[1].GetText()));
info.successors.push_back(find_block_index(inst.GetOperands()[2].GetText()));
}
break;
default:
break;
}
position += 2;
}
std::sort(info.successors.begin(), info.successors.end());
info.successors.erase(std::unique(info.successors.begin(), info.successors.end()),
info.successors.end());
info.end_pos = position;
}
bool changed = true;
while (changed) {
changed = false;
for (int block_index = num_blocks - 1; block_index >= 0; --block_index) {
auto& info = infos[static_cast<size_t>(block_index)];
std::vector<std::uint8_t> next_out(static_cast<size_t>(num_vregs), 0);
std::vector<std::uint8_t> next_in(static_cast<size_t>(num_vregs), 0);
for (int succ : info.successors) {
const auto& succ_in = infos[static_cast<size_t>(succ)].live_in;
for (int vreg = 0; vreg < num_vregs; ++vreg) {
next_out[static_cast<size_t>(vreg)] |= succ_in[static_cast<size_t>(vreg)];
}
}
for (int vreg = 0; vreg < num_vregs; ++vreg) {
const size_t idx = static_cast<size_t>(vreg);
next_in[idx] = info.use[idx] |
(next_out[idx] & static_cast<std::uint8_t>(!info.def[idx]));
}
if (next_out != info.live_out || next_in != info.live_in) {
changed = true;
info.live_out = std::move(next_out);
info.live_in = std::move(next_in);
}
}
}
return infos;
}
class GeorgeColoringAllocator {
public:
GeorgeColoringAllocator(MachineFunction& function, RegClass reg_class,
const std::vector<BlockInfo>& block_infos)
: function_(function),
reg_class_(reg_class),
regs_(GetAllocatableRegs(reg_class)),
k_(static_cast<int>(regs_.size())),
block_infos_(block_infos),
num_vregs_(static_cast<int>(function.GetVRegs().size())),
in_class_(static_cast<size_t>(num_vregs_), 0),
live_across_call_(static_cast<size_t>(num_vregs_), 0),
rematerializable_(static_cast<size_t>(num_vregs_), 0),
adjacency_(static_cast<size_t>(num_vregs_)),
degree_(static_cast<size_t>(num_vregs_), 0),
spill_cost_(static_cast<size_t>(num_vregs_), 0.0),
move_list_(static_cast<size_t>(num_vregs_)),
alias_(static_cast<size_t>(num_vregs_), -1),
color_index_(static_cast<size_t>(num_vregs_), -1),
in_select_stack_(static_cast<size_t>(num_vregs_), 0),
is_coalesced_(static_cast<size_t>(num_vregs_), 0),
is_spilled_(static_cast<size_t>(num_vregs_), 0),
is_colored_(static_cast<size_t>(num_vregs_), 0),
simplify_worklist_(static_cast<size_t>(num_vregs_), 0),
freeze_worklist_(static_cast<size_t>(num_vregs_), 0),
spill_worklist_(static_cast<size_t>(num_vregs_), 0) {
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
alias_[static_cast<size_t>(vreg)] = vreg;
in_class_[static_cast<size_t>(vreg)] =
BelongsToClass(function_.GetVRegInfo(vreg).type, reg_class_) ? 1 : 0;
}
}
void Run() {
if (k_ == 0) {
throw std::runtime_error(FormatError("mir", "no allocatable physical registers"));
}
MarkRematerializableDefs();
Build();
MakeWorklists();
while (HasNodes(simplify_worklist_) || HasNodes(freeze_worklist_) ||
HasNodes(spill_worklist_) || HasMoves(worklist_moves_)) {
if (HasNodes(simplify_worklist_)) {
Simplify();
} else if (HasMoves(worklist_moves_)) {
Coalesce();
} else if (HasNodes(freeze_worklist_)) {
Freeze();
} else if (HasNodes(spill_worklist_)) {
SelectSpill();
}
}
AssignColors();
CommitAllocations();
}
private:
void MarkRematerializableDefs() {
for (const auto& block : function_.GetBlocks()) {
for (const auto& inst : block->GetInstructions()) {
if (!IsCheapRematerializableInst(inst)) {
continue;
}
for (int def : inst.GetDefs()) {
if (def >= 0 && def < num_vregs_ && in_class_[static_cast<size_t>(def)]) {
rematerializable_[static_cast<size_t>(def)] = 1;
}
}
}
}
}
void Build() {
const auto& blocks = function_.GetBlocks();
for (size_t block_index = 0; block_index < blocks.size(); ++block_index) {
const auto& block = blocks[block_index];
const auto& info = block_infos_[block_index];
std::vector<std::uint8_t> live = info.live_out;
double block_weight = 1.0;
for (int succ : info.successors) {
if (succ <= static_cast<int>(block_index)) {
block_weight = 8.0;
break;
}
}
const auto& instructions = block->GetInstructions();
for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) {
const auto& inst = *it;
auto defs = FilterClass(inst.GetDefs());
auto uses = FilterClass(inst.GetUses());
if (inst.GetOpcode() == MachineInstr::Opcode::Call ||
inst.GetOpcode() == MachineInstr::Opcode::Memset) {
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (live[static_cast<size_t>(vreg)] &&
in_class_[static_cast<size_t>(vreg)]) {
live_across_call_[static_cast<size_t>(vreg)] = 1;
}
}
}
for (int def : defs) {
spill_cost_[static_cast<size_t>(def)] +=
block_weight * (rematerializable_[static_cast<size_t>(def)] ? 0.25 : 1.0);
}
for (int use : uses) {
spill_cost_[static_cast<size_t>(use)] +=
block_weight * (rematerializable_[static_cast<size_t>(use)] ? 0.25 : 1.0);
}
} // namespace
// All source operands are simultaneously live at the instruction input.
// They must interfere with each other, otherwise two distinct values
// used by the same instruction may be colored to the same register.
for (size_t i = 0; i < uses.size(); ++i) {
for (size_t j = i + 1; j < uses.size(); ++j) {
AddEdge(uses[i], uses[j]);
}
}
void RunRegAlloc(MachineFunction& function) {
for (const auto& inst : function.GetEntry().GetInstructions()) {
for (const auto& operand : inst.GetOperands()) {
if (operand.GetKind() == Operand::Kind::Reg &&
!IsAllowedReg(operand.GetReg())) {
throw std::runtime_error(FormatError("mir", "寄存器分配失败"));
const bool is_move = inst.GetOpcode() == MachineInstr::Opcode::Copy &&
defs.size() == 1 && uses.size() == 1 && defs[0] != uses[0];
if (is_move) {
const int dst = defs[0];
const int src = uses[0];
const int move_index = static_cast<int>(moves_.size());
moves_.push_back({dst, src});
move_list_[static_cast<size_t>(dst)].push_back(move_index);
move_list_[static_cast<size_t>(src)].push_back(move_index);
live[static_cast<size_t>(src)] = 0;
}
for (int def : defs) {
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (!live[static_cast<size_t>(vreg)] || !in_class_[static_cast<size_t>(vreg)]) {
continue;
}
AddEdge(def, vreg);
}
}
for (int def : defs) {
live[static_cast<size_t>(def)] = 0;
}
for (int use : uses) {
live[static_cast<size_t>(use)] = 1;
}
}
}
worklist_moves_.assign(moves_.size(), 1);
active_moves_.assign(moves_.size(), 0);
coalesced_moves_.assign(moves_.size(), 0);
constrained_moves_.assign(moves_.size(), 0);
frozen_moves_.assign(moves_.size(), 0);
}
void MakeWorklists() {
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (!in_class_[static_cast<size_t>(vreg)]) {
continue;
}
if (degree_[static_cast<size_t>(vreg)] >= k_) {
spill_worklist_[static_cast<size_t>(vreg)] = 1;
} else if (MoveRelated(vreg)) {
freeze_worklist_[static_cast<size_t>(vreg)] = 1;
} else {
simplify_worklist_[static_cast<size_t>(vreg)] = 1;
}
}
}
void Simplify() {
const int node = PickAnyNode(simplify_worklist_);
simplify_worklist_[static_cast<size_t>(node)] = 0;
select_stack_.push_back(node);
in_select_stack_[static_cast<size_t>(node)] = 1;
for (int neighbor : Adjacent(node)) {
DecrementDegree(neighbor);
}
}
void Coalesce() {
const int move_index = PickBestMove();
worklist_moves_[static_cast<size_t>(move_index)] = 0;
int x = GetAlias(moves_[static_cast<size_t>(move_index)].dst);
int y = GetAlias(moves_[static_cast<size_t>(move_index)].src);
if (x == y) {
coalesced_moves_[static_cast<size_t>(move_index)] = 1;
AddWorkList(x);
return;
}
if (AdjacentTo(x, y)) {
constrained_moves_[static_cast<size_t>(move_index)] = 1;
AddWorkList(x);
AddWorkList(y);
return;
}
int u = x;
int v = y;
if (degree_[static_cast<size_t>(v)] > degree_[static_cast<size_t>(u)]) {
std::swap(u, v);
}
if (GeorgeOK(v, u) || ConservativeUnion(u, v)) {
coalesced_moves_[static_cast<size_t>(move_index)] = 1;
Combine(u, v);
AddWorkList(u);
return;
}
active_moves_[static_cast<size_t>(move_index)] = 1;
}
void Freeze() {
const int node = PickAnyNode(freeze_worklist_);
freeze_worklist_[static_cast<size_t>(node)] = 0;
simplify_worklist_[static_cast<size_t>(node)] = 1;
FreezeMoves(node);
}
void SelectSpill() {
int best = -1;
double best_priority = std::numeric_limits<double>::infinity();
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (!spill_worklist_[static_cast<size_t>(vreg)]) {
continue;
}
double priority = spill_cost_[static_cast<size_t>(vreg)] /
std::max(1, degree_[static_cast<size_t>(vreg)]);
if (rematerializable_[static_cast<size_t>(vreg)]) {
priority *= 0.2;
}
if (MoveRelated(vreg)) {
priority *= 1.15;
}
if (live_across_call_[static_cast<size_t>(vreg)] &&
!rematerializable_[static_cast<size_t>(vreg)]) {
priority *= 1.25;
}
if (best < 0 || priority < best_priority) {
best = vreg;
best_priority = priority;
}
}
if (best < 0) {
throw std::runtime_error(FormatError("mir", "failed to select spill candidate"));
}
spill_worklist_[static_cast<size_t>(best)] = 0;
simplify_worklist_[static_cast<size_t>(best)] = 1;
FreezeMoves(best);
}
void AssignColors() {
while (!select_stack_.empty()) {
const int node = select_stack_.back();
select_stack_.pop_back();
in_select_stack_[static_cast<size_t>(node)] = 0;
std::vector<std::uint8_t> ok_colors(static_cast<size_t>(regs_.size()), 1);
if (live_across_call_[static_cast<size_t>(node)]) {
for (size_t i = 0; i < regs_.size(); ++i) {
if (IsCallerSaved(regs_[i])) {
ok_colors[i] = 0;
}
}
}
for (int neighbor : adjacency_[static_cast<size_t>(node)]) {
const int alias = GetAlias(neighbor);
if (!is_colored_[static_cast<size_t>(alias)]) {
continue;
}
const int color = color_index_[static_cast<size_t>(alias)];
if (color >= 0 && color < static_cast<int>(regs_.size())) {
ok_colors[static_cast<size_t>(color)] = 0;
}
}
int chosen = -1;
for (size_t i = 0; i < ok_colors.size(); ++i) {
if (ok_colors[i]) {
chosen = static_cast<int>(i);
break;
}
}
if (chosen < 0) {
is_spilled_[static_cast<size_t>(node)] = 1;
continue;
}
is_colored_[static_cast<size_t>(node)] = 1;
color_index_[static_cast<size_t>(node)] = chosen;
}
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (!is_coalesced_[static_cast<size_t>(vreg)]) {
continue;
}
const int alias = GetAlias(vreg);
if (is_spilled_[static_cast<size_t>(alias)]) {
is_spilled_[static_cast<size_t>(vreg)] = 1;
} else {
is_colored_[static_cast<size_t>(vreg)] = 1;
color_index_[static_cast<size_t>(vreg)] = color_index_[static_cast<size_t>(alias)];
}
}
}
void CommitAllocations() {
std::unordered_map<int, Allocation> representative_allocations;
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (!in_class_[static_cast<size_t>(vreg)]) {
continue;
}
const int rep = GetAlias(vreg);
if (representative_allocations.find(rep) != representative_allocations.end()) {
continue;
}
Allocation allocation;
if (is_spilled_[static_cast<size_t>(rep)]) {
allocation.kind = Allocation::Kind::Spill;
allocation.stack_object = CreateSpillSlot(function_, rep);
} else if (is_colored_[static_cast<size_t>(rep)]) {
allocation.kind = Allocation::Kind::PhysReg;
allocation.phys = regs_[static_cast<size_t>(color_index_[static_cast<size_t>(rep)])];
} else {
allocation.kind = Allocation::Kind::Spill;
allocation.stack_object = CreateSpillSlot(function_, rep);
}
representative_allocations.emplace(rep, allocation);
}
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (!in_class_[static_cast<size_t>(vreg)]) {
continue;
}
const Allocation allocation = representative_allocations.at(GetAlias(vreg));
function_.SetAllocation(vreg, allocation);
if (allocation.kind == Allocation::Kind::PhysReg && IsCalleeSaved(allocation.phys)) {
if (allocation.phys.reg_class == RegClass::GPR) {
function_.AddUsedCalleeSavedGPR(allocation.phys.index);
} else {
function_.AddUsedCalleeSavedFPR(allocation.phys.index);
}
}
}
}
void DecrementDegree(int node) {
const int old_degree = degree_[static_cast<size_t>(node)];
--degree_[static_cast<size_t>(node)];
if (old_degree != k_) {
return;
}
auto neighbors = Adjacent(node);
neighbors.push_back(node);
EnableMoves(neighbors);
spill_worklist_[static_cast<size_t>(node)] = 0;
if (MoveRelated(node)) {
freeze_worklist_[static_cast<size_t>(node)] = 1;
} else {
simplify_worklist_[static_cast<size_t>(node)] = 1;
}
}
void AddWorkList(int node) {
if (!in_class_[static_cast<size_t>(node)] || is_coalesced_[static_cast<size_t>(node)] ||
in_select_stack_[static_cast<size_t>(node)] || degree_[static_cast<size_t>(node)] >= k_ ||
MoveRelated(node)) {
return;
}
freeze_worklist_[static_cast<size_t>(node)] = 0;
spill_worklist_[static_cast<size_t>(node)] = 0;
simplify_worklist_[static_cast<size_t>(node)] = 1;
}
void Combine(int keep, int remove) {
simplify_worklist_[static_cast<size_t>(remove)] = 0;
freeze_worklist_[static_cast<size_t>(remove)] = 0;
spill_worklist_[static_cast<size_t>(remove)] = 0;
is_coalesced_[static_cast<size_t>(remove)] = 1;
alias_[static_cast<size_t>(remove)] = keep;
live_across_call_[static_cast<size_t>(keep)] |=
live_across_call_[static_cast<size_t>(remove)];
auto& keep_moves = move_list_[static_cast<size_t>(keep)];
const auto& remove_moves = move_list_[static_cast<size_t>(remove)];
keep_moves.insert(keep_moves.end(), remove_moves.begin(), remove_moves.end());
EnableMoves({remove});
for (int neighbor : Adjacent(remove)) {
AddEdge(neighbor, keep);
DecrementDegree(neighbor);
}
if (freeze_worklist_[static_cast<size_t>(keep)] && degree_[static_cast<size_t>(keep)] >= k_) {
freeze_worklist_[static_cast<size_t>(keep)] = 0;
spill_worklist_[static_cast<size_t>(keep)] = 1;
}
}
void FreezeMoves(int node) {
for (int move_index : NodeMoves(node)) {
if (worklist_moves_[static_cast<size_t>(move_index)]) {
worklist_moves_[static_cast<size_t>(move_index)] = 0;
} else if (active_moves_[static_cast<size_t>(move_index)]) {
active_moves_[static_cast<size_t>(move_index)] = 0;
} else {
continue;
}
frozen_moves_[static_cast<size_t>(move_index)] = 1;
const auto& move = moves_[static_cast<size_t>(move_index)];
const int x = GetAlias(move.dst);
const int y = GetAlias(move.src);
const int other = y == GetAlias(node) ? x : y;
if (!MoveRelated(other) && degree_[static_cast<size_t>(other)] < k_) {
freeze_worklist_[static_cast<size_t>(other)] = 0;
simplify_worklist_[static_cast<size_t>(other)] = 1;
}
}
}
void EnableMoves(const std::vector<int>& nodes) {
for (int node : nodes) {
for (int move_index : NodeMoves(node)) {
if (active_moves_[static_cast<size_t>(move_index)]) {
active_moves_[static_cast<size_t>(move_index)] = 0;
worklist_moves_[static_cast<size_t>(move_index)] = 1;
}
}
}
}
std::vector<int> Adjacent(int node) const {
std::vector<int> neighbors;
for (int neighbor : adjacency_[static_cast<size_t>(node)]) {
if (in_select_stack_[static_cast<size_t>(neighbor)] ||
is_coalesced_[static_cast<size_t>(neighbor)]) {
continue;
}
neighbors.push_back(neighbor);
}
return neighbors;
}
std::vector<int> NodeMoves(int node) const {
std::vector<int> related_moves;
for (int move_index : move_list_[static_cast<size_t>(node)]) {
if (worklist_moves_[static_cast<size_t>(move_index)] ||
active_moves_[static_cast<size_t>(move_index)]) {
related_moves.push_back(move_index);
}
}
return related_moves;
}
bool MoveRelated(int node) const { return !NodeMoves(node).empty(); }
int GetAlias(int node) const {
int current = node;
while (is_coalesced_[static_cast<size_t>(current)]) {
current = alias_[static_cast<size_t>(current)];
}
return current;
}
bool AdjacentTo(int lhs, int rhs) const {
return adjacency_[static_cast<size_t>(lhs)].find(rhs) !=
adjacency_[static_cast<size_t>(lhs)].end();
}
bool GeorgeOK(int candidate, int target) const {
for (int neighbor : Adjacent(candidate)) {
if (degree_[static_cast<size_t>(neighbor)] >= k_ && !AdjacentTo(neighbor, target)) {
return false;
}
}
return true;
}
bool ConservativeUnion(int lhs, int rhs) const {
std::unordered_set<int> union_neighbors;
for (int neighbor : Adjacent(lhs)) {
union_neighbors.insert(neighbor);
}
for (int neighbor : Adjacent(rhs)) {
union_neighbors.insert(neighbor);
}
int high_degree_count = 0;
for (int neighbor : union_neighbors) {
if (degree_[static_cast<size_t>(neighbor)] >= k_) {
++high_degree_count;
}
}
return high_degree_count < k_;
}
void AddEdge(int lhs, int rhs) {
if (lhs == rhs || !in_class_[static_cast<size_t>(lhs)] ||
!in_class_[static_cast<size_t>(rhs)]) {
return;
}
if (adjacency_[static_cast<size_t>(lhs)].insert(rhs).second) {
adjacency_[static_cast<size_t>(rhs)].insert(lhs);
++degree_[static_cast<size_t>(lhs)];
++degree_[static_cast<size_t>(rhs)];
}
}
std::vector<int> FilterClass(const std::vector<int>& regs) const {
std::vector<int> filtered;
for (int reg : regs) {
if (reg >= 0 && reg < num_vregs_ && in_class_[static_cast<size_t>(reg)]) {
filtered.push_back(reg);
}
}
return filtered;
}
bool HasNodes(const std::vector<std::uint8_t>& worklist) const {
return std::any_of(worklist.begin(), worklist.end(),
[](std::uint8_t flag) { return flag != 0; });
}
bool HasMoves(const std::vector<std::uint8_t>& move_flags) const {
return std::any_of(move_flags.begin(), move_flags.end(),
[](std::uint8_t flag) { return flag != 0; });
}
int PickAnyNode(const std::vector<std::uint8_t>& worklist) const {
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (worklist[static_cast<size_t>(vreg)]) {
return vreg;
}
}
throw std::runtime_error(FormatError("mir", "failed to pick worklist node"));
}
int PickBestMove() const {
int best = -1;
int best_score = std::numeric_limits<int>::max();
for (size_t i = 0; i < worklist_moves_.size(); ++i) {
if (!worklist_moves_[i]) {
continue;
}
const auto& move = moves_[i];
const int dst = GetAlias(move.dst);
const int src = GetAlias(move.src);
int score = degree_[static_cast<size_t>(dst)] + degree_[static_cast<size_t>(src)];
if (live_across_call_[static_cast<size_t>(dst)] !=
live_across_call_[static_cast<size_t>(src)]) {
score += 4;
}
if (rematerializable_[static_cast<size_t>(dst)] ||
rematerializable_[static_cast<size_t>(src)]) {
score += 2;
}
if (score < best_score) {
best = static_cast<int>(i);
best_score = score;
}
}
if (best >= 0) {
return best;
}
throw std::runtime_error(FormatError("mir", "failed to pick worklist move"));
}
}
private:
MachineFunction& function_;
RegClass reg_class_;
std::vector<PhysReg> regs_;
int k_ = 0;
const std::vector<BlockInfo>& block_infos_;
int num_vregs_ = 0;
} // namespace mir
std::vector<std::uint8_t> in_class_;
std::vector<std::uint8_t> live_across_call_;
std::vector<std::uint8_t> rematerializable_;
std::vector<std::unordered_set<int>> adjacency_;
std::vector<int> degree_;
std::vector<double> spill_cost_;
std::vector<std::vector<int>> move_list_;
std::vector<MoveEdge> moves_;
std::vector<int> alias_;
std::vector<int> color_index_;
std::vector<std::uint8_t> in_select_stack_;
std::vector<std::uint8_t> is_coalesced_;
std::vector<std::uint8_t> is_spilled_;
std::vector<std::uint8_t> is_colored_;
std::vector<std::uint8_t> simplify_worklist_;
std::vector<std::uint8_t> freeze_worklist_;
std::vector<std::uint8_t> spill_worklist_;
std::vector<int> select_stack_;
std::vector<std::uint8_t> worklist_moves_;
std::vector<std::uint8_t> active_moves_;
std::vector<std::uint8_t> coalesced_moves_;
std::vector<std::uint8_t> constrained_moves_;
std::vector<std::uint8_t> frozen_moves_;
};
} // namespace
void RunRegAlloc(MachineModule& module) {
for (auto& function : module.GetFunctions()) {
const auto block_infos = AnalyzeBlocks(*function);
GeorgeColoringAllocator gpr_allocator(*function, RegClass::GPR, block_infos);
gpr_allocator.Run();
GeorgeColoringAllocator fpr_allocator(*function, RegClass::FPR, block_infos);
fpr_allocator.Run();
}
}
} // namespace mir

@ -2,26 +2,75 @@
#include <stdexcept>
#include "utils/Log.h"
namespace mir {
namespace {
const char* kWRegNames[] = {
"w0", "w1", "w2", "w3", "w4", "w5", "w6", "w7",
"w8", "w9", "w10", "w11", "w12", "w13", "w14", "w15",
"w16", "w17", "w18", "w19", "w20", "w21", "w22", "w23",
"w24", "w25", "w26", "w27", "w28", "w29", "w30"};
const char* kXRegNames[] = {
"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15",
"x16", "x17", "x18", "x19", "x20", "x21", "x22", "x23",
"x24", "x25", "x26", "x27", "x28", "x29", "x30"};
const char* kSRegNames[] = {
"s0", "s1", "s2", "s3", "s4", "s5", "s6", "s7",
"s8", "s9", "s10", "s11", "s12", "s13", "s14", "s15",
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
"s24", "s25", "s26", "s27", "s28", "s29", "s30", "s31"};
} // namespace
bool IsGPR(ValueType type) {
return type == ValueType::I1 || type == ValueType::I32 || type == ValueType::Ptr;
}
bool IsFPR(ValueType type) { return type == ValueType::F32; }
int GetValueSize(ValueType type) {
switch (type) {
case ValueType::Void:
return 0;
case ValueType::I1:
case ValueType::I32:
case ValueType::F32:
return 4;
case ValueType::Ptr:
return 8;
}
return 0;
}
const char* PhysRegName(PhysReg reg) {
switch (reg) {
case PhysReg::W0:
return "w0";
case PhysReg::W8:
return "w8";
case PhysReg::W9:
return "w9";
case PhysReg::X29:
return "x29";
case PhysReg::X30:
return "x30";
case PhysReg::SP:
return "sp";
int GetValueAlign(ValueType type) {
switch (type) {
case ValueType::Void:
return 1;
case ValueType::Ptr:
return 8;
case ValueType::I1:
case ValueType::I32:
case ValueType::F32:
return 4;
}
return 1;
}
const char* GetPhysRegName(PhysReg reg, ValueType type) {
if (!reg.IsValid()) {
throw std::runtime_error("invalid physical register");
}
if (reg.reg_class == RegClass::FPR) {
if (reg.index < 0 || reg.index >= 32) {
throw std::runtime_error("float register index out of range");
}
return kSRegNames[reg.index];
}
if (reg.index < 0 || reg.index >= 31) {
throw std::runtime_error("gpr register index out of range");
}
throw std::runtime_error(FormatError("mir", "未知物理寄存器"));
return type == ValueType::Ptr ? kXRegNames[reg.index] : kWRegNames[reg.index];
}
} // namespace mir
} // namespace mir

@ -0,0 +1,239 @@
#include "mir/MIR.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace mir {
namespace {
using BlockList = std::vector<std::unique_ptr<MachineBasicBlock>>;
int FindBlockIndex(const MachineFunction& function, const std::string& name) {
const auto& blocks = function.GetBlocks();
for (size_t i = 0; i < blocks.size(); ++i) {
if (blocks[i] && blocks[i]->GetName() == name) {
return static_cast<int>(i);
}
}
return -1;
}
std::vector<int> CollectSuccessors(const MachineFunction& function, int index) {
std::vector<int> succs;
const auto& blocks = function.GetBlocks();
if (index < 0 || index >= static_cast<int>(blocks.size()) || !blocks[index]) {
return succs;
}
const auto& instructions = blocks[index]->GetInstructions();
if (instructions.empty()) {
return succs;
}
const auto& term = instructions.back();
if (term.GetOpcode() == MachineInstr::Opcode::Br && !term.GetOperands().empty() &&
term.GetOperands()[0].GetKind() == OperandKind::Block) {
const int succ = FindBlockIndex(function, term.GetOperands()[0].GetText());
if (succ >= 0) {
succs.push_back(succ);
}
return succs;
}
if (term.GetOpcode() == MachineInstr::Opcode::CondBr &&
term.GetOperands().size() >= 3) {
for (size_t i = 1; i <= 2; ++i) {
if (term.GetOperands()[i].GetKind() != OperandKind::Block) {
continue;
}
const int succ = FindBlockIndex(function, term.GetOperands()[i].GetText());
if (succ >= 0 &&
std::find(succs.begin(), succs.end(), succ) == succs.end()) {
succs.push_back(succ);
}
}
}
return succs;
}
std::vector<int> BuildPredecessorCount(const MachineFunction& function) {
std::vector<int> preds(function.GetBlocks().size(), 0);
for (size_t i = 0; i < function.GetBlocks().size(); ++i) {
for (int succ : CollectSuccessors(function, static_cast<int>(i))) {
++preds[static_cast<size_t>(succ)];
}
}
return preds;
}
bool IsTrivialJumpBlock(const MachineFunction& function, int index) {
const auto& blocks = function.GetBlocks();
if (index < 0 || index >= static_cast<int>(blocks.size()) || !blocks[index]) {
return false;
}
const auto& instructions = blocks[index]->GetInstructions();
return instructions.size() == 1 &&
instructions.front().GetOpcode() == MachineInstr::Opcode::Br &&
!instructions.front().GetOperands().empty() &&
instructions.front().GetOperands()[0].GetKind() == OperandKind::Block;
}
std::string ResolveJumpChain(const MachineFunction& function, const std::string& target) {
std::string current = target;
std::unordered_set<std::string> visited{current};
while (true) {
const int index = FindBlockIndex(function, current);
if (index < 0 || !IsTrivialJumpBlock(function, index)) {
return current;
}
const auto& inst = function.GetBlocks()[static_cast<size_t>(index)]->GetInstructions().front();
const std::string& next = inst.GetOperands()[0].GetText();
if (!visited.insert(next).second) {
return current;
}
current = next;
}
}
bool RewriteBranchTargets(MachineFunction& function) {
bool changed = false;
for (auto& block : function.GetBlocks()) {
if (!block || block->GetInstructions().empty()) {
continue;
}
auto& term = block->GetInstructions().back();
auto& operands = term.GetOperands();
if (term.GetOpcode() == MachineInstr::Opcode::Br && !operands.empty() &&
operands[0].GetKind() == OperandKind::Block) {
const std::string resolved = ResolveJumpChain(function, operands[0].GetText());
if (resolved != operands[0].GetText()) {
operands[0] = MachineOperand::Block(resolved);
changed = true;
}
continue;
}
if (term.GetOpcode() != MachineInstr::Opcode::CondBr || operands.size() < 3) {
continue;
}
for (size_t i = 1; i <= 2; ++i) {
if (operands[i].GetKind() != OperandKind::Block) {
continue;
}
const std::string resolved = ResolveJumpChain(function, operands[i].GetText());
if (resolved != operands[i].GetText()) {
operands[i] = MachineOperand::Block(resolved);
changed = true;
}
}
if (operands[1].GetKind() == OperandKind::Block &&
operands[2].GetKind() == OperandKind::Block &&
operands[1].GetText() == operands[2].GetText()) {
term = MachineInstr(MachineInstr::Opcode::Br, {operands[1]});
changed = true;
}
}
return changed;
}
bool RemoveUnreachableBlocks(MachineFunction& function) {
auto& blocks = function.GetBlocks();
if (blocks.empty() || !blocks.front()) {
return false;
}
std::unordered_set<std::string> reachable;
std::vector<std::string> stack{blocks.front()->GetName()};
while (!stack.empty()) {
std::string name = stack.back();
stack.pop_back();
if (!reachable.insert(name).second) {
continue;
}
const int index = FindBlockIndex(function, name);
if (index < 0) {
continue;
}
for (int succ : CollectSuccessors(function, index)) {
stack.push_back(blocks[static_cast<size_t>(succ)]->GetName());
}
}
const size_t old_size = blocks.size();
blocks.erase(std::remove_if(blocks.begin(), blocks.end(),
[&](const std::unique_ptr<MachineBasicBlock>& block) {
return block && reachable.count(block->GetName()) == 0;
}),
blocks.end());
return blocks.size() != old_size;
}
bool MergeLinearBlocks(MachineFunction& function) {
auto preds = BuildPredecessorCount(function);
auto& blocks = function.GetBlocks();
for (size_t i = 0; i < blocks.size(); ++i) {
auto& block = blocks[i];
if (!block || block->GetInstructions().empty()) {
continue;
}
auto& insts = block->GetInstructions();
auto& term = insts.back();
if (term.GetOpcode() != MachineInstr::Opcode::Br || term.GetOperands().empty() ||
term.GetOperands()[0].GetKind() != OperandKind::Block) {
continue;
}
const int succ_index = FindBlockIndex(function, term.GetOperands()[0].GetText());
if (succ_index <= 0 || succ_index == static_cast<int>(i) ||
preds[static_cast<size_t>(succ_index)] != 1) {
continue;
}
auto& succ = blocks[static_cast<size_t>(succ_index)];
if (!succ || succ->GetInstructions().empty()) {
continue;
}
insts.pop_back();
auto& succ_insts = succ->GetInstructions();
insts.insert(insts.end(),
std::make_move_iterator(succ_insts.begin()),
std::make_move_iterator(succ_insts.end()));
blocks.erase(blocks.begin() + succ_index);
return true;
}
return false;
}
bool RunCFGCleanupOnFunction(MachineFunction& function) {
bool changed = false;
while (true) {
bool local_changed = false;
local_changed |= RewriteBranchTargets(function);
local_changed |= RemoveUnreachableBlocks(function);
if (MergeLinearBlocks(function)) {
local_changed = true;
}
changed |= local_changed;
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunCFGCleanup(MachineModule& module) {
bool changed = false;
for (auto& function : module.GetFunctions()) {
if (function) {
changed |= RunCFGCleanupOnFunction(*function);
}
}
return changed;
}
} // namespace mir

@ -1,6 +1,8 @@
add_library(mir_passes STATIC
PassManager.cpp
Peephole.cpp
SpillReduction.cpp
CFGCleanup.cpp
)
target_link_libraries(mir_passes PUBLIC

@ -1,4 +1,53 @@
// MIR Pass 管理:
// - 组织后端 pass 的运行顺序PreRA/PostRA/PEI 等阶段)
// - 统一运行 pass 与调试输出(按需要扩展)
#include "mir/MIR.h"
#include <cstdlib>
namespace mir {
void RunMIRPreRegAllocPassPipeline(MachineModule& module) {
const char* disable_spill_reduction = std::getenv("NUDTC_DISABLE_MIR_SPILL_REDUCTION");
const bool run_spill_reduction =
disable_spill_reduction == nullptr || disable_spill_reduction[0] == '\0' ||
disable_spill_reduction[0] == '0';
const char* disable_cfg_cleanup = std::getenv("NUDTC_DISABLE_MIR_CFG_CLEANUP");
const bool run_cfg_cleanup =
disable_cfg_cleanup == nullptr || disable_cfg_cleanup[0] == '\0' ||
disable_cfg_cleanup[0] == '0';
if (run_spill_reduction) {
RunSpillReduction(module);
}
RunAddressHoisting(module);
constexpr int kMaxIterations = 4;
for (int iteration = 0; iteration < kMaxIterations; ++iteration) {
bool changed = false;
changed |= RunPeephole(module);
if (run_cfg_cleanup) {
changed |= RunCFGCleanup(module);
}
if (!changed) {
break;
}
}
}
void RunMIRPostRegAllocPassPipeline(MachineModule& module) {
const char* disable_cfg_cleanup = std::getenv("NUDTC_DISABLE_MIR_CFG_CLEANUP");
const bool run_cfg_cleanup =
disable_cfg_cleanup == nullptr || disable_cfg_cleanup[0] == '\0' ||
disable_cfg_cleanup[0] == '0';
constexpr int kMaxIterations = 2;
for (int iteration = 0; iteration < kMaxIterations; ++iteration) {
bool changed = false;
changed |= RunPeephole(module);
if (run_cfg_cleanup) {
changed |= RunCFGCleanup(module);
}
if (!changed) {
break;
}
}
}
} // namespace mir

@ -1,4 +1,913 @@
// 窥孔优化Peephole
// - 删除冗余 move、合并常见指令模式
// - 提升最终汇编质量(按实现范围裁剪)
#include "mir/MIR.h"
#include "ir/IR.h"
#include <algorithm>
#include <cstdint>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace mir {
namespace {
using AliasMap = std::unordered_map<int, MachineOperand>;
struct CFGInfo {
std::vector<std::vector<int>> predecessors;
std::vector<std::vector<int>> successors;
};
struct AddressKey {
AddrBaseKind base_kind = AddrBaseKind::None;
int base_index = -1;
std::string symbol;
std::int64_t const_offset = 0;
std::vector<std::pair<int, std::int64_t>> scaled_vregs;
bool operator==(const AddressKey& rhs) const {
return base_kind == rhs.base_kind && base_index == rhs.base_index &&
symbol == rhs.symbol && const_offset == rhs.const_offset &&
scaled_vregs == rhs.scaled_vregs;
}
};
struct AddressKeyHash {
std::size_t operator()(const AddressKey& key) const {
std::size_t h = static_cast<std::size_t>(key.base_kind);
h ^= std::hash<int>{}(key.base_index) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<std::string>{}(key.symbol) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<std::int64_t>{}(key.const_offset) + 0x9e3779b9 + (h << 6) + (h >> 2);
for (const auto& term : key.scaled_vregs) {
h ^= std::hash<int>{}(term.first) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<std::int64_t>{}(term.second) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
struct MemoryState {
MachineOperand value;
ValueType type = ValueType::Void;
int pending_store_index = -1;
};
using MemoryMap = std::unordered_map<AddressKey, MemoryState, AddressKeyHash>;
bool IsImm(const MachineOperand& operand, std::int64_t value) {
return operand.GetKind() == OperandKind::Imm && operand.GetImm() == value;
}
bool SameExactOperand(const MachineOperand& lhs, const MachineOperand& rhs) {
if (lhs.GetKind() != rhs.GetKind()) {
return false;
}
switch (lhs.GetKind()) {
case OperandKind::Invalid:
return true;
case OperandKind::VReg:
return lhs.GetVReg() == rhs.GetVReg();
case OperandKind::Imm:
return lhs.GetImm() == rhs.GetImm();
case OperandKind::Block:
case OperandKind::Symbol:
return lhs.GetText() == rhs.GetText();
}
return false;
}
bool SameResolvedLocation(const MachineFunction& function, int lhs_vreg, int rhs_vreg) {
if (lhs_vreg == rhs_vreg) {
return true;
}
const auto& lhs = function.GetAllocation(lhs_vreg);
const auto& rhs = function.GetAllocation(rhs_vreg);
if (lhs.kind == Allocation::Kind::Unassigned || rhs.kind == Allocation::Kind::Unassigned ||
lhs.kind != rhs.kind) {
return false;
}
if (lhs.kind == Allocation::Kind::PhysReg) {
return lhs.phys == rhs.phys;
}
if (lhs.kind == Allocation::Kind::Spill) {
return lhs.stack_object == rhs.stack_object;
}
return false;
}
bool SameResolvedOperand(const MachineFunction& function, const MachineOperand& lhs,
const MachineOperand& rhs) {
if (SameExactOperand(lhs, rhs)) {
return true;
}
if (lhs.GetKind() == OperandKind::VReg && rhs.GetKind() == OperandKind::VReg) {
return SameResolvedLocation(function, lhs.GetVReg(), rhs.GetVReg());
}
return false;
}
MachineOperand ResolveAlias(const AliasMap& aliases, const MachineOperand& operand) {
if (operand.GetKind() != OperandKind::VReg) {
return operand;
}
int current = operand.GetVReg();
std::unordered_set<int> visited;
visited.insert(current);
while (true) {
auto it = aliases.find(current);
if (it == aliases.end()) {
return MachineOperand::VReg(current);
}
if (it->second.GetKind() != OperandKind::VReg) {
return it->second;
}
const int next = it->second.GetVReg();
if (!visited.insert(next).second) {
return MachineOperand::VReg(current);
}
current = next;
}
}
bool RewriteOperand(MachineOperand& operand, const AliasMap& aliases) {
const auto rewritten = ResolveAlias(aliases, operand);
if (SameExactOperand(rewritten, operand)) {
return false;
}
operand = rewritten;
return true;
}
bool RewriteAddress(AddressExpr& address, const AliasMap& aliases) {
bool changed = false;
if (address.base_kind == AddrBaseKind::VReg && address.base_index >= 0) {
const auto rewritten = ResolveAlias(aliases, MachineOperand::VReg(address.base_index));
if (rewritten.GetKind() == OperandKind::VReg &&
rewritten.GetVReg() != address.base_index) {
address.base_index = rewritten.GetVReg();
changed = true;
}
}
std::vector<std::pair<int, std::int64_t>> rewritten_scaled;
rewritten_scaled.reserve(address.scaled_vregs.size());
for (const auto& term : address.scaled_vregs) {
const auto rewritten = ResolveAlias(aliases, MachineOperand::VReg(term.first));
if (rewritten.GetKind() == OperandKind::Imm) {
address.const_offset += rewritten.GetImm() * term.second;
changed = true;
continue;
}
if (rewritten.GetKind() == OperandKind::VReg && rewritten.GetVReg() != term.first) {
rewritten_scaled.push_back({rewritten.GetVReg(), term.second});
changed = true;
continue;
}
rewritten_scaled.push_back(term);
}
if (rewritten_scaled.size() != address.scaled_vregs.size()) {
changed = true;
}
address.scaled_vregs = std::move(rewritten_scaled);
return changed;
}
bool RewriteUses(MachineInstr& inst, const AliasMap& aliases) {
bool changed = false;
auto& operands = inst.GetOperands();
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::Copy:
case MachineInstr::Opcode::ZExt:
case MachineInstr::Opcode::ItoF:
case MachineInstr::Opcode::FtoI:
case MachineInstr::Opcode::FSqrt:
case MachineInstr::Opcode::FNeg:
if (operands.size() >= 2) {
changed |= RewriteOperand(operands[1], aliases);
}
break;
case MachineInstr::Opcode::Store:
if (!operands.empty()) {
changed |= RewriteOperand(operands[0], aliases);
}
break;
case MachineInstr::Opcode::Add:
case MachineInstr::Opcode::Sub:
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::Rem:
case MachineInstr::Opcode::ModMul:
case MachineInstr::Opcode::ModPow:
case MachineInstr::Opcode::DigitExtractPow2:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
case MachineInstr::Opcode::Shl:
case MachineInstr::Opcode::AShr:
case MachineInstr::Opcode::LShr:
case MachineInstr::Opcode::FAdd:
case MachineInstr::Opcode::FSub:
case MachineInstr::Opcode::FMul:
case MachineInstr::Opcode::FDiv:
case MachineInstr::Opcode::ICmp:
case MachineInstr::Opcode::FCmp:
if (operands.size() >= 2) {
changed |= RewriteOperand(operands[1], aliases);
}
if (operands.size() >= 3) {
changed |= RewriteOperand(operands[2], aliases);
}
break;
case MachineInstr::Opcode::CondBr:
if (!operands.empty()) {
changed |= RewriteOperand(operands[0], aliases);
}
break;
case MachineInstr::Opcode::Call: {
const size_t arg_begin = inst.GetCallReturnType() == ValueType::Void ? 0 : 1;
for (size_t i = arg_begin; i < operands.size(); ++i) {
changed |= RewriteOperand(operands[i], aliases);
}
break;
}
case MachineInstr::Opcode::Ret:
if (!operands.empty()) {
changed |= RewriteOperand(operands[0], aliases);
}
break;
case MachineInstr::Opcode::Memset:
if (!operands.empty()) {
changed |= RewriteOperand(operands[0], aliases);
}
if (operands.size() >= 2) {
changed |= RewriteOperand(operands[1], aliases);
}
break;
case MachineInstr::Opcode::Arg:
case MachineInstr::Opcode::Load:
case MachineInstr::Opcode::Lea:
case MachineInstr::Opcode::Br:
case MachineInstr::Opcode::Unreachable:
break;
}
if (inst.HasAddress()) {
changed |= RewriteAddress(inst.GetAddress(), aliases);
}
return changed;
}
MachineInstr MakeCopyLike(const MachineInstr& inst, MachineOperand source) {
return MachineInstr(MachineInstr::Opcode::Copy,
{inst.GetOperands()[0], std::move(source)});
}
bool SimplifyCopy(const MachineFunction& function, MachineInstr& inst) {
if (inst.GetOpcode() != MachineInstr::Opcode::Copy) {
return false;
}
const auto& operands = inst.GetOperands();
if (operands.size() < 2 || operands[0].GetKind() != OperandKind::VReg) {
return false;
}
return SameResolvedOperand(function, operands[0], operands[1]);
}
bool SimplifyZExt(MachineInstr& inst) {
const auto& operands = inst.GetOperands();
if (operands.size() < 2 || operands[1].GetKind() != OperandKind::Imm) {
return false;
}
inst = MakeCopyLike(inst, MachineOperand::Imm(operands[1].GetImm() != 0 ? 1 : 0));
return true;
}
bool SimplifyIntegerBinary(MachineInstr& inst) {
const auto opcode = inst.GetOpcode();
const auto& operands = inst.GetOperands();
if (operands.size() < 3) {
return false;
}
const auto& lhs = operands[1];
const auto& rhs = operands[2];
switch (opcode) {
case MachineInstr::Opcode::Add:
if (IsImm(rhs, 0)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
if (IsImm(lhs, 0)) {
inst = MakeCopyLike(inst, rhs);
return true;
}
return false;
case MachineInstr::Opcode::Sub:
if (IsImm(rhs, 0)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
return false;
case MachineInstr::Opcode::Mul:
if (IsImm(rhs, 1)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
if (IsImm(lhs, 1)) {
inst = MakeCopyLike(inst, rhs);
return true;
}
if (IsImm(rhs, 0) || IsImm(lhs, 0)) {
inst = MakeCopyLike(inst, MachineOperand::Imm(0));
return true;
}
return false;
case MachineInstr::Opcode::Div:
if (IsImm(rhs, 1)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
return false;
case MachineInstr::Opcode::And:
if (IsImm(rhs, -1)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
if (IsImm(lhs, -1)) {
inst = MakeCopyLike(inst, rhs);
return true;
}
if (IsImm(rhs, 0) || IsImm(lhs, 0)) {
inst = MakeCopyLike(inst, MachineOperand::Imm(0));
return true;
}
return false;
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
if (IsImm(rhs, 0)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
if (IsImm(lhs, 0)) {
inst = MakeCopyLike(inst, rhs);
return true;
}
return false;
case MachineInstr::Opcode::Shl:
case MachineInstr::Opcode::AShr:
case MachineInstr::Opcode::LShr:
if (IsImm(rhs, 0)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
return false;
default:
return false;
}
}
bool SimplifyCondBr(MachineInstr& inst) {
auto& operands = inst.GetOperands();
if (operands.size() < 3) {
return false;
}
if (operands[1].GetKind() == OperandKind::Block &&
operands[2].GetKind() == OperandKind::Block &&
operands[1].GetText() == operands[2].GetText()) {
inst = MachineInstr(MachineInstr::Opcode::Br, {operands[1]});
return true;
}
if (operands[0].GetKind() != OperandKind::Imm) {
return false;
}
inst = MachineInstr(MachineInstr::Opcode::Br,
{operands[0].GetImm() != 0 ? operands[1] : operands[2]});
return true;
}
bool SimplifyInstruction(MachineInstr& inst) {
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::ZExt:
return SimplifyZExt(inst);
case MachineInstr::Opcode::Add:
case MachineInstr::Opcode::Sub:
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
case MachineInstr::Opcode::Shl:
case MachineInstr::Opcode::AShr:
case MachineInstr::Opcode::LShr:
return SimplifyIntegerBinary(inst);
case MachineInstr::Opcode::CondBr:
return SimplifyCondBr(inst);
default:
return false;
}
}
bool TrackAlias(const MachineInstr& inst, AliasMap& aliases) {
if (inst.GetOpcode() != MachineInstr::Opcode::Copy) {
return false;
}
const auto& operands = inst.GetOperands();
if (operands.size() < 2 || operands[0].GetKind() != OperandKind::VReg) {
return false;
}
aliases[operands[0].GetVReg()] = operands[1];
return true;
}
AddressKey MakeAddressKey(const AddressExpr& address) {
return {address.base_kind, address.base_index, address.symbol, address.const_offset,
address.scaled_vregs};
}
bool HasTrackedAddress(const MachineInstr& inst) {
return inst.HasAddress() && inst.GetAddress().base_kind != AddrBaseKind::None;
}
const ir::Function* LookupSourceCallee(const MachineModule& module,
const MachineInstr& inst) {
if (inst.GetOpcode() != MachineInstr::Opcode::Call || inst.GetCallee().empty()) {
return nullptr;
}
return module.GetSourceModule().GetFunction(inst.GetCallee());
}
bool CallMayReadMemory(const MachineModule& module, const MachineInstr& inst) {
auto* callee = LookupSourceCallee(module, inst);
return callee == nullptr || callee->MayReadMemory();
}
bool CallMayWriteMemory(const MachineModule& module, const MachineInstr& inst) {
auto* callee = LookupSourceCallee(module, inst);
return callee == nullptr || callee->MayWriteMemory();
}
bool SameMemoryStateValue(const MemoryState& lhs, const MemoryState& rhs) {
return lhs.type == rhs.type && SameExactOperand(lhs.value, rhs.value);
}
bool SameMemoryMap(const MemoryMap& lhs, const MemoryMap& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (const auto& [key, value] : lhs) {
auto it = rhs.find(key);
if (it == rhs.end() || !SameMemoryStateValue(value, it->second)) {
return false;
}
}
return true;
}
MemoryMap MeetMemoryStates(const std::vector<const MemoryMap*>& predecessors) {
if (predecessors.empty()) {
return {};
}
MemoryMap in = *predecessors.front();
for (auto it = in.begin(); it != in.end();) {
bool keep = true;
for (std::size_t i = 1; i < predecessors.size(); ++i) {
auto pred_it = predecessors[i]->find(it->first);
if (pred_it == predecessors[i]->end() ||
!SameMemoryStateValue(it->second, pred_it->second)) {
keep = false;
break;
}
}
if (!keep) {
it = in.erase(it);
continue;
}
++it;
}
return in;
}
CFGInfo BuildCFG(const MachineFunction& function) {
CFGInfo cfg;
const auto& blocks = function.GetBlocks();
cfg.predecessors.resize(blocks.size());
cfg.successors.resize(blocks.size());
std::unordered_map<std::string, int> name_to_index;
for (std::size_t i = 0; i < blocks.size(); ++i) {
name_to_index.emplace(blocks[i]->GetName(), static_cast<int>(i));
}
auto add_edge = [&](int pred, const std::string& succ_name) {
auto it = name_to_index.find(succ_name);
if (it == name_to_index.end()) {
return;
}
cfg.successors[static_cast<std::size_t>(pred)].push_back(it->second);
cfg.predecessors[static_cast<std::size_t>(it->second)].push_back(pred);
};
for (std::size_t i = 0; i < blocks.size(); ++i) {
const auto& instructions = blocks[i]->GetInstructions();
if (instructions.empty()) {
continue;
}
const auto& terminator = instructions.back();
if (terminator.GetOpcode() == MachineInstr::Opcode::Br &&
!terminator.GetOperands().empty()) {
add_edge(static_cast<int>(i), terminator.GetOperands()[0].GetText());
} else if (terminator.GetOpcode() == MachineInstr::Opcode::CondBr &&
terminator.GetOperands().size() >= 3) {
add_edge(static_cast<int>(i), terminator.GetOperands()[1].GetText());
add_edge(static_cast<int>(i), terminator.GetOperands()[2].GetText());
}
auto& succs = cfg.successors[i];
std::sort(succs.begin(), succs.end());
succs.erase(std::unique(succs.begin(), succs.end()), succs.end());
}
for (auto& preds : cfg.predecessors) {
std::sort(preds.begin(), preds.end());
preds.erase(std::unique(preds.begin(), preds.end()), preds.end());
}
return cfg;
}
bool SameBaseObject(const AddressKey& lhs, const AddressKey& rhs) {
if (lhs.base_kind != rhs.base_kind) {
return false;
}
switch (lhs.base_kind) {
case AddrBaseKind::FrameObject:
case AddrBaseKind::VReg:
return lhs.base_index == rhs.base_index;
case AddrBaseKind::Global:
return lhs.symbol == rhs.symbol;
case AddrBaseKind::None:
return false;
}
return false;
}
void InvalidateMemoryState(std::unordered_map<AddressKey, MemoryState, AddressKeyHash>& states,
const AddressKey* store_key) {
if (store_key == nullptr) {
states.clear();
return;
}
if (store_key->base_kind == AddrBaseKind::VReg) {
states.clear();
return;
}
for (auto it = states.begin(); it != states.end();) {
if (it->first.base_kind == AddrBaseKind::VReg || SameBaseObject(it->first, *store_key)) {
it = states.erase(it);
continue;
}
++it;
}
}
void ObservePendingStores(MemoryMap& states) {
for (auto& [_, state] : states) {
state.pending_store_index = -1;
}
}
bool TryOptimizeMemoryInstruction(
const MachineModule& module, const MachineFunction& function,
MachineInstr& inst,
MemoryMap& states,
std::vector<bool>& removed,
std::size_t current_index,
bool* remove_current) {
*remove_current = false;
if (inst.GetOpcode() == MachineInstr::Opcode::Call) {
if (CallMayWriteMemory(module, inst)) {
InvalidateMemoryState(states, nullptr);
}
return false;
}
if (inst.GetOpcode() == MachineInstr::Opcode::Memset) {
InvalidateMemoryState(states, nullptr);
return false;
}
if (!HasTrackedAddress(inst)) {
return false;
}
const AddressKey key = MakeAddressKey(inst.GetAddress());
if (inst.GetOpcode() == MachineInstr::Opcode::Load) {
ValueType load_type = ValueType::Void;
if (!inst.GetOperands().empty() && inst.GetOperands()[0].GetKind() == OperandKind::VReg) {
load_type = function.GetVRegInfo(inst.GetOperands()[0].GetVReg()).type;
}
auto it = states.find(key);
if (it != states.end() && it->second.type == load_type) {
inst = MakeCopyLike(inst, it->second.value);
it->second.pending_store_index = -1;
return true;
}
auto dest = inst.GetOperands()[0];
states[key] = {dest, load_type, -1};
return false;
}
if (inst.GetOpcode() != MachineInstr::Opcode::Store) {
return false;
}
const auto value = inst.GetOperands()[0];
auto existing = states.find(key);
if (existing != states.end() && existing->second.type == inst.GetValueType() &&
SameExactOperand(existing->second.value, value)) {
*remove_current = true;
return true;
}
if (existing != states.end() && existing->second.pending_store_index >= 0) {
removed[static_cast<std::size_t>(existing->second.pending_store_index)] = true;
}
InvalidateMemoryState(states, &key);
states[key] = {value, inst.GetValueType(), static_cast<int>(current_index)};
return false;
}
void ApplyMemoryDataflowInstruction(const MachineModule& module, const MachineInstr& inst,
MemoryMap& states) {
if (inst.GetOpcode() == MachineInstr::Opcode::Call) {
if (CallMayWriteMemory(module, inst)) {
InvalidateMemoryState(states, nullptr);
}
return;
}
if (inst.GetOpcode() == MachineInstr::Opcode::Memset) {
InvalidateMemoryState(states, nullptr);
return;
}
if (!HasTrackedAddress(inst)) {
return;
}
const AddressKey key = MakeAddressKey(inst.GetAddress());
if (inst.GetOpcode() == MachineInstr::Opcode::Store) {
InvalidateMemoryState(states, &key);
states[key] = {inst.GetOperands()[0], inst.GetValueType(), -1};
return;
}
}
MemoryMap SimulateBlockMemory(const MachineModule& module, const MachineBasicBlock& block,
const MemoryMap& in_state) {
MemoryMap state = in_state;
for (const auto& inst : block.GetInstructions()) {
ApplyMemoryDataflowInstruction(module, inst, state);
}
return state;
}
bool RunPeepholeOnBlock(const MachineModule& module, const MachineFunction& function,
MachineBasicBlock& block, const MemoryMap& in_state) {
bool changed = false;
AliasMap aliases;
MemoryMap memory_states = in_state;
std::vector<MachineInstr> rewritten;
std::vector<bool> removed;
rewritten.reserve(block.GetInstructions().size());
removed.reserve(block.GetInstructions().size());
for (const auto& original : block.GetInstructions()) {
MachineInstr inst = original;
changed |= RewriteUses(inst, aliases);
changed |= SimplifyInstruction(inst);
if (SimplifyCopy(function, inst)) {
changed = true;
continue;
}
rewritten.push_back(std::move(inst));
removed.push_back(false);
MachineInstr& current = rewritten.back();
bool remove_current = false;
changed |= TryOptimizeMemoryInstruction(module, function, current, memory_states, removed,
rewritten.size() - 1, &remove_current);
if (remove_current) {
removed.back() = true;
changed = true;
continue;
}
changed |= SimplifyInstruction(current);
if (SimplifyCopy(function, current)) {
removed.back() = true;
changed = true;
continue;
}
if (current.GetOpcode() == MachineInstr::Opcode::Call) {
if (CallMayReadMemory(module, current) || CallMayWriteMemory(module, current)) {
ObservePendingStores(memory_states);
}
} else if (current.GetOpcode() == MachineInstr::Opcode::Memset) {
ObservePendingStores(memory_states);
}
TrackAlias(current, aliases);
}
std::vector<MachineInstr> compacted;
compacted.reserve(rewritten.size());
for (std::size_t i = 0; i < rewritten.size(); ++i) {
if (!removed[i]) {
compacted.push_back(std::move(rewritten[i]));
} else {
changed = true;
}
}
if (compacted.size() != block.GetInstructions().size()) {
changed = true;
}
if (changed) {
block.GetInstructions() = std::move(compacted);
}
return changed;
}
bool IsSideEffectFree(const MachineInstr& inst) {
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::Arg:
case MachineInstr::Opcode::Copy:
case MachineInstr::Opcode::Load:
case MachineInstr::Opcode::Lea:
case MachineInstr::Opcode::Add:
case MachineInstr::Opcode::Sub:
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::Rem:
case MachineInstr::Opcode::ModMul:
case MachineInstr::Opcode::ModPow:
case MachineInstr::Opcode::DigitExtractPow2:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
case MachineInstr::Opcode::Shl:
case MachineInstr::Opcode::AShr:
case MachineInstr::Opcode::LShr:
case MachineInstr::Opcode::FAdd:
case MachineInstr::Opcode::FSub:
case MachineInstr::Opcode::FMul:
case MachineInstr::Opcode::FDiv:
case MachineInstr::Opcode::FNeg:
case MachineInstr::Opcode::ICmp:
case MachineInstr::Opcode::FCmp:
case MachineInstr::Opcode::ZExt:
case MachineInstr::Opcode::ItoF:
case MachineInstr::Opcode::FtoI:
return true;
case MachineInstr::Opcode::FSqrt:
return !inst.HasAddress();
case MachineInstr::Opcode::Store:
case MachineInstr::Opcode::Br:
case MachineInstr::Opcode::CondBr:
case MachineInstr::Opcode::Call:
case MachineInstr::Opcode::Ret:
case MachineInstr::Opcode::Memset:
case MachineInstr::Opcode::Unreachable:
return false;
}
return false;
}
bool RunDeadInstrElimination(MachineFunction& function) {
bool changed = false;
while (true) {
std::unordered_map<int, int> use_counts;
for (const auto& block : function.GetBlocks()) {
for (const auto& inst : block->GetInstructions()) {
for (int use : inst.GetUses()) {
++use_counts[use];
}
}
}
bool local_changed = false;
for (auto& block : function.GetBlocks()) {
std::vector<MachineInstr> rewritten;
rewritten.reserve(block->GetInstructions().size());
for (auto& inst : block->GetInstructions()) {
const auto defs = inst.GetDefs();
const bool has_live_def =
defs.empty() || use_counts.find(defs.front()) != use_counts.end();
if (has_live_def || !IsSideEffectFree(inst)) {
rewritten.push_back(inst);
continue;
}
local_changed = true;
}
if (local_changed) {
block->GetInstructions() = std::move(rewritten);
}
}
if (!local_changed) {
break;
}
changed = true;
}
return changed;
}
bool HasAssignedAllocations(const MachineFunction& function) {
for (const auto& vreg : function.GetVRegs()) {
if (function.GetAllocation(vreg.id).kind != Allocation::Kind::Unassigned) {
return true;
}
}
return false;
}
} // namespace
bool RunPeephole(MachineModule& module) {
bool changed = false;
for (auto& function : module.GetFunctions()) {
if (!function) {
continue;
}
bool function_changed = false;
const auto cfg = BuildCFG(*function);
std::vector<MemoryMap> in_states(function->GetBlocks().size());
std::vector<MemoryMap> out_states(function->GetBlocks().size());
bool dataflow_changed = true;
while (dataflow_changed) {
dataflow_changed = false;
for (std::size_t i = 0; i < function->GetBlocks().size(); ++i) {
MemoryMap in_state;
if (i != 0) {
std::vector<const MemoryMap*> predecessors;
for (int pred : cfg.predecessors[i]) {
predecessors.push_back(&out_states[static_cast<std::size_t>(pred)]);
}
in_state = MeetMemoryStates(predecessors);
}
auto out_state =
SimulateBlockMemory(module, *function->GetBlocks()[i], in_state);
if (!SameMemoryMap(in_states[i], in_state)) {
in_states[i] = std::move(in_state);
dataflow_changed = true;
}
if (!SameMemoryMap(out_states[i], out_state)) {
out_states[i] = std::move(out_state);
dataflow_changed = true;
}
}
}
for (std::size_t i = 0; i < function->GetBlocks().size(); ++i) {
function_changed |=
RunPeepholeOnBlock(module, *function, *function->GetBlocks()[i], in_states[i]);
}
if (!HasAssignedAllocations(*function)) {
function_changed |= RunDeadInstrElimination(*function);
}
changed |= function_changed;
}
return changed;
}
} // namespace mir

@ -0,0 +1,257 @@
#include "mir/MIR.h"
#include <unordered_map>
#include <utility>
#include <vector>
namespace mir {
namespace {
struct RematDef {
enum class Kind { Invalid, ImmCopy, Lea };
Kind kind = Kind::Invalid;
ValueType type = ValueType::Void;
MachineOperand source;
AddressExpr address;
};
bool IsCheapRematerializableDef(const MachineInstr& inst, RematDef& def) {
const auto defs = inst.GetDefs();
if (defs.size() != 1) {
return false;
}
if (inst.GetOpcode() == MachineInstr::Opcode::Copy) {
const auto& operands = inst.GetOperands();
if (operands.size() < 2 || operands[1].GetKind() != OperandKind::Imm) {
return false;
}
def.kind = RematDef::Kind::ImmCopy;
def.type = inst.GetValueType();
def.source = operands[1];
return true;
}
if (inst.GetOpcode() != MachineInstr::Opcode::Lea || !inst.HasAddress()) {
return false;
}
const auto& address = inst.GetAddress();
if (address.base_kind == AddrBaseKind::VReg || !address.scaled_vregs.empty()) {
return false;
}
def.kind = RematDef::Kind::Lea;
def.type = ValueType::Ptr;
def.address = address;
return true;
}
MachineInstr BuildRematInstr(int dst_vreg, const RematDef& def) {
switch (def.kind) {
case RematDef::Kind::ImmCopy: {
MachineInstr inst(MachineInstr::Opcode::Copy,
{MachineOperand::VReg(dst_vreg), def.source});
inst.SetValueType(def.type);
return inst;
}
case RematDef::Kind::Lea: {
MachineInstr inst(MachineInstr::Opcode::Lea, {MachineOperand::VReg(dst_vreg)});
inst.SetAddress(def.address);
inst.SetValueType(ValueType::Ptr);
return inst;
}
case RematDef::Kind::Invalid:
break;
}
return MachineInstr(MachineInstr::Opcode::Unreachable, {});
}
bool RewriteMappedOperand(MachineOperand& operand,
const std::unordered_map<int, int>& rename_map) {
if (operand.GetKind() != OperandKind::VReg) {
return false;
}
auto it = rename_map.find(operand.GetVReg());
if (it == rename_map.end() || it->second == operand.GetVReg()) {
return false;
}
operand = MachineOperand::VReg(it->second);
return true;
}
bool RewriteMappedAddress(AddressExpr& address,
const std::unordered_map<int, int>& rename_map) {
bool changed = false;
if (address.base_kind == AddrBaseKind::VReg && address.base_index >= 0) {
auto it = rename_map.find(address.base_index);
if (it != rename_map.end() && it->second != address.base_index) {
address.base_index = it->second;
changed = true;
}
}
for (auto& term : address.scaled_vregs) {
auto it = rename_map.find(term.first);
if (it != rename_map.end() && it->second != term.first) {
term.first = it->second;
changed = true;
}
}
return changed;
}
bool RewriteUses(MachineInstr& inst, const std::unordered_map<int, int>& rename_map) {
bool changed = false;
auto& operands = inst.GetOperands();
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::Copy:
case MachineInstr::Opcode::ZExt:
case MachineInstr::Opcode::ItoF:
case MachineInstr::Opcode::FtoI:
case MachineInstr::Opcode::FSqrt:
case MachineInstr::Opcode::FNeg:
if (operands.size() >= 2) {
changed |= RewriteMappedOperand(operands[1], rename_map);
}
break;
case MachineInstr::Opcode::Store:
if (!operands.empty()) {
changed |= RewriteMappedOperand(operands[0], rename_map);
}
break;
case MachineInstr::Opcode::Add:
case MachineInstr::Opcode::Sub:
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::Rem:
case MachineInstr::Opcode::ModMul:
case MachineInstr::Opcode::ModPow:
case MachineInstr::Opcode::DigitExtractPow2:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
case MachineInstr::Opcode::Shl:
case MachineInstr::Opcode::AShr:
case MachineInstr::Opcode::LShr:
case MachineInstr::Opcode::FAdd:
case MachineInstr::Opcode::FSub:
case MachineInstr::Opcode::FMul:
case MachineInstr::Opcode::FDiv:
case MachineInstr::Opcode::ICmp:
case MachineInstr::Opcode::FCmp:
if (operands.size() >= 2) {
changed |= RewriteMappedOperand(operands[1], rename_map);
}
if (operands.size() >= 3) {
changed |= RewriteMappedOperand(operands[2], rename_map);
}
break;
case MachineInstr::Opcode::CondBr:
if (!operands.empty()) {
changed |= RewriteMappedOperand(operands[0], rename_map);
}
break;
case MachineInstr::Opcode::Call: {
const size_t arg_begin = inst.GetCallReturnType() == ValueType::Void ? 0 : 1;
for (size_t i = arg_begin; i < operands.size(); ++i) {
changed |= RewriteMappedOperand(operands[i], rename_map);
}
break;
}
case MachineInstr::Opcode::Ret:
if (!operands.empty()) {
changed |= RewriteMappedOperand(operands[0], rename_map);
}
break;
case MachineInstr::Opcode::Memset:
if (!operands.empty()) {
changed |= RewriteMappedOperand(operands[0], rename_map);
}
if (operands.size() >= 2) {
changed |= RewriteMappedOperand(operands[1], rename_map);
}
break;
case MachineInstr::Opcode::Arg:
case MachineInstr::Opcode::Load:
case MachineInstr::Opcode::Lea:
case MachineInstr::Opcode::Br:
case MachineInstr::Opcode::Unreachable:
break;
}
if (inst.HasAddress()) {
changed |= RewriteMappedAddress(inst.GetAddress(), rename_map);
}
return changed;
}
bool RunSpillReductionOnFunction(MachineFunction& function) {
bool changed = false;
for (auto& block_ptr : function.GetBlocks()) {
auto& instructions = block_ptr->GetInstructions();
std::unordered_map<int, RematDef> available_defs;
std::unordered_map<int, RematDef> after_call_defs;
std::unordered_map<int, int> rename_map;
bool after_call = false;
for (size_t i = 0; i < instructions.size(); ++i) {
if (after_call) {
const auto uses = instructions[i].GetUses();
for (int use : uses) {
if (rename_map.count(use) != 0) {
continue;
}
auto it = after_call_defs.find(use);
if (it == after_call_defs.end()) {
continue;
}
const int new_vreg = function.NewVReg(function.GetVRegInfo(use).type);
instructions.insert(instructions.begin() + static_cast<long long>(i),
BuildRematInstr(new_vreg, it->second));
++i;
rename_map[use] = new_vreg;
available_defs[new_vreg] = it->second;
changed = true;
}
RewriteUses(instructions[i], rename_map);
}
const auto defs = instructions[i].GetDefs();
for (int def : defs) {
available_defs.erase(def);
after_call_defs.erase(def);
rename_map.erase(def);
}
RematDef def;
if (IsCheapRematerializableDef(instructions[i], def)) {
for (int vreg : defs) {
available_defs[vreg] = def;
}
}
if (instructions[i].GetOpcode() == MachineInstr::Opcode::Call ||
instructions[i].GetOpcode() == MachineInstr::Opcode::Memset) {
after_call_defs = available_defs;
rename_map.clear();
after_call = true;
}
}
}
return changed;
}
} // namespace
bool RunSpillReduction(MachineModule& module) {
bool changed = false;
for (auto& function : module.GetFunctions()) {
if (function) {
changed |= RunSpillReductionOnFunction(*function);
}
}
return changed;
}
} // namespace mir

@ -6,5 +6,6 @@ add_library(sem STATIC
target_link_libraries(sem PUBLIC
build_options
frontend
${ANTLR4_RUNTIME_TARGET}
)
)

@ -1,200 +1,685 @@
#include "sem/Sema.h"
#include "sem/Sema.h"
#include <any>
#include <stdexcept>
#include <algorithm>
#include <string>
#include "SysYBaseVisitor.h"
#include "sem/SymbolTable.h"
#include "utils/Log.h"
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace {
std::string GetLValueName(SysYParser::LValueContext& lvalue) {
if (!lvalue.ID()) {
throw std::runtime_error(FormatError("sema", "非法左值"));
enum class MemoryRoot {
None,
Local,
Global,
Param,
Unknown,
};
struct SymbolInfo {
SemanticType type = SemanticType::Int;
bool is_array = false;
bool is_param_array = false;
MemoryRoot root = MemoryRoot::Local;
std::vector<int> dims;
};
struct ExprInfo {
MemoryRoot root = MemoryRoot::None;
bool is_array = false;
};
struct CallSiteInfo {
std::string callee;
std::vector<MemoryRoot> arg_roots;
};
struct DirectFunctionAnalysis {
FunctionSemanticInfo info;
std::vector<CallSiteInfo> calls;
};
class ScopedSymbols {
public:
void EnterScope() { scopes_.emplace_back(); }
void ExitScope() {
if (!scopes_.empty()) {
scopes_.pop_back();
}
}
bool Insert(const std::string& name, const SymbolInfo& info) {
if (scopes_.empty()) {
EnterScope();
}
return scopes_.back().emplace(name, info).second;
}
const SymbolInfo* Lookup(const std::string& name) const {
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
auto found = it->find(name);
if (found != it->end()) {
return &found->second;
}
}
return nullptr;
}
private:
std::vector<std::unordered_map<std::string, SymbolInfo>> scopes_;
};
std::string ExpectIdent(antlr4::tree::TerminalNode* ident) {
return ident == nullptr ? std::string{} : ident->getText();
}
SemanticType ParseBType(SysYParser::BTypeContext* ctx) {
if (ctx != nullptr && ctx->FLOAT() != nullptr) {
return SemanticType::Float;
}
return lvalue.ID()->getText();
return SemanticType::Int;
}
class SemaVisitor final : public SysYBaseVisitor {
public:
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少编译单元"));
SemanticType ParseFuncType(SysYParser::FuncTypeContext* ctx) {
if (ctx == nullptr || ctx->VOID() != nullptr) {
return SemanticType::Void;
}
if (ctx->FLOAT() != nullptr) {
return SemanticType::Float;
}
return SemanticType::Int;
}
std::vector<int> MakeShape(std::size_t rank) {
return std::vector<int>(rank, -1);
}
void RegisterBuiltinFunctions(SemanticContext& context) {
struct BuiltinSpec {
const char* name;
SemanticType return_type;
std::vector<bool> param_is_array;
bool reads_global_memory = false;
bool writes_global_memory = false;
bool reads_param_memory = false;
bool writes_param_memory = false;
bool has_io = false;
};
const std::vector<BuiltinSpec> builtins = {
{"getint", SemanticType::Int, {}, false, false, false, false, true},
{"getch", SemanticType::Int, {}, false, false, false, false, true},
{"getfloat", SemanticType::Float, {}, false, false, false, false, true},
{"getarray", SemanticType::Int, {true}, false, false, false, true, true},
{"getfarray", SemanticType::Int, {true}, false, false, false, true, true},
{"putint", SemanticType::Void, {false}, false, false, false, false, true},
{"putch", SemanticType::Void, {false}, false, false, false, false, true},
{"putfloat", SemanticType::Void, {false}, false, false, false, false, true},
{"putarray", SemanticType::Void, {false, true}, false, false, true, false, true},
{"putfarray", SemanticType::Void, {false, true}, false, false, true, false, true},
{"starttime", SemanticType::Void, {}, false, false, false, false, true},
{"stoptime", SemanticType::Void, {}, false, false, false, false, true},
};
for (const auto& builtin : builtins) {
auto& info = context.UpsertFunction(builtin.name);
info.return_type = builtin.return_type;
info.param_is_array = builtin.param_is_array;
info.is_builtin = true;
info.is_defined = false;
info.reads_global_memory = builtin.reads_global_memory;
info.writes_global_memory = builtin.writes_global_memory;
info.reads_param_memory = builtin.reads_param_memory;
info.writes_param_memory = builtin.writes_param_memory;
info.has_io = builtin.has_io;
info.has_unknown_effects = false;
info.is_recursive = false;
info.direct_callees.clear();
}
}
void CollectGlobalDecl(SemanticContext& context, SysYParser::DeclContext& ctx) {
if (auto* const_decl = ctx.constDecl()) {
const auto type = ParseBType(const_decl->bType());
for (auto* def : const_decl->constDef()) {
auto& info = context.UpsertGlobal(ExpectIdent(def->Ident()));
info.type = type;
info.is_const = true;
info.is_array = !def->constExp().empty();
info.dims = MakeShape(def->constExp().size());
}
return;
}
if (auto* var_decl = ctx.varDecl()) {
const auto type = ParseBType(var_decl->bType());
for (auto* def : var_decl->varDef()) {
auto& info = context.UpsertGlobal(ExpectIdent(def->Ident()));
info.type = type;
info.is_const = false;
info.is_array = !def->constExp().empty();
info.dims = MakeShape(def->constExp().size());
}
auto* func = ctx->funcDef();
if (!func || !func->blockStmt()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
}
void CollectFunctionSignature(SemanticContext& context, SysYParser::FuncDefContext& ctx) {
auto& info = context.UpsertFunction(ExpectIdent(ctx.Ident()));
info.return_type = ParseFuncType(ctx.funcType());
info.param_is_array.clear();
info.is_builtin = false;
info.is_defined = true;
info.reads_global_memory = false;
info.writes_global_memory = false;
info.reads_param_memory = false;
info.writes_param_memory = false;
info.has_io = false;
info.has_unknown_effects = false;
info.is_recursive = false;
info.direct_callees.clear();
if (auto* params = ctx.funcFParams()) {
for (auto* param : params->funcFParam()) {
info.param_is_array.push_back(!param->LBRACK().empty());
}
if (!func->ID() || func->ID()->getText() != "main") {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
}
bool SameInfo(const FunctionSemanticInfo& lhs, const FunctionSemanticInfo& rhs) {
return lhs.return_type == rhs.return_type &&
lhs.param_is_array == rhs.param_is_array &&
lhs.is_builtin == rhs.is_builtin && lhs.is_defined == rhs.is_defined &&
lhs.reads_global_memory == rhs.reads_global_memory &&
lhs.writes_global_memory == rhs.writes_global_memory &&
lhs.reads_param_memory == rhs.reads_param_memory &&
lhs.writes_param_memory == rhs.writes_param_memory &&
lhs.has_io == rhs.has_io &&
lhs.has_unknown_effects == rhs.has_unknown_effects &&
lhs.is_recursive == rhs.is_recursive &&
lhs.direct_callees == rhs.direct_callees;
}
class FunctionAnalyzer {
public:
FunctionAnalyzer(const SemanticContext& context, DirectFunctionAnalysis& analysis)
: context_(context), analysis_(analysis) {}
void Analyze(SysYParser::FuncDefContext& ctx) {
symbols_.EnterScope();
if (auto* params = ctx.funcFParams()) {
for (auto* param : params->funcFParam()) {
SymbolInfo info;
info.type = ParseBType(param->bType());
info.is_array = !param->LBRACK().empty();
info.is_param_array = info.is_array;
info.root = info.is_array ? MemoryRoot::Param : MemoryRoot::Local;
info.dims = MakeShape(param->exp().size());
symbols_.Insert(ExpectIdent(param->Ident()), info);
}
}
AnalyzeBlock(*ctx.block(), false);
symbols_.ExitScope();
}
private:
struct LValueShape {
MemoryRoot root = MemoryRoot::None;
bool is_array = false;
};
const SymbolInfo* LookupSymbol(const std::string& name) const {
if (const auto* local = symbols_.Lookup(name)) {
return local;
}
func->accept(this);
if (!seen_return_) {
throw std::runtime_error(
FormatError("sema", "main 函数必须包含 return 语句"));
if (const auto* global = context_.LookupGlobal(name)) {
static thread_local SymbolInfo scratch;
scratch.type = global->type;
scratch.is_array = global->is_array;
scratch.is_param_array = false;
scratch.root = MemoryRoot::Global;
scratch.dims = global->dims;
return &scratch;
}
return {};
return nullptr;
}
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override {
if (!ctx || !ctx->blockStmt()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
void AnalyzeBlock(SysYParser::BlockContext& ctx, bool create_scope) {
if (create_scope) {
symbols_.EnterScope();
}
if (!ctx->funcType() || !ctx->funcType()->INT()) {
throw std::runtime_error(FormatError("sema", "当前仅支持 int main"));
for (auto* item : ctx.blockItem()) {
AnalyzeBlockItem(*item);
}
const auto& items = ctx->blockStmt()->blockItem();
if (items.empty()) {
throw std::runtime_error(
FormatError("sema", "main 函数不能为空,且必须以 return 结束"));
if (create_scope) {
symbols_.ExitScope();
}
ctx->blockStmt()->accept(this);
return {};
}
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少语句块"));
void AnalyzeBlockItem(SysYParser::BlockItemContext& ctx) {
if (auto* decl = ctx.decl()) {
AnalyzeDecl(*decl);
} else if (auto* stmt = ctx.stmt()) {
AnalyzeStmt(*stmt);
}
const auto& items = ctx->blockItem();
for (size_t i = 0; i < items.size(); ++i) {
auto* item = items[i];
if (!item) {
continue;
}
if (seen_return_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
}
void AnalyzeDecl(SysYParser::DeclContext& ctx) {
if (auto* const_decl = ctx.constDecl()) {
AnalyzeConstDecl(*const_decl);
} else if (auto* var_decl = ctx.varDecl()) {
AnalyzeVarDecl(*var_decl);
}
}
void AnalyzeVarDecl(SysYParser::VarDeclContext& ctx) {
const auto type = ParseBType(ctx.bType());
for (auto* def : ctx.varDef()) {
SymbolInfo info;
info.type = type;
info.is_array = !def->constExp().empty();
info.root = MemoryRoot::Local;
info.dims = MakeShape(def->constExp().size());
symbols_.Insert(ExpectIdent(def->Ident()), info);
if (def->initVal() != nullptr) {
AnalyzeInitVal(def->initVal());
}
current_item_index_ = i;
total_items_ = items.size();
item->accept(this);
}
return {};
}
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
void AnalyzeConstDecl(SysYParser::ConstDeclContext& ctx) {
const auto type = ParseBType(ctx.bType());
for (auto* def : ctx.constDef()) {
SymbolInfo info;
info.type = type;
info.is_array = !def->constExp().empty();
info.root = MemoryRoot::Local;
info.dims = MakeShape(def->constExp().size());
symbols_.Insert(ExpectIdent(def->Ident()), info);
AnalyzeConstInitVal(def->constInitVal());
}
if (ctx->decl()) {
ctx->decl()->accept(this);
return {};
}
void AnalyzeInitVal(SysYParser::InitValContext* ctx) {
if (ctx == nullptr) {
return;
}
if (ctx->stmt()) {
ctx->stmt()->accept(this);
return {};
if (ctx->exp() != nullptr) {
AnalyzeExp(*ctx->exp());
return;
}
for (auto* child : ctx->initVal()) {
AnalyzeInitVal(child);
}
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
std::any visitDecl(SysYParser::DeclContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
void AnalyzeConstInitVal(SysYParser::ConstInitValContext* ctx) {
if (ctx == nullptr) {
return;
}
if (ctx->constExp() != nullptr) {
AnalyzeAddExp(*ctx->constExp()->addExp());
return;
}
if (!ctx->btype() || !ctx->btype()->INT()) {
throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明"));
for (auto* child : ctx->constInitVal()) {
AnalyzeConstInitVal(child);
}
auto* var_def = ctx->varDef();
if (!var_def || !var_def->lValue()) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
}
void AnalyzeStmt(SysYParser::StmtContext& ctx) {
if (ctx.lVal() != nullptr && ctx.ASSIGN() != nullptr) {
AnalyzeLValWrite(*ctx.lVal());
AnalyzeExp(*ctx.exp());
return;
}
const std::string name = GetLValueName(*var_def->lValue());
if (table_.Contains(name)) {
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
if (ctx.block() != nullptr) {
AnalyzeBlock(*ctx.block(), true);
return;
}
if (auto* init = var_def->initValue()) {
if (!init->exp()) {
throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化"));
if (ctx.IF() != nullptr) {
if (ctx.cond() != nullptr) {
AnalyzeLOrExp(*ctx.cond()->lOrExp());
}
init->exp()->accept(this);
if (!ctx.stmt().empty()) {
AnalyzeStmt(*ctx.stmt()[0]);
}
if (ctx.stmt().size() > 1 && ctx.stmt()[1] != nullptr) {
AnalyzeStmt(*ctx.stmt()[1]);
}
return;
}
table_.Add(name, var_def);
return {};
if (ctx.WHILE() != nullptr) {
if (ctx.cond() != nullptr) {
AnalyzeLOrExp(*ctx.cond()->lOrExp());
}
if (!ctx.stmt().empty() && ctx.stmt()[0] != nullptr) {
AnalyzeStmt(*ctx.stmt()[0]);
}
return;
}
if (ctx.RETURN() != nullptr) {
if (ctx.exp() != nullptr) {
AnalyzeExp(*ctx.exp());
}
return;
}
if (ctx.exp() != nullptr) {
AnalyzeExp(*ctx.exp());
}
}
ExprInfo AnalyzeExp(SysYParser::ExpContext& ctx) {
return AnalyzeAddExp(*ctx.addExp());
}
std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (!ctx || !ctx->returnStmt()) {
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
ExprInfo AnalyzeAddExp(SysYParser::AddExpContext& ctx) {
if (ctx.addExp() != nullptr) {
AnalyzeAddExp(*ctx.addExp());
AnalyzeMulExp(*ctx.mulExp());
return {};
}
ctx->returnStmt()->accept(this);
return {};
return AnalyzeMulExp(*ctx.mulExp());
}
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "return 缺少表达式"));
ExprInfo AnalyzeMulExp(SysYParser::MulExpContext& ctx) {
if (ctx.mulExp() != nullptr) {
AnalyzeMulExp(*ctx.mulExp());
AnalyzeUnaryExp(*ctx.unaryExp());
return {};
}
ctx->exp()->accept(this);
seen_return_ = true;
if (current_item_index_ + 1 != total_items_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
return AnalyzeUnaryExp(*ctx.unaryExp());
}
ExprInfo AnalyzeUnaryExp(SysYParser::UnaryExpContext& ctx) {
if (ctx.primaryExp() != nullptr) {
return AnalyzePrimaryExp(*ctx.primaryExp());
}
if (ctx.Ident() != nullptr) {
const auto name = ExpectIdent(ctx.Ident());
CallSiteInfo call;
call.callee = name;
const auto* callee = context_.LookupFunction(name);
const auto args = ctx.funcRParams() == nullptr ? std::vector<SysYParser::ExpContext*>{}
: ctx.funcRParams()->exp();
call.arg_roots.resize(args.size(), MemoryRoot::None);
for (std::size_t i = 0; i < args.size(); ++i) {
auto arg_info = AnalyzeExp(*args[i]);
if (callee != nullptr && i < callee->param_is_array.size() && callee->param_is_array[i]) {
call.arg_roots[i] = arg_info.is_array ? arg_info.root : MemoryRoot::Unknown;
}
}
if (callee == nullptr) {
analysis_.info.has_unknown_effects = true;
}
analysis_.calls.push_back(std::move(call));
return {};
}
if (ctx.unaryExp() != nullptr) {
AnalyzeUnaryExp(*ctx.unaryExp());
}
return {};
}
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "非法括号表达式"));
ExprInfo AnalyzePrimaryExp(SysYParser::PrimaryExpContext& ctx) {
if (ctx.exp() != nullptr) {
return AnalyzeExp(*ctx.exp());
}
if (ctx.lVal() != nullptr) {
return AnalyzeLValRead(*ctx.lVal());
}
ctx->exp()->accept(this);
return {};
}
std::any visitVarExp(SysYParser::VarExpContext* ctx) override {
if (!ctx || !ctx->var()) {
throw std::runtime_error(FormatError("sema", "非法变量表达式"));
ExprInfo AnalyzeRelExp(SysYParser::RelExpContext& ctx) {
if (ctx.relExp() != nullptr) {
AnalyzeRelExp(*ctx.relExp());
AnalyzeAddExp(*ctx.addExp());
return {};
}
return AnalyzeAddExp(*ctx.addExp());
}
ExprInfo AnalyzeEqExp(SysYParser::EqExpContext& ctx) {
if (ctx.eqExp() != nullptr) {
AnalyzeEqExp(*ctx.eqExp());
AnalyzeRelExp(*ctx.relExp());
return {};
}
ctx->var()->accept(this);
return {};
return AnalyzeRelExp(*ctx.relExp());
}
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override {
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量"));
ExprInfo AnalyzeLAndExp(SysYParser::LAndExpContext& ctx) {
if (ctx.lAndExp() != nullptr) {
AnalyzeLAndExp(*ctx.lAndExp());
AnalyzeEqExp(*ctx.eqExp());
return {};
}
return {};
return AnalyzeEqExp(*ctx.eqExp());
}
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override {
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式"));
ExprInfo AnalyzeLOrExp(SysYParser::LOrExpContext& ctx) {
if (ctx.lOrExp() != nullptr) {
AnalyzeLOrExp(*ctx.lOrExp());
AnalyzeLAndExp(*ctx.lAndExp());
return {};
}
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
return AnalyzeLAndExp(*ctx.lAndExp());
}
std::any visitVar(SysYParser::VarContext* ctx) override {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "非法变量引用"));
LValueShape DescribeLVal(SysYParser::LValContext& ctx) {
for (auto* index : ctx.exp()) {
AnalyzeExp(*index);
}
const std::string name = ctx->ID()->getText();
auto* decl = table_.Lookup(name);
if (!decl) {
throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name));
const auto* symbol = LookupSymbol(ExpectIdent(ctx.Ident()));
if (symbol == nullptr) {
return {};
}
sema_.BindVarUse(ctx, decl);
return {};
if (!symbol->is_array) {
return {symbol->root, false};
}
const auto index_count = ctx.exp().size();
bool still_array = false;
if (symbol->is_param_array) {
if (index_count == 0) {
still_array = true;
} else if (index_count <= symbol->dims.size()) {
still_array = true;
}
} else {
still_array = index_count < symbol->dims.size();
}
return {symbol->root, still_array};
}
SemanticContext TakeSemanticContext() { return std::move(sema_); }
ExprInfo AnalyzeLValRead(SysYParser::LValContext& ctx) {
const auto shape = DescribeLVal(ctx);
if (!shape.is_array) {
if (shape.root == MemoryRoot::Global) {
analysis_.info.reads_global_memory = true;
} else if (shape.root == MemoryRoot::Param) {
analysis_.info.reads_param_memory = true;
}
}
return {shape.root, shape.is_array};
}
private:
SymbolTable table_;
SemanticContext sema_;
bool seen_return_ = false;
size_t current_item_index_ = 0;
size_t total_items_ = 0;
void AnalyzeLValWrite(SysYParser::LValContext& ctx) {
const auto shape = DescribeLVal(ctx);
if (shape.root == MemoryRoot::Global) {
analysis_.info.writes_global_memory = true;
} else if (shape.root == MemoryRoot::Param) {
analysis_.info.writes_param_memory = true;
}
}
const SemanticContext& context_;
DirectFunctionAnalysis& analysis_;
ScopedSymbols symbols_;
};
void PropagateCallEffects(SemanticContext& context,
const std::unordered_map<std::string, DirectFunctionAnalysis>& analyses) {
bool changed = true;
while (changed) {
changed = false;
for (const auto& [name, analysis] : analyses) {
auto next = analysis.info;
std::unordered_set<std::string> callees_seen;
for (const auto& call : analysis.calls) {
if (!call.callee.empty()) {
callees_seen.insert(call.callee);
}
const auto* callee = context.LookupFunction(call.callee);
if (callee == nullptr) {
next.has_unknown_effects = true;
continue;
}
next.has_io = next.has_io || callee->has_io;
next.has_unknown_effects = next.has_unknown_effects || callee->has_unknown_effects;
next.reads_global_memory =
next.reads_global_memory || callee->reads_global_memory;
next.writes_global_memory =
next.writes_global_memory || callee->writes_global_memory;
const auto arg_count = std::min(call.arg_roots.size(), callee->param_is_array.size());
for (std::size_t i = 0; i < arg_count; ++i) {
if (!callee->param_is_array[i]) {
continue;
}
switch (call.arg_roots[i]) {
case MemoryRoot::Global:
next.reads_global_memory =
next.reads_global_memory || callee->reads_param_memory;
next.writes_global_memory =
next.writes_global_memory || callee->writes_param_memory;
break;
case MemoryRoot::Param:
next.reads_param_memory =
next.reads_param_memory || callee->reads_param_memory;
next.writes_param_memory =
next.writes_param_memory || callee->writes_param_memory;
break;
case MemoryRoot::Unknown:
if (callee->reads_param_memory || callee->writes_param_memory) {
next.has_unknown_effects = true;
}
break;
case MemoryRoot::None:
case MemoryRoot::Local:
break;
}
}
}
next.direct_callees.assign(callees_seen.begin(), callees_seen.end());
std::sort(next.direct_callees.begin(), next.direct_callees.end());
auto* current = context.LookupFunction(name);
if (current == nullptr || !SameInfo(next, *current)) {
context.UpsertFunction(name) = std::move(next);
changed = true;
}
}
}
}
bool ReachesSelf(const SemanticContext& context, const std::string& root,
const std::string& current,
std::unordered_set<std::string>& visiting) {
const auto* info = context.LookupFunction(current);
if (info == nullptr) {
return false;
}
if (!visiting.insert(current).second) {
return false;
}
for (const auto& callee : info->direct_callees) {
if (callee == root) {
return true;
}
if (ReachesSelf(context, root, callee, visiting)) {
return true;
}
}
return false;
}
void MarkRecursiveFunctions(SemanticContext& context) {
std::vector<std::string> function_names;
function_names.reserve(context.GetFunctions().size());
for (const auto& [name, info] : context.GetFunctions()) {
if (!info.is_builtin) {
function_names.push_back(name);
}
}
for (const auto& name : function_names) {
std::unordered_set<std::string> visiting;
if (ReachesSelf(context, name, name, visiting)) {
context.UpsertFunction(name).is_recursive = true;
}
}
}
} // namespace
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) {
SemaVisitor visitor;
comp_unit.accept(&visitor);
return visitor.TakeSemanticContext();
SemanticContext context;
RegisterBuiltinFunctions(context);
for (auto* child : comp_unit.children) {
if (auto* decl = dynamic_cast<SysYParser::DeclContext*>(child)) {
CollectGlobalDecl(context, *decl);
} else if (auto* func = dynamic_cast<SysYParser::FuncDefContext*>(child)) {
CollectFunctionSignature(context, *func);
}
}
std::unordered_map<std::string, DirectFunctionAnalysis> analyses;
for (auto* child : comp_unit.children) {
auto* func = dynamic_cast<SysYParser::FuncDefContext*>(child);
if (func == nullptr) {
continue;
}
const auto name = ExpectIdent(func->Ident());
auto* existing = context.LookupFunction(name);
if (existing == nullptr) {
continue;
}
DirectFunctionAnalysis analysis;
analysis.info = *existing;
analysis.info.reads_global_memory = false;
analysis.info.writes_global_memory = false;
analysis.info.reads_param_memory = false;
analysis.info.writes_param_memory = false;
analysis.info.has_io = false;
analysis.info.has_unknown_effects = false;
analysis.info.is_recursive = false;
analysis.info.direct_callees.clear();
FunctionAnalyzer analyzer(context, analysis);
analyzer.Analyze(*func);
analyses.emplace(name, std::move(analysis));
}
for (const auto& [name, analysis] : analyses) {
context.UpsertFunction(name) = analysis.info;
}
PropagateCallEffects(context, analyses);
MarkRecursiveFunctions(context);
return context;
}

@ -1,17 +1,43 @@
// 维护局部变量声明的注册与查找。
#include "sem/SymbolTable.h"
#include "sem/SymbolTable.h"
void SymbolTable::Clear() { scopes_.clear(); }
void SymbolTable::Add(const std::string& name,
SysYParser::VarDefContext* decl) {
table_[name] = decl;
void SymbolTable::EnterScope() { scopes_.emplace_back(); }
void SymbolTable::ExitScope() {
if (!scopes_.empty()) {
scopes_.pop_back();
}
}
bool SymbolTable::Insert(const std::string& name, const SymbolEntry& entry) {
if (scopes_.empty()) {
EnterScope();
}
auto& scope = scopes_.back();
return scope.emplace(name, entry).second;
}
bool SymbolTable::Contains(const std::string& name) const {
return table_.find(name) != table_.end();
bool SymbolTable::ContainsInCurrentScope(const std::string& name) const {
return !scopes_.empty() && scopes_.back().find(name) != scopes_.back().end();
}
SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const {
auto it = table_.find(name);
return it == table_.end() ? nullptr : it->second;
SymbolEntry* SymbolTable::Lookup(const std::string& name) {
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
auto found = it->find(name);
if (found != it->end()) {
return &found->second;
}
}
return nullptr;
}
const SymbolEntry* SymbolTable::Lookup(const std::string& name) const {
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
auto found = it->find(name);
if (found != it->end()) {
return &found->second;
}
}
return nullptr;
}

@ -1,4 +1,81 @@
// SysY 运行库实现:
// - 按实验/评测规范提供 I/O 等函数实现
// - 与编译器生成的目标代码链接,支撑运行时行为
#include "sylib.h"
#include <stdio.h>
#include <stdlib.h>
static int read_char_normalized(void) {
int ch = getchar();
if (ch == '\r') {
int next = getchar();
if (next != '\n' && next != EOF) {
ungetc(next, stdin);
}
return '\n';
}
return ch;
}
static float read_float_token(void) {
char buffer[256];
if (scanf("%255s", buffer) != 1) {
return 0.0f;
}
return strtof(buffer, NULL);
}
int getint(void) {
int value = 0;
if (scanf("%d", &value) != 1) {
return 0;
}
return value;
}
int getch(void) {
int ch = read_char_normalized();
return ch == EOF ? 0 : ch;
}
float getfloat(void) { return read_float_token(); }
int getarray(int a[]) {
int n = getint();
for (int i = 0; i < n; ++i) {
a[i] = getint();
}
return n;
}
int getfarray(float a[]) {
int n = getint();
for (int i = 0; i < n; ++i) {
a[i] = getfloat();
}
return n;
}
void putint(int x) { printf("%d", x); }
void putch(int x) { putchar(x); }
void putfloat(float x) { printf("%a", (double)x); }
void putarray(int n, const int a[]) {
printf("%d:", n);
for (int i = 0; i < n; ++i) {
printf(" %d", a[i]);
}
putchar('\n');
}
void putfarray(int n, const float a[]) {
printf("%d:", n);
for (int i = 0; i < n; ++i) {
printf(" %a", (double)a[i]);
}
putchar('\n');
}
void starttime(void) {}
void stoptime(void) {}

@ -1,4 +1,17 @@
// SysY 运行库头文件:
// - 声明运行库函数原型(供编译器生成 call 或链接阶段引用)
// - 与 sylib.c 配套,按规范逐步补齐声明
#ifndef SYLIB_H_
#define SYLIB_H_
int getint(void);
int getch(void);
float getfloat(void);
int getarray(int a[]);
int getfarray(float a[]);
void putint(int x);
void putch(int x);
void putfloat(float x);
void putarray(int n, const int a[]);
void putfarray(int n, const float a[]);
void starttime(void);
void stoptime(void);
#endif

@ -1,9 +0,0 @@
int main(){
const int a[4][2] = {{1, 2}, {3, 4}, {}, 7};
const int N = 3;
int b[4][2] = {};
int c[4][2] = {1, 2, 3, 4, 5, 6, 7, 8};
int d[N + 1][2] = {1, 2, {3}, {5}, a[3][0], 8};
int e[4][2][1] = {{d[2][1], {c[2][1]}}, {3, 4}, {5, 6}, {7, 8}};
return e[3][1][0] + e[0][0][0] + e[0][1][0] + d[3][0];
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save