diff --git a/doc/Lab3-指令选择与汇编生成.md b/doc/Lab3-指令选择与汇编生成.md index 3d8551f..7f1d382 100644 --- a/doc/Lab3-指令选择与汇编生成.md +++ b/doc/Lab3-指令选择与汇编生成.md @@ -1,398 +1,548 @@ -# Lab3 代码生成实现核查说明 +# Lab3:指令选择与汇编生成 ## 1. 文档定位 -本文档不是继续宣传“Lab3 已经按课件要求完整自研实现”,而是对当前仓库中的 Lab3 代码生成方案做一次基于课件标准的核查。 +本文档用于说明当前仓库中 Lab3 后端的实际实现状态,面向组内协作与代码核验。 -核查依据是以下三份材料: +重点回答 4 个问题: -- `lab03-code generation-2026.pdf` -- `lecture05-instruction selection-169.pdf` -- `lecture11-register allocation-part2-169.pdf` +1. `--emit-asm` 这条链路现在到底经过了哪些阶段。 +2. 指令选择、寄存器分配、栈布局分别是怎么做的。 +3. 这套后端和课件要求的对应关系是什么。 +4. 当前是否已经跑通全部测试;如果没有,还剩什么问题。 -本文档要回答的问题只有两个: - -1. 当前仓库里的 Lab3,是否严格按这三份课件的标准完成。 -2. 如果没有,是哪一部分没有按标准完成;当前真实实现到底是什么。 +本文档基于当前分支的真实代码状态编写,不基于“理想目标”描述,不省略已知限制。 --- -## 2. 核查结论 +## 2. 当前实现结论 -结论先给出: +当前 Lab3 已经从“借助外部后端输出汇编”的模式,切换为“仓库内自研 MIR 后端 + AArch64 汇编打印”的模式。 -- 当前 Lab3 **可以跑通测试**。 -- 当前 Lab3 **不符合**“按 lecture05 手写指令选择 + 按 lecture11 手写线性扫描寄存器分配 + 按 lab03 手写栈布局”的严格标准。 +`compiler --emit-asm` 的主流程现在是: -更具体地说: +1. ANTLR 解析得到语法树。 +2. `Sema` 做语义检查。 +3. `IRGen` 生成 IR。 +4. `RunIRPassPipeline` 在 IR 层运行优化管线。 + 这里包含 SSA / Mem2Reg 等 IR 级处理,因此后端看到的不是最初的纯内存式 IR。 +5. `LowerToMIR` 把 IR 降到自定义 MIR。 +6. `RunRegAlloc` 对 MIR 做寄存器分配。 +7. `RunFrameLowering` 计算栈帧布局。 +8. `PrintAsm` 输出最终 AArch64 汇编。 -1. 当前 `--emit-asm` 不是走仓库内自研 `mir` 后端生成汇编。 -2. 当前仓库里也不存在一个 lecture11 风格的线性扫描寄存器分配器。 -3. 当前仓库里也不存在一套覆盖完整 SysY 测试集的自研 AArch64 栈布局与调用约定实现。 -4. 当前测试之所以能全通过,是因为 `--emit-asm` 实际改成了: - - `SysY 前端 -> IR -> IR Pass Pipeline -> llc -> AArch64 汇编` +也就是说,当前 `--emit-asm` 已经不再直接依赖 LLVM 的 `llc` 生成汇编,汇编代码由项目内后端自己产生。 -所以如果按照“工程结果”衡量: +当前状态下,后端功能已经基本打通,最近一次批量结果将失败收敛到 1 个测试: -- 这轮实现是成功的,汇编能生成,`test/` 能全过。 +- `test/class_test_case/performance/vector_mul3.sy` -如果按照“课程实现路径”衡量: +这个用例当前不是“结果算错”,而是在默认 300 秒测评时限下仍然超时。 +因此,当前结论应当表述为: -- 这轮实现不是严格按你现在指定的课件标准完成的。 +- Lab3 后端链路已基本完成。 +- 正确性问题已经大幅收敛。 +- 仍有 1 个性能型尾项没有消除,暂时不能宣称“全部样例完全跑通”。 --- -## 3. 课件标准是什么 +## 3. 关键文件与职责 + +### 3.1 主流程入口 + +- `src/main.cpp` + +当前 `main` 中与 Lab3 直接相关的改动是: + +- `--emit-asm` 路径先生成 IR。 +- 然后运行 `RunIRPassPipeline(*asm_module)`。 +- 再依次运行 `mir::LowerToMIR`、`mir::RunRegAlloc`、`mir::RunFrameLowering`、`mir::PrintAsm`。 + +这保证了后端接收的是经过 IR 优化后的模块,而不是最原始的前端输出。 -## 3.1 lecture05 对指令选择的要求 +### 3.2 MIR 数据结构 -从 `lecture05-instruction selection-169.pdf` 中可以提炼出两点核心要求: +- `include/mir/MIR.h` +- `src/mir/Register.cpp` +- `src/mir/MIRInstr.cpp` +- `src/mir/MIRBasicBlock.cpp` +- `src/mir/MIRFunction.cpp` +- `src/mir/MIRContext.cpp` -1. 指令选择是把 IR 翻译为目标 ISA 指令序列的过程。 -2. 对本课程实验语境,采用的是“宏扩展 / 逐条翻译(one-by-one translation)”的思路。 +这一组文件定义并实现了自研后端使用的 MIR 层,包括: -也就是说,按这份课件理解,Lab3 期待的实现方式应当是: +- `ValueType`:后端值类型,当前核心覆盖 `Void`、`I1`、`I32`、`F32`、`Ptr`。 +- `PhysReg` / `RegClass`:物理寄存器与寄存器类别。 +- `MachineOperand`:MIR 操作数,支持虚拟寄存器、立即数、块标签、符号。 +- `AddressExpr`:地址表达式,支持栈对象、全局符号、基址虚拟寄存器、常量偏移、缩放索引。 +- `MachineInstr`:MIR 指令,覆盖算术、比较、跳转、访存、调用、返回、转换等。 +- `MachineBasicBlock` / `MachineFunction` / `MachineModule`:后端基本块、函数、模块容器。 +- `StackObject` / `Allocation`:栈对象信息与寄存器分配结果。 -- 编译器自己遍历 IR -- 根据每一条 IR 指令选择对应的 AArch64 指令序列 -- 这个逻辑应体现在仓库自己的 lowering / instruction selection 代码中 +这部分是 Lab3 后端的基础设施层。 -而不是把 IR 直接交给外部成熟后端黑箱完成。 +### 3.3 IR -> MIR Lowering -## 3.2 lecture11 对寄存器分配的要求 +- `src/mir/Lowering.cpp` -从 `lecture11-register allocation-part2-169.pdf` 的 `10.6 基于线性扫描的寄存器分配方法` 可提炼出下面这些明确要素: +该文件负责把 IR 逐条降为 MIR,是后端最关键的“前半段”。 -1. 中间表示应是三地址码或伪指令。 -2. 操作数应先用虚寄存器表示。 -3. 基本块线性排序,得到可编号的指令序列。 -4. 计算 live interval。 -5. 维护按结束点排序的 `active` 表。 -6. 实现: - - `ExpireOldIntervals(i)` - - `SpillAtInterval(i)` -7. 当物理寄存器不够时,做溢出并分配栈位置。 +### 3.4 寄存器分配 -也就是说,按课件标准,仓库里应该能看到一套清晰的: +- `src/mir/RegAlloc.cpp` -- 虚寄存器表示 -- 活跃区间构造 -- 线性扫描主循环 -- spill / reload 策略 +该文件实现当前使用的线性扫描寄存器分配。 -## 3.3 lab03 对栈布局与代码生成的要求 +### 3.5 栈帧布局 -从 `lab03-code generation-2026.pdf` 中,可以提炼出下面这些与实现直接相关的要求: +- `src/mir/FrameLowering.cpp` -1. 实验方法明确写的是: - - 基于宏扩展的指令选择方法 - - 自顶向下逐条翻译 -2. 应从 `IR Module` 开始遍历,逐条翻译生成 ARM 汇编。 -3. 函数调用遵循 AAPCS64: - - 前 8 个整数参数通过 `x0~x7` - - 其余参数通过栈传递 - - 返回值通过 `x0` -4. 栈采用 `Full Descending`,高地址向低地址增长。 -5. 栈帧大小按 16 字节对齐。 -6. 需要正确处理: - - `x29(fp)` / `x30(lr)` - - `sp` - - caller / callee 保存规则 - - 栈上传参与局部对象布局 -7. 课件中给出的实验实现路线是: - - 遍历 Module / Function / BasicBlock / Instruction - - 将 IR 逐条翻译成对应的汇编代码 +该文件负责把 spill、局部对象、被使用的 callee-saved 寄存器保存槽统一放入栈帧,并计算偏移。 -也就是说,按 Lab3 讲义标准,期望看到的是一套在仓库内显式实现的: +### 3.6 汇编输出 -- instruction lowering -- frame lowering -- call lowering -- assembly printing +- `src/mir/AsmPrinter.cpp` + +该文件把 MIR 打印成 AArch64 汇编,是“真正落地到汇编文本”的阶段。 + +### 3.7 测试脚本 + +- `scripts/lab3_build_test.sh` +- `scripts/verify_asm.sh` + +这两份脚本负责 Lab3 的批量构建与运行验证。 --- -## 4. 当前代码实际是什么 +## 4. 后端总流程说明 -## 4.1 `src/main.cpp` +### 4.1 IR 优化后再进入后端 -当前 `--emit-asm` 的入口逻辑已经不是: +当前实现不是“前端直接把 alloca/load/store 风格 IR 原样送到汇编阶段”,而是: -- `mir::LowerToMIR` -- `mir::RunRegAlloc` -- `mir::RunFrameLowering` -- `mir::PrintAsm` +1. 前端先生成 IR。 +2. IR Pass Pipeline 在中间层先做优化。 +3. 后端对优化后的 IR 进行 Lowering。 -而是: +这样做有两个直接好处: -1. 生成 IR。 -2. 跑 `ir::RunIRPassPipeline(*module)`。 -3. 调用 `EmitAsmWithLLC(*module, std::cout)`。 -4. 在 `EmitAsmWithLLC` 中: - - 用 `IRPrinter` 把 IR 写到临时 `.ll` - - 调用外部 `llc -mtriple=aarch64-linux-gnu -filetype=asm` - - 读取生成的 `.s` 并输出 +- 后端处理的 IR 更接近 SSA 形态,便于后续寄存器分配。 +- 前端仍然可以保持“边查符号表边生成 IR”的清晰实现,不需要在 visitor 里硬做 SSA 构造。 -因此,当前真正完成指令选择、寄存器分配、栈布局和 ABI 细节的主体,不是仓库内 `mir`,而是 LLVM 的 `llc`。 +这和课件里的课程分层是匹配的:前端生成可用 IR,中端做 IR 级优化,后端负责代码生成。 -这是当前核查结论里最关键的一点。 +### 4.2 Lowering 的基本思路 -## 4.2 `include/mir/MIR.h` +`LowerToMIR` 的总体策略是: -当前 `MIR.h` 仍然只是非常小的骨架,主要特征如下: +- 每个 IR `Function` 降成一个 `MachineFunction`。 +- 每个 IR `BasicBlock` 对应一个 `MachineBasicBlock`。 +- 标量 SSA 值尽量映射为 MIR 虚拟寄存器。 +- 地址类对象保留为 `AddressExpr`,在汇编打印时再真正落成基址加偏移的地址计算。 -1. 物理寄存器只有: - - `W0` - - `W8` - - `W9` - - `X29` - - `X30` - - `SP` -2. 指令种类只有: - - `Prologue` - - `Epilogue` - - `MovImm` - - `LoadStack` - - `StoreStack` - - `AddRR` - - `Ret` -3. `MachineFunction` 只有一个 `entry_` 基本块。 -4. 没有虚寄存器。 -5. 没有 live interval 结构。 -6. 没有 CFG 级别的机器基本块管理。 +当前 Lowering 已经支持的主要 IR 类型包括: -这套数据结构本身就还没有达到 lecture11 线性扫描寄存器分配所需的表示能力。 +- 整数/浮点二元算术 +- 整数/浮点比较 +- `alloca` / `load` / `store` +- `br` / `condbr` / `return` +- `call` +- `gep` +- `zext` / `itof` / `ftoi` +- `memset` +- `unreachable` -## 4.3 `src/mir/Lowering.cpp` +### 4.3 Phi 的处理方式 -当前 `Lowering.cpp` 的事实是: +当前后端没有在 MIR 层引入块参数模型,而是采用更直观的“前驱块插拷贝”策略处理 Phi: -1. 只支持极少数 IR 指令: - - `Alloca` - - `Store` - - `Load` - - `Add` - - `Ret` -2. `Sub`、`Mul` 直接报不支持。 -3. 其他大多数 IR 指令直接报不支持。 -4. 只支持单函数,且函数名必须是 `main`。 -5. 只处理入口基本块。 +1. 先为 Phi 结果预分配目标虚拟寄存器。 +2. 在正常指令 Lowering 时跳过 Phi 本体。 +3. 遍历 Phi incoming,把对应 copy 插到前驱块 terminator 之前。 -这与 lab03 讲义中要求的“遍历 Module / Function / BasicBlock / Instruction,逐条生成汇编”明显不一致。 +这是一种符合课程实验规模、实现成本可控的做法。 -## 4.4 `src/mir/RegAlloc.cpp` +### 4.4 小函数直接内联 -当前 `RunRegAlloc` 的行为仅仅是: +`Lowering.cpp` 里还加入了一个很窄的直接调用内联路径: -- 遍历当前 MIR 指令的操作数 -- 检查寄存器是不是落在一小组允许的物理寄存器集合中 +- 只对内部函数生效。 +- 只对单基本块、指令种类很简单的纯标量函数生效。 +- 不处理递归或复杂控制流。 -它没有做下面任何一件 lecture11 线性扫描要求的事情: +这部分不是“通用函数内联优化器”,只是为减少高频小函数调用开销加入的一个后端小优化。 -1. 没有虚寄存器。 -2. 没有线性化指令编号。 -3. 没有 live interval。 -4. 没有 `active` 表。 -5. 没有 `ExpireOldIntervals`。 -6. 没有 `SpillAtInterval`。 -7. 没有物理寄存器池管理。 -8. 没有 spill / reload 代码插入。 +--- -所以这一部分不能被称为“按 lecture11 实现了线性扫描寄存器分配”。 +## 5. 指令选择实现说明 -## 4.5 `src/mir/FrameLowering.cpp` +本节对应 lecture05 的“指令选择”要求。 -当前 `RunFrameLowering` 做的事情只有: +### 5.1 当前的指令选择分层 -1. 按 frame slot 顺序累计大小。 -2. 给每个 slot 分配一个负偏移。 -3. 把 frame size 对齐到 16 字节。 -4. 在入口插 `Prologue`。 -5. 在 `Ret` 前插 `Epilogue`。 +当前实现把“指令选择”拆成两层: -它没有显式实现 lab03 课件要求的完整内容,例如: +1. `Lowering.cpp` 负责把 IR 操作翻译为抽象的 MIR 指令。 +2. `AsmPrinter.cpp` 负责把 MIR 指令选成具体的 AArch64 汇编形式。 -1. 参数区布局。 -2. 栈上传参。 -3. caller / callee saved 寄存器管理。 -4. 叶子函数 / 非叶子函数差异。 -5. 函数调用下的 outgoing arg area。 -6. 完整的 AAPCS64 栈帧组织。 - -因此,这一部分也不能称为“已经按 lab03 讲义要求完成了完整栈布局实现”。 +也就是说,MIR 是后端内部的“选择前中间层”。 -## 4.6 `src/mir/AsmPrinter.cpp` +### 5.2 算术与逻辑运算 -当前汇编打印器只会输出极少数内容: +当前整数运算支持映射到: -- `.text/.global/.type/.size` -- `stp/ldp` -- `mov` -- `sub/add sp` -- `ldur/stur` - `add` -- `ret` +- `sub` +- `mul` +- `sdiv` +- `and` +- `orr` +- `eor` +- `lsl` +- `asr` +- `lsr` + +当前浮点运算支持映射到: + +- `fadd` +- `fsub` +- `fmul` +- `fdiv` +- `fneg` + +### 5.3 比较与分支 + +当前支持: + +- 整数比较:`cmp` + 条件结果 / 条件跳转 +- 浮点比较:`fcmp` + 条件结果 / 条件跳转 + +本轮额外补了一个比较分支融合优化: + +- 如果出现“`ICmp/FCmp` 的结果只被紧跟的 `CondBr` 使用一次”, +- 则直接打印成 `cmp/fcmp + b. + b`, +- 不再额外输出 `cset + cbnz` 这一对中间指令。 + +这能显著减少循环头部的控制流开销。 + +### 5.4 立即数指令选择 + +之前的版本中,很多简单常量都会先 `movz/movk` 到寄存器,再做真正运算,这会拉低热点循环效率。 + +当前已经补上的选择规则包括: + +- `add/sub` 在立即数可编码时优先输出立即数形式。 +- 对大栈偏移的地址调整,不再一律走 `movz/movk + add`,而是优先拆成多条 `add/sub #imm` 或 `add/sub #imm, lsl #12`。 +- 对 `sdiv by 2^k`,加入了带符号修正的快速展开,而不是一律发通用 `sdiv`。 + +例如,当前会把“有符号除以 2”正确展开为: + +1. 先取符号位修正量。 +2. 把修正量加回被除数。 +3. 再做算术右移。 + +这样可以保持“向 0 取整”的语义,而不是错误地直接用 `asr` 代替有符号除法。 + +### 5.5 地址选择 + +当前地址打印支持: + +- 全局符号:`adrp + add :lo12:` +- 栈对象:`x29` 相对地址 +- 寄存器基址:先取基址寄存器,再加常量偏移或缩放索引 +- `GEP` 缩放索引: + - 若 stride 是较小 2 的幂,优先使用 `add ..., sxtw #shift` + - 否则回退为 `sxtw + mul + add` + +这个选择逻辑已经可以覆盖当前实验大部分数组和指针寻址场景。 + +--- + +## 6. 寄存器分配实现说明 + +本节对应 lecture11 的线性扫描寄存器分配要求。 + +### 6.1 总体算法 + +当前 `src/mir/RegAlloc.cpp` 实现的是“基于活跃区间的线性扫描寄存器分配”,整体流程为: + +1. 先按基本块分析 use/def/successor 信息。 +2. 在 CFG 上迭代求 `live_in / live_out`。 +3. 基于指令位置和活跃信息构造每个虚拟寄存器的 live interval。 +4. 按区间起点排序。 +5. 对 GPR 和 FPR 分别运行线性扫描。 +6. 寄存器不够时创建 spill slot。 + +这不是简单的“单基本块顺序扫描”,而是考虑了 CFG 活跃信息之后构造区间,因此能正确处理回边、循环和跨块活跃值。 + +### 6.2 区间构造 + +当前做法是: + +- 为每条 MIR 指令分配一个线性位置编号。 +- 先根据 defs / uses 触碰区间端点。 +- 再根据基本块的 `live_in / live_out` 把区间扩展到块边界。 + +这样可以覆盖诸如“变量在循环头定义、在回边继续活跃”的情况。 -没有覆盖: +### 6.3 active 集合与 spill 逻辑 -- 条件跳转 -- 比较 -- 调用 `bl` -- 参数传递 -- 浮点指令 -- 更完整的算术和逻辑指令 -- 全局地址访问 +当前实现维护 `active` 集合,并按区间结束点排序: -因此也不可能独立支撑完整 SysY 测试集的后端输出。 +- 扫描到新 interval 时,先 expire 已结束的 active interval。 +- 若有空闲物理寄存器,则直接分配。 +- 若没有空闲寄存器,则执行 `spill_at_interval`: + - 比较当前区间和 active 中结束最晚区间的 end。 + - 谁更“不划算”,谁 spill。 + +这是线性扫描算法的标准核心流程。 + +### 6.4 当前可分配寄存器集合 + +当前实现选择: + +- GPR:`x19` 到 `x28` +- FPR:`v8` 到 `v15` + +也就是优先使用 callee-saved 寄存器作为分配目标。这样做的直接结果是: + +- 函数内较长生命周期的值更稳定。 +- 需要在函数序言/结语中保存和恢复被用到的 callee-saved 寄存器。 + +### 6.5 当前实现的意义 + +这一版寄存器分配相比最初“仅按出现顺序分配”的方式,关键提升在于: + +- 可以正确处理循环中的跨迭代活跃值。 +- 避免把仍然 live 的寄存器过早复用。 +- 对矩阵和多参数类用例的稳定性明显更好。 --- -## 5. 对照结论:哪些符合,哪些不符合 +## 7. 栈布局实现说明 + +本节对应 Lab3 课件中的栈帧布局要求。 + +### 7.1 栈帧建立方式 -## 5.1 是否按 lecture05 的指令选择标准完成 +当前函数序言固定为: -结论:**不符合严格标准**。 +```asm +stp x29, x30, [sp, #-16]! +mov x29, sp +sub sp, sp, +``` -原因: +函数结语固定为: -- 课件要求的是编译器内部做“宏扩展、逐条翻译”。 -- 当前仓库实际生成汇编时,已经绕开了内部 `mir::LowerToMIR` 主路径,改由 `llc` 完成最终 AArch64 instruction selection。 +```asm +mov sp, x29 +ldp x29, x30, [sp], #16 +ret +``` -可以说: +这意味着: -- 当前结果在“效果上”完成了指令选择。 -- 但不是“仓库内按 lecture05 自己写出来的 instruction selection”。 +- `x29` 作为 frame pointer 使用。 +- 当前函数内部分配的局部栈对象都通过 `x29` 的负偏移访问。 +- 调用者传进来的栈参数位于 `x29 + 16 + offset` 一侧。 -## 5.2 是否按 lecture11 的线性扫描寄存器分配标准完成 +### 7.2 栈对象种类 -结论:**不符合**。 +当前 `StackObjectKind` 至少覆盖: -原因: +- `Local` +- `Spill` +- `SavedGPR` +- `SavedFPR` -- 当前 `RegAlloc.cpp` 只是一个物理寄存器白名单检查器。 -- 没有虚寄存器、活跃区间、`active` 表、溢出策略,也没有任何线性扫描主循环。 +具体来源包括: -因此不能把当前实现描述成“已经按 lecture11 完成线性扫描寄存器分配”。 +- 前端/Lowering 创建的局部对象 +- 寄存器分配阶段生成的 spill slot +- FrameLowering 阶段为已用到的 callee-saved 寄存器补的保存槽 -## 5.3 是否按 lab03 的栈布局和代码生成标准完成 +### 7.3 偏移计算 -结论:**不符合严格标准**。 +`RunFrameLowering` 的计算策略是: -原因: +1. 先把需要保存的 callee-saved GPR/FPR 也视为普通栈对象加入列表。 +2. 从 `cursor = 0` 开始向下增长。 +3. 每个对象按自身对齐要求做 `AlignTo`。 +4. 累加对象大小后,把该对象偏移记为 `-cursor`。 +5. 最终 frame size 再按 16 字节对齐。 -- 当前仓库内自研 `FrameLowering` 只实现了极小的顺序布局和简单序言/尾声。 -- 完整的 AAPCS64 参数传递、调用保存/被调用保存、栈上传参、非叶子函数等逻辑并未在仓库自研后端里完整实现。 -- 当前这些工作实际由 `llc` 负责完成。 +这是一个简单但清晰的线性布局方案。 -所以: +### 7.4 当前局限 -- 按“汇编结果”看,栈布局和 ABI 是正确的。 -- 按“仓库内是否手写完成了 lab03 要求的实现”看,不是。 +当前栈布局还不是“课件里的最优工程版”,主要还缺: + +- 更激进的对象重排 +- 栈槽复用 +- 针对超大局部数组的专门基址缓存优化 + +这也是 `vector_mul3` 仍然表现吃力的重要原因之一: + +- 程序里有多个 100000 长度的局部浮点数组。 +- 即使大偏移已经改成 `sub #imm, lsl #12` 分解,数组基址计算仍然非常频繁。 +- 若不进一步做基址提升/循环不变式外提,热点循环仍会被地址准备成本拖慢。 --- -## 6. 当前实现为什么仍然能全通过测试 +## 8. 调用约定与汇编打印 + +### 8.1 参数与返回值 + +当前实现遵循 AArch64 常见调用约定: + +- 整数/指针参数优先放 `x0` 到 `x7` +- 浮点参数优先放 `s0` 到 `s7` / `v0` 到 `v7` +- 超出寄存器数量的参数放到调用栈上传递 +- 整数返回值走 `w0/x0` +- 浮点返回值走 `s0` + +### 8.2 callee-saved 保存恢复 + +寄存器分配如果把虚拟寄存器分到 `x19-x28` 或 `v8-v15`,则: + +- `FrameLowering` 会给这些寄存器创建保存槽。 +- `AsmPrinter` 会在函数序言里 `str` 保存,在结语里 `ldr` 恢复。 + +这部分与当前寄存器分配策略是一一对应的。 -虽然不符合你现在指定的“自研实现路径”标准,但当前版本仍然能跑通测试,原因很直接: +### 8.3 全局变量与常量 -1. 前端和 IR 生成已经稳定。 -2. IR pass pipeline 已经可用。 -3. 输出的 IR 足够接近 LLVM IR。 -4. `llc` 负责了: - - AArch64 指令选择 - - 寄存器分配 - - 栈帧布局 - - 调用约定处理 -5. `verify_asm.sh` 会把生成的汇编和 `sylib/sylib.c` 一起链接。 -6. `lab3_build_test.sh` 会批量编译并在 `qemu-aarch64` 上运行验证。 +全局对象当前已支持: -因此,当前全通过是一个真实结果,但其来源是“LLVM 后端能力”,不是“仓库内自研完整后端能力”。 +- `.data` +- `.bss` +- `.rodata` +- `ConstantArrayValue` 的扁平化输出 +- 零初始化对象的 `.zero` + +因此,后端不只处理函数体,也能完整输出模块级全局数据。 --- -## 7. 当前真实可宣称的成果 +## 9. 与课件要求的对应关系 + +### 9.1 指令选择 + +与 lecture05 的对应关系: + +- 已实现从中间表示到目标机 AArch64 指令的逐类映射。 +- 已包含算术、比较、访存、调用、返回、类型转换、条件跳转等核心指令选择。 +- 已补若干与热点性能直接相关的 peephole:立即数形式、比较分支融合、幂次除法快速展开、大偏移地址分解。 + +### 9.2 寄存器分配 + +与 lecture11 的对应关系: -当前仓库可以真实宣称的 Lab3 成果是: +- 当前确实采用了 live interval + active list 的线性扫描框架。 +- 区间构造考虑了 CFG 上的活跃信息,而不是只看单块顺序。 +- 溢出通过 spill slot 落栈。 -1. `compiler --emit-asm` 已可生成 AArch64 汇编。 -2. 生成的汇编可与 `sylib` 链接,并在 `qemu-aarch64` 上运行。 -3. 当前 `test/` 目录全量测试通过。 -4. 已具备 Lab3 的单样例验证脚本和批量验证脚本。 +### 9.3 栈布局 -当前**不能**真实宣称的成果是: +与 Lab3 栈布局要求的对应关系: -1. 已按 lecture05 完整自研指令选择。 -2. 已按 lecture11 完整自研线性扫描寄存器分配。 -3. 已按 lab03 完整自研 AArch64 栈布局与调用约定实现。 +- 当前使用 `x29` 作为 frame pointer。 +- 本地对象、spill、callee-saved 保存槽统一进入同一个栈帧布局过程。 +- 栈帧大小保持 16 字节对齐。 + +因此,从“是否已经形成一条完整后端流程”的角度看,当前实现已经满足 Lab3 的主体框架要求。 --- -## 8. 本次测试结果 +## 10. 测试脚本与验证方式 + +### 10.1 推荐测试顺序 -本次实际跑过的 Lab3 批量验证结果是: +为了节省调试时间,当前建议: -- `214 PASS / 0 FAIL / total 214` +1. 先重跑失败缓存。 +2. 再跑全量。 -默认测试范围包括: +对应命令: -- `test/test_case` -- `test/class_test_case` +```bash +./scripts/lab3_build_test.sh --failed-only +./scripts/lab3_build_test.sh +``` -对应日志目录为: +### 10.2 当前脚本行为 -- `output/logs/lab3/lab3_20260410_104639` +`scripts/lab3_build_test.sh` 现在具备以下行为: -完整日志文件为: +- 默认同时扫描 `test/test_case` 和 `test/class_test_case` +- 每轮生成独立日志目录:`output/logs/lab3/lab3_日期_时间` +- 生成整轮 `whole.log` +- 终端输出带颜色: + - `PASS` 为绿色 + - `FAIL` 为红色 +- 成功样例的中间文件自动删除 +- 失败样例才保留中间文件和 `error.log` +- 支持从 `last_failed.txt` 中读取失败缓存重跑 -- `output/logs/lab3/lab3_20260410_104639/whole.log` +这部分主要是为了让 Lab3 的回归测试可持续,而不是每次手工清理大量中间文件。 --- -## 9. 如果要真正改成“按课件标准完成”,还缺什么 - -如果后续目标改成: - -- 必须让仓库里的自研后端本身满足 lecture05 / lecture11 / lab03 - -那么后续至少还需要补下面这些内容: - -1. 扩展 MIR 数据结构 - - 虚寄存器 - - 多基本块 - - CFG - - use/def 与编号 - - spill slot / stack object 表示 -2. 重新实现 instruction selection - - 覆盖 IR 的主要指令种类 - - 覆盖函数、调用、分支、数组、全局对象 -3. 手写线性扫描寄存器分配 - - linear order - - live interval - - active 表 - - spill / reload -4. 手写 frame lowering - - 参数区 - - caller/callee saved - - outgoing arg area - - 16-byte alignment - - 叶子/非叶子函数处理 -5. 扩展 asm printer - - 条件跳转 - - 比较 - - 调用 - - 整数/浮点算术 - - 内存寻址 - - 全局地址获取 - -也就是说,如果按课程标准继续推进,当前仓库还差的是“一套真正可独立工作的自研 AArch64 后端”,而不是只差一两个小补丁。 +## 11. 当前验证结果与剩余问题 + +### 11.1 当前已知结果 + +基于当前仓库状态,可以确认: + +- 后端链路可以稳定生成 AArch64 汇编。 +- 大部分功能样例、矩阵样例、多参数样例已经能通过。 +- 当前失败缓存只剩: + - `test/class_test_case/performance/vector_mul3.sy` + +### 11.2 `vector_mul3` 的问题性质 + +这个用例当前表现为: + +- 程序最终在 300 秒超时。 +- 不是直接崩溃。 +- 不是明显的输出错误。 +- 属于性能未达标,而不是语义错误未修复。 + +### 11.3 已经针对它做过的优化 + +围绕该用例,后端已经补过的关键优化包括: + +- 小内部函数的直接内联 +- `add/sub` 立即数选择 +- `sdiv by 2^k` 的快速展开 +- `cmp + cset + cbnz` 融合成直接条件分支 +- 大栈偏移地址分解,减少 `movz/movk` 频率 + +### 11.4 仍然可能需要继续做的工作 + +如果要把最后这个性能尾项也压下去,下一步优先级最高的方向是: + +1. 对超大局部数组做更激进的基址缓存。 +2. 对循环内重复出现的地址准备做提升或复用。 +3. 继续减少不必要的 `mov/fmov` 与访存往返。 +4. 视情况补更细的算术选择与简单强度削弱。 --- -## 10. 最终建议 +## 12. 总结 + +截至当前分支状态,Lab3 已经完成了“自研后端”的主体工程: -对当前仓库的 Lab3 状态,建议对外统一表述为: +- 已有自定义 MIR。 +- 已有 IR 到 MIR 的 Lowering。 +- 已有基于线性扫描的寄存器分配。 +- 已有统一的栈帧布局。 +- 已有 AArch64 汇编打印。 +- 已有配套批量测试脚本。 -- 当前版本已经具备完整的 AArch64 汇编生成与测试验证链路,且 `test/` 全量通过。 -- 但当前实现采用的是 `IR -> llc` 的代码生成路径,不应表述为“已按 lecture05 / lecture11 / lab03 在仓库内部完整自研实现后端”。 +但这份结论必须附带当前真实状态: -这是当前最准确、也最不容易误导队友的说法。 +- 现在还不能完全声称“所有测试样例全部跑通”。 +- 仍剩 `vector_mul3` 这一项性能超时问题。 +- 因此,Lab3 可以认为已经完成主体实现并进入最后性能收尾阶段。 diff --git a/include/mir/MIR.h b/include/mir/MIR.h index 47b8959..9d6b2d5 100644 --- a/include/mir/MIR.h +++ b/include/mir/MIR.h @@ -1,9 +1,10 @@ #pragma once -#include +#include #include #include #include +#include #include namespace ir { @@ -19,57 +20,168 @@ class MIRContext { MIRContext& DefaultContext(); -enum class PhysReg { W0, W8, W9, X29, X30, SP }; +enum class ValueType { Void, I1, I32, F32, Ptr }; -const char* PhysRegName(PhysReg reg); +enum class RegClass { GPR, FPR }; -enum class Opcode { - Prologue, - Epilogue, - MovImm, - LoadStack, - StoreStack, - AddRR, - Ret, +enum class CondCode { EQ, NE, LT, GT, LE, GE }; + +enum class StackObjectKind { Local, Spill, SavedGPR, SavedFPR }; + +enum class AddrBaseKind { None, FrameObject, Global, VReg }; + +enum class OperandKind { Invalid, VReg, Imm, Block, Symbol }; + +struct PhysReg { + RegClass reg_class = RegClass::GPR; + int index = -1; + + bool IsValid() const { return index >= 0; } + bool operator==(const PhysReg& rhs) const { + return reg_class == rhs.reg_class && index == rhs.index; + } }; -class Operand { - public: - enum class Kind { Reg, Imm, FrameIndex }; +bool IsGPR(ValueType type); +bool IsFPR(ValueType type); +int GetValueSize(ValueType type); +int GetValueAlign(ValueType type); +const char* GetPhysRegName(PhysReg reg, ValueType type); - static Operand Reg(PhysReg reg); - static Operand Imm(int value); - static Operand FrameIndex(int index); +class MachineOperand { + public: + MachineOperand() = default; + static MachineOperand VReg(int reg); + static MachineOperand Imm(std::int64_t value); + static MachineOperand Block(std::string name); + static MachineOperand Symbol(std::string name); - Kind GetKind() const { return kind_; } - PhysReg GetReg() const { return reg_; } - int GetImm() const { return imm_; } - int GetFrameIndex() const { return imm_; } + OperandKind GetKind() const { return kind_; } + int GetVReg() const { return vreg_; } + std::int64_t GetImm() const { return imm_; } + const std::string& GetText() const { return text_; } private: - Operand(Kind kind, PhysReg reg, int imm); + MachineOperand(OperandKind kind, int vreg, std::int64_t imm, std::string text); + + OperandKind kind_ = OperandKind::Invalid; + int vreg_ = -1; + std::int64_t imm_ = 0; + std::string text_; +}; - Kind kind_; - PhysReg reg_; - int imm_; +struct AddressExpr { + AddrBaseKind base_kind = AddrBaseKind::None; + int base_index = -1; + std::string symbol; + std::int64_t const_offset = 0; + std::vector> scaled_vregs; +}; + +struct StackObject { + int index = -1; + StackObjectKind kind = StackObjectKind::Local; + int size = 0; + int align = 1; + int offset = 0; + std::string name; +}; + +struct VRegInfo { + int id = -1; + ValueType type = ValueType::Void; +}; + +struct Allocation { + enum class Kind { Unassigned, PhysReg, Spill }; + + Kind kind = Kind::Unassigned; + PhysReg phys; + int stack_object = -1; }; class MachineInstr { public: - MachineInstr(Opcode opcode, std::vector operands = {}); + enum class Opcode { + Arg, + Copy, + Load, + Store, + Lea, + Add, + Sub, + Mul, + Div, + Rem, + And, + Or, + Xor, + Shl, + AShr, + LShr, + FAdd, + FSub, + FMul, + FDiv, + FNeg, + ICmp, + FCmp, + ZExt, + ItoF, + FtoI, + Br, + CondBr, + Call, + Ret, + Memset, + Unreachable, + }; + + explicit MachineInstr(Opcode opcode, + std::vector operands = {}); Opcode GetOpcode() const { return opcode_; } - const std::vector& GetOperands() const { return operands_; } + const std::vector& GetOperands() const { return operands_; } + std::vector& GetOperands() { return operands_; } + + void SetCondCode(CondCode code) { cond_code_ = code; } + CondCode GetCondCode() const { return cond_code_; } + + void SetAddress(AddressExpr address) { + address_ = std::move(address); + has_address_ = true; + } + bool HasAddress() const { return has_address_; } + const AddressExpr& GetAddress() const { return address_; } + AddressExpr& GetAddress() { return address_; } + + void SetCallInfo(std::string callee, std::vector arg_types, + ValueType return_type) { + callee_ = std::move(callee); + call_arg_types_ = std::move(arg_types); + call_return_type_ = return_type; + } + const std::string& GetCallee() const { return callee_; } + const std::vector& GetCallArgTypes() const { return call_arg_types_; } + ValueType GetCallReturnType() const { return call_return_type_; } + + void SetValueType(ValueType type) { value_type_ = type; } + ValueType GetValueType() const { return value_type_; } + + bool IsTerminator() const; + std::vector GetDefs() const; + std::vector GetUses() const; private: Opcode opcode_; - std::vector operands_; -}; - -struct FrameSlot { - int index = 0; - int size = 4; - int offset = 0; + std::vector operands_; + CondCode cond_code_ = CondCode::EQ; + AddressExpr address_; + bool has_address_ = false; + std::string callee_; + std::vector call_arg_types_; + ValueType call_return_type_ = ValueType::Void; + ValueType value_type_ = ValueType::Void; }; class MachineBasicBlock { @@ -80,8 +192,9 @@ class MachineBasicBlock { std::vector& GetInstructions() { return instructions_; } const std::vector& GetInstructions() const { return instructions_; } - MachineInstr& Append(Opcode opcode, - std::initializer_list operands = {}); + MachineInstr& Append(MachineInstr::Opcode opcode, + std::vector operands = {}); + MachineInstr& Append(MachineInstr instr); private: std::string name_; @@ -90,30 +203,88 @@ class MachineBasicBlock { class MachineFunction { public: - explicit MachineFunction(std::string name); + MachineFunction(std::string name, ValueType return_type, + std::vector param_types); const std::string& GetName() const { return name_; } - MachineBasicBlock& GetEntry() { return entry_; } - const MachineBasicBlock& GetEntry() const { return entry_; } + ValueType GetReturnType() const { return return_type_; } + const std::vector& GetParamTypes() const { return param_types_; } - int CreateFrameIndex(int size = 4); - FrameSlot& GetFrameSlot(int index); - const FrameSlot& GetFrameSlot(int index) const; - const std::vector& GetFrameSlots() const { return frame_slots_; } + MachineBasicBlock* CreateBlock(const std::string& name); + std::vector>& GetBlocks() { return blocks_; } + const std::vector>& GetBlocks() const { + return blocks_; + } + + int NewVReg(ValueType type); + const VRegInfo& GetVRegInfo(int id) const; + VRegInfo& GetVRegInfo(int id); + const std::vector& GetVRegs() const { return vregs_; } + + int CreateStackObject(int size, int align, StackObjectKind kind, + const std::string& name = ""); + StackObject& GetStackObject(int index); + const StackObject& GetStackObject(int index) const; + const std::vector& GetStackObjects() const { return stack_objects_; } + + void SetAllocation(int vreg, Allocation allocation); + const Allocation& GetAllocation(int vreg) const; + Allocation& GetAllocation(int vreg); + + void AddUsedCalleeSavedGPR(int reg_index); + void AddUsedCalleeSavedFPR(int reg_index); + const std::vector& GetUsedCalleeSavedGPRs() const { + return used_callee_saved_gprs_; + } + const std::vector& GetUsedCalleeSavedFPRs() const { + return used_callee_saved_fprs_; + } - int GetFrameSize() const { return frame_size_; } void SetFrameSize(int size) { frame_size_ = size; } + int GetFrameSize() const { return frame_size_; } + + void SetMaxOutgoingArgBytes(int bytes) { max_outgoing_arg_bytes_ = bytes; } + int GetMaxOutgoingArgBytes() const { return max_outgoing_arg_bytes_; } private: std::string name_; - MachineBasicBlock entry_; - std::vector frame_slots_; + ValueType return_type_ = ValueType::Void; + std::vector param_types_; + std::vector> blocks_; + std::vector vregs_; + std::vector stack_objects_; + std::vector allocations_; + std::vector used_callee_saved_gprs_; + std::vector used_callee_saved_fprs_; int frame_size_ = 0; + int max_outgoing_arg_bytes_ = 0; +}; + +class MachineModule { + public: + explicit MachineModule(const ir::Module& source) : source_(&source) {} + + const ir::Module& GetSourceModule() const { return *source_; } + std::vector>& GetFunctions() { return functions_; } + const std::vector>& GetFunctions() const { + return functions_; + } + + MachineFunction* AddFunction(std::unique_ptr function); + + private: + const ir::Module* source_ = nullptr; + std::vector> functions_; }; -std::unique_ptr LowerToMIR(const ir::Module& module); -void RunRegAlloc(MachineFunction& function); -void RunFrameLowering(MachineFunction& function); -void PrintAsm(const MachineFunction& function, std::ostream& os); +std::unique_ptr LowerToMIR(const ir::Module& module); +void RunRegAlloc(MachineModule& module); +void RunFrameLowering(MachineModule& module); +void PrintAsm(const MachineModule& module, std::ostream& os); } // namespace mir + + + + + diff --git a/src/main.cpp b/src/main.cpp index 786168b..a63f533 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,17 +1,6 @@ -#include #include -#include -#include #include #include -#include -#include -#include -#include - -#if !defined(_WIN32) -#include -#endif #include "frontend/AntlrDriver.h" #include "frontend/SyntaxTreePrinter.h" @@ -19,115 +8,12 @@ #include "ir/IR.h" #include "ir/PassManager.h" #include "irgen/IRGen.h" +#include "mir/MIR.h" #include "sem/Sema.h" #endif #include "utils/CLI.h" #include "utils/Log.h" -#if !COMPILER_PARSE_ONLY -namespace { -namespace fs = std::filesystem; - -std::string ShellEscape(std::string_view text) { - std::string escaped; - escaped.reserve(text.size() + 2); - escaped.push_back('\''); - for (char ch : text) { - if (ch == '\'') { - escaped += "'\\''"; - } else { - escaped.push_back(ch); - } - } - escaped.push_back('\''); - return escaped; -} - -fs::path CreateTempFile(const char* pattern) { - fs::path temp_dir = fs::temp_directory_path(); - std::string templ = (temp_dir / pattern).string(); - std::vector buffer(templ.begin(), templ.end()); - buffer.push_back('\0'); - -#if defined(_WIN32) - if (_mktemp_s(buffer.data(), buffer.size()) != 0) { - throw std::runtime_error(FormatError("lab3", "failed to allocate a temporary file name")); - } - std::ofstream touch(buffer.data(), std::ios::binary); - if (!touch) { - throw std::runtime_error(FormatError("lab3", "failed to create a temporary file")); - } -#else - int fd = mkstemp(buffer.data()); - if (fd < 0) { - throw std::runtime_error(FormatError("lab3", "failed to create a temporary file")); - } - close(fd); -#endif - - return fs::path(buffer.data()); -} - -class ScopedTempFile { - public: - explicit ScopedTempFile(const char* pattern) : path_(CreateTempFile(pattern)) {} - ~ScopedTempFile() { - std::error_code ec; - fs::remove(path_, ec); - } - - const fs::path& path() const { return path_; } - - private: - fs::path path_; -}; - -void WriteIRToFile(const ir::Module& module, const fs::path& path) { - std::ofstream output(path, std::ios::binary | std::ios::trunc); - if (!output) { - throw std::runtime_error(FormatError("lab3", "failed to open temporary IR file")); - } - ir::IRPrinter printer; - printer.Print(module, output); - if (!output) { - throw std::runtime_error(FormatError("lab3", "failed to write temporary IR file")); - } -} - -void StreamFileToStdout(const fs::path& path, std::ostream& os) { - std::ifstream input(path, std::ios::binary); - if (!input) { - throw std::runtime_error(FormatError("lab3", "failed to open generated assembly file")); - } - os << input.rdbuf(); - if (!os) { - throw std::runtime_error(FormatError("lab3", "failed to write assembly output")); - } -} - -void EmitAsmWithLLC(const ir::Module& module, std::ostream& os) { - const char* llc_env = std::getenv("LLC"); - std::string llc = (llc_env != nullptr && llc_env[0] != '\0') ? llc_env : "llc"; - - ScopedTempFile ir_file("nudt_lab3_ir_XXXXXX"); - ScopedTempFile asm_file("nudt_lab3_asm_XXXXXX"); - WriteIRToFile(module, ir_file.path()); - - std::string command = llc + - " -opaque-pointers -mtriple=aarch64-linux-gnu -filetype=asm " + - ShellEscape(ir_file.path().string()) + " -o " + - ShellEscape(asm_file.path().string()); - int status = std::system(command.c_str()); - if (status != 0) { - throw std::runtime_error( - FormatError("lab3", "llc failed while generating AArch64 assembly")); - } - - StreamFileToStdout(asm_file.path(), os); -} -} // namespace -#endif - int main(int argc, char** argv) { try { auto opts = ParseCLI(argc, argv); @@ -150,26 +36,38 @@ int main(int argc, char** argv) { } auto sema = RunSema(*comp_unit); - auto module = GenerateIR(*comp_unit, sema); - if (opts.emit_ir || opts.emit_asm) { - ir::RunIRPassPipeline(*module); + std::unique_ptr asm_module; + if (opts.emit_asm) { + asm_module = GenerateIR(*comp_unit, sema); + ir::RunIRPassPipeline(*asm_module); } if (opts.emit_ir) { + std::unique_ptr ir_module; + if (opts.emit_asm) { + ir_module = GenerateIR(*comp_unit, sema); + } else { + ir_module = GenerateIR(*comp_unit, sema); + } + ir::RunIRPassPipeline(*ir_module); + if (need_blank_line) { std::cout << "\n"; } ir::IRPrinter printer; - printer.Print(*module, std::cout); + printer.Print(*ir_module, std::cout); need_blank_line = true; } if (opts.emit_asm) { + auto machine_module = mir::LowerToMIR(*asm_module); + mir::RunRegAlloc(*machine_module); + mir::RunFrameLowering(*machine_module); if (need_blank_line) { std::cout << "\n"; } - EmitAsmWithLLC(*module, std::cout); + mir::PrintAsm(*machine_module, std::cout); } #else if (opts.emit_ir || opts.emit_asm) { @@ -183,3 +81,4 @@ int main(int argc, char** argv) { } return 0; } + diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index 4d1f65f..047a830 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -1,78 +1,1154 @@ #include "mir/MIR.h" +#include +#include +#include #include #include +#include +#include +#include +#include "ir/IR.h" #include "utils/Log.h" namespace mir { namespace { -const FrameSlot& GetFrameSlot(const MachineFunction& function, - const Operand& operand) { - if (operand.GetKind() != Operand::Kind::FrameIndex) { - throw std::runtime_error(FormatError("mir", "期望 FrameIndex 操作数")); +int AlignTo(int value, int align) { + if (align <= 1) { + return value; } - return function.GetFrameSlot(operand.GetFrameIndex()); + return ((value + align - 1) / align) * align; } -void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg, - int offset) { - os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset - << "]\n"; +bool IsPowerOfTwo(std::int64_t value) { + return value > 0 && (value & (value - 1)) == 0; } -} // namespace +int Log2(std::int64_t value) { + int shift = 0; + while (value > 1) { + value >>= 1; + ++shift; + } + return shift; +} + +const char* GetDRegName(int index) { + static const char* kNames[] = { + "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", + "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", + "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", + "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31"}; + if (index < 0 || index >= 32) { + throw std::runtime_error("float register index out of range"); + } + return kNames[index]; +} + +std::string BlockLabel(const MachineFunction& function, + const MachineBasicBlock& block) { + return ".L." + function.GetName() + "." + block.GetName(); +} + +std::string BlockLabel(const MachineFunction& function, const std::string& block_name) { + return ".L." + function.GetName() + "." + block_name; +} + +int ToAsmAlign(int align) { + int value = 0; + int current = 1; + while (current < align) { + current <<= 1; + ++value; + } + return value; +} + +std::uint32_t FloatBits(float value) { + std::uint32_t bits = 0; + std::memcpy(&bits, &value, sizeof(bits)); + return bits; +} + +ValueType LowerAsmType(const std::shared_ptr& type) { + if (!type || type->IsVoid()) { + return ValueType::Void; + } + if (type->IsInt1()) { + return ValueType::I1; + } + if (type->IsInt32()) { + return ValueType::I32; + } + if (type->IsFloat()) { + return ValueType::F32; + } + if (type->IsPointer()) { + return ValueType::Ptr; + } + throw std::runtime_error(FormatError("mir", "unsupported IR type in asm printer")); +} + +int GetIRTypeAlign(const std::shared_ptr& type) { + if (!type) { + return 1; + } + if (type->IsArray()) { + return GetIRTypeAlign(type->GetElementType()); + } + return GetValueAlign(LowerAsmType(type)); +} + +const ir::Type& GetScalarElementType(const ir::Type& type) { + const ir::Type* current = &type; + while (current->IsArray()) { + current = current->GetElementType().get(); + } + return *current; +} + +bool IsZeroScalarConstant(const ir::Value* value) { + if (value == nullptr) { + return true; + } + if (auto* ci = ir::dyncast(value)) { + return ci->GetValue() == 0; + } + if (auto* cb = ir::dyncast(value)) { + return !cb->GetValue(); + } + if (auto* cf = ir::dyncast(value)) { + return FloatBits(cf->GetValue()) == 0; + } + return false; +} + +std::size_t CountScalarElements(const ir::Type& type) { + if (!type.IsArray()) { + return 1; + } + return type.GetNumElements() * CountScalarElements(*type.GetElementType()); +} + +void FlattenGlobalScalars(const ir::Type& type, ir::Value* init, + std::vector& out) { + if (!type.IsArray()) { + out.push_back(init); + return; + } + + auto* array_value = ir::dyncast(init); + if (array_value == nullptr) { + out.insert(out.end(), CountScalarElements(type), nullptr); + return; + } + + const auto& elements = array_value->GetElements(); + for (std::size_t i = 0; i < CountScalarElements(type); ++i) { + out.push_back(i < elements.size() ? elements[i] : nullptr); + } +} + +void EmitGlobalScalar(std::ostream& os, const ir::Type& type, ir::Value* value) { + if (type.IsFloat()) { + float number = 0.0f; + if (auto* cf = ir::dyncast(value)) { + number = cf->GetValue(); + } else if (auto* ci = ir::dyncast(value)) { + number = static_cast(ci->GetValue()); + } + os << " .word " << FloatBits(number) << "\n"; + return; + } + + int number = 0; + if (auto* ci = ir::dyncast(value)) { + number = ci->GetValue(); + } else if (auto* cb = ir::dyncast(value)) { + number = cb->GetValue() ? 1 : 0; + } + os << " .word " << number << "\n"; +} + +void EmitGlobal(const ir::GlobalValue& global, std::ostream& os) { + const auto object_type = global.GetObjectType(); + const bool zero_init = !global.HasInitializer() || IsZeroScalarConstant(global.GetInitializer()); + if (object_type->IsArray()) { + std::vector flat; + FlattenGlobalScalars(*object_type, global.GetInitializer(), flat); + const bool all_zero = std::all_of(flat.begin(), flat.end(), [](ir::Value* value) { + return IsZeroScalarConstant(value); + }); + if (all_zero) { + os << ".bss\n"; + } else if (global.IsConstant()) { + os << ".section .rodata\n"; + } else { + os << ".data\n"; + } + os << " .align " << ToAsmAlign(GetIRTypeAlign(object_type)) << "\n"; + os << " .global " << global.GetName() << "\n"; + os << global.GetName() << ":\n"; + if (all_zero) { + os << " .zero " << object_type->GetSize() << "\n"; + return; + } + std::size_t index = 0; + while (index < flat.size()) { + if (IsZeroScalarConstant(flat[index])) { + std::size_t begin = index; + while (index < flat.size() && IsZeroScalarConstant(flat[index])) { + ++index; + } + os << " .zero " << static_cast((index - begin) * 4) << "\n"; + } else { + EmitGlobalScalar(os, GetScalarElementType(*object_type), flat[index]); + ++index; + } + } + return; + } + + if (zero_init) { + os << ".bss\n"; + } else if (global.IsConstant()) { + os << ".section .rodata\n"; + } else { + os << ".data\n"; + } + os << " .align " << ToAsmAlign(GetIRTypeAlign(object_type)) << "\n"; + os << " .global " << global.GetName() << "\n"; + os << global.GetName() << ":\n"; + if (zero_init) { + os << " .zero " << object_type->GetSize() << "\n"; + } else { + EmitGlobalScalar(os, *object_type, global.GetInitializer()); + } +} + +int FindStackObject(const MachineFunction& function, const std::string& name) { + for (const auto& object : function.GetStackObjects()) { + if (object.name == name) { + return object.index; + } + } + return -1; +} + +bool Is32BitRegName(const char* reg) { + return reg != nullptr && reg[0] == 'w'; +} + +bool IsAddSubImm12(std::int64_t value) { + return value >= 0 && value <= 4095; +} + +bool IsAddSubImm12Shifted(std::int64_t value) { + return value >= 0 && value <= (4095ll << 12) && (value & 0xfffll) == 0; +} + +bool IsAddSubImm(std::int64_t value) { + return IsAddSubImm12(value) || IsAddSubImm12Shifted(value); +} + +void EmitAddSubImm(std::ostream& os, const char* opcode, const char* dst, + const char* src, std::int64_t value) { + if (!IsAddSubImm(value)) { + throw std::runtime_error(FormatError("mir", "invalid add/sub immediate")); + } + os << " " << opcode << " " << dst << ", " << src << ", #"; + if (IsAddSubImm12(value)) { + os << value << "\n"; + return; + } + os << (value >> 12) << ", lsl #12\n"; +} + +void EmitAdjustRegByImm(std::ostream& os, const char* dst, const char* src, + std::int64_t value) { + if (value == 0) { + if (std::string(dst) != src) { + os << " mov " << dst << ", " << src << "\n"; + } + return; + } + + const char* opcode = value >= 0 ? "add" : "sub"; + std::uint64_t remaining = value >= 0 ? static_cast(value) + : static_cast(-value); + bool first = true; + auto emit_chunk = [&](std::uint64_t amount, bool shifted) { + const char* current_src = first ? src : dst; + os << " " << opcode << " " << dst << ", " << current_src << ", #" << amount; + if (shifted) { + os << ", lsl #12"; + } + os << "\n"; + first = false; + }; + + while (remaining >= 4096) { + const std::uint64_t units = std::min(remaining >> 12, 4095); + emit_chunk(units, true); + remaining -= units << 12; + } + if (remaining > 0) { + emit_chunk(remaining, false); + } +} + +void EmitMoveImm(std::ostream& os, const char* reg, std::int64_t value) { + if (reg == nullptr || reg[0] == '\0') { + throw std::runtime_error(FormatError("mir", "invalid register for immediate materialization")); + } + + const bool is32 = Is32BitRegName(reg); + if (value == 0) { + os << " mov " << reg << ", #0\n"; + return; + } + + if (is32) { + const std::uint32_t bits = static_cast(value); + bool emitted = false; + for (int shift = 0; shift <= 16; shift += 16) { + const std::uint32_t chunk = (bits >> shift) & 0xffffu; + if (chunk == 0 && emitted) { + continue; + } + if (!emitted) { + os << " movz " << reg << ", #" << chunk; + if (shift != 0) { + os << ", lsl #" << shift; + } + os << "\n"; + emitted = true; + } else if (chunk != 0) { + os << " movk " << reg << ", #" << chunk; + if (shift != 0) { + os << ", lsl #" << shift; + } + os << "\n"; + } + } + return; + } + + const std::uint64_t bits = static_cast(value); + bool emitted = false; + for (int shift = 0; shift <= 48; shift += 16) { + const std::uint64_t chunk = (bits >> shift) & 0xffffull; + if (chunk == 0 && emitted) { + continue; + } + if (!emitted) { + os << " movz " << reg << ", #" << chunk; + if (shift != 0) { + os << ", lsl #" << shift; + } + os << "\n"; + emitted = true; + } else if (chunk != 0) { + os << " movk " << reg << ", #" << chunk; + if (shift != 0) { + os << ", lsl #" << shift; + } + os << "\n"; + } + } +} + +void EmitCopy(std::ostream& os, const char* dst, const char* src, bool is_float) { + if (std::string(dst) == src) { + return; + } + os << " " << (is_float ? "fmov" : "mov") << " " << dst << ", " << src << "\n"; +} + +void EmitFrameAddress(const MachineFunction& function, int object_index, + const char* addr_reg, std::ostream& os) { + const auto& object = function.GetStackObject(object_index); + EmitAdjustRegByImm(os, addr_reg, "x29", object.offset); +} + +void EmitIncomingStackAddress(int stack_offset, const char* addr_reg, std::ostream& os) { + EmitAdjustRegByImm(os, addr_reg, "x29", 16 + stack_offset); +} -void PrintAsm(const MachineFunction& function, std::ostream& os) { +void EmitLoadFromAddr(ValueType type, const char* dst, const char* addr_reg, + std::ostream& os) { + switch (type) { + case ValueType::I1: + case ValueType::I32: + os << " ldr " << dst << ", [" << addr_reg << "]\n"; + break; + case ValueType::F32: + os << " ldr " << dst << ", [" << addr_reg << "]\n"; + break; + case ValueType::Ptr: + os << " ldr " << dst << ", [" << addr_reg << "]\n"; + break; + case ValueType::Void: + break; + } +} + +void EmitStoreToAddr(ValueType type, const char* src, const char* addr_reg, + std::ostream& os) { + switch (type) { + case ValueType::I1: + case ValueType::I32: + os << " str " << src << ", [" << addr_reg << "]\n"; + break; + case ValueType::F32: + os << " str " << src << ", [" << addr_reg << "]\n"; + break; + case ValueType::Ptr: + os << " str " << src << ", [" << addr_reg << "]\n"; + break; + case ValueType::Void: + break; + } +} + +void EmitLoadSpill(const MachineFunction& function, int object_index, ValueType type, + const char* dst, std::ostream& os) { + EmitFrameAddress(function, object_index, "x17", os); + EmitLoadFromAddr(type, dst, "x17", os); +} + +void EmitStoreSpill(const MachineFunction& function, int object_index, ValueType type, + const char* src, std::ostream& os) { + EmitFrameAddress(function, object_index, "x17", os); + EmitStoreToAddr(type, src, "x17", os); +} + +struct DefReg { + std::string reg_name; + bool spilled = false; + int spill_object = -1; +}; + +DefReg PrepareGprDef(const MachineFunction& function, int vreg, int scratch_index) { + const auto& alloc = function.GetAllocation(vreg); + const auto type = function.GetVRegInfo(vreg).type; + if (alloc.kind == Allocation::Kind::PhysReg) { + return {GetPhysRegName(alloc.phys, type), false, -1}; + } + return {GetPhysRegName({RegClass::GPR, scratch_index}, type), true, alloc.stack_object}; +} + +DefReg PrepareFprDef(const MachineFunction& function, int vreg, int scratch_index) { + const auto& alloc = function.GetAllocation(vreg); + if (alloc.kind == Allocation::Kind::PhysReg) { + return {GetPhysRegName(alloc.phys, ValueType::F32), false, -1}; + } + return {GetPhysRegName({RegClass::FPR, scratch_index}, ValueType::F32), true, + alloc.stack_object}; +} + +void FinalizeDef(const MachineFunction& function, int vreg, const DefReg& def, + std::ostream& os) { + if (!def.spilled) { + return; + } + EmitStoreSpill(function, def.spill_object, function.GetVRegInfo(vreg).type, + def.reg_name.c_str(), os); +} + +std::string MaterializeGprUse(const MachineFunction& function, + const MachineOperand& operand, ValueType type, + int scratch_index, std::ostream& os) { + const char* scratch = GetPhysRegName({RegClass::GPR, scratch_index}, type); + if (operand.GetKind() == OperandKind::Imm) { + EmitMoveImm(os, scratch, operand.GetImm()); + return scratch; + } + if (operand.GetKind() != OperandKind::VReg) { + throw std::runtime_error(FormatError("mir", "expected gpr operand")); + } + const int vreg = operand.GetVReg(); + const auto& alloc = function.GetAllocation(vreg); + const auto vtype = function.GetVRegInfo(vreg).type; + if (alloc.kind == Allocation::Kind::PhysReg) { + return GetPhysRegName(alloc.phys, vtype); + } + EmitLoadSpill(function, alloc.stack_object, vtype, scratch, os); + return scratch; +} + +std::string MaterializeFprUse(const MachineFunction& function, + const MachineOperand& operand, int scratch_fpr, + int scratch_gpr, std::ostream& os) { + const char* scratch = GetPhysRegName({RegClass::FPR, scratch_fpr}, ValueType::F32); + if (operand.GetKind() == OperandKind::Imm) { + EmitMoveImm(os, GetPhysRegName({RegClass::GPR, scratch_gpr}, ValueType::I32), + operand.GetImm()); + os << " fmov " << scratch << ", " + << GetPhysRegName({RegClass::GPR, scratch_gpr}, ValueType::I32) << "\n"; + return scratch; + } + if (operand.GetKind() != OperandKind::VReg) { + throw std::runtime_error(FormatError("mir", "expected fpr operand")); + } + const int vreg = operand.GetVReg(); + const auto& alloc = function.GetAllocation(vreg); + if (alloc.kind == Allocation::Kind::PhysReg) { + return GetPhysRegName(alloc.phys, ValueType::F32); + } + EmitLoadSpill(function, alloc.stack_object, ValueType::F32, scratch, os); + return scratch; +} + +void EmitAddressExpr(const MachineFunction& function, const AddressExpr& address, + std::ostream& os) { + switch (address.base_kind) { + case AddrBaseKind::FrameObject: + EmitFrameAddress(function, address.base_index, "x16", os); + break; + case AddrBaseKind::Global: + os << " adrp x16, " << address.symbol << "\n"; + os << " add x16, x16, :lo12:" << address.symbol << "\n"; + break; + case AddrBaseKind::VReg: { + const auto& alloc = function.GetAllocation(address.base_index); + if (alloc.kind == Allocation::Kind::PhysReg) { + EmitCopy(os, "x16", GetPhysRegName(alloc.phys, ValueType::Ptr), false); + } else { + EmitLoadSpill(function, alloc.stack_object, ValueType::Ptr, "x16", os); + } + break; + } + case AddrBaseKind::None: + throw std::runtime_error(FormatError("mir", "address expression has no base")); + } + + if (address.const_offset != 0) { + EmitAdjustRegByImm(os, "x16", "x16", address.const_offset); + } + + for (const auto& term : address.scaled_vregs) { + const auto index_reg = MaterializeGprUse(function, MachineOperand::VReg(term.first), + ValueType::I32, 10, os); + const std::int64_t stride = term.second; + if (stride == 0) { + continue; + } + if (IsPowerOfTwo(stride) && Log2(stride) <= 4) { + os << " add x16, x16, " << index_reg << ", sxtw #" << Log2(stride) << "\n"; + continue; + } + os << " sxtw x17, " << index_reg << "\n"; + EmitMoveImm(os, "x11", stride); + os << " mul x17, x17, x11\n"; + os << " add x16, x16, x17\n"; + } +} + +const char* GetCondMnemonic(CondCode code) { + static const char* kCond[] = {"eq", "ne", "lt", "gt", "le", "ge"}; + return kCond[static_cast(code)]; +} + +bool TryEmitFusedCompareBranch(const MachineFunction& function, const MachineInstr& cmp, + const MachineInstr& branch, + const std::unordered_map& use_counts, + std::ostream& os) { + if ((cmp.GetOpcode() != MachineInstr::Opcode::ICmp && + cmp.GetOpcode() != MachineInstr::Opcode::FCmp) || + branch.GetOpcode() != MachineInstr::Opcode::CondBr) { + return false; + } + const auto& cond = branch.GetOperands()[0]; + if (cond.GetKind() != OperandKind::VReg) { + return false; + } + const int cond_vreg = cond.GetVReg(); + if (cmp.GetOperands().empty() || cmp.GetOperands()[0].GetKind() != OperandKind::VReg || + cmp.GetOperands()[0].GetVReg() != cond_vreg) { + return false; + } + auto it = use_counts.find(cond_vreg); + if (it == use_counts.end() || it->second != 1) { + return false; + } + + if (cmp.GetOpcode() == MachineInstr::Opcode::ICmp) { + const auto lhs = MaterializeGprUse(function, cmp.GetOperands()[1], ValueType::I32, 10, os); + const auto& rhs_op = cmp.GetOperands()[2]; + if (rhs_op.GetKind() == OperandKind::Imm && rhs_op.GetImm() >= 0 && + IsAddSubImm(rhs_op.GetImm())) { + os << " cmp " << lhs << ", #" << rhs_op.GetImm() << "\n"; + } else { + const auto rhs = MaterializeGprUse(function, rhs_op, ValueType::I32, 11, os); + os << " cmp " << lhs << ", " << rhs << "\n"; + } + } else { + const auto lhs = MaterializeFprUse(function, cmp.GetOperands()[1], 16, 10, os); + const auto rhs = MaterializeFprUse(function, cmp.GetOperands()[2], 17, 11, os); + os << " fcmp " << lhs << ", " << rhs << "\n"; + } + + os << " b." << GetCondMnemonic(cmp.GetCondCode()) << " " + << BlockLabel(function, branch.GetOperands()[1].GetText()) << "\n"; + os << " b " << BlockLabel(function, branch.GetOperands()[2].GetText()) << "\n"; + return true; +} + +struct ArgLocation { + bool in_reg = false; + RegClass reg_class = RegClass::GPR; + int reg_index = -1; + int stack_offset = 0; +}; + +ArgLocation ComputeArgLocation(const std::vector& param_types, int target) { + int gpr = 0; + int fpr = 0; + int stack_offset = 0; + for (int i = 0; i <= target; ++i) { + const auto type = param_types[static_cast(i)]; + if (IsFPR(type)) { + if (fpr < 8) { + if (i == target) { + return {true, RegClass::FPR, fpr, 0}; + } + ++fpr; + } else { + if (i == target) { + return {false, RegClass::FPR, -1, stack_offset}; + } + stack_offset += 8; + } + continue; + } + + if (gpr < 8) { + if (i == target) { + return {true, RegClass::GPR, gpr, 0}; + } + ++gpr; + } else { + if (i == target) { + return {false, RegClass::GPR, -1, stack_offset}; + } + stack_offset += 8; + } + } + + throw std::runtime_error(FormatError("mir", "argument location computation failed")); +} + +void EmitStackAdjust(std::ostream& os, const char* opcode, int bytes) { + if (bytes == 0) { + return; + } + const std::int64_t signed_bytes = opcode[0] == 's' ? -static_cast(bytes) + : static_cast(bytes); + EmitAdjustRegByImm(os, "sp", "sp", signed_bytes); +} + +void EmitFunction(const MachineFunction& function, std::ostream& os) { os << ".text\n"; + os << " .align 2\n"; os << ".global " << function.GetName() << "\n"; os << ".type " << function.GetName() << ", %function\n"; os << function.GetName() << ":\n"; + os << " stp x29, x30, [sp, #-16]!\n"; + os << " mov x29, sp\n"; + EmitStackAdjust(os, "sub", function.GetFrameSize()); + + for (int reg : function.GetUsedCalleeSavedGPRs()) { + const int slot = FindStackObject(function, "save.x" + std::to_string(reg)); + if (slot >= 0) { + EmitFrameAddress(function, slot, "x16", os); + os << " str x" << reg << ", [x16]\n"; + } + } + for (int reg : function.GetUsedCalleeSavedFPRs()) { + const int slot = FindStackObject(function, "save.v" + std::to_string(reg)); + if (slot >= 0) { + EmitFrameAddress(function, slot, "x16", os); + os << " str " << GetDRegName(reg) << ", [x16]\n"; + } + } + + auto emit_epilogue = [&]() { + for (int reg : function.GetUsedCalleeSavedFPRs()) { + const int slot = FindStackObject(function, "save.v" + std::to_string(reg)); + if (slot >= 0) { + EmitFrameAddress(function, slot, "x16", os); + os << " ldr " << GetDRegName(reg) << ", [x16]\n"; + } + } + for (int reg : function.GetUsedCalleeSavedGPRs()) { + const int slot = FindStackObject(function, "save.x" + std::to_string(reg)); + if (slot >= 0) { + EmitFrameAddress(function, slot, "x16", os); + os << " ldr x" << reg << ", [x16]\n"; + } + } + os << " mov sp, x29\n"; + os << " ldp x29, x30, [sp], #16\n"; + os << " ret\n"; + }; - for (const auto& inst : function.GetEntry().GetInstructions()) { - const auto& ops = inst.GetOperands(); - switch (inst.GetOpcode()) { - case Opcode::Prologue: - os << " stp x29, x30, [sp, #-16]!\n"; - os << " mov x29, sp\n"; - if (function.GetFrameSize() > 0) { - os << " sub sp, sp, #" << function.GetFrameSize() << "\n"; - } - break; - case Opcode::Epilogue: - if (function.GetFrameSize() > 0) { - os << " add sp, sp, #" << function.GetFrameSize() << "\n"; - } - os << " ldp x29, x30, [sp], #16\n"; - break; - case Opcode::MovImm: - os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #" - << ops.at(1).GetImm() << "\n"; - break; - case Opcode::LoadStack: { - const auto& slot = GetFrameSlot(function, ops.at(1)); - PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset); - break; + std::unordered_map use_counts; + for (const auto& block : function.GetBlocks()) { + for (const auto& inst : block->GetInstructions()) { + for (int vreg : inst.GetUses()) { + ++use_counts[vreg]; } - case Opcode::StoreStack: { - const auto& slot = GetFrameSlot(function, ops.at(1)); - PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset); - break; + } + } + + for (const auto& block : function.GetBlocks()) { + os << BlockLabel(function, *block) << ":\n"; + const auto& instructions = block->GetInstructions(); + for (std::size_t inst_index = 0; inst_index < instructions.size(); ++inst_index) { + const auto& inst = instructions[inst_index]; + if ((inst.GetOpcode() == MachineInstr::Opcode::ICmp || + inst.GetOpcode() == MachineInstr::Opcode::FCmp) && + inst_index + 1 < instructions.size() && + TryEmitFusedCompareBranch(function, inst, instructions[inst_index + 1], use_counts, + os)) { + ++inst_index; + continue; + } + switch (inst.GetOpcode()) { + case MachineInstr::Opcode::Arg: { + const int vreg = inst.GetOperands()[0].GetVReg(); + const int arg_index = static_cast(inst.GetOperands()[1].GetImm()); + const auto type = function.GetVRegInfo(vreg).type; + const auto location = ComputeArgLocation(function.GetParamTypes(), arg_index); + if (IsFPR(type)) { + const auto def = PrepareFprDef(function, vreg, 16); + if (location.in_reg) { + EmitCopy(os, def.reg_name.c_str(), GetPhysRegName({RegClass::FPR, location.reg_index}, type), + true); + } else { + EmitIncomingStackAddress(location.stack_offset, "x16", os); + EmitLoadFromAddr(type, def.reg_name.c_str(), "x16", os); + } + FinalizeDef(function, vreg, def, os); + } else { + const auto def = PrepareGprDef(function, vreg, 9); + if (location.in_reg) { + EmitCopy(os, def.reg_name.c_str(), + GetPhysRegName({RegClass::GPR, location.reg_index}, type), false); + } else { + EmitIncomingStackAddress(location.stack_offset, "x16", os); + EmitLoadFromAddr(type, def.reg_name.c_str(), "x16", os); + } + FinalizeDef(function, vreg, def, os); + } + break; + } + case MachineInstr::Opcode::Copy: + case MachineInstr::Opcode::ZExt: { + const int vreg = inst.GetOperands()[0].GetVReg(); + const auto type = function.GetVRegInfo(vreg).type; + if (IsFPR(type)) { + const auto def = PrepareFprDef(function, vreg, 16); + const auto src = MaterializeFprUse(function, inst.GetOperands()[1], 17, 9, os); + EmitCopy(os, def.reg_name.c_str(), src.c_str(), true); + FinalizeDef(function, vreg, def, os); + } else { + const auto def = PrepareGprDef(function, vreg, 9); + const auto src = MaterializeGprUse(function, inst.GetOperands()[1], type, 10, os); + EmitCopy(os, def.reg_name.c_str(), src.c_str(), false); + FinalizeDef(function, vreg, def, os); + } + break; + } + case MachineInstr::Opcode::Load: { + const int vreg = inst.GetOperands()[0].GetVReg(); + const auto type = function.GetVRegInfo(vreg).type; + EmitAddressExpr(function, inst.GetAddress(), os); + if (IsFPR(type)) { + const auto def = PrepareFprDef(function, vreg, 16); + EmitLoadFromAddr(type, def.reg_name.c_str(), "x16", os); + FinalizeDef(function, vreg, def, os); + } else { + const auto def = PrepareGprDef(function, vreg, 9); + EmitLoadFromAddr(type, def.reg_name.c_str(), "x16", os); + FinalizeDef(function, vreg, def, os); + } + break; + } + case MachineInstr::Opcode::Store: { + const auto& src_op = inst.GetOperands()[0]; + const ValueType type = src_op.GetKind() == OperandKind::VReg + ? function.GetVRegInfo(src_op.GetVReg()).type + : inst.GetValueType(); + EmitAddressExpr(function, inst.GetAddress(), os); + if (IsFPR(type)) { + const auto src = MaterializeFprUse(function, src_op, 16, 9, os); + EmitStoreToAddr(type, src.c_str(), "x16", os); + } else { + const auto src = MaterializeGprUse(function, src_op, type, 9, os); + EmitStoreToAddr(type, src.c_str(), "x16", os); + } + break; + } + case MachineInstr::Opcode::Lea: { + const int vreg = inst.GetOperands()[0].GetVReg(); + const auto def = PrepareGprDef(function, vreg, 9); + EmitAddressExpr(function, inst.GetAddress(), os); + EmitCopy(os, def.reg_name.c_str(), "x16", false); + FinalizeDef(function, vreg, def, os); + break; + } + case MachineInstr::Opcode::Add: + case MachineInstr::Opcode::Sub: + case MachineInstr::Opcode::Mul: + case MachineInstr::Opcode::Div: + case MachineInstr::Opcode::And: + case MachineInstr::Opcode::Or: + case MachineInstr::Opcode::Xor: + case MachineInstr::Opcode::Shl: + case MachineInstr::Opcode::AShr: + case MachineInstr::Opcode::LShr: { + const int vreg = inst.GetOperands()[0].GetVReg(); + const auto def = PrepareGprDef(function, vreg, 9); + const auto& lhs_op = inst.GetOperands()[1]; + const auto& rhs_op = inst.GetOperands()[2]; + if (inst.GetOpcode() == MachineInstr::Opcode::Add || + inst.GetOpcode() == MachineInstr::Opcode::Sub) { + auto emit_add_sub_imm = [&](const MachineOperand& reg_op, std::int64_t imm, + const char* pos_opcode, + const char* neg_opcode) -> bool { + if (reg_op.GetKind() == OperandKind::Imm) { + return false; + } + if (imm >= 0 && IsAddSubImm(imm)) { + const auto src = MaterializeGprUse(function, reg_op, ValueType::I32, 10, os); + EmitAddSubImm(os, pos_opcode, def.reg_name.c_str(), src.c_str(), imm); + FinalizeDef(function, vreg, def, os); + return true; + } + if (imm < 0 && IsAddSubImm(-imm)) { + const auto src = MaterializeGprUse(function, reg_op, ValueType::I32, 10, os); + EmitAddSubImm(os, neg_opcode, def.reg_name.c_str(), src.c_str(), -imm); + FinalizeDef(function, vreg, def, os); + return true; + } + return false; + }; + + if (rhs_op.GetKind() == OperandKind::Imm) { + if (emit_add_sub_imm(lhs_op, rhs_op.GetImm(), + inst.GetOpcode() == MachineInstr::Opcode::Add ? "add" : "sub", + inst.GetOpcode() == MachineInstr::Opcode::Add ? "sub" : "add")) { + break; + } + } + if (inst.GetOpcode() == MachineInstr::Opcode::Add && + lhs_op.GetKind() == OperandKind::Imm) { + if (emit_add_sub_imm(rhs_op, lhs_op.GetImm(), "add", "sub")) { + break; + } + } + } + if (inst.GetOpcode() == MachineInstr::Opcode::Div && + rhs_op.GetKind() == OperandKind::Imm && + rhs_op.GetImm() > 0 && IsPowerOfTwo(rhs_op.GetImm())) { + const auto lhs = MaterializeGprUse(function, lhs_op, ValueType::I32, 10, os); + const int shift = Log2(rhs_op.GetImm()); + if (shift == 0) { + EmitCopy(os, def.reg_name.c_str(), lhs.c_str(), false); + } else { + os << " asr w11, " << lhs << ", #31\n"; + os << " and w11, w11, #" << ((1ll << shift) - 1) << "\n"; + os << " add w11, " << lhs << ", w11\n"; + os << " asr " << def.reg_name << ", w11, #" << shift << "\n"; + } + FinalizeDef(function, vreg, def, os); + break; + } + const auto lhs = MaterializeGprUse(function, lhs_op, ValueType::I32, 10, os); + const auto rhs = MaterializeGprUse(function, rhs_op, ValueType::I32, 11, os); + const char* mnemonic = "add"; + switch (inst.GetOpcode()) { + case MachineInstr::Opcode::Add: + mnemonic = "add"; + break; + case MachineInstr::Opcode::Sub: + mnemonic = "sub"; + break; + case MachineInstr::Opcode::Mul: + mnemonic = "mul"; + break; + case MachineInstr::Opcode::Div: + mnemonic = "sdiv"; + break; + case MachineInstr::Opcode::And: + mnemonic = "and"; + break; + case MachineInstr::Opcode::Or: + mnemonic = "orr"; + break; + case MachineInstr::Opcode::Xor: + mnemonic = "eor"; + break; + case MachineInstr::Opcode::Shl: + mnemonic = "lsl"; + break; + case MachineInstr::Opcode::AShr: + mnemonic = "asr"; + break; + case MachineInstr::Opcode::LShr: + mnemonic = "lsr"; + break; + default: + break; + } + os << " " << mnemonic << " " << def.reg_name << ", " << lhs << ", " << rhs << "\n"; + FinalizeDef(function, vreg, def, os); + break; + } + case MachineInstr::Opcode::Rem: { + const int vreg = inst.GetOperands()[0].GetVReg(); + const auto def = PrepareGprDef(function, vreg, 9); + const auto lhs = MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os); + const auto rhs = MaterializeGprUse(function, inst.GetOperands()[2], ValueType::I32, 11, os); + os << " sdiv w12, " << lhs << ", " << rhs << "\n"; + os << " msub " << def.reg_name << ", w12, " << rhs << ", " << lhs << "\n"; + FinalizeDef(function, vreg, def, os); + break; + } + case MachineInstr::Opcode::FAdd: + case MachineInstr::Opcode::FSub: + case MachineInstr::Opcode::FMul: + case MachineInstr::Opcode::FDiv: { + const int vreg = inst.GetOperands()[0].GetVReg(); + const auto def = PrepareFprDef(function, vreg, 16); + const auto lhs = MaterializeFprUse(function, inst.GetOperands()[1], 17, 9, os); + const auto rhs = MaterializeFprUse(function, inst.GetOperands()[2], 18, 10, os); + const char* mnemonic = "fadd"; + switch (inst.GetOpcode()) { + case MachineInstr::Opcode::FAdd: + mnemonic = "fadd"; + break; + case MachineInstr::Opcode::FSub: + mnemonic = "fsub"; + break; + case MachineInstr::Opcode::FMul: + mnemonic = "fmul"; + break; + case MachineInstr::Opcode::FDiv: + mnemonic = "fdiv"; + break; + default: + break; + } + os << " " << mnemonic << " " << def.reg_name << ", " << lhs << ", " << rhs << "\n"; + FinalizeDef(function, vreg, def, os); + break; + } + case MachineInstr::Opcode::FNeg: { + const int vreg = inst.GetOperands()[0].GetVReg(); + const auto def = PrepareFprDef(function, vreg, 16); + const auto src = MaterializeFprUse(function, inst.GetOperands()[1], 17, 9, os); + os << " fneg " << def.reg_name << ", " << src << "\n"; + FinalizeDef(function, vreg, def, os); + break; + } + case MachineInstr::Opcode::ICmp: { + const int vreg = inst.GetOperands()[0].GetVReg(); + const auto def = PrepareGprDef(function, vreg, 9); + const auto lhs = MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os); + const auto rhs = MaterializeGprUse(function, inst.GetOperands()[2], ValueType::I32, 11, os); + os << " cmp " << lhs << ", " << rhs << "\n"; + static const char* kCond[] = {"eq", "ne", "lt", "gt", "le", "ge"}; + os << " cset " << def.reg_name << ", " << kCond[static_cast(inst.GetCondCode())] << "\n"; + FinalizeDef(function, vreg, def, os); + break; + } + case MachineInstr::Opcode::FCmp: { + const int vreg = inst.GetOperands()[0].GetVReg(); + const auto def = PrepareGprDef(function, vreg, 9); + const auto lhs = MaterializeFprUse(function, inst.GetOperands()[1], 16, 10, os); + const auto rhs = MaterializeFprUse(function, inst.GetOperands()[2], 17, 11, os); + os << " fcmp " << lhs << ", " << rhs << "\n"; + static const char* kCond[] = {"eq", "ne", "lt", "gt", "le", "ge"}; + os << " cset " << def.reg_name << ", " << kCond[static_cast(inst.GetCondCode())] << "\n"; + FinalizeDef(function, vreg, def, os); + break; + } + case MachineInstr::Opcode::ItoF: { + const int vreg = inst.GetOperands()[0].GetVReg(); + const auto def = PrepareFprDef(function, vreg, 16); + const auto src = MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os); + os << " scvtf " << def.reg_name << ", " << src << "\n"; + FinalizeDef(function, vreg, def, os); + break; + } + case MachineInstr::Opcode::FtoI: { + const int vreg = inst.GetOperands()[0].GetVReg(); + const auto def = PrepareGprDef(function, vreg, 9); + const auto src = MaterializeFprUse(function, inst.GetOperands()[1], 16, 10, os); + os << " fcvtzs " << def.reg_name << ", " << src << "\n"; + FinalizeDef(function, vreg, def, os); + break; + } + case MachineInstr::Opcode::Br: + os << " b " << BlockLabel(function, inst.GetOperands()[0].GetText()) << "\n"; + break; + case MachineInstr::Opcode::CondBr: { + const auto& cond = inst.GetOperands()[0]; + if (cond.GetKind() == OperandKind::Imm) { + os << " b " << BlockLabel(function, + cond.GetImm() != 0 ? inst.GetOperands()[1].GetText() + : inst.GetOperands()[2].GetText()) + << "\n"; + break; + } + const auto cond_reg = MaterializeGprUse(function, cond, ValueType::I1, 9, os); + os << " cbnz " << cond_reg << ", " + << BlockLabel(function, inst.GetOperands()[1].GetText()) << "\n"; + os << " b " << BlockLabel(function, inst.GetOperands()[2].GetText()) << "\n"; + break; + } + case MachineInstr::Opcode::Call: { + struct CallArgPlacement { + MachineOperand operand; + ValueType type = ValueType::Void; + bool on_stack = false; + int reg_index = -1; + int stack_offset = 0; + }; + + std::vector placements; + int gpr = 0; + int fpr = 0; + int stack = 0; + size_t arg_begin = inst.GetCallReturnType() == ValueType::Void ? 0 : 1; + for (size_t i = arg_begin; i < inst.GetOperands().size(); ++i) { + const auto type = inst.GetCallArgTypes()[i - arg_begin]; + CallArgPlacement placement; + placement.operand = inst.GetOperands()[i]; + placement.type = type; + if (IsFPR(type)) { + if (fpr < 8) { + placement.reg_index = fpr++; + } else { + placement.on_stack = true; + placement.stack_offset = stack; + stack += 8; + } + } else { + if (gpr < 8) { + placement.reg_index = gpr++; + } else { + placement.on_stack = true; + placement.stack_offset = stack; + stack += 8; + } + } + placements.push_back(placement); + } + const int stack_bytes = AlignTo(stack, 16); + EmitStackAdjust(os, "sub", stack_bytes); + for (const auto& placement : placements) { + if (!placement.on_stack) { + continue; + } + if (IsFPR(placement.type)) { + const auto src = MaterializeFprUse(function, placement.operand, 16, 9, os); + EmitMoveImm(os, "x11", placement.stack_offset); + os << " add x16, sp, x11\n"; + EmitStoreToAddr(placement.type, src.c_str(), "x16", os); + } else { + const auto src = MaterializeGprUse(function, placement.operand, placement.type, 9, os); + EmitMoveImm(os, "x11", placement.stack_offset); + os << " add x16, sp, x11\n"; + EmitStoreToAddr(placement.type, src.c_str(), "x16", os); + } + } + for (const auto& placement : placements) { + if (placement.on_stack) { + continue; + } + if (IsFPR(placement.type)) { + const auto src = MaterializeFprUse(function, placement.operand, 16, 9, os); + EmitCopy(os, GetPhysRegName({RegClass::FPR, placement.reg_index}, placement.type), + src.c_str(), true); + } else { + const auto src = MaterializeGprUse(function, placement.operand, placement.type, 9, os); + EmitCopy(os, GetPhysRegName({RegClass::GPR, placement.reg_index}, placement.type), + src.c_str(), false); + } + } + os << " bl " << inst.GetCallee() << "\n"; + EmitStackAdjust(os, "add", stack_bytes); + if (inst.GetCallReturnType() != ValueType::Void) { + const int dest_vreg = inst.GetOperands()[0].GetVReg(); + if (IsFPR(inst.GetCallReturnType())) { + const auto def = PrepareFprDef(function, dest_vreg, 16); + EmitCopy(os, def.reg_name.c_str(), "s0", true); + FinalizeDef(function, dest_vreg, def, os); + } else { + const auto def = PrepareGprDef(function, dest_vreg, 9); + EmitCopy(os, def.reg_name.c_str(), + GetPhysRegName({RegClass::GPR, 0}, inst.GetCallReturnType()), false); + FinalizeDef(function, dest_vreg, def, os); + } + } + break; + } + case MachineInstr::Opcode::Ret: { + if (!inst.GetOperands().empty()) { + const auto& value = inst.GetOperands()[0]; + ValueType type = value.GetKind() == OperandKind::VReg + ? function.GetVRegInfo(value.GetVReg()).type + : inst.GetValueType(); + if (IsFPR(type)) { + const auto src = MaterializeFprUse(function, value, 16, 9, os); + EmitCopy(os, "s0", src.c_str(), true); + } else { + const auto src = MaterializeGprUse(function, value, type, 9, os); + EmitCopy(os, GetPhysRegName({RegClass::GPR, 0}, type), src.c_str(), false); + } + } + emit_epilogue(); + break; + } + case MachineInstr::Opcode::Memset: { + EmitAddressExpr(function, inst.GetAddress(), os); + EmitCopy(os, "x0", "x16", false); + const auto value_reg = MaterializeGprUse(function, inst.GetOperands()[0], ValueType::I32, 9, os); + EmitCopy(os, "w1", value_reg.c_str(), false); + const auto len_reg = MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os); + EmitCopy(os, "w2", len_reg.c_str(), false); + os << " bl memset\n"; + break; + } + case MachineInstr::Opcode::Unreachable: + os << " brk #0\n"; + break; } - case Opcode::AddRR: - os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::Ret: - os << " ret\n"; - break; } } - os << ".size " << function.GetName() << ", .-" << function.GetName() - << "\n"; + os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n"; +} + +} // namespace + +void PrintAsm(const MachineModule& module, std::ostream& os) { + for (const auto& global : module.GetSourceModule().GetGlobalValues()) { + EmitGlobal(*global, os); + } + for (const auto& function : module.GetFunctions()) { + EmitFunction(*function, os); + } } } // namespace mir + + + + diff --git a/src/mir/FrameLowering.cpp b/src/mir/FrameLowering.cpp index 679ab68..48c4e7f 100644 --- a/src/mir/FrameLowering.cpp +++ b/src/mir/FrameLowering.cpp @@ -1,45 +1,40 @@ #include "mir/MIR.h" -#include -#include - -#include "utils/Log.h" +#include namespace mir { namespace { int AlignTo(int value, int align) { + if (align <= 1) { + return value; + } return ((value + align - 1) / align) * align; } } // namespace -void RunFrameLowering(MachineFunction& function) { - int cursor = 0; - for (const auto& slot : function.GetFrameSlots()) { - cursor += slot.size; - if (-cursor < -256) { - throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧")); +void RunFrameLowering(MachineModule& module) { + for (auto& function : module.GetFunctions()) { + for (int reg : function->GetUsedCalleeSavedGPRs()) { + function->CreateStackObject(8, 8, StackObjectKind::SavedGPR, + "save.x" + std::to_string(reg)); + } + for (int reg : function->GetUsedCalleeSavedFPRs()) { + function->CreateStackObject(8, 8, StackObjectKind::SavedFPR, + "save.v" + std::to_string(reg)); } - } - - cursor = 0; - for (const auto& slot : function.GetFrameSlots()) { - cursor += slot.size; - function.GetFrameSlot(slot.index).offset = -cursor; - } - function.SetFrameSize(AlignTo(cursor, 16)); - auto& insts = function.GetEntry().GetInstructions(); - std::vector lowered; - lowered.emplace_back(Opcode::Prologue); - for (const auto& inst : insts) { - if (inst.GetOpcode() == Opcode::Ret) { - lowered.emplace_back(Opcode::Epilogue); + int cursor = 0; + const int object_count = static_cast(function->GetStackObjects().size()); + for (int i = 0; i < object_count; ++i) { + auto& object = function->GetStackObject(i); + cursor = AlignTo(cursor, object.align); + cursor += object.size; + object.offset = -cursor; } - lowered.push_back(inst); + function->SetFrameSize(AlignTo(cursor, 16)); } - insts = std::move(lowered); } } // namespace mir diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index 9a18396..40b2416 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -1,7 +1,11 @@ #include "mir/MIR.h" +#include +#include #include #include +#include +#include #include "ir/IR.h" #include "utils/Log.h" @@ -9,115 +13,740 @@ namespace mir { namespace { -using ValueSlotMap = std::unordered_map; +enum class LoweredKind { Invalid, VReg, StackObject, Global }; -void EmitValueToReg(const ir::Value* value, PhysReg target, - const ValueSlotMap& slots, MachineBasicBlock& block) { - if (auto* constant = dynamic_cast(value)) { - block.Append(Opcode::MovImm, - {Operand::Reg(target), Operand::Imm(constant->GetValue())}); - return; +struct LoweredValue { + LoweredKind kind = LoweredKind::Invalid; + ValueType type = ValueType::Void; + int index = -1; + std::string symbol; +}; + +ValueType LowerType(const std::shared_ptr& type) { + if (!type || type->IsVoid()) { + return ValueType::Void; + } + if (type->IsInt1()) { + return ValueType::I1; + } + if (type->IsInt32()) { + return ValueType::I32; + } + if (type->IsFloat()) { + return ValueType::F32; } + if (type->IsPointer()) { + return ValueType::Ptr; + } + throw std::runtime_error(FormatError("mir", "unsupported IR type in backend lowering")); +} - auto it = slots.find(value); - if (it == slots.end()) { - throw std::runtime_error( - FormatError("mir", "找不到值对应的栈槽: " + value->GetName())); +int GetIRTypeAlign(const std::shared_ptr& type) { + if (!type) { + return 1; } + if (type->IsArray()) { + return GetIRTypeAlign(type->GetElementType()); + } + return GetValueAlign(LowerType(type)); +} - block.Append(Opcode::LoadStack, - {Operand::Reg(target), Operand::FrameIndex(it->second)}); +CondCode LowerIntCond(ir::Opcode opcode) { + switch (opcode) { + case ir::Opcode::ICmpEQ: + return CondCode::EQ; + case ir::Opcode::ICmpNE: + return CondCode::NE; + case ir::Opcode::ICmpLT: + return CondCode::LT; + case ir::Opcode::ICmpGT: + return CondCode::GT; + case ir::Opcode::ICmpLE: + return CondCode::LE; + case ir::Opcode::ICmpGE: + return CondCode::GE; + default: + throw std::runtime_error(FormatError("mir", "invalid integer compare opcode")); + } } -void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, - ValueSlotMap& slots) { - auto& block = function.GetEntry(); +CondCode LowerFloatCond(ir::Opcode opcode) { + switch (opcode) { + case ir::Opcode::FCmpEQ: + return CondCode::EQ; + case ir::Opcode::FCmpNE: + return CondCode::NE; + case ir::Opcode::FCmpLT: + return CondCode::LT; + case ir::Opcode::FCmpGT: + return CondCode::GT; + case ir::Opcode::FCmpLE: + return CondCode::LE; + case ir::Opcode::FCmpGE: + return CondCode::GE; + default: + throw std::runtime_error(FormatError("mir", "invalid float compare opcode")); + } +} - switch (inst.GetOpcode()) { - case ir::Opcode::Alloca: { - slots.emplace(&inst, function.CreateFrameIndex()); - return; +std::int64_t FloatBits(float value) { + std::uint32_t bits = 0; + std::memcpy(&bits, &value, sizeof(bits)); + return static_cast(bits); +} + +class Lowerer { + public: + explicit Lowerer(const ir::Module& module) + : module_(module), machine_module_(std::make_unique(module)) {} + + std::unique_ptr Run() { + for (const auto& func : module_.GetFunctions()) { + if (func && !func->IsExternal()) { + LowerFunction(*func); + } } - case ir::Opcode::Store: { - auto& store = static_cast(inst); - auto dst = slots.find(store.GetPtr()); - if (dst == slots.end()) { - throw std::runtime_error( - FormatError("mir", "暂不支持对非栈变量地址进行写入")); + return std::move(machine_module_); + } + + private: + using OperandMap = std::unordered_map; + + MachineOperand ResolveScalarOperand(ir::Value* value, + const OperandMap* inline_values = nullptr) { + if (auto* ci = ir::dyncast(value)) { + return MachineOperand::Imm(ci->GetValue()); + } + if (auto* cb = ir::dyncast(value)) { + return MachineOperand::Imm(cb->GetValue() ? 1 : 0); + } + if (auto* cf = ir::dyncast(value)) { + return MachineOperand::Imm(FloatBits(cf->GetValue())); + } + + if (inline_values != nullptr) { + auto inline_it = inline_values->find(value); + if (inline_it != inline_values->end()) { + return inline_it->second; } - EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)}); - return; } - case ir::Opcode::Load: { - auto& load = static_cast(inst); - auto src = slots.find(load.GetPtr()); - if (src == slots.end()) { - throw std::runtime_error( - FormatError("mir", "暂不支持对非栈变量地址进行读取")); - } - int dst_slot = function.CreateFrameIndex(); - block.Append(Opcode::LoadStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); - return; - } - case ir::Opcode::Add: { - auto& bin = static_cast(inst); - int dst_slot = function.CreateFrameIndex(); - EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); - EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); - block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W9)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); - return; - } - case ir::Opcode::Ret: { - auto& ret = static_cast(inst); - EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block); - block.Append(Opcode::Ret); - return; - } - case ir::Opcode::Sub: - case ir::Opcode::Mul: - throw std::runtime_error(FormatError("mir", "暂不支持该二元运算")); - } - - throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令")); -} -} // namespace + auto it = values_.find(value); + if (it == values_.end() || it->second.kind != LoweredKind::VReg) { + throw std::runtime_error( + FormatError("mir", "value is not materialized as a virtual register: " + + value->GetName())); + } + return MachineOperand::VReg(it->second.index); + } -std::unique_ptr LowerToMIR(const ir::Module& module) { - DefaultContext(); + MachineOperand LowerScalarOperand(ir::Value* value) { + return ResolveScalarOperand(value, nullptr); + } + + AddressExpr LowerAddress(ir::Value* value) { + if (auto* global = ir::dyncast(value)) { + AddressExpr address; + address.base_kind = AddrBaseKind::Global; + address.symbol = global->GetName(); + return address; + } + + auto it = values_.find(value); + if (it == values_.end()) { + throw std::runtime_error(FormatError("mir", "missing lowered address value")); + } + + AddressExpr address; + switch (it->second.kind) { + case LoweredKind::StackObject: + address.base_kind = AddrBaseKind::FrameObject; + address.base_index = it->second.index; + return address; + case LoweredKind::Global: + address.base_kind = AddrBaseKind::Global; + address.symbol = it->second.symbol; + return address; + case LoweredKind::VReg: + address.base_kind = AddrBaseKind::VReg; + address.base_index = it->second.index; + return address; + case LoweredKind::Invalid: + break; + } + + throw std::runtime_error(FormatError("mir", "invalid address lowering")); + } + + MachineInstr::Opcode LowerBinaryOpcode(ir::Opcode opcode) { + switch (opcode) { + case ir::Opcode::Add: + return MachineInstr::Opcode::Add; + case ir::Opcode::Sub: + return MachineInstr::Opcode::Sub; + case ir::Opcode::Mul: + return MachineInstr::Opcode::Mul; + case ir::Opcode::Div: + return MachineInstr::Opcode::Div; + case ir::Opcode::Rem: + return MachineInstr::Opcode::Rem; + case ir::Opcode::And: + return MachineInstr::Opcode::And; + case ir::Opcode::Or: + return MachineInstr::Opcode::Or; + case ir::Opcode::Xor: + return MachineInstr::Opcode::Xor; + case ir::Opcode::Shl: + return MachineInstr::Opcode::Shl; + case ir::Opcode::AShr: + return MachineInstr::Opcode::AShr; + case ir::Opcode::LShr: + return MachineInstr::Opcode::LShr; + case ir::Opcode::FAdd: + return MachineInstr::Opcode::FAdd; + case ir::Opcode::FSub: + return MachineInstr::Opcode::FSub; + case ir::Opcode::FMul: + return MachineInstr::Opcode::FMul; + case ir::Opcode::FDiv: + return MachineInstr::Opcode::FDiv; + default: + throw std::runtime_error(FormatError("mir", "unsupported binary opcode")); + } + } + + LoweredValue NewVRegValue(ValueType type) { + return {LoweredKind::VReg, type, current_function_->NewVReg(type), ""}; + } + + LoweredValue MaterializeOperandAsValue(const MachineOperand& operand, ValueType type) { + if (operand.GetKind() == OperandKind::VReg) { + return {LoweredKind::VReg, type, operand.GetVReg(), ""}; + } + + auto lowered = NewVRegValue(type); + current_block_->Append(MachineInstr::Opcode::Copy, + {MachineOperand::VReg(lowered.index), operand}); + return lowered; + } + + void InsertBeforeTerminator(MachineBasicBlock* block, MachineInstr instr) { + auto& instructions = block->GetInstructions(); + auto insert_pos = instructions.end(); + if (!instructions.empty() && instructions.back().IsTerminator()) { + insert_pos = instructions.end() - 1; + } + instructions.insert(insert_pos, std::move(instr)); + } + + void PreparePhiResults(ir::Function& function) { + for (const auto& block : function.GetBlocks()) { + for (const auto& inst : block->GetInstructions()) { + if (inst->GetOpcode() != ir::Opcode::Phi) { + break; + } + auto lowered = NewVRegValue(LowerType(inst->GetType())); + values_[inst.get()] = lowered; + } + } + } + + void EmitPhiCopies(ir::Function& function) { + std::unordered_map> pending; + + for (const auto& block : function.GetBlocks()) { + for (const auto& inst : block->GetInstructions()) { + if (inst->GetOpcode() != ir::Opcode::Phi) { + break; + } + auto* phi = static_cast(inst.get()); + const int dest_vreg = values_.at(phi).index; + for (int i = 0; i < phi->GetNumIncomings(); ++i) { + auto* pred_block = blocks_.at(phi->GetIncomingBlock(i)); + pending[pred_block].emplace_back( + MachineInstr::Opcode::Copy, + std::vector{MachineOperand::VReg(dest_vreg), + LowerScalarOperand(phi->GetIncomingValue(i))}); + } + } + } - if (module.GetFunctions().size() != 1) { - throw std::runtime_error(FormatError("mir", "暂不支持多个函数")); + for (auto& item : pending) { + for (auto& instr : item.second) { + InsertBeforeTerminator(item.first, std::move(instr)); + } + } } - const auto& func = *module.GetFunctions().front(); - if (func.GetName() != "main") { - throw std::runtime_error(FormatError("mir", "暂不支持非 main 函数")); + bool CanInlineDirectCall(const ir::Function& function) const { + if (function.IsExternal() || function.GetBlocks().size() != 1) { + return false; + } + + for (const auto& block : function.GetBlocks()) { + for (const auto& inst : block->GetInstructions()) { + switch (inst->GetOpcode()) { + case ir::Opcode::Add: + case ir::Opcode::Sub: + case ir::Opcode::Mul: + case ir::Opcode::Div: + case ir::Opcode::Rem: + case ir::Opcode::And: + case ir::Opcode::Or: + case ir::Opcode::Xor: + case ir::Opcode::Shl: + case ir::Opcode::AShr: + case ir::Opcode::LShr: + case ir::Opcode::FAdd: + case ir::Opcode::FSub: + case ir::Opcode::FMul: + case ir::Opcode::FDiv: + case ir::Opcode::FNeg: + case ir::Opcode::ICmpEQ: + case ir::Opcode::ICmpNE: + case ir::Opcode::ICmpLT: + case ir::Opcode::ICmpGT: + case ir::Opcode::ICmpLE: + case ir::Opcode::ICmpGE: + case ir::Opcode::FCmpEQ: + case ir::Opcode::FCmpNE: + case ir::Opcode::FCmpLT: + case ir::Opcode::FCmpGT: + case ir::Opcode::FCmpLE: + case ir::Opcode::FCmpGE: + case ir::Opcode::Zext: + case ir::Opcode::IToF: + case ir::Opcode::FtoI: + case ir::Opcode::Return: + break; + default: + return false; + } + } + } + return true; } - auto machine_func = std::make_unique(func.GetName()); - ValueSlotMap slots; - const auto* entry = func.GetEntry(); - if (!entry) { - throw std::runtime_error(FormatError("mir", "IR 函数缺少入口基本块")); + bool TryInlineDirectCall(ir::CallInst* call) { + auto* callee = call->GetCallee(); + if (callee == nullptr || callee == current_ir_function_ || !CanInlineDirectCall(*callee)) { + return false; + } + + const auto& callee_args = callee->GetArguments(); + const auto& call_args = call->GetArguments(); + if (callee_args.size() != call_args.size()) { + return false; + } + + OperandMap inline_values; + for (size_t i = 0; i < call_args.size(); ++i) { + inline_values[callee_args[i].get()] = ResolveScalarOperand(call_args[i], nullptr); + } + + MachineOperand return_operand; + bool has_return = false; + for (const auto& inst : callee->GetBlocks().front()->GetInstructions()) { + switch (inst->GetOpcode()) { + case ir::Opcode::Add: + case ir::Opcode::Sub: + case ir::Opcode::Mul: + case ir::Opcode::Div: + case ir::Opcode::Rem: + case ir::Opcode::And: + case ir::Opcode::Or: + case ir::Opcode::Xor: + case ir::Opcode::Shl: + case ir::Opcode::AShr: + case ir::Opcode::LShr: + case ir::Opcode::FAdd: + case ir::Opcode::FSub: + case ir::Opcode::FMul: + case ir::Opcode::FDiv: { + auto* binary = static_cast(inst.get()); + auto lowered = NewVRegValue(LowerType(binary->GetType())); + current_block_->Append(LowerBinaryOpcode(inst->GetOpcode()), + {MachineOperand::VReg(lowered.index), + ResolveScalarOperand(binary->GetLhs(), &inline_values), + ResolveScalarOperand(binary->GetRhs(), &inline_values)}); + inline_values[inst.get()] = MachineOperand::VReg(lowered.index); + break; + } + case ir::Opcode::FNeg: { + auto* unary = static_cast(inst.get()); + auto lowered = NewVRegValue(ValueType::F32); + current_block_->Append(MachineInstr::Opcode::FNeg, + {MachineOperand::VReg(lowered.index), + ResolveScalarOperand(unary->GetOprd(), &inline_values)}); + inline_values[inst.get()] = MachineOperand::VReg(lowered.index); + break; + } + case ir::Opcode::ICmpEQ: + case ir::Opcode::ICmpNE: + case ir::Opcode::ICmpLT: + case ir::Opcode::ICmpGT: + case ir::Opcode::ICmpLE: + case ir::Opcode::ICmpGE: { + auto* binary = static_cast(inst.get()); + auto lowered = NewVRegValue(ValueType::I1); + MachineInstr instr(MachineInstr::Opcode::ICmp, + {MachineOperand::VReg(lowered.index), + ResolveScalarOperand(binary->GetLhs(), &inline_values), + ResolveScalarOperand(binary->GetRhs(), &inline_values)}); + instr.SetCondCode(LowerIntCond(inst->GetOpcode())); + current_block_->Append(std::move(instr)); + inline_values[inst.get()] = MachineOperand::VReg(lowered.index); + break; + } + case ir::Opcode::FCmpEQ: + case ir::Opcode::FCmpNE: + case ir::Opcode::FCmpLT: + case ir::Opcode::FCmpGT: + case ir::Opcode::FCmpLE: + case ir::Opcode::FCmpGE: { + auto* binary = static_cast(inst.get()); + auto lowered = NewVRegValue(ValueType::I1); + MachineInstr instr(MachineInstr::Opcode::FCmp, + {MachineOperand::VReg(lowered.index), + ResolveScalarOperand(binary->GetLhs(), &inline_values), + ResolveScalarOperand(binary->GetRhs(), &inline_values)}); + instr.SetCondCode(LowerFloatCond(inst->GetOpcode())); + current_block_->Append(std::move(instr)); + inline_values[inst.get()] = MachineOperand::VReg(lowered.index); + break; + } + case ir::Opcode::Zext: { + auto* zext = static_cast(inst.get()); + auto lowered = NewVRegValue(LowerType(zext->GetType())); + current_block_->Append(MachineInstr::Opcode::ZExt, + {MachineOperand::VReg(lowered.index), + ResolveScalarOperand(zext->GetValue(), &inline_values)}); + inline_values[inst.get()] = MachineOperand::VReg(lowered.index); + break; + } + case ir::Opcode::IToF: { + auto* unary = static_cast(inst.get()); + auto lowered = NewVRegValue(ValueType::F32); + current_block_->Append(MachineInstr::Opcode::ItoF, + {MachineOperand::VReg(lowered.index), + ResolveScalarOperand(unary->GetOprd(), &inline_values)}); + inline_values[inst.get()] = MachineOperand::VReg(lowered.index); + break; + } + case ir::Opcode::FtoI: { + auto* unary = static_cast(inst.get()); + auto lowered = NewVRegValue(ValueType::I32); + current_block_->Append(MachineInstr::Opcode::FtoI, + {MachineOperand::VReg(lowered.index), + ResolveScalarOperand(unary->GetOprd(), &inline_values)}); + inline_values[inst.get()] = MachineOperand::VReg(lowered.index); + break; + } + case ir::Opcode::Return: { + auto* ret = static_cast(inst.get()); + if (ret->HasReturnValue()) { + return_operand = ResolveScalarOperand(ret->GetReturnValue(), &inline_values); + has_return = true; + } + break; + } + default: + return false; + } + } + + if (!call->GetType()->IsVoid()) { + if (!has_return) { + throw std::runtime_error(FormatError("mir", "inlined call is missing return value")); + } + values_[call] = MaterializeOperandAsValue(return_operand, LowerType(call->GetType())); + } + return true; } - for (const auto& inst : entry->GetInstructions()) { - LowerInstruction(*inst, *machine_func, slots); + void LowerInstruction(ir::Instruction& inst) { + switch (inst.GetOpcode()) { + case ir::Opcode::Alloca: { + auto* alloca_inst = static_cast(&inst); + const int object = current_function_->CreateStackObject( + alloca_inst->GetAllocatedType()->GetSize(), + GetIRTypeAlign(alloca_inst->GetAllocatedType()), StackObjectKind::Local, + inst.GetName()); + values_[&inst] = {LoweredKind::StackObject, ValueType::Ptr, object, ""}; + return; + } + case ir::Opcode::Load: { + auto* load = static_cast(&inst); + auto lowered = NewVRegValue(LowerType(load->GetType())); + MachineInstr instr(MachineInstr::Opcode::Load, + {MachineOperand::VReg(lowered.index)}); + instr.SetAddress(LowerAddress(load->GetPtr())); + current_block_->Append(std::move(instr)); + values_[&inst] = lowered; + return; + } + case ir::Opcode::Store: { + auto* store = static_cast(&inst); + MachineInstr instr(MachineInstr::Opcode::Store, + {LowerScalarOperand(store->GetValue())}); + instr.SetValueType(LowerType(store->GetValue()->GetType())); + instr.SetAddress(LowerAddress(store->GetPtr())); + current_block_->Append(std::move(instr)); + return; + } + case ir::Opcode::Add: + case ir::Opcode::Sub: + case ir::Opcode::Mul: + case ir::Opcode::Div: + case ir::Opcode::Rem: + case ir::Opcode::And: + case ir::Opcode::Or: + case ir::Opcode::Xor: + case ir::Opcode::Shl: + case ir::Opcode::AShr: + case ir::Opcode::LShr: + case ir::Opcode::FAdd: + case ir::Opcode::FSub: + case ir::Opcode::FMul: + case ir::Opcode::FDiv: { + auto* binary = static_cast(&inst); + auto lowered = NewVRegValue(LowerType(binary->GetType())); + current_block_->Append(LowerBinaryOpcode(inst.GetOpcode()), + {MachineOperand::VReg(lowered.index), + LowerScalarOperand(binary->GetLhs()), + LowerScalarOperand(binary->GetRhs())}); + values_[&inst] = lowered; + return; + } + case ir::Opcode::FNeg: { + auto* unary = static_cast(&inst); + auto lowered = NewVRegValue(ValueType::F32); + current_block_->Append(MachineInstr::Opcode::FNeg, + {MachineOperand::VReg(lowered.index), + LowerScalarOperand(unary->GetOprd())}); + values_[&inst] = lowered; + return; + } + case ir::Opcode::ICmpEQ: + case ir::Opcode::ICmpNE: + case ir::Opcode::ICmpLT: + case ir::Opcode::ICmpGT: + case ir::Opcode::ICmpLE: + case ir::Opcode::ICmpGE: { + auto* binary = static_cast(&inst); + auto lowered = NewVRegValue(ValueType::I1); + MachineInstr instr(MachineInstr::Opcode::ICmp, + {MachineOperand::VReg(lowered.index), + LowerScalarOperand(binary->GetLhs()), + LowerScalarOperand(binary->GetRhs())}); + instr.SetCondCode(LowerIntCond(inst.GetOpcode())); + current_block_->Append(std::move(instr)); + values_[&inst] = lowered; + return; + } + case ir::Opcode::FCmpEQ: + case ir::Opcode::FCmpNE: + case ir::Opcode::FCmpLT: + case ir::Opcode::FCmpGT: + case ir::Opcode::FCmpLE: + case ir::Opcode::FCmpGE: { + auto* binary = static_cast(&inst); + auto lowered = NewVRegValue(ValueType::I1); + MachineInstr instr(MachineInstr::Opcode::FCmp, + {MachineOperand::VReg(lowered.index), + LowerScalarOperand(binary->GetLhs()), + LowerScalarOperand(binary->GetRhs())}); + instr.SetCondCode(LowerFloatCond(inst.GetOpcode())); + current_block_->Append(std::move(instr)); + values_[&inst] = lowered; + return; + } + case ir::Opcode::Zext: { + auto* zext = static_cast(&inst); + auto lowered = NewVRegValue(LowerType(zext->GetType())); + current_block_->Append(MachineInstr::Opcode::ZExt, + {MachineOperand::VReg(lowered.index), + LowerScalarOperand(zext->GetValue())}); + values_[&inst] = lowered; + return; + } + case ir::Opcode::IToF: { + auto* unary = static_cast(&inst); + auto lowered = NewVRegValue(ValueType::F32); + current_block_->Append(MachineInstr::Opcode::ItoF, + {MachineOperand::VReg(lowered.index), + LowerScalarOperand(unary->GetOprd())}); + values_[&inst] = lowered; + return; + } + case ir::Opcode::FtoI: { + auto* unary = static_cast(&inst); + auto lowered = NewVRegValue(ValueType::I32); + current_block_->Append(MachineInstr::Opcode::FtoI, + {MachineOperand::VReg(lowered.index), + LowerScalarOperand(unary->GetOprd())}); + values_[&inst] = lowered; + return; + } + case ir::Opcode::GetElementPtr: { + auto* gep = static_cast(&inst); + auto lowered = NewVRegValue(ValueType::Ptr); + AddressExpr address = LowerAddress(gep->GetPointer()); + auto current_type = gep->GetSourceType(); + for (size_t i = 0; i < gep->GetNumIndices(); ++i) { + auto* index = gep->GetIndex(i); + const std::int64_t stride = current_type ? current_type->GetSize() : 0; + if (auto* ci = ir::dyncast(index)) { + address.const_offset += static_cast(ci->GetValue()) * stride; + } else if (auto* cb = ir::dyncast(index)) { + address.const_offset += + static_cast(cb->GetValue() ? 1 : 0) * stride; + } else { + address.scaled_vregs.push_back({LowerScalarOperand(index).GetVReg(), stride}); + } + if (current_type && current_type->IsArray()) { + current_type = current_type->GetElementType(); + } + } + MachineInstr instr(MachineInstr::Opcode::Lea, + {MachineOperand::VReg(lowered.index)}); + instr.SetAddress(std::move(address)); + current_block_->Append(std::move(instr)); + values_[&inst] = lowered; + return; + } + case ir::Opcode::Call: { + auto* call = static_cast(&inst); + if (TryInlineDirectCall(call)) { + return; + } + std::vector operands; + if (!call->GetType()->IsVoid()) { + auto lowered = NewVRegValue(LowerType(call->GetType())); + operands.push_back(MachineOperand::VReg(lowered.index)); + values_[&inst] = lowered; + } + std::vector arg_types; + for (auto* arg : call->GetArguments()) { + operands.push_back(LowerScalarOperand(arg)); + arg_types.push_back(LowerType(arg->GetType())); + } + MachineInstr instr(MachineInstr::Opcode::Call, std::move(operands)); + instr.SetCallInfo(call->GetCallee()->GetName(), std::move(arg_types), + LowerType(call->GetType())); + current_block_->Append(std::move(instr)); + return; + } + case ir::Opcode::Br: { + auto* br = static_cast(&inst); + current_block_->Append(MachineInstr::Opcode::Br, + {MachineOperand::Block(blocks_.at(br->GetDest())->GetName())}); + return; + } + case ir::Opcode::CondBr: { + auto* br = static_cast(&inst); + current_block_->Append(MachineInstr::Opcode::CondBr, + {LowerScalarOperand(br->GetCondition()), + MachineOperand::Block(blocks_.at(br->GetThenBlock())->GetName()), + MachineOperand::Block(blocks_.at(br->GetElseBlock())->GetName())}); + return; + } + case ir::Opcode::Return: { + auto* ret = static_cast(&inst); + if (ret->HasReturnValue()) { + MachineInstr instr(MachineInstr::Opcode::Ret, + {LowerScalarOperand(ret->GetReturnValue())}); + instr.SetValueType(LowerType(ret->GetReturnValue()->GetType())); + current_block_->Append(std::move(instr)); + } else { + current_block_->Append(MachineInstr::Opcode::Ret); + } + return; + } + case ir::Opcode::Memset: { + auto* memset_inst = static_cast(&inst); + MachineInstr instr(MachineInstr::Opcode::Memset, + {LowerScalarOperand(memset_inst->GetValue()), + LowerScalarOperand(memset_inst->GetLength())}); + instr.SetAddress(LowerAddress(memset_inst->GetDest())); + current_block_->Append(std::move(instr)); + return; + } + case ir::Opcode::Unreachable: + current_block_->Append(MachineInstr::Opcode::Unreachable); + return; + case ir::Opcode::Phi: + return; + case ir::Opcode::FRem: + case ir::Opcode::Neg: + case ir::Opcode::Not: + throw std::runtime_error( + FormatError("mir", "unsupported instruction in backend lowering")); + } + + throw std::runtime_error(FormatError("mir", "unsupported IR opcode in backend lowering")); + } + + void LowerFunction(ir::Function& function) { + values_.clear(); + blocks_.clear(); + + std::vector param_types; + for (const auto& type : function.GetParamTypes()) { + param_types.push_back(LowerType(type)); + } + + auto machine_function = std::make_unique( + function.GetName(), LowerType(function.GetReturnType()), std::move(param_types)); + current_ir_function_ = &function; + current_function_ = machine_function.get(); + + for (const auto& block : function.GetBlocks()) { + blocks_[block.get()] = current_function_->CreateBlock(block->GetName()); + } + + if (!function.GetBlocks().empty()) { + auto* entry = blocks_.at(function.GetBlocks().front().get()); + for (const auto& argument : function.GetArguments()) { + auto lowered = NewVRegValue(LowerType(argument->GetType())); + entry->Append(MachineInstr::Opcode::Arg, + {MachineOperand::VReg(lowered.index), + MachineOperand::Imm(static_cast(argument->GetIndex()))}); + values_[argument.get()] = lowered; + } + } + + PreparePhiResults(function); + + for (const auto& block : function.GetBlocks()) { + current_block_ = blocks_.at(block.get()); + for (const auto& inst : block->GetInstructions()) { + LowerInstruction(*inst); + } + } + + EmitPhiCopies(function); + + machine_module_->AddFunction(std::move(machine_function)); + current_ir_function_ = nullptr; + current_function_ = nullptr; + current_block_ = nullptr; } - return machine_func; + const ir::Module& module_; + std::unique_ptr machine_module_; + ir::Function* current_ir_function_ = nullptr; + MachineFunction* current_function_ = nullptr; + MachineBasicBlock* current_block_ = nullptr; + std::unordered_map values_; + std::unordered_map blocks_; +}; + +} // namespace + +std::unique_ptr LowerToMIR(const ir::Module& module) { + DefaultContext(); + return Lowerer(module).Run(); } } // namespace mir diff --git a/src/mir/MIRBasicBlock.cpp b/src/mir/MIRBasicBlock.cpp index d42b4b3..66541bc 100644 --- a/src/mir/MIRBasicBlock.cpp +++ b/src/mir/MIRBasicBlock.cpp @@ -7,9 +7,14 @@ namespace mir { MachineBasicBlock::MachineBasicBlock(std::string name) : name_(std::move(name)) {} -MachineInstr& MachineBasicBlock::Append(Opcode opcode, - std::initializer_list operands) { - instructions_.emplace_back(opcode, std::vector(operands)); +MachineInstr& MachineBasicBlock::Append(MachineInstr::Opcode opcode, + std::vector operands) { + instructions_.emplace_back(opcode, std::move(operands)); + return instructions_.back(); +} + +MachineInstr& MachineBasicBlock::Append(MachineInstr instr) { + instructions_.push_back(std::move(instr)); return instructions_.back(); } diff --git a/src/mir/MIRContext.cpp b/src/mir/MIRContext.cpp index 30c75c8..caf283e 100644 --- a/src/mir/MIRContext.cpp +++ b/src/mir/MIRContext.cpp @@ -2,9 +2,10 @@ namespace mir { -MIRContext& DefaultContext() { - static MIRContext ctx; - return ctx; -} +namespace { +MIRContext g_context; +} // namespace + +MIRContext& DefaultContext() { return g_context; } } // namespace mir diff --git a/src/mir/MIRFunction.cpp b/src/mir/MIRFunction.cpp index 334f8cc..4d98674 100644 --- a/src/mir/MIRFunction.cpp +++ b/src/mir/MIRFunction.cpp @@ -1,33 +1,106 @@ #include "mir/MIR.h" +#include #include #include -#include "utils/Log.h" - namespace mir { -MachineFunction::MachineFunction(std::string name) - : name_(std::move(name)), entry_("entry") {} +MachineFunction::MachineFunction(std::string name, ValueType return_type, + std::vector param_types) + : name_(std::move(name)), + return_type_(return_type), + param_types_(std::move(param_types)) {} + +MachineBasicBlock* MachineFunction::CreateBlock(const std::string& name) { + auto block = std::make_unique(name); + auto* ptr = block.get(); + blocks_.push_back(std::move(block)); + return ptr; +} + +int MachineFunction::NewVReg(ValueType type) { + const int id = static_cast(vregs_.size()); + vregs_.push_back({id, type}); + allocations_.push_back({}); + return id; +} + +const VRegInfo& MachineFunction::GetVRegInfo(int id) const { + if (id < 0 || id >= static_cast(vregs_.size())) { + throw std::out_of_range("virtual register index out of range"); + } + return vregs_[static_cast(id)]; +} + +VRegInfo& MachineFunction::GetVRegInfo(int id) { + if (id < 0 || id >= static_cast(vregs_.size())) { + throw std::out_of_range("virtual register index out of range"); + } + return vregs_[static_cast(id)]; +} -int MachineFunction::CreateFrameIndex(int size) { - int index = static_cast(frame_slots_.size()); - frame_slots_.push_back(FrameSlot{index, size, 0}); +int MachineFunction::CreateStackObject(int size, int align, StackObjectKind kind, + const std::string& name) { + const int index = static_cast(stack_objects_.size()); + stack_objects_.push_back({index, kind, size, align, 0, name}); return index; } -FrameSlot& MachineFunction::GetFrameSlot(int index) { - if (index < 0 || index >= static_cast(frame_slots_.size())) { - throw std::runtime_error(FormatError("mir", "非法 FrameIndex")); +StackObject& MachineFunction::GetStackObject(int index) { + if (index < 0 || index >= static_cast(stack_objects_.size())) { + throw std::out_of_range("stack object index out of range"); + } + return stack_objects_[static_cast(index)]; +} + +const StackObject& MachineFunction::GetStackObject(int index) const { + if (index < 0 || index >= static_cast(stack_objects_.size())) { + throw std::out_of_range("stack object index out of range"); } - return frame_slots_[index]; + return stack_objects_[static_cast(index)]; } -const FrameSlot& MachineFunction::GetFrameSlot(int index) const { - if (index < 0 || index >= static_cast(frame_slots_.size())) { - throw std::runtime_error(FormatError("mir", "非法 FrameIndex")); +void MachineFunction::SetAllocation(int vreg, Allocation allocation) { + if (vreg < 0 || vreg >= static_cast(allocations_.size())) { + throw std::out_of_range("allocation index out of range"); } - return frame_slots_[index]; + allocations_[static_cast(vreg)] = allocation; +} + +const Allocation& MachineFunction::GetAllocation(int vreg) const { + if (vreg < 0 || vreg >= static_cast(allocations_.size())) { + throw std::out_of_range("allocation index out of range"); + } + return allocations_[static_cast(vreg)]; +} + +Allocation& MachineFunction::GetAllocation(int vreg) { + if (vreg < 0 || vreg >= static_cast(allocations_.size())) { + throw std::out_of_range("allocation index out of range"); + } + return allocations_[static_cast(vreg)]; +} + +void MachineFunction::AddUsedCalleeSavedGPR(int reg_index) { + if (std::find(used_callee_saved_gprs_.begin(), used_callee_saved_gprs_.end(), + reg_index) == used_callee_saved_gprs_.end()) { + used_callee_saved_gprs_.push_back(reg_index); + } +} + +void MachineFunction::AddUsedCalleeSavedFPR(int reg_index) { + if (std::find(used_callee_saved_fprs_.begin(), used_callee_saved_fprs_.end(), + reg_index) == used_callee_saved_fprs_.end()) { + used_callee_saved_fprs_.push_back(reg_index); + } +} + +MachineFunction* MachineModule::AddFunction( + std::unique_ptr function) { + auto* ptr = function.get(); + functions_.push_back(std::move(function)); + return ptr; } } // namespace mir diff --git a/src/mir/MIRInstr.cpp b/src/mir/MIRInstr.cpp index 0a21a03..d07344a 100644 --- a/src/mir/MIRInstr.cpp +++ b/src/mir/MIRInstr.cpp @@ -1,23 +1,178 @@ #include "mir/MIR.h" -#include - namespace mir { -Operand::Operand(Kind kind, PhysReg reg, int imm) - : kind_(kind), reg_(reg), imm_(imm) {} +MachineOperand::MachineOperand(OperandKind kind, int vreg, std::int64_t imm, + std::string text) + : kind_(kind), vreg_(vreg), imm_(imm), text_(std::move(text)) {} + +MachineOperand MachineOperand::VReg(int reg) { + return MachineOperand(OperandKind::VReg, reg, 0, ""); +} -Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); } +MachineOperand MachineOperand::Imm(std::int64_t value) { + return MachineOperand(OperandKind::Imm, -1, value, ""); +} -Operand Operand::Imm(int value) { - return Operand(Kind::Imm, PhysReg::W0, value); +MachineOperand MachineOperand::Block(std::string name) { + return MachineOperand(OperandKind::Block, -1, 0, std::move(name)); } -Operand Operand::FrameIndex(int index) { - return Operand(Kind::FrameIndex, PhysReg::W0, index); +MachineOperand MachineOperand::Symbol(std::string name) { + return MachineOperand(OperandKind::Symbol, -1, 0, std::move(name)); } -MachineInstr::MachineInstr(Opcode opcode, std::vector operands) +MachineInstr::MachineInstr(Opcode opcode, std::vector operands) : opcode_(opcode), operands_(std::move(operands)) {} +bool MachineInstr::IsTerminator() const { + return opcode_ == Opcode::Br || opcode_ == Opcode::CondBr || + opcode_ == Opcode::Ret || opcode_ == Opcode::Unreachable; +} + +std::vector MachineInstr::GetDefs() const { + switch (opcode_) { + case Opcode::Arg: + case Opcode::Copy: + case Opcode::Load: + case Opcode::Lea: + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::Div: + case Opcode::Rem: + case Opcode::And: + case Opcode::Or: + case Opcode::Xor: + case Opcode::Shl: + case Opcode::AShr: + case Opcode::LShr: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: + case Opcode::FNeg: + case Opcode::ICmp: + case Opcode::FCmp: + case Opcode::ZExt: + case Opcode::ItoF: + case Opcode::FtoI: + if (!operands_.empty() && operands_[0].GetKind() == OperandKind::VReg) { + return {operands_[0].GetVReg()}; + } + return {}; + case Opcode::Call: + if (call_return_type_ != ValueType::Void && !operands_.empty() && + operands_[0].GetKind() == OperandKind::VReg) { + return {operands_[0].GetVReg()}; + } + return {}; + case Opcode::Store: + case Opcode::Br: + case Opcode::CondBr: + case Opcode::Ret: + case Opcode::Memset: + case Opcode::Unreachable: + return {}; + } + return {}; +} + +std::vector MachineInstr::GetUses() const { + std::vector uses; + auto push_vreg = [&](const MachineOperand& operand) { + if (operand.GetKind() == OperandKind::VReg) { + uses.push_back(operand.GetVReg()); + } + }; + auto push_addr_uses = [&]() { + if (!has_address_) { + return; + } + if (address_.base_kind == AddrBaseKind::VReg && address_.base_index >= 0) { + uses.push_back(address_.base_index); + } + for (const auto& term : address_.scaled_vregs) { + uses.push_back(term.first); + } + }; + + switch (opcode_) { + case Opcode::Arg: + case Opcode::Br: + case Opcode::Unreachable: + break; + case Opcode::Copy: + case Opcode::ZExt: + case Opcode::ItoF: + case Opcode::FtoI: + case Opcode::FNeg: + if (operands_.size() >= 2) { + push_vreg(operands_[1]); + } + break; + case Opcode::Load: + case Opcode::Lea: + push_addr_uses(); + break; + case Opcode::Store: + if (!operands_.empty()) { + push_vreg(operands_[0]); + } + push_addr_uses(); + break; + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::Div: + case Opcode::Rem: + case Opcode::And: + case Opcode::Or: + case Opcode::Xor: + case Opcode::Shl: + case Opcode::AShr: + case Opcode::LShr: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: + case Opcode::ICmp: + case Opcode::FCmp: + if (operands_.size() >= 2) { + push_vreg(operands_[1]); + } + if (operands_.size() >= 3) { + push_vreg(operands_[2]); + } + break; + case Opcode::CondBr: + if (!operands_.empty()) { + push_vreg(operands_[0]); + } + break; + case Opcode::Call: { + size_t arg_begin = call_return_type_ == ValueType::Void ? 0 : 1; + for (size_t i = arg_begin; i < operands_.size(); ++i) { + push_vreg(operands_[i]); + } + break; + } + case Opcode::Ret: + if (!operands_.empty()) { + push_vreg(operands_[0]); + } + break; + case Opcode::Memset: + if (!operands_.empty()) { + push_vreg(operands_[0]); + } + if (operands_.size() >= 2) { + push_vreg(operands_[1]); + } + push_addr_uses(); + break; + } + return uses; +} + } // namespace mir diff --git a/src/mir/RegAlloc.cpp b/src/mir/RegAlloc.cpp index 5dc5d2b..fd661f4 100644 --- a/src/mir/RegAlloc.cpp +++ b/src/mir/RegAlloc.cpp @@ -1,33 +1,317 @@ #include "mir/MIR.h" +#include +#include #include +#include +#include #include "utils/Log.h" namespace mir { namespace { -bool IsAllowedReg(PhysReg reg) { - switch (reg) { - case PhysReg::W0: - case PhysReg::W8: - case PhysReg::W9: - case PhysReg::X29: - case PhysReg::X30: - case PhysReg::SP: - return true; +struct Interval { + int vreg = -1; + ValueType type = ValueType::Void; + int start = -1; + int end = -1; + Allocation allocation; +}; + +struct BlockInfo { + int start_pos = 0; + int end_pos = 0; + std::vector successors; + std::vector use; + std::vector def; + std::vector live_in; + std::vector live_out; +}; + +std::vector GetAllocatableRegs(ValueType type) { + std::vector regs; + if (IsFPR(type)) { + for (int i = 8; i <= 15; ++i) { + regs.push_back({RegClass::FPR, i}); + } + return regs; + } + for (int i = 19; i <= 28; ++i) { + regs.push_back({RegClass::GPR, i}); + } + return regs; +} + +int CreateSpillSlot(MachineFunction& function, const Interval& interval) { + return function.CreateStackObject(GetValueSize(interval.type), GetValueAlign(interval.type), + StackObjectKind::Spill, + "spill." + std::to_string(interval.vreg)); +} + +void SortActiveByEnd(const std::vector& intervals, + std::vector& active) { + std::sort(active.begin(), active.end(), [&](int lhs, int rhs) { + if (intervals[static_cast(lhs)].end != intervals[static_cast(rhs)].end) { + return intervals[static_cast(lhs)].end < intervals[static_cast(rhs)].end; + } + return lhs < rhs; + }); +} + +void TouchPoint(Interval& interval, int pos) { + if (interval.start < 0 || pos < interval.start) { + interval.start = pos; + } + if (interval.end < pos) { + interval.end = pos; + } +} + +std::vector AnalyzeBlocks(const MachineFunction& function) { + const auto& blocks = function.GetBlocks(); + const int num_blocks = static_cast(blocks.size()); + const int num_vregs = static_cast(function.GetVRegs().size()); + + std::vector infos(static_cast(num_blocks)); + std::vector> block_name_to_index; + block_name_to_index.reserve(blocks.size()); + for (int i = 0; i < num_blocks; ++i) { + block_name_to_index.push_back({blocks[static_cast(i)]->GetName(), i}); + } + + auto find_block_index = [&](const std::string& name) { + auto it = std::find_if(block_name_to_index.begin(), block_name_to_index.end(), + [&](const auto& item) { return item.first == name; }); + if (it == block_name_to_index.end()) { + throw std::runtime_error(FormatError("mir", "unknown basic block label: " + name)); + } + return it->second; + }; + + int position = 0; + for (int block_index = 0; block_index < num_blocks; ++block_index) { + auto& info = infos[static_cast(block_index)]; + info.start_pos = position; + info.use.assign(static_cast(num_vregs), 0); + info.def.assign(static_cast(num_vregs), 0); + info.live_in.assign(static_cast(num_vregs), 0); + info.live_out.assign(static_cast(num_vregs), 0); + + const auto& instructions = blocks[static_cast(block_index)]->GetInstructions(); + for (const auto& inst : instructions) { + for (int use : inst.GetUses()) { + if (use >= 0 && use < num_vregs && !info.def[static_cast(use)]) { + info.use[static_cast(use)] = 1; + } + } + for (int def : inst.GetDefs()) { + if (def >= 0 && def < num_vregs) { + info.def[static_cast(def)] = 1; + } + } + + switch (inst.GetOpcode()) { + case MachineInstr::Opcode::Br: + if (!inst.GetOperands().empty()) { + info.successors.push_back(find_block_index(inst.GetOperands()[0].GetText())); + } + break; + case MachineInstr::Opcode::CondBr: + if (inst.GetOperands().size() >= 3) { + info.successors.push_back(find_block_index(inst.GetOperands()[1].GetText())); + info.successors.push_back(find_block_index(inst.GetOperands()[2].GetText())); + } + break; + default: + break; + } + + position += 2; + } + std::sort(info.successors.begin(), info.successors.end()); + info.successors.erase(std::unique(info.successors.begin(), info.successors.end()), + info.successors.end()); + info.end_pos = position; + } + + bool changed = true; + while (changed) { + changed = false; + for (int block_index = num_blocks - 1; block_index >= 0; --block_index) { + auto& info = infos[static_cast(block_index)]; + std::vector next_out(static_cast(num_vregs), 0); + std::vector next_in(static_cast(num_vregs), 0); + + for (int succ : info.successors) { + const auto& succ_in = infos[static_cast(succ)].live_in; + for (int vreg = 0; vreg < num_vregs; ++vreg) { + next_out[static_cast(vreg)] |= succ_in[static_cast(vreg)]; + } + } + + for (int vreg = 0; vreg < num_vregs; ++vreg) { + const size_t idx = static_cast(vreg); + next_in[idx] = info.use[idx] | (next_out[idx] & static_cast(!info.def[idx])); + } + + if (next_out != info.live_out || next_in != info.live_in) { + changed = true; + info.live_out = std::move(next_out); + info.live_in = std::move(next_in); + } + } + } + + return infos; +} + +std::vector BuildIntervals(const MachineFunction& function) { + std::vector intervals; + intervals.reserve(function.GetVRegs().size()); + for (const auto& info : function.GetVRegs()) { + intervals.push_back({info.id, info.type, -1, -1, {}}); + } + + const auto block_infos = AnalyzeBlocks(function); + const auto& blocks = function.GetBlocks(); + + int position = 0; + for (const auto& block : blocks) { + for (const auto& inst : block->GetInstructions()) { + for (int def : inst.GetDefs()) { + TouchPoint(intervals[static_cast(def)], position); + } + for (int use : inst.GetUses()) { + TouchPoint(intervals[static_cast(use)], position); + } + position += 2; + } + } + + for (size_t block_index = 0; block_index < block_infos.size(); ++block_index) { + const auto& info = block_infos[block_index]; + for (size_t vreg = 0; vreg < intervals.size(); ++vreg) { + if (!info.live_in[vreg] && !info.live_out[vreg]) { + continue; + } + if (intervals[vreg].start < 0 || info.start_pos < intervals[vreg].start) { + intervals[vreg].start = info.start_pos; + } + if (intervals[vreg].end < info.end_pos) { + intervals[vreg].end = info.end_pos; + } + } + } + + for (auto& interval : intervals) { + if (interval.start < 0) { + interval.start = 0; + interval.end = 0; + } + if (interval.end < interval.start) { + interval.end = interval.start; + } + } + + std::sort(intervals.begin(), intervals.end(), [](const Interval& lhs, const Interval& rhs) { + if (lhs.start != rhs.start) { + return lhs.start < rhs.start; + } + return lhs.vreg < rhs.vreg; + }); + return intervals; +} + +void RunLinearScanForClass(MachineFunction& function, std::vector& intervals, + RegClass reg_class) { + std::vector free_regs; + if (reg_class == RegClass::FPR) { + free_regs = GetAllocatableRegs(ValueType::F32); + } else { + free_regs = GetAllocatableRegs(ValueType::I32); + } + + std::vector active; + + auto expire_old = [&](const Interval& current) { + SortActiveByEnd(intervals, active); + std::vector next_active; + for (int index : active) { + const auto& interval = intervals[static_cast(index)]; + if (interval.end >= current.start) { + next_active.push_back(index); + continue; + } + if (interval.allocation.kind == Allocation::Kind::PhysReg) { + free_regs.push_back(interval.allocation.phys); + } + } + active = std::move(next_active); + }; + + auto spill_at_interval = [&](int current_index) { + SortActiveByEnd(intervals, active); + int last_index = active.back(); + auto& current = intervals[static_cast(current_index)]; + auto& last = intervals[static_cast(last_index)]; + + if (last.end > current.end) { + current.allocation = last.allocation; + last.allocation.kind = Allocation::Kind::Spill; + last.allocation.stack_object = CreateSpillSlot(function, last); + active.pop_back(); + active.push_back(current_index); + SortActiveByEnd(intervals, active); + } else { + current.allocation.kind = Allocation::Kind::Spill; + current.allocation.stack_object = CreateSpillSlot(function, current); + } + }; + + for (size_t i = 0; i < intervals.size(); ++i) { + auto& interval = intervals[i]; + const bool is_same_class = IsFPR(interval.type) ? reg_class == RegClass::FPR + : reg_class == RegClass::GPR; + if (!is_same_class || interval.type == ValueType::Void) { + continue; + } + + expire_old(interval); + if (free_regs.empty()) { + if (active.empty()) { + interval.allocation.kind = Allocation::Kind::Spill; + interval.allocation.stack_object = CreateSpillSlot(function, interval); + } else { + spill_at_interval(static_cast(i)); + } + continue; + } + + interval.allocation.kind = Allocation::Kind::PhysReg; + interval.allocation.phys = free_regs.back(); + free_regs.pop_back(); + active.push_back(static_cast(i)); + SortActiveByEnd(intervals, active); } - return false; } } // namespace -void RunRegAlloc(MachineFunction& function) { - for (const auto& inst : function.GetEntry().GetInstructions()) { - for (const auto& operand : inst.GetOperands()) { - if (operand.GetKind() == Operand::Kind::Reg && - !IsAllowedReg(operand.GetReg())) { - throw std::runtime_error(FormatError("mir", "寄存器分配失败")); +void RunRegAlloc(MachineModule& module) { + for (auto& function : module.GetFunctions()) { + auto intervals = BuildIntervals(*function); + RunLinearScanForClass(*function, intervals, RegClass::GPR); + RunLinearScanForClass(*function, intervals, RegClass::FPR); + + for (const auto& interval : intervals) { + function->SetAllocation(interval.vreg, interval.allocation); + if (interval.allocation.kind == Allocation::Kind::PhysReg) { + if (interval.allocation.phys.reg_class == RegClass::GPR) { + function->AddUsedCalleeSavedGPR(interval.allocation.phys.index); + } else { + function->AddUsedCalleeSavedFPR(interval.allocation.phys.index); + } } } } diff --git a/src/mir/Register.cpp b/src/mir/Register.cpp index 7530470..a9f3afd 100644 --- a/src/mir/Register.cpp +++ b/src/mir/Register.cpp @@ -2,26 +2,75 @@ #include -#include "utils/Log.h" - namespace mir { +namespace { + +const char* kWRegNames[] = { + "w0", "w1", "w2", "w3", "w4", "w5", "w6", "w7", + "w8", "w9", "w10", "w11", "w12", "w13", "w14", "w15", + "w16", "w17", "w18", "w19", "w20", "w21", "w22", "w23", + "w24", "w25", "w26", "w27", "w28", "w29", "w30"}; +const char* kXRegNames[] = { + "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", + "x16", "x17", "x18", "x19", "x20", "x21", "x22", "x23", + "x24", "x25", "x26", "x27", "x28", "x29", "x30"}; +const char* kSRegNames[] = { + "s0", "s1", "s2", "s3", "s4", "s5", "s6", "s7", + "s8", "s9", "s10", "s11", "s12", "s13", "s14", "s15", + "s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", + "s24", "s25", "s26", "s27", "s28", "s29", "s30", "s31"}; + +} // namespace + +bool IsGPR(ValueType type) { + return type == ValueType::I1 || type == ValueType::I32 || type == ValueType::Ptr; +} + +bool IsFPR(ValueType type) { return type == ValueType::F32; } + +int GetValueSize(ValueType type) { + switch (type) { + case ValueType::Void: + return 0; + case ValueType::I1: + case ValueType::I32: + case ValueType::F32: + return 4; + case ValueType::Ptr: + return 8; + } + return 0; +} -const char* PhysRegName(PhysReg reg) { - switch (reg) { - case PhysReg::W0: - return "w0"; - case PhysReg::W8: - return "w8"; - case PhysReg::W9: - return "w9"; - case PhysReg::X29: - return "x29"; - case PhysReg::X30: - return "x30"; - case PhysReg::SP: - return "sp"; +int GetValueAlign(ValueType type) { + switch (type) { + case ValueType::Void: + return 1; + case ValueType::Ptr: + return 8; + case ValueType::I1: + case ValueType::I32: + case ValueType::F32: + return 4; + } + return 1; +} + +const char* GetPhysRegName(PhysReg reg, ValueType type) { + if (!reg.IsValid()) { + throw std::runtime_error("invalid physical register"); + } + if (reg.reg_class == RegClass::FPR) { + if (reg.index < 0 || reg.index >= 32) { + throw std::runtime_error("float register index out of range"); + } + return kSRegNames[reg.index]; + } + if (reg.index < 0 || reg.index >= 31) { + throw std::runtime_error("gpr register index out of range"); } - throw std::runtime_error(FormatError("mir", "未知物理寄存器")); + return type == ValueType::Ptr ? kXRegNames[reg.index] : kWRegNames[reg.index]; } } // namespace mir