diff --git a/doc/competition-optimization-round.md b/doc/competition-optimization-round.md new file mode 100644 index 0000000..433e1c4 --- /dev/null +++ b/doc/competition-optimization-round.md @@ -0,0 +1,114 @@ +# 比赛性能优化记录 + +日期:2026-04-27 + +## 本轮已落地 + +### 1. FFT:模乘/模幂 idiom lowering + +目标用例:`fft1`、`fft0`。 + +已实现: + +- 在 MIR 增加 `ModMul`,识别递归 `multiply(a, b)` 的模乘 idiom,lower 成 `smull + sdiv + msub`,消除 `multiply` 递归调用。 +- 在 MIR 增加 `ModPow`,识别递归 `power(a, b)` 的快速幂 idiom,lower 成后端内联循环,消除 `power` 递归调用。 +- `fft1` 汇编中 `bl multiply` / `bl power` 数量降为 0,仅保留算法本身的 `fft` 递归。 + +主要位置: + +- `include/mir/MIR.h` +- `src/mir/Lowering.cpp` +- `src/mir/AsmPrinter.cpp` +- `src/mir/MIRInstr.cpp` +- `src/mir/passes/Peephole.cpp` +- `src/mir/passes/SpillReduction.cpp` + +验证结果: + +- `fft1`:输出匹配,qemu 本地约 `0.42s`。 +- `fft0`:输出匹配,qemu 本地约 `0.23s`。 + +### 2. 03_sort2:power-of-two digit extraction + +目标用例:`03_sort2`。 + +已实现: + +- 识别 `while (i < pos) num = num / 16; return num % 16;` 这类 power-of-two radix digit helper。 +- IR 内联器会跳过该 helper,避免把小函数展开成大量循环。 +- 后端用 `DigitExtractPow2` 直接 lower 成移位、带符号除法修正和取余序列,消除 `bl getNumPos`。 +- 修复 GVN/CSE 的常量等价键,避免等值常量因对象地址不同而错过跨块消冗余。 + +主要位置: + +- `src/ir/passes/MathIdiomUtils.h` +- `src/ir/passes/Inline.cpp` +- `src/ir/passes/GVN.cpp` +- `src/ir/passes/CSE.cpp` +- `src/mir/Lowering.cpp` +- `src/mir/AsmPrinter.cpp` + +验证结果: + +- `03_sort2`:输出匹配,qemu 本地约 `19.56s`。 +- 对比此前表中 `31.317s`,该项收益明显。 + +### 3. matmul / 2025-MYO-20:标量基础优化 + +目标用例:`matmul1/2/3`、`2025-MYO-20`。 + +已实现: + +- 新增 IR `ArithmeticSimplify`,把 `% power_of_two == 0` 化成 bit-test,例如 `x % 2 == 0` 变为 `(x & 1) == 0`。 +- 增强 `LoadStoreElim`,允许安全的跨块 load forwarding,解决 `if` 前已加载、then 块重复加载的问题。 +- 修复 `DominatorTree` 的 immediate dominator 判定方向,恢复跨块 GVN/LICM/LSE 的基础支配关系。 +- `matmul2` 的内层核心从重复 load + 重复 mul 变为复用同一个乘积。 + +主要位置: + +- `src/ir/passes/ArithmeticSimplify.cpp` +- `src/ir/passes/LoadStoreElim.cpp` +- `src/ir/analysis/DominatorTree.cpp` +- `src/ir/passes/PassManager.cpp` + +验证结果: + +- `matmul2`:输出匹配,qemu 本地约 `7.09s`。 +- 对比此前表中 `8.407s`,已有收益。 + +尚未完成: + +- 真正的 NEON 向量化、矩阵 loop interchange/blocking 还没有落地。当前 MIR 没有 SIMD value type、NEON 寄存器类、向量 load/store、向量 arithmetic,也没有稳定的 loop-nest interchange/blocking 框架。硬塞样例级重写风险过高,不适合作为通用比赛编译器优化。 + +### 4. gameoflife:stencil 前置优化 + +目标用例:`gameoflife-*`。 + +已实现: + +- 通过支配树修复和跨块 load forwarding,让 stencil 里的重复地址计算和重复 load 有更多被 GVN/LSE 消除的机会。 + +验证结果: + +- `gameoflife-oscillator`:输出匹配,qemu 本地约 `8.82s`。 + +尚未完成: + +- 真正的 stencil NEON/行缓存优化还未落地。需要先补 SIMD MIR 和更明确的二维数组滑窗识别,否则容易做成样例特化。 + +### 5. 65_color + +该用例加速比难看但绝对损失很小,本轮未优先处理。后续应只在大头用例收敛后再看。 + +## 本轮验证 + +- `cmake --build build -j`:通过。 +- 单例 qemu 对比均做了 stdout + exit code 的规范化 diff。 +- 未运行全量测试,避免耗时过长。 + +## 下一步优先级 + +1. 为 MIR 增加 NEON value type、向量寄存器类、vector load/store 和基础 i32x4/f32x4 arithmetic。 +2. 在 IR 层补 loop-nest 识别,先做安全的矩阵 loop interchange,再考虑 blocking。 +3. 对 `gameoflife` 做通用 stencil matcher,先生成 scalar row-cache,再接 NEON。 +4. 对 `2025-MYO-20` 单独用 `scripts/analyze_case.sh` 保存 IR/ASM,与 GCC 汇编对照后决定是否值得做 matmul micro-kernel lowering。 diff --git a/doc/lab3-latest-test-analysis.md b/doc/lab3-latest-test-analysis.md new file mode 100644 index 0000000..59253be --- /dev/null +++ b/doc/lab3-latest-test-analysis.md @@ -0,0 +1,104 @@ +# Lab3 最新测试结果分析 + +日期:2026-04-29 + +## 数据源 + +- 我方测试日志:`output/logs/lab3/lab3_20260429_192016/whole.log` +- 我方计时表:`output/logs/lab3/lab3_20260429_192016/timing.tsv` +- GCC baseline:`output/baseline/gcc_timing.tsv` + +本轮我方结果: + +```text +summary: 214 PASS / 0 FAIL / total 214 +build elapsed: 0.72401s +validation elapsed: 632.18659s +total elapsed: 632.91658s +``` + +GCC baseline 结果: + +```text +Summary: 214 DONE / 0 SKIP (cached) / 0 FAIL / total 214 +Total elapsed : 484.24024s +Timing TSV : output/baseline/gcc_timing.tsv (213 entries) +``` + +## 总体结论 + +本轮功能正确性已经通过,`214/214 PASS`。但性能口径需要分开看: + +| 口径 | 我方 | GCC baseline | 差值 | +| --- | ---: | ---: | ---: | +| 脚本整轮墙钟时间 | 632.91658s | 484.24024s | +148.67634s | +| 程序运行时间总和 | 485.95009s | 425.55356s | +60.39653s | + +程序运行时间口径下,当前总体 speedup 为: + +```text +425.55356 / 485.95009 = 0.8757x +``` + +也就是说,生成代码运行时间目前整体比 GCC baseline 慢约 `60.40s`。脚本整轮慢约 `148.68s`,其中额外约 `88s` 来自我方逐样例编译、汇编、链接、校验等流程开销,不完全等价于生成代码性能。 + +补充说明:`timing.tsv` 有 214 行,当前 `gcc_timing.tsv` 有 213 行;额外项是 `class_test_case/functional/05_arr_defn4`。严格汇总时按当前 baseline 文件可精确匹配的 213 条计算,上表采用这个口径。 + +## 最大亏损样例 + +这些样例是当前最值得优先优化的对象,按“我方运行时间 - GCC 运行时间”排序: + +| 样例 | 我方 | GCC | 慢多少 | +| --- | ---: | ---: | ---: | +| `class_test_case/performance/2025-MYO-20` | 54.01749s | 29.75174s | +24.26575s | +| `test_case/h_performance/h-14-01` | 33.94136s | 26.19856s | +7.74280s | +| `test_case/h_performance/h-11-01` | 60.07281s | 52.58051s | +7.49230s | +| `test_case/h_performance/h-1-01` | 25.46834s | 20.48401s | +4.98433s | +| `test_case/h_performance/h-12-01` | 20.04854s | 15.68926s | +4.35928s | +| `test_case/h_performance/matmul3` | 7.04411s | 2.87407s | +4.17004s | +| `test_case/h_performance/matmul1` | 7.02077s | 2.86589s | +4.15488s | +| `test_case/h_performance/matmul2` | 6.92980s | 2.92273s | +4.00707s | +| `test_case/h_performance/gameoflife-gosper` | 10.77375s | 7.53120s | +3.24255s | +| `test_case/h_performance/gameoflife-oscillator` | 9.72381s | 6.73087s | +2.99294s | + +主要问题集中在四类: + +- `2025-MYO-20` 是最大单点亏损,单独慢约 `24.27s`,应作为第一分析对象。 +- `matmul1/2/3` 合计慢约 `12.33s`,说明矩阵类内核还缺少有效的 NEON、地址递推、缓存友好变换或循环分块。 +- `gameoflife*` 合计慢约 `11s+`,说明 stencil 型访问还没有做到行缓存、重复 load 消除或向量化。 +- `h-14-01`、`h-11-01`、`h-1-01`、`h-12-01` 总体占比较大,需要逐个看 IR 和汇编,判断是中端 load/store 没消掉,还是后端 spill/address 质量差。 + +## 最大收益样例 + +这些样例说明当前已有优化确实生效: + +| 样例 | 我方 | GCC | 快多少 | +| --- | ---: | ---: | ---: | +| `test_case/h_performance/fft1` | 0.42533s | 6.63117s | -6.20584s | +| `class_test_case/performance/fft0` | 0.20593s | 3.13259s | -2.92666s | +| `test_case/h_performance/fft0` | 0.21674s | 3.12871s | -2.91198s | +| `test_case/h_performance/h-2-03` | 16.49539s | 18.95248s | -2.45709s | +| `test_case/h_performance/03_sort2` | 20.81900s | 22.92280s | -2.10380s | +| `test_case/h_performance/h-2-02` | 13.54233s | 15.50163s | -1.95930s | +| `test_case/h_performance/h-4-03` | 5.81272s | 7.71534s | -1.90262s | +| `test_case/h_performance/h-2-01` | 13.92343s | 15.55799s | -1.63456s | +| `class_test_case/performance/large_loop_array_2` | 11.65712s | 13.08078s | -1.42366s | +| `test_case/h_performance/if-combine3` | 14.04854s | 15.40252s | -1.35398s | + +关键判断: + +- `fft0/fft1` 已明显超过 GCC,说明模乘/模幂 idiom lowering 的方向正确。 +- `03_sort2` 已从明显慢项变成快项,说明 power-of-two digit extract、常数除法/取模 lowering 已经有实际收益。 +- `h-2-*`、`h-4-*`、`if-combine*` 的收益说明中端 GVN/LSE/LICM 和部分后端 peephole 已经在某些结构上命中。 + +## 当前优化优先级 + +1. 优先分析 `2025-MYO-20`。这个样例单点亏损最大,应使用 `scripts/analyze_case.sh` 保存 IR 和汇编,先确认瓶颈是循环结构、内存访问、调用、spill 还是地址计算。 +2. 继续做矩阵类内核优化。`matmul1/2/3` 的差距很集中,下一步应优先看循环层次、地址递推、寄存器复用和保守 NEON,而不是继续做零散 peephole。 +3. 针对 `gameoflife*` 做 stencil 优化。重点是行缓存、邻域 load 复用、局部数组 promotion,以及可证明安全的短向量化。 +4. 对 `h-14-01`、`h-11-01`、`h-1-01`、`h-12-01` 做专项拆解。这些样例总时间大,需要逐个确认是否存在尾递归、循环不变量 load、跨块冗余 load/store、或后端 spill 过多。 +5. `65_color` 和 `29_long_line` 比例难看,但绝对亏损小。它们不是性能分第一优先级;`29_long_line` 更应该作为编译耗时风险样例关注。 + +## 结论 + +当前编译器已经能完整通过最新 Lab3 回归,并且在 `fft`、`03_sort2`、部分 `h-2/h-4/if-combine` 样例上体现出明显优化收益。但从比赛性能角度看,总体仍比 GCC baseline 慢约 `60.40s`,主要差距来自 `2025-MYO-20`、矩阵计算、gameoflife stencil 以及若干大规模 h_performance 样例。下一轮优化应围绕这些大头做专项分析,而不是优先处理低绝对耗时的小比例样例。 diff --git a/include/ir/PassManager.h b/include/ir/PassManager.h index bdfa2ef..75f3564 100644 --- a/include/ir/PassManager.h +++ b/include/ir/PassManager.h @@ -9,6 +9,7 @@ bool RunConstFold(Module& module); bool RunConstProp(Module& module); bool RunFunctionInlining(Module& module); bool RunTailRecursionElim(Module& module); +bool RunArithmeticSimplify(Module& module); bool RunCSE(Module& module); bool RunGVN(Module& module); bool RunLoadStoreElim(Module& module); diff --git a/include/mir/MIR.h b/include/mir/MIR.h index 3975475..0910165 100644 --- a/include/mir/MIR.h +++ b/include/mir/MIR.h @@ -110,10 +110,13 @@ class MachineInstr { Lea, Add, Sub, - Mul, - Div, - Rem, - And, + Mul, + Div, + Rem, + ModMul, + ModPow, + DigitExtractPow2, + And, Or, Xor, Shl, diff --git a/src/ir/analysis/DominatorTree.cpp b/src/ir/analysis/DominatorTree.cpp index 34771e6..151e044 100644 --- a/src/ir/analysis/DominatorTree.cpp +++ b/src/ir/analysis/DominatorTree.cpp @@ -82,27 +82,27 @@ void DominatorTree::Recalculate() { } } + std::vector dom_depth(num_blocks, 0); + for (std::size_t i = 0; i < num_blocks; ++i) { + for (std::size_t candidate = 0; candidate < num_blocks; ++candidate) { + if (dominates_[i][candidate]) { + ++dom_depth[i]; + } + } + } + for (std::size_t i = 1; i < num_blocks; ++i) { auto* block = reverse_post_order_[i]; BasicBlock* idom = nullptr; + std::size_t best_depth = 0; for (std::size_t candidate = 0; candidate < num_blocks; ++candidate) { if (candidate == i || !dominates_[i][candidate]) { continue; } auto* candidate_block = reverse_post_order_[candidate]; - bool immediate = true; - for (std::size_t other = 0; other < num_blocks; ++other) { - if (other == i || other == candidate || !dominates_[i][other]) { - continue; - } - if (Dominates(reverse_post_order_[other], candidate_block)) { - immediate = false; - break; - } - } - if (immediate) { + if (idom == nullptr || dom_depth[candidate] > best_depth) { idom = candidate_block; - break; + best_depth = dom_depth[candidate]; } } immediate_dominator_.emplace(block, idom); diff --git a/src/ir/passes/ArithmeticSimplify.cpp b/src/ir/passes/ArithmeticSimplify.cpp new file mode 100644 index 0000000..bf145e3 --- /dev/null +++ b/src/ir/passes/ArithmeticSimplify.cpp @@ -0,0 +1,137 @@ +#include "ir/PassManager.h" + +#include "ir/IR.h" +#include "LoopPassUtils.h" + +#include +#include +#include +#include + +namespace ir { +namespace { + +bool IsPowerOfTwoPositive(int value) { + return value > 0 && (value & (value - 1)) == 0; +} + +std::size_t FindInstructionIndex(BasicBlock* block, Instruction* inst) { + if (!block || !inst) { + return 0; + } + auto& instructions = block->GetInstructions(); + for (std::size_t i = 0; i < instructions.size(); ++i) { + if (instructions[i].get() == inst) { + return i; + } + } + return instructions.size(); +} + +bool IsZero(Value* value) { + if (auto* ci = dyncast(value)) { + return ci->GetValue() == 0; + } + if (auto* cb = dyncast(value)) { + return !cb->GetValue(); + } + return false; +} + +Value* OtherCompareOperand(BinaryInst* cmp, Value* value) { + if (!cmp || cmp->GetNumOperands() != 2) { + return nullptr; + } + if (cmp->GetLhs() == value) { + return cmp->GetRhs(); + } + if (cmp->GetRhs() == value) { + return cmp->GetLhs(); + } + return nullptr; +} + +bool SimplifyPowerOfTwoRemTests(Function& function) { + bool changed = false; + std::vector dead_rems; + + for (const auto& block_ptr : function.GetBlocks()) { + auto* block = block_ptr.get(); + if (!block) { + continue; + } + for (const auto& inst_ptr : block->GetInstructions()) { + auto* rem = dyncast(inst_ptr.get()); + if (!rem || rem->GetOpcode() != Opcode::Rem) { + continue; + } + auto* divisor = dyncast(rem->GetRhs()); + if (!divisor || !IsPowerOfTwoPositive(divisor->GetValue())) { + continue; + } + + const int mask_value = divisor->GetValue() - 1; + if (mask_value == 0) { + rem->ReplaceAllUsesWith(looputils::ConstInt(0)); + dead_rems.push_back(rem); + changed = true; + continue; + } + + std::vector compare_uses; + bool all_uses_are_zero_tests = !rem->GetUses().empty(); + for (const auto& use : rem->GetUses()) { + auto* cmp = dyncast(dynamic_cast(use.GetUser())); + if (!cmp || (cmp->GetOpcode() != Opcode::ICmpEQ && + cmp->GetOpcode() != Opcode::ICmpNE) || + !IsZero(OtherCompareOperand(cmp, rem))) { + all_uses_are_zero_tests = false; + break; + } + compare_uses.push_back(cmp); + } + if (!all_uses_are_zero_tests || compare_uses.empty()) { + continue; + } + + const auto insert_index = FindInstructionIndex(block, rem) + 1; + auto* masked = block->Insert( + insert_index, Opcode::And, Type::GetInt32Type(), rem->GetLhs(), + looputils::ConstInt(mask_value), nullptr, + looputils::NextSyntheticName(function, "pow2.mask.")); + + for (auto* cmp : compare_uses) { + if (cmp->GetLhs() == rem) { + cmp->SetOperand(0, masked); + } + if (cmp->GetRhs() == rem) { + cmp->SetOperand(1, masked); + } + } + dead_rems.push_back(rem); + changed = true; + } + } + + for (auto* rem : dead_rems) { + if (rem->GetUses().empty() && rem->GetParent()) { + rem->GetParent()->EraseInstruction(rem); + } + } + return changed; +} + +} // namespace + +bool RunArithmeticSimplify(Module& module) { + bool changed = false; + for (const auto& function : module.GetFunctions()) { + if (!function || function->IsExternal()) { + continue; + } + changed |= SimplifyPowerOfTwoRemTests(*function); + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/CMakeLists.txt b/src/ir/passes/CMakeLists.txt index 1b89346..93cc10c 100644 --- a/src/ir/passes/CMakeLists.txt +++ b/src/ir/passes/CMakeLists.txt @@ -4,6 +4,7 @@ add_library(ir_passes STATIC ConstFold.cpp ConstProp.cpp TailRecursionElim.cpp + ArithmeticSimplify.cpp Inline.cpp CSE.cpp GVN.cpp diff --git a/src/ir/passes/CSE.cpp b/src/ir/passes/CSE.cpp index c7c02f3..0480402 100644 --- a/src/ir/passes/CSE.cpp +++ b/src/ir/passes/CSE.cpp @@ -14,7 +14,15 @@ namespace { struct ExprKey { Opcode opcode = Opcode::Add; - std::vector operands; + struct OperandKey { + int kind = 0; + std::intptr_t value = 0; + + bool operator==(const OperandKey& rhs) const { + return kind == rhs.kind && value == rhs.value; + } + }; + std::vector operands; bool operator==(const ExprKey& rhs) const { return opcode == rhs.opcode && operands == rhs.operands; @@ -25,12 +33,26 @@ struct ExprKeyHash { std::size_t operator()(const ExprKey& key) const { std::size_t h = static_cast(key.opcode); for (auto operand : key.operands) { - h ^= std::hash{}(operand) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(operand.kind) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(operand.value) + 0x9e3779b9 + (h << 6) + (h >> 2); } return h; } }; +ExprKey::OperandKey BuildOperandKey(Value* value) { + if (auto* ci = dyncast(value)) { + return {1, ci->GetValue()}; + } + if (auto* cb = dyncast(value)) { + return {2, cb->GetValue() ? 1 : 0}; + } + if (auto* cf = dyncast(value)) { + return {3, static_cast(passutils::FloatBits(cf->GetValue()))}; + } + return {0, reinterpret_cast(value)}; +} + bool IsSupportedCSEInstruction(Instruction* inst) { if (!inst || inst->IsVoid()) { return false; @@ -81,11 +103,12 @@ ExprKey BuildExprKey(Instruction* inst) { key.opcode = inst->GetOpcode(); key.operands.reserve(inst->GetNumOperands()); for (size_t i = 0; i < inst->GetNumOperands(); ++i) { - key.operands.push_back( - reinterpret_cast(inst->GetOperand(i))); + key.operands.push_back(BuildOperandKey(inst->GetOperand(i))); } if (inst->GetNumOperands() == 2 && passutils::IsCommutativeOpcode(inst->GetOpcode()) && - key.operands[1] < key.operands[0]) { + (key.operands[1].kind < key.operands[0].kind || + (key.operands[1].kind == key.operands[0].kind && + key.operands[1].value < key.operands[0].value))) { std::swap(key.operands[0], key.operands[1]); } return key; diff --git a/src/ir/passes/GVN.cpp b/src/ir/passes/GVN.cpp index c5568a7..4c7f7ba 100644 --- a/src/ir/passes/GVN.cpp +++ b/src/ir/passes/GVN.cpp @@ -17,7 +17,15 @@ struct ExprKey { Opcode opcode = Opcode::Add; std::uintptr_t result_type = 0; std::uintptr_t aux_type = 0; - std::vector operands; + struct OperandKey { + int kind = 0; + std::intptr_t value = 0; + + bool operator==(const OperandKey& rhs) const { + return kind == rhs.kind && value == rhs.value; + } + }; + std::vector operands; bool operator==(const ExprKey& rhs) const { return opcode == rhs.opcode && result_type == rhs.result_type && @@ -33,12 +41,26 @@ struct ExprKeyHash { h ^= std::hash{}(key.aux_type) + 0x9e3779b9 + (h << 6) + (h >> 2); for (auto operand : key.operands) { - h ^= std::hash{}(operand) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(operand.kind) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(operand.value) + 0x9e3779b9 + (h << 6) + (h >> 2); } return h; } }; +ExprKey::OperandKey BuildOperandKey(Value* value) { + if (auto* ci = dyncast(value)) { + return {1, ci->GetValue()}; + } + if (auto* cb = dyncast(value)) { + return {2, cb->GetValue() ? 1 : 0}; + } + if (auto* cf = dyncast(value)) { + return {3, static_cast(passutils::FloatBits(cf->GetValue()))}; + } + return {0, reinterpret_cast(value)}; +} + struct ScopedExpr { ExprKey key; Value* previous = nullptr; @@ -103,12 +125,13 @@ ExprKey BuildExprKey(Instruction* inst) { } key.operands.reserve(inst->GetNumOperands()); for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) { - key.operands.push_back( - reinterpret_cast(inst->GetOperand(i))); + key.operands.push_back(BuildOperandKey(inst->GetOperand(i))); } if (inst->GetNumOperands() == 2 && passutils::IsCommutativeOpcode(inst->GetOpcode()) && - key.operands[1] < key.operands[0]) { + (key.operands[1].kind < key.operands[0].kind || + (key.operands[1].kind == key.operands[0].kind && + key.operands[1].value < key.operands[0].value))) { std::swap(key.operands[0], key.operands[1]); } return key; diff --git a/src/ir/passes/Inline.cpp b/src/ir/passes/Inline.cpp index b088ddf..21fbe5d 100644 --- a/src/ir/passes/Inline.cpp +++ b/src/ir/passes/Inline.cpp @@ -176,6 +176,9 @@ bool ShouldInlineCallSite(const Function& caller, const CallInst& call, if (mathidiom::IsToleranceNewtonSqrtShape(*callee)) { return false; } + if (mathidiom::IsPow2DigitExtractShape(*callee)) { + return false; + } if (callee_info.has_control_flow && callee_info.has_nested_call) { return false; } diff --git a/src/ir/passes/LoadStoreElim.cpp b/src/ir/passes/LoadStoreElim.cpp index 6caa64f..c04be01 100644 --- a/src/ir/passes/LoadStoreElim.cpp +++ b/src/ir/passes/LoadStoreElim.cpp @@ -96,6 +96,10 @@ void SimulateInstruction(const memutils::EscapeSummary& escapes, Instruction* in memutils::AddressKey key; if (!memutils::BuildExactAddressKey(load->GetPtr(), &escapes, key)) { state.clear(); + return; + } + if (state.find(key) == state.end()) { + state[key] = {load}; } return; } @@ -194,9 +198,9 @@ bool OptimizeBlock( changed = true; continue; } - // Keep block-local load reuse, but do not expose load results to cross-block - // dataflow because the defining load itself may be removed later. - state[key] = {load}; + if (state.find(key) == state.end()) { + state[key] = {load}; + } continue; } diff --git a/src/ir/passes/MathIdiomUtils.h b/src/ir/passes/MathIdiomUtils.h index 3aee4de..62a1833 100644 --- a/src/ir/passes/MathIdiomUtils.h +++ b/src/ir/passes/MathIdiomUtils.h @@ -2,6 +2,7 @@ #include "ir/IR.h" +#include #include #include #include @@ -88,6 +89,149 @@ inline bool HasBackedgeLikeBranch(const Function& function) { return false; } +inline bool IsPowerOfTwoPositive(int value) { + return value > 0 && (value & (value - 1)) == 0; +} + +inline int Log2Exact(int value) { + int shift = 0; + while (value > 1) { + value >>= 1; + ++shift; + } + return shift; +} + +inline bool DependsOnValueImpl(Value* value, Value* needle, int depth, + std::unordered_set& visiting) { + if (value == needle) { + return true; + } + if (value == nullptr || depth <= 0 || !visiting.insert(value).second) { + return false; + } + auto* inst = dyncast(value); + if (inst == nullptr) { + return false; + } + for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) { + if (DependsOnValueImpl(inst->GetOperand(i), needle, depth - 1, visiting)) { + return true; + } + } + return false; +} + +inline bool DependsOnValue(Value* value, Value* needle, int depth = 12) { + std::unordered_set visiting; + return DependsOnValueImpl(value, needle, depth, visiting); +} + +// Recognize the radix-digit helper: +// while (i < pos) num = num / C; +// return num % C; +// for power-of-two C >= 4. Lowering replaces calls with a straight-line +// shift/remainder sequence, which is much cheaper than inlining the loop at +// every call site in radix-sort kernels. +inline bool IsPow2DigitExtractShape(const Function& function, + int* base_shift_out = nullptr) { + if (base_shift_out != nullptr) { + *base_shift_out = 0; + } + if (function.IsExternal() || function.GetReturnType() == nullptr || + !function.GetReturnType()->IsInt32() || function.GetArguments().size() != 2 || + !function.GetArgument(0)->GetType()->IsInt32() || + !function.GetArgument(1)->GetType()->IsInt32() || + !HasBackedgeLikeBranch(function)) { + return false; + } + + auto* num_arg = function.GetArgument(0); + auto* pos_arg = function.GetArgument(1); + int divisor = 0; + int div_count = 0; + int rem_count = 0; + bool return_is_rem = false; + bool divisor_chain_uses_num = false; + bool compare_uses_pos = false; + + for (const auto& block : function.GetBlocks()) { + for (const auto& inst_ptr : block->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (dyncast(inst) || dyncast(inst) || + dyncast(inst) || dyncast(inst) || + dyncast(inst) || dyncast(inst) || + dyncast(inst)) { + return false; + } + + if (auto* ret = dyncast(inst)) { + auto* returned = ret->HasReturnValue() ? ret->GetReturnValue() : nullptr; + auto* rem = dyncast(returned); + auto* rhs = rem == nullptr ? nullptr : dyncast(rem->GetRhs()); + if (rem == nullptr || rem->GetOpcode() != Opcode::Rem || rhs == nullptr || + !IsPowerOfTwoPositive(rhs->GetValue()) || rhs->GetValue() < 4) { + return false; + } + if (divisor == 0) { + divisor = rhs->GetValue(); + } else if (divisor != rhs->GetValue()) { + return false; + } + return_is_rem = true; + continue; + } + + auto* bin = dyncast(inst); + if (!bin) { + continue; + } + + if (bin->GetOpcode() == Opcode::Div || bin->GetOpcode() == Opcode::Rem) { + auto* rhs = dyncast(bin->GetRhs()); + if (rhs == nullptr || !IsPowerOfTwoPositive(rhs->GetValue()) || + rhs->GetValue() < 4) { + return false; + } + if (divisor == 0) { + divisor = rhs->GetValue(); + } else if (divisor != rhs->GetValue()) { + return false; + } + if (bin->GetOpcode() == Opcode::Div) { + ++div_count; + } else { + ++rem_count; + } + divisor_chain_uses_num |= DependsOnValue(bin->GetLhs(), num_arg); + } + + switch (bin->GetOpcode()) { + case Opcode::ICmpEQ: + case Opcode::ICmpNE: + case Opcode::ICmpLT: + case Opcode::ICmpGT: + case Opcode::ICmpLE: + case Opcode::ICmpGE: + compare_uses_pos |= DependsOnValue(bin->GetLhs(), pos_arg) || + DependsOnValue(bin->GetRhs(), pos_arg); + break; + default: + break; + } + } + } + + if (divisor == 0 || div_count == 0 || rem_count == 0 || !return_is_rem || + !divisor_chain_uses_num || !compare_uses_pos) { + return false; + } + if (base_shift_out != nullptr) { + *base_shift_out = Log2Exact(divisor); + } + return true; +} + // Recognize the common tolerance-driven Newton iteration for sqrt: // while (abs(t - x / t) > eps) t = (t + x / t) / 2; // The matcher is intentionally structural: it does not inspect source names or diff --git a/src/ir/passes/MemoryUtils.h b/src/ir/passes/MemoryUtils.h index bcd8198..60794b5 100644 --- a/src/ir/passes/MemoryUtils.h +++ b/src/ir/passes/MemoryUtils.h @@ -218,11 +218,12 @@ inline bool CallMayReadRoot(Function* callee, PointerRootKind kind) { case PointerRootKind::ReadonlyGlobal: return callee->ReadsGlobalMemory(); case PointerRootKind::Global: - return callee->ReadsGlobalMemory() || callee->WritesGlobalMemory(); + return callee->ReadsGlobalMemory() || callee->WritesGlobalMemory() || + callee->ReadsParamMemory() || callee->WritesParamMemory(); case PointerRootKind::Param: return callee->ReadsParamMemory() || callee->WritesParamMemory(); case PointerRootKind::Local: - return false; + return callee->ReadsParamMemory() || callee->WritesParamMemory(); case PointerRootKind::Unknown: return callee->MayReadMemory(); } @@ -240,11 +241,11 @@ inline bool CallMayWriteRoot(Function* callee, PointerRootKind kind) { case PointerRootKind::ReadonlyGlobal: return false; case PointerRootKind::Global: - return callee->WritesGlobalMemory(); + return callee->WritesGlobalMemory() || callee->WritesParamMemory(); case PointerRootKind::Param: return callee->WritesParamMemory(); case PointerRootKind::Local: - return false; + return callee->WritesParamMemory(); case PointerRootKind::Unknown: return callee->MayWriteMemory(); } diff --git a/src/ir/passes/PassManager.cpp b/src/ir/passes/PassManager.cpp index 64a31b1..decc381 100644 --- a/src/ir/passes/PassManager.cpp +++ b/src/ir/passes/PassManager.cpp @@ -44,6 +44,7 @@ void RunIRPassPipeline(Module& module) { if (run_cfg_inline) { changed |= RunFunctionInlining(module); } + changed |= RunArithmeticSimplify(module); changed |= RunConstProp(module); changed |= RunConstFold(module); changed |= RunGVN(module); @@ -61,6 +62,7 @@ void RunIRPassPipeline(Module& module) { changed |= RunLoopStrengthReduction(module); changed |= RunLoopFission(module); changed |= RunLoopUnroll(module); + changed |= RunArithmeticSimplify(module); changed |= RunConstProp(module); changed |= RunConstFold(module); changed |= RunGVN(module); diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index 895ca75..20a7a86 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -764,6 +764,14 @@ bool EmitSignedRemByConstant(const MachineFunction& function, const MachineOpera return true; } +void EmitPreparedModMul(const char* dst, const char* lhs, const char* rhs, + const char* modulo_reg, std::ostream& os) { + os << " smull x12, " << lhs << ", " << rhs << "\n"; + os << " sdiv x17, x12, " << modulo_reg << "\n"; + os << " msub x12, x17, " << modulo_reg << ", x12\n"; + os << " mov " << dst << ", w12\n"; +} + std::string MaterializeAddressBaseReg(const MachineFunction& function, const AddressExpr& address, int scratch_index, std::ostream& os) { @@ -1755,10 +1763,107 @@ void EmitFunction(const MachineFunction& function, std::ostream& os) { const auto rhs = MaterializeGprUse(function, inst.GetOperands()[2], ValueType::I32, 11, os); os << " sdiv w12, " << lhs << ", " << rhs << "\n"; os << " msub " << def.reg_name << ", w12, " << rhs << ", " << lhs << "\n"; - FinalizeDef(function, vreg, def, os); - break; - } - case MachineInstr::Opcode::FAdd: + FinalizeDef(function, vreg, def, os); + break; + } + case MachineInstr::Opcode::ModMul: { + const int vreg = inst.GetOperands()[0].GetVReg(); + const auto def = PrepareGprDef(function, vreg, 9); + const auto lhs = + MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os); + const auto rhs = + MaterializeGprUse(function, inst.GetOperands()[2], ValueType::I32, 11, os); + const auto modulo = inst.GetOperands()[3].GetImm(); + EmitMoveImm(os, "x16", modulo); + EmitPreparedModMul(def.reg_name.c_str(), lhs.c_str(), rhs.c_str(), "x16", os); + FinalizeDef(function, vreg, def, os); + break; + } + case MachineInstr::Opcode::ModPow: { + const int vreg = inst.GetOperands()[0].GetVReg(); + const auto def = PrepareGprDef(function, vreg, 9); + const auto base = + MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os); + const auto exp = + MaterializeGprUse(function, inst.GetOperands()[2], ValueType::I32, 11, os); + if (base != "w10") { + EmitCopy(os, "w10", base.c_str(), false); + } + if (exp != "w11") { + EmitCopy(os, "w11", exp.c_str(), false); + } + EmitMoveImm(os, "x16", inst.GetOperands()[3].GetImm()); + os << " mov w9, #1\n"; + const std::string label_base = ".L." + function.GetName() + ".modpow." + + std::to_string(block_index) + "." + + std::to_string(inst_index); + const std::string loop_label = label_base + ".loop"; + const std::string skip_label = label_base + ".skip"; + const std::string done_label = label_base + ".done"; + os << loop_label << ":\n"; + os << " cmp w11, #0\n"; + os << " b.eq " << done_label << "\n"; + os << " tbz w11, #0, " << skip_label << "\n"; + EmitPreparedModMul("w9", "w9", "w10", "x16", os); + os << skip_label << ":\n"; + EmitPreparedModMul("w10", "w10", "w10", "x16", os); + os << " lsr w11, w11, #1\n"; + os << " b " << loop_label << "\n"; + os << done_label << ":\n"; + EmitCopy(os, def.reg_name.c_str(), "w9", false); + FinalizeDef(function, vreg, def, os); + break; + } + case MachineInstr::Opcode::DigitExtractPow2: { + const int vreg = inst.GetOperands()[0].GetVReg(); + const auto def = PrepareGprDef(function, vreg, 9); + const auto num = + MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os); + const auto pos = + MaterializeGprUse(function, inst.GetOperands()[2], ValueType::I32, 11, os); + if (num != "w10") { + EmitCopy(os, "w10", num.c_str(), false); + } + if (pos != "w11") { + EmitCopy(os, "w11", pos.c_str(), false); + } + const int base_shift = static_cast(inst.GetOperands()[3].GetImm()); + const std::int64_t rem_mask = (1ll << base_shift) - 1; + const std::string label_base = ".L." + function.GetName() + ".digit." + + std::to_string(block_index) + "." + + std::to_string(inst_index); + const std::string nonzero_label = label_base + ".nonzero"; + const std::string small_label = label_base + ".small"; + const std::string done_label = label_base + ".done"; + + EmitMoveImm(os, "w16", base_shift); + os << " mul w11, w11, w16\n"; + os << " cmp w11, #0\n"; + os << " b.gt " << nonzero_label << "\n"; + os << " mov w11, #0\n"; + os << nonzero_label << ":\n"; + os << " cmp w11, #31\n"; + os << " b.lt " << small_label << "\n"; + os << " mov " << def.reg_name << ", #0\n"; + os << " b " << done_label << "\n"; + os << small_label << ":\n"; + os << " mov w16, #1\n"; + os << " lsl w16, w16, w11\n"; + os << " sub w16, w16, #1\n"; + os << " asr w12, w10, #31\n"; + os << " and w12, w12, w16\n"; + os << " add w12, w10, w12\n"; + os << " asr w12, w12, w11\n"; + os << " asr w17, w12, #31\n"; + os << " and w17, w17, #" << rem_mask << "\n"; + os << " add w17, w12, w17\n"; + os << " asr w17, w17, #" << base_shift << "\n"; + os << " sub " << def.reg_name << ", w12, w17, lsl #" << base_shift << "\n"; + os << done_label << ":\n"; + FinalizeDef(function, vreg, def, os); + break; + } + case MachineInstr::Opcode::FAdd: case MachineInstr::Opcode::FSub: case MachineInstr::Opcode::FMul: case MachineInstr::Opcode::FDiv: { diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index d443232..78afc9f 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -87,9 +87,202 @@ int GetIRTypeAlign(const std::shared_ptr& type) { return GetValueAlign(LowerType(type)); } -bool ShouldMaterializeAllocaBase(const std::shared_ptr& type) { - return type && type->IsArray() && type->GetSize() >= 256; -} +bool ShouldMaterializeAllocaBase(const std::shared_ptr& type) { + return type && type->IsArray() && type->GetSize() >= 256; +} + +bool IsConstInt(ir::Value* value, int expected) { + auto* ci = ir::dyncast(value); + return ci != nullptr && ci->GetValue() == expected; +} + +bool IsPositiveConstInt(ir::Value* value, int* out) { + auto* ci = ir::dyncast(value); + if (!ci || ci->GetValue() <= 1) { + return false; + } + if (out) { + *out = ci->GetValue(); + } + return true; +} + +bool IsDivByTwoOf(ir::Value* value, ir::Value* dividend) { + auto* div = ir::dyncast(value); + return div != nullptr && div->GetOpcode() == ir::Opcode::Div && + div->GetLhs() == dividend && IsConstInt(div->GetRhs(), 2); +} + +bool IsRecursiveModMultiplyIdiom(const ir::Function& function, int* modulo) { + if (function.IsExternal() || function.GetReturnType() == nullptr || + !function.GetReturnType()->IsInt32() || function.GetArguments().size() != 2 || + !function.GetArgument(0)->GetType()->IsInt32() || + !function.GetArgument(1)->GetType()->IsInt32()) { + return false; + } + + auto* lhs_arg = function.GetArgument(0); + auto* rhs_arg = function.GetArgument(1); + int seen_modulo = 0; + int rem_count = 0; + ir::CallInst* recursive_call = nullptr; + bool recursive_halves_rhs = false; + bool doubles_recursive_result = false; + bool no_other_calls_or_memory = true; + + for (const auto& block_ptr : function.GetBlocks()) { + for (const auto& inst_ptr : block_ptr->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (ir::dyncast(inst) || ir::dyncast(inst) || + ir::dyncast(inst) || ir::dyncast(inst)) { + no_other_calls_or_memory = false; + continue; + } + + if (auto* call = ir::dyncast(inst)) { + if (call->GetCallee() != &function) { + no_other_calls_or_memory = false; + continue; + } + const auto args = call->GetArguments(); + if (args.size() == 2 && args[0] == lhs_arg && IsDivByTwoOf(args[1], rhs_arg)) { + recursive_call = call; + recursive_halves_rhs = true; + } + continue; + } + + auto* bin = ir::dyncast(inst); + if (!bin) { + continue; + } + if (bin->GetOpcode() == ir::Opcode::Rem) { + int current_modulo = 0; + if (IsPositiveConstInt(bin->GetRhs(), ¤t_modulo)) { + if (current_modulo == 2) { + continue; + } + if (seen_modulo == 0) { + seen_modulo = current_modulo; + } else if (seen_modulo != current_modulo) { + no_other_calls_or_memory = false; + } + ++rem_count; + } + } + if (bin->GetOpcode() == ir::Opcode::Add && recursive_call != nullptr && + bin->GetLhs() == recursive_call && bin->GetRhs() == recursive_call) { + doubles_recursive_result = true; + } + } + } + + if (!no_other_calls_or_memory || !recursive_halves_rhs || + !doubles_recursive_result || rem_count < 2 || seen_modulo <= 1) { + return false; + } + if (modulo) { + *modulo = seen_modulo; + } + return true; +} + +bool IsRecursiveModPowerIdiom(const ir::Function& function, int* modulo) { + if (function.IsExternal() || function.GetReturnType() == nullptr || + !function.GetReturnType()->IsInt32() || function.GetArguments().size() != 2 || + !function.GetArgument(0)->GetType()->IsInt32() || + !function.GetArgument(1)->GetType()->IsInt32()) { + return false; + } + + auto* lhs_arg = function.GetArgument(0); + auto* rhs_arg = function.GetArgument(1); + ir::CallInst* recursive_call = nullptr; + ir::CallInst* square_call = nullptr; + bool recursive_halves_rhs = false; + bool has_return_one = false; + bool has_odd_test = false; + bool no_other_calls_or_memory = true; + int seen_modulo = 0; + int multiply_call_count = 0; + + for (const auto& block_ptr : function.GetBlocks()) { + for (const auto& inst_ptr : block_ptr->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (ir::dyncast(inst) || ir::dyncast(inst) || + ir::dyncast(inst) || ir::dyncast(inst)) { + no_other_calls_or_memory = false; + continue; + } + + if (auto* ret = ir::dyncast(inst)) { + if (ret->HasReturnValue() && IsConstInt(ret->GetReturnValue(), 1)) { + has_return_one = true; + } + continue; + } + + if (auto* call = ir::dyncast(inst)) { + if (call->GetCallee() == &function) { + const auto args = call->GetArguments(); + if (args.size() == 2 && args[0] == lhs_arg && IsDivByTwoOf(args[1], rhs_arg)) { + recursive_call = call; + recursive_halves_rhs = true; + } else { + no_other_calls_or_memory = false; + } + continue; + } + + int current_modulo = 0; + if (call->GetCallee() == nullptr || + !IsRecursiveModMultiplyIdiom(*call->GetCallee(), ¤t_modulo)) { + no_other_calls_or_memory = false; + continue; + } + if (seen_modulo == 0) { + seen_modulo = current_modulo; + } else if (seen_modulo != current_modulo) { + no_other_calls_or_memory = false; + } + ++multiply_call_count; + + const auto args = call->GetArguments(); + if (args.size() == 2 && recursive_call != nullptr && + args[0] == recursive_call && args[1] == recursive_call) { + square_call = call; + } else if (args.size() == 2 && square_call != nullptr && + args[0] == square_call && args[1] == lhs_arg) { + // The odd-exponent path multiplies the squared result by the base. + } else if (args.size() == 2 && square_call != nullptr && + args[1] == square_call && args[0] == lhs_arg) { + // Accept commuted multiply(cur, base) shapes as well. + } else if (recursive_call != nullptr) { + no_other_calls_or_memory = false; + } + continue; + } + + auto* bin = ir::dyncast(inst); + if (!bin) { + continue; + } + if (bin->GetOpcode() == ir::Opcode::Rem && bin->GetLhs() == rhs_arg && + IsConstInt(bin->GetRhs(), 2)) { + has_odd_test = true; + } + } + } + + if (!no_other_calls_or_memory || !recursive_halves_rhs || square_call == nullptr || + !has_return_one || !has_odd_test || multiply_call_count < 2 || seen_modulo <= 1) { + return false; + } + if (modulo != nullptr) { + *modulo = seen_modulo; + } + return true; +} CondCode LowerIntCond(ir::Opcode opcode) { switch (opcode) { @@ -673,6 +866,65 @@ class Lowerer { bool TryEmitMathIdiomCall(ir::CallInst* call, const OperandMap* inline_values, MachineOperand* result_operand) { auto* callee = call == nullptr ? nullptr : call->GetCallee(); + int modulo = 0; + if (callee != nullptr && call->GetType() != nullptr && call->GetType()->IsInt32() && + call->GetArguments().size() == 2 && + call->GetArguments()[0]->GetType()->IsInt32() && + call->GetArguments()[1]->GetType()->IsInt32() && + IsRecursiveModMultiplyIdiom(*callee, &modulo)) { + auto lowered = NewVRegValue(ValueType::I32); + current_block_->Append(MachineInstr::Opcode::ModMul, + {MachineOperand::VReg(lowered.index), + ResolveScalarOperand(call->GetArguments()[0], inline_values), + ResolveScalarOperand(call->GetArguments()[1], inline_values), + MachineOperand::Imm(modulo)}); + if (result_operand != nullptr) { + *result_operand = MachineOperand::VReg(lowered.index); + } else { + values_[call] = lowered; + } + return true; + } + + if (callee != nullptr && call->GetType() != nullptr && call->GetType()->IsInt32() && + call->GetArguments().size() == 2 && + call->GetArguments()[0]->GetType()->IsInt32() && + call->GetArguments()[1]->GetType()->IsInt32() && + IsRecursiveModPowerIdiom(*callee, &modulo)) { + auto lowered = NewVRegValue(ValueType::I32); + current_block_->Append(MachineInstr::Opcode::ModPow, + {MachineOperand::VReg(lowered.index), + ResolveScalarOperand(call->GetArguments()[0], inline_values), + ResolveScalarOperand(call->GetArguments()[1], inline_values), + MachineOperand::Imm(modulo)}); + if (result_operand != nullptr) { + *result_operand = MachineOperand::VReg(lowered.index); + } else { + values_[call] = lowered; + } + return true; + } + + int digit_base_shift = 0; + if (callee != nullptr && call->GetType() != nullptr && call->GetType()->IsInt32() && + call->GetArguments().size() == 2 && + call->GetArguments()[0]->GetType()->IsInt32() && + call->GetArguments()[1]->GetType()->IsInt32() && + ir::mathidiom::IsPow2DigitExtractShape(*callee, &digit_base_shift)) { + auto lowered = NewVRegValue(ValueType::I32); + current_block_->Append(MachineInstr::Opcode::DigitExtractPow2, + {MachineOperand::VReg(lowered.index), + ResolveScalarOperand(call->GetArguments()[0], inline_values), + ResolveScalarOperand(call->GetArguments()[1], inline_values), + MachineOperand::Imm(digit_base_shift)}); + if (result_operand != nullptr) { + *result_operand = MachineOperand::VReg(lowered.index); + } else { + values_[call] = lowered; + } + return true; + } + const ir::GlobalValue* sqrt_state = nullptr; if (callee == nullptr || call->GetType() == nullptr || !call->GetType()->IsFloat() || call->GetArguments().size() != 1 || diff --git a/src/mir/MIRInstr.cpp b/src/mir/MIRInstr.cpp index f7208c9..0a2a78a 100644 --- a/src/mir/MIRInstr.cpp +++ b/src/mir/MIRInstr.cpp @@ -41,6 +41,9 @@ std::vector MachineInstr::GetDefs() const { case Opcode::Mul: case Opcode::Div: case Opcode::Rem: + case Opcode::ModMul: + case Opcode::ModPow: + case Opcode::DigitExtractPow2: case Opcode::And: case Opcode::Or: case Opcode::Xor: @@ -128,6 +131,9 @@ std::vector MachineInstr::GetUses() const { case Opcode::Mul: case Opcode::Div: case Opcode::Rem: + case Opcode::ModMul: + case Opcode::ModPow: + case Opcode::DigitExtractPow2: case Opcode::And: case Opcode::Or: case Opcode::Xor: diff --git a/src/mir/passes/Peephole.cpp b/src/mir/passes/Peephole.cpp index 60d3c30..2e96302 100644 --- a/src/mir/passes/Peephole.cpp +++ b/src/mir/passes/Peephole.cpp @@ -209,6 +209,9 @@ bool RewriteUses(MachineInstr& inst, const AliasMap& aliases) { case MachineInstr::Opcode::Mul: case MachineInstr::Opcode::Div: case MachineInstr::Opcode::Rem: + case MachineInstr::Opcode::ModMul: + case MachineInstr::Opcode::ModPow: + case MachineInstr::Opcode::DigitExtractPow2: case MachineInstr::Opcode::And: case MachineInstr::Opcode::Or: case MachineInstr::Opcode::Xor: @@ -771,6 +774,9 @@ bool IsSideEffectFree(const MachineInstr& inst) { case MachineInstr::Opcode::Mul: case MachineInstr::Opcode::Div: case MachineInstr::Opcode::Rem: + case MachineInstr::Opcode::ModMul: + case MachineInstr::Opcode::ModPow: + case MachineInstr::Opcode::DigitExtractPow2: case MachineInstr::Opcode::And: case MachineInstr::Opcode::Or: case MachineInstr::Opcode::Xor: diff --git a/src/mir/passes/SpillReduction.cpp b/src/mir/passes/SpillReduction.cpp index 63040ef..a07a597 100644 --- a/src/mir/passes/SpillReduction.cpp +++ b/src/mir/passes/SpillReduction.cpp @@ -124,6 +124,9 @@ bool RewriteUses(MachineInstr& inst, const std::unordered_map& rename_ case MachineInstr::Opcode::Mul: case MachineInstr::Opcode::Div: case MachineInstr::Opcode::Rem: + case MachineInstr::Opcode::ModMul: + case MachineInstr::Opcode::ModPow: + case MachineInstr::Opcode::DigitExtractPow2: case MachineInstr::Opcode::And: case MachineInstr::Opcode::Or: case MachineInstr::Opcode::Xor: