最新测试结果分析

master
tangttangtang 1 week ago
parent 69892ef133
commit e55421f447

@ -0,0 +1,114 @@
# 比赛性能优化记录
日期2026-04-27
## 本轮已落地
### 1. FFT模乘/模幂 idiom lowering
目标用例:`fft1`、`fft0`。
已实现:
- 在 MIR 增加 `ModMul`,识别递归 `multiply(a, b)` 的模乘 idiomlower 成 `smull + sdiv + msub`,消除 `multiply` 递归调用。
- 在 MIR 增加 `ModPow`,识别递归 `power(a, b)` 的快速幂 idiomlower 成后端内联循环,消除 `power` 递归调用。
- `fft1` 汇编中 `bl multiply` / `bl power` 数量降为 0仅保留算法本身的 `fft` 递归。
主要位置:
- `include/mir/MIR.h`
- `src/mir/Lowering.cpp`
- `src/mir/AsmPrinter.cpp`
- `src/mir/MIRInstr.cpp`
- `src/mir/passes/Peephole.cpp`
- `src/mir/passes/SpillReduction.cpp`
验证结果:
- `fft1`输出匹配qemu 本地约 `0.42s`
- `fft0`输出匹配qemu 本地约 `0.23s`
### 2. 03_sort2power-of-two digit extraction
目标用例:`03_sort2`。
已实现:
- 识别 `while (i < pos) num = num / 16; return num % 16;` 这类 power-of-two radix digit helper。
- IR 内联器会跳过该 helper避免把小函数展开成大量循环。
- 后端用 `DigitExtractPow2` 直接 lower 成移位、带符号除法修正和取余序列,消除 `bl getNumPos`
- 修复 GVN/CSE 的常量等价键,避免等值常量因对象地址不同而错过跨块消冗余。
主要位置:
- `src/ir/passes/MathIdiomUtils.h`
- `src/ir/passes/Inline.cpp`
- `src/ir/passes/GVN.cpp`
- `src/ir/passes/CSE.cpp`
- `src/mir/Lowering.cpp`
- `src/mir/AsmPrinter.cpp`
验证结果:
- `03_sort2`输出匹配qemu 本地约 `19.56s`
- 对比此前表中 `31.317s`,该项收益明显。
### 3. matmul / 2025-MYO-20标量基础优化
目标用例:`matmul1/2/3`、`2025-MYO-20`。
已实现:
- 新增 IR `ArithmeticSimplify`,把 `% power_of_two == 0` 化成 bit-test例如 `x % 2 == 0` 变为 `(x & 1) == 0`
- 增强 `LoadStoreElim`,允许安全的跨块 load forwarding解决 `if` 前已加载、then 块重复加载的问题。
- 修复 `DominatorTree` 的 immediate dominator 判定方向,恢复跨块 GVN/LICM/LSE 的基础支配关系。
- `matmul2` 的内层核心从重复 load + 重复 mul 变为复用同一个乘积。
主要位置:
- `src/ir/passes/ArithmeticSimplify.cpp`
- `src/ir/passes/LoadStoreElim.cpp`
- `src/ir/analysis/DominatorTree.cpp`
- `src/ir/passes/PassManager.cpp`
验证结果:
- `matmul2`输出匹配qemu 本地约 `7.09s`
- 对比此前表中 `8.407s`,已有收益。
尚未完成:
- 真正的 NEON 向量化、矩阵 loop interchange/blocking 还没有落地。当前 MIR 没有 SIMD value type、NEON 寄存器类、向量 load/store、向量 arithmetic也没有稳定的 loop-nest interchange/blocking 框架。硬塞样例级重写风险过高,不适合作为通用比赛编译器优化。
### 4. gameoflifestencil 前置优化
目标用例:`gameoflife-*`。
已实现:
- 通过支配树修复和跨块 load forwarding让 stencil 里的重复地址计算和重复 load 有更多被 GVN/LSE 消除的机会。
验证结果:
- `gameoflife-oscillator`输出匹配qemu 本地约 `8.82s`
尚未完成:
- 真正的 stencil NEON/行缓存优化还未落地。需要先补 SIMD MIR 和更明确的二维数组滑窗识别,否则容易做成样例特化。
### 5. 65_color
该用例加速比难看但绝对损失很小,本轮未优先处理。后续应只在大头用例收敛后再看。
## 本轮验证
- `cmake --build build -j`:通过。
- 单例 qemu 对比均做了 stdout + exit code 的规范化 diff。
- 未运行全量测试,避免耗时过长。
## 下一步优先级
1. 为 MIR 增加 NEON value type、向量寄存器类、vector load/store 和基础 i32x4/f32x4 arithmetic。
2. 在 IR 层补 loop-nest 识别,先做安全的矩阵 loop interchange再考虑 blocking。
3. 对 `gameoflife` 做通用 stencil matcher先生成 scalar row-cache再接 NEON。
4. 对 `2025-MYO-20` 单独用 `scripts/analyze_case.sh` 保存 IR/ASM与 GCC 汇编对照后决定是否值得做 matmul micro-kernel lowering。

@ -0,0 +1,104 @@
# Lab3 最新测试结果分析
日期2026-04-29
## 数据源
- 我方测试日志:`output/logs/lab3/lab3_20260429_192016/whole.log`
- 我方计时表:`output/logs/lab3/lab3_20260429_192016/timing.tsv`
- GCC baseline`output/baseline/gcc_timing.tsv`
本轮我方结果:
```text
summary: 214 PASS / 0 FAIL / total 214
build elapsed: 0.72401s
validation elapsed: 632.18659s
total elapsed: 632.91658s
```
GCC baseline 结果:
```text
Summary: 214 DONE / 0 SKIP (cached) / 0 FAIL / total 214
Total elapsed : 484.24024s
Timing TSV : output/baseline/gcc_timing.tsv (213 entries)
```
## 总体结论
本轮功能正确性已经通过,`214/214 PASS`。但性能口径需要分开看:
| 口径 | 我方 | GCC baseline | 差值 |
| --- | ---: | ---: | ---: |
| 脚本整轮墙钟时间 | 632.91658s | 484.24024s | +148.67634s |
| 程序运行时间总和 | 485.95009s | 425.55356s | +60.39653s |
程序运行时间口径下,当前总体 speedup 为:
```text
425.55356 / 485.95009 = 0.8757x
```
也就是说,生成代码运行时间目前整体比 GCC baseline 慢约 `60.40s`。脚本整轮慢约 `148.68s`,其中额外约 `88s` 来自我方逐样例编译、汇编、链接、校验等流程开销,不完全等价于生成代码性能。
补充说明:`timing.tsv` 有 214 行,当前 `gcc_timing.tsv` 有 213 行;额外项是 `class_test_case/functional/05_arr_defn4`。严格汇总时按当前 baseline 文件可精确匹配的 213 条计算,上表采用这个口径。
## 最大亏损样例
这些样例是当前最值得优先优化的对象,按“我方运行时间 - GCC 运行时间”排序:
| 样例 | 我方 | GCC | 慢多少 |
| --- | ---: | ---: | ---: |
| `class_test_case/performance/2025-MYO-20` | 54.01749s | 29.75174s | +24.26575s |
| `test_case/h_performance/h-14-01` | 33.94136s | 26.19856s | +7.74280s |
| `test_case/h_performance/h-11-01` | 60.07281s | 52.58051s | +7.49230s |
| `test_case/h_performance/h-1-01` | 25.46834s | 20.48401s | +4.98433s |
| `test_case/h_performance/h-12-01` | 20.04854s | 15.68926s | +4.35928s |
| `test_case/h_performance/matmul3` | 7.04411s | 2.87407s | +4.17004s |
| `test_case/h_performance/matmul1` | 7.02077s | 2.86589s | +4.15488s |
| `test_case/h_performance/matmul2` | 6.92980s | 2.92273s | +4.00707s |
| `test_case/h_performance/gameoflife-gosper` | 10.77375s | 7.53120s | +3.24255s |
| `test_case/h_performance/gameoflife-oscillator` | 9.72381s | 6.73087s | +2.99294s |
主要问题集中在四类:
- `2025-MYO-20` 是最大单点亏损,单独慢约 `24.27s`,应作为第一分析对象。
- `matmul1/2/3` 合计慢约 `12.33s`,说明矩阵类内核还缺少有效的 NEON、地址递推、缓存友好变换或循环分块。
- `gameoflife*` 合计慢约 `11s+`,说明 stencil 型访问还没有做到行缓存、重复 load 消除或向量化。
- `h-14-01`、`h-11-01`、`h-1-01`、`h-12-01` 总体占比较大,需要逐个看 IR 和汇编,判断是中端 load/store 没消掉,还是后端 spill/address 质量差。
## 最大收益样例
这些样例说明当前已有优化确实生效:
| 样例 | 我方 | GCC | 快多少 |
| --- | ---: | ---: | ---: |
| `test_case/h_performance/fft1` | 0.42533s | 6.63117s | -6.20584s |
| `class_test_case/performance/fft0` | 0.20593s | 3.13259s | -2.92666s |
| `test_case/h_performance/fft0` | 0.21674s | 3.12871s | -2.91198s |
| `test_case/h_performance/h-2-03` | 16.49539s | 18.95248s | -2.45709s |
| `test_case/h_performance/03_sort2` | 20.81900s | 22.92280s | -2.10380s |
| `test_case/h_performance/h-2-02` | 13.54233s | 15.50163s | -1.95930s |
| `test_case/h_performance/h-4-03` | 5.81272s | 7.71534s | -1.90262s |
| `test_case/h_performance/h-2-01` | 13.92343s | 15.55799s | -1.63456s |
| `class_test_case/performance/large_loop_array_2` | 11.65712s | 13.08078s | -1.42366s |
| `test_case/h_performance/if-combine3` | 14.04854s | 15.40252s | -1.35398s |
关键判断:
- `fft0/fft1` 已明显超过 GCC说明模乘/模幂 idiom lowering 的方向正确。
- `03_sort2` 已从明显慢项变成快项,说明 power-of-two digit extract、常数除法/取模 lowering 已经有实际收益。
- `h-2-*`、`h-4-*`、`if-combine*` 的收益说明中端 GVN/LSE/LICM 和部分后端 peephole 已经在某些结构上命中。
## 当前优化优先级
1. 优先分析 `2025-MYO-20`。这个样例单点亏损最大,应使用 `scripts/analyze_case.sh` 保存 IR 和汇编先确认瓶颈是循环结构、内存访问、调用、spill 还是地址计算。
2. 继续做矩阵类内核优化。`matmul1/2/3` 的差距很集中,下一步应优先看循环层次、地址递推、寄存器复用和保守 NEON而不是继续做零散 peephole。
3. 针对 `gameoflife*` 做 stencil 优化。重点是行缓存、邻域 load 复用、局部数组 promotion以及可证明安全的短向量化。
4. 对 `h-14-01`、`h-11-01`、`h-1-01`、`h-12-01` 做专项拆解。这些样例总时间大,需要逐个确认是否存在尾递归、循环不变量 load、跨块冗余 load/store、或后端 spill 过多。
5. `65_color``29_long_line` 比例难看,但绝对亏损小。它们不是性能分第一优先级;`29_long_line` 更应该作为编译耗时风险样例关注。
## 结论
当前编译器已经能完整通过最新 Lab3 回归,并且在 `fft`、`03_sort2`、部分 `h-2/h-4/if-combine` 样例上体现出明显优化收益。但从比赛性能角度看,总体仍比 GCC baseline 慢约 `60.40s`,主要差距来自 `2025-MYO-20`、矩阵计算、gameoflife stencil 以及若干大规模 h_performance 样例。下一轮优化应围绕这些大头做专项分析,而不是优先处理低绝对耗时的小比例样例。

@ -9,6 +9,7 @@ bool RunConstFold(Module& module);
bool RunConstProp(Module& module);
bool RunFunctionInlining(Module& module);
bool RunTailRecursionElim(Module& module);
bool RunArithmeticSimplify(Module& module);
bool RunCSE(Module& module);
bool RunGVN(Module& module);
bool RunLoadStoreElim(Module& module);

@ -110,10 +110,13 @@ class MachineInstr {
Lea,
Add,
Sub,
Mul,
Div,
Rem,
And,
Mul,
Div,
Rem,
ModMul,
ModPow,
DigitExtractPow2,
And,
Or,
Xor,
Shl,

@ -82,27 +82,27 @@ void DominatorTree::Recalculate() {
}
}
std::vector<std::size_t> dom_depth(num_blocks, 0);
for (std::size_t i = 0; i < num_blocks; ++i) {
for (std::size_t candidate = 0; candidate < num_blocks; ++candidate) {
if (dominates_[i][candidate]) {
++dom_depth[i];
}
}
}
for (std::size_t i = 1; i < num_blocks; ++i) {
auto* block = reverse_post_order_[i];
BasicBlock* idom = nullptr;
std::size_t best_depth = 0;
for (std::size_t candidate = 0; candidate < num_blocks; ++candidate) {
if (candidate == i || !dominates_[i][candidate]) {
continue;
}
auto* candidate_block = reverse_post_order_[candidate];
bool immediate = true;
for (std::size_t other = 0; other < num_blocks; ++other) {
if (other == i || other == candidate || !dominates_[i][other]) {
continue;
}
if (Dominates(reverse_post_order_[other], candidate_block)) {
immediate = false;
break;
}
}
if (immediate) {
if (idom == nullptr || dom_depth[candidate] > best_depth) {
idom = candidate_block;
break;
best_depth = dom_depth[candidate];
}
}
immediate_dominator_.emplace(block, idom);

@ -0,0 +1,137 @@
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "LoopPassUtils.h"
#include <algorithm>
#include <cstdint>
#include <memory>
#include <vector>
namespace ir {
namespace {
bool IsPowerOfTwoPositive(int value) {
return value > 0 && (value & (value - 1)) == 0;
}
std::size_t FindInstructionIndex(BasicBlock* block, Instruction* inst) {
if (!block || !inst) {
return 0;
}
auto& instructions = block->GetInstructions();
for (std::size_t i = 0; i < instructions.size(); ++i) {
if (instructions[i].get() == inst) {
return i;
}
}
return instructions.size();
}
bool IsZero(Value* value) {
if (auto* ci = dyncast<ConstantInt>(value)) {
return ci->GetValue() == 0;
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return !cb->GetValue();
}
return false;
}
Value* OtherCompareOperand(BinaryInst* cmp, Value* value) {
if (!cmp || cmp->GetNumOperands() != 2) {
return nullptr;
}
if (cmp->GetLhs() == value) {
return cmp->GetRhs();
}
if (cmp->GetRhs() == value) {
return cmp->GetLhs();
}
return nullptr;
}
bool SimplifyPowerOfTwoRemTests(Function& function) {
bool changed = false;
std::vector<Instruction*> dead_rems;
for (const auto& block_ptr : function.GetBlocks()) {
auto* block = block_ptr.get();
if (!block) {
continue;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* rem = dyncast<BinaryInst>(inst_ptr.get());
if (!rem || rem->GetOpcode() != Opcode::Rem) {
continue;
}
auto* divisor = dyncast<ConstantInt>(rem->GetRhs());
if (!divisor || !IsPowerOfTwoPositive(divisor->GetValue())) {
continue;
}
const int mask_value = divisor->GetValue() - 1;
if (mask_value == 0) {
rem->ReplaceAllUsesWith(looputils::ConstInt(0));
dead_rems.push_back(rem);
changed = true;
continue;
}
std::vector<BinaryInst*> compare_uses;
bool all_uses_are_zero_tests = !rem->GetUses().empty();
for (const auto& use : rem->GetUses()) {
auto* cmp = dyncast<BinaryInst>(dynamic_cast<Value*>(use.GetUser()));
if (!cmp || (cmp->GetOpcode() != Opcode::ICmpEQ &&
cmp->GetOpcode() != Opcode::ICmpNE) ||
!IsZero(OtherCompareOperand(cmp, rem))) {
all_uses_are_zero_tests = false;
break;
}
compare_uses.push_back(cmp);
}
if (!all_uses_are_zero_tests || compare_uses.empty()) {
continue;
}
const auto insert_index = FindInstructionIndex(block, rem) + 1;
auto* masked = block->Insert<BinaryInst>(
insert_index, Opcode::And, Type::GetInt32Type(), rem->GetLhs(),
looputils::ConstInt(mask_value), nullptr,
looputils::NextSyntheticName(function, "pow2.mask."));
for (auto* cmp : compare_uses) {
if (cmp->GetLhs() == rem) {
cmp->SetOperand(0, masked);
}
if (cmp->GetRhs() == rem) {
cmp->SetOperand(1, masked);
}
}
dead_rems.push_back(rem);
changed = true;
}
}
for (auto* rem : dead_rems) {
if (rem->GetUses().empty() && rem->GetParent()) {
rem->GetParent()->EraseInstruction(rem);
}
}
return changed;
}
} // namespace
bool RunArithmeticSimplify(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (!function || function->IsExternal()) {
continue;
}
changed |= SimplifyPowerOfTwoRemTests(*function);
}
return changed;
}
} // namespace ir

@ -4,6 +4,7 @@ add_library(ir_passes STATIC
ConstFold.cpp
ConstProp.cpp
TailRecursionElim.cpp
ArithmeticSimplify.cpp
Inline.cpp
CSE.cpp
GVN.cpp

@ -14,7 +14,15 @@ namespace {
struct ExprKey {
Opcode opcode = Opcode::Add;
std::vector<std::uintptr_t> operands;
struct OperandKey {
int kind = 0;
std::intptr_t value = 0;
bool operator==(const OperandKey& rhs) const {
return kind == rhs.kind && value == rhs.value;
}
};
std::vector<OperandKey> operands;
bool operator==(const ExprKey& rhs) const {
return opcode == rhs.opcode && operands == rhs.operands;
@ -25,12 +33,26 @@ struct ExprKeyHash {
std::size_t operator()(const ExprKey& key) const {
std::size_t h = static_cast<std::size_t>(key.opcode);
for (auto operand : key.operands) {
h ^= std::hash<std::uintptr_t>{}(operand) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<int>{}(operand.kind) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<std::intptr_t>{}(operand.value) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
ExprKey::OperandKey BuildOperandKey(Value* value) {
if (auto* ci = dyncast<ConstantInt>(value)) {
return {1, ci->GetValue()};
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return {2, cb->GetValue() ? 1 : 0};
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
return {3, static_cast<std::intptr_t>(passutils::FloatBits(cf->GetValue()))};
}
return {0, reinterpret_cast<std::intptr_t>(value)};
}
bool IsSupportedCSEInstruction(Instruction* inst) {
if (!inst || inst->IsVoid()) {
return false;
@ -81,11 +103,12 @@ ExprKey BuildExprKey(Instruction* inst) {
key.opcode = inst->GetOpcode();
key.operands.reserve(inst->GetNumOperands());
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
key.operands.push_back(
reinterpret_cast<std::uintptr_t>(inst->GetOperand(i)));
key.operands.push_back(BuildOperandKey(inst->GetOperand(i)));
}
if (inst->GetNumOperands() == 2 && passutils::IsCommutativeOpcode(inst->GetOpcode()) &&
key.operands[1] < key.operands[0]) {
(key.operands[1].kind < key.operands[0].kind ||
(key.operands[1].kind == key.operands[0].kind &&
key.operands[1].value < key.operands[0].value))) {
std::swap(key.operands[0], key.operands[1]);
}
return key;

@ -17,7 +17,15 @@ struct ExprKey {
Opcode opcode = Opcode::Add;
std::uintptr_t result_type = 0;
std::uintptr_t aux_type = 0;
std::vector<std::uintptr_t> operands;
struct OperandKey {
int kind = 0;
std::intptr_t value = 0;
bool operator==(const OperandKey& rhs) const {
return kind == rhs.kind && value == rhs.value;
}
};
std::vector<OperandKey> operands;
bool operator==(const ExprKey& rhs) const {
return opcode == rhs.opcode && result_type == rhs.result_type &&
@ -33,12 +41,26 @@ struct ExprKeyHash {
h ^= std::hash<std::uintptr_t>{}(key.aux_type) + 0x9e3779b9 + (h << 6) +
(h >> 2);
for (auto operand : key.operands) {
h ^= std::hash<std::uintptr_t>{}(operand) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<int>{}(operand.kind) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<std::intptr_t>{}(operand.value) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
ExprKey::OperandKey BuildOperandKey(Value* value) {
if (auto* ci = dyncast<ConstantInt>(value)) {
return {1, ci->GetValue()};
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return {2, cb->GetValue() ? 1 : 0};
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
return {3, static_cast<std::intptr_t>(passutils::FloatBits(cf->GetValue()))};
}
return {0, reinterpret_cast<std::intptr_t>(value)};
}
struct ScopedExpr {
ExprKey key;
Value* previous = nullptr;
@ -103,12 +125,13 @@ ExprKey BuildExprKey(Instruction* inst) {
}
key.operands.reserve(inst->GetNumOperands());
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
key.operands.push_back(
reinterpret_cast<std::uintptr_t>(inst->GetOperand(i)));
key.operands.push_back(BuildOperandKey(inst->GetOperand(i)));
}
if (inst->GetNumOperands() == 2 &&
passutils::IsCommutativeOpcode(inst->GetOpcode()) &&
key.operands[1] < key.operands[0]) {
(key.operands[1].kind < key.operands[0].kind ||
(key.operands[1].kind == key.operands[0].kind &&
key.operands[1].value < key.operands[0].value))) {
std::swap(key.operands[0], key.operands[1]);
}
return key;

@ -176,6 +176,9 @@ bool ShouldInlineCallSite(const Function& caller, const CallInst& call,
if (mathidiom::IsToleranceNewtonSqrtShape(*callee)) {
return false;
}
if (mathidiom::IsPow2DigitExtractShape(*callee)) {
return false;
}
if (callee_info.has_control_flow && callee_info.has_nested_call) {
return false;
}

@ -96,6 +96,10 @@ void SimulateInstruction(const memutils::EscapeSummary& escapes, Instruction* in
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(load->GetPtr(), &escapes, key)) {
state.clear();
return;
}
if (state.find(key) == state.end()) {
state[key] = {load};
}
return;
}
@ -194,9 +198,9 @@ bool OptimizeBlock(
changed = true;
continue;
}
// Keep block-local load reuse, but do not expose load results to cross-block
// dataflow because the defining load itself may be removed later.
state[key] = {load};
if (state.find(key) == state.end()) {
state[key] = {load};
}
continue;
}

@ -2,6 +2,7 @@
#include "ir/IR.h"
#include <cstdint>
#include <cstddef>
#include <unordered_map>
#include <unordered_set>
@ -88,6 +89,149 @@ inline bool HasBackedgeLikeBranch(const Function& function) {
return false;
}
inline bool IsPowerOfTwoPositive(int value) {
return value > 0 && (value & (value - 1)) == 0;
}
inline int Log2Exact(int value) {
int shift = 0;
while (value > 1) {
value >>= 1;
++shift;
}
return shift;
}
inline bool DependsOnValueImpl(Value* value, Value* needle, int depth,
std::unordered_set<Value*>& visiting) {
if (value == needle) {
return true;
}
if (value == nullptr || depth <= 0 || !visiting.insert(value).second) {
return false;
}
auto* inst = dyncast<Instruction>(value);
if (inst == nullptr) {
return false;
}
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
if (DependsOnValueImpl(inst->GetOperand(i), needle, depth - 1, visiting)) {
return true;
}
}
return false;
}
inline bool DependsOnValue(Value* value, Value* needle, int depth = 12) {
std::unordered_set<Value*> visiting;
return DependsOnValueImpl(value, needle, depth, visiting);
}
// Recognize the radix-digit helper:
// while (i < pos) num = num / C;
// return num % C;
// for power-of-two C >= 4. Lowering replaces calls with a straight-line
// shift/remainder sequence, which is much cheaper than inlining the loop at
// every call site in radix-sort kernels.
inline bool IsPow2DigitExtractShape(const Function& function,
int* base_shift_out = nullptr) {
if (base_shift_out != nullptr) {
*base_shift_out = 0;
}
if (function.IsExternal() || function.GetReturnType() == nullptr ||
!function.GetReturnType()->IsInt32() || function.GetArguments().size() != 2 ||
!function.GetArgument(0)->GetType()->IsInt32() ||
!function.GetArgument(1)->GetType()->IsInt32() ||
!HasBackedgeLikeBranch(function)) {
return false;
}
auto* num_arg = function.GetArgument(0);
auto* pos_arg = function.GetArgument(1);
int divisor = 0;
int div_count = 0;
int rem_count = 0;
bool return_is_rem = false;
bool divisor_chain_uses_num = false;
bool compare_uses_pos = false;
for (const auto& block : function.GetBlocks()) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dyncast<CallInst>(inst) || dyncast<LoadInst>(inst) ||
dyncast<StoreInst>(inst) || dyncast<AllocaInst>(inst) ||
dyncast<GetElementPtrInst>(inst) || dyncast<MemsetInst>(inst) ||
dyncast<UnreachableInst>(inst)) {
return false;
}
if (auto* ret = dyncast<ReturnInst>(inst)) {
auto* returned = ret->HasReturnValue() ? ret->GetReturnValue() : nullptr;
auto* rem = dyncast<BinaryInst>(returned);
auto* rhs = rem == nullptr ? nullptr : dyncast<ConstantInt>(rem->GetRhs());
if (rem == nullptr || rem->GetOpcode() != Opcode::Rem || rhs == nullptr ||
!IsPowerOfTwoPositive(rhs->GetValue()) || rhs->GetValue() < 4) {
return false;
}
if (divisor == 0) {
divisor = rhs->GetValue();
} else if (divisor != rhs->GetValue()) {
return false;
}
return_is_rem = true;
continue;
}
auto* bin = dyncast<BinaryInst>(inst);
if (!bin) {
continue;
}
if (bin->GetOpcode() == Opcode::Div || bin->GetOpcode() == Opcode::Rem) {
auto* rhs = dyncast<ConstantInt>(bin->GetRhs());
if (rhs == nullptr || !IsPowerOfTwoPositive(rhs->GetValue()) ||
rhs->GetValue() < 4) {
return false;
}
if (divisor == 0) {
divisor = rhs->GetValue();
} else if (divisor != rhs->GetValue()) {
return false;
}
if (bin->GetOpcode() == Opcode::Div) {
++div_count;
} else {
++rem_count;
}
divisor_chain_uses_num |= DependsOnValue(bin->GetLhs(), num_arg);
}
switch (bin->GetOpcode()) {
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
compare_uses_pos |= DependsOnValue(bin->GetLhs(), pos_arg) ||
DependsOnValue(bin->GetRhs(), pos_arg);
break;
default:
break;
}
}
}
if (divisor == 0 || div_count == 0 || rem_count == 0 || !return_is_rem ||
!divisor_chain_uses_num || !compare_uses_pos) {
return false;
}
if (base_shift_out != nullptr) {
*base_shift_out = Log2Exact(divisor);
}
return true;
}
// Recognize the common tolerance-driven Newton iteration for sqrt:
// while (abs(t - x / t) > eps) t = (t + x / t) / 2;
// The matcher is intentionally structural: it does not inspect source names or

@ -218,11 +218,12 @@ inline bool CallMayReadRoot(Function* callee, PointerRootKind kind) {
case PointerRootKind::ReadonlyGlobal:
return callee->ReadsGlobalMemory();
case PointerRootKind::Global:
return callee->ReadsGlobalMemory() || callee->WritesGlobalMemory();
return callee->ReadsGlobalMemory() || callee->WritesGlobalMemory() ||
callee->ReadsParamMemory() || callee->WritesParamMemory();
case PointerRootKind::Param:
return callee->ReadsParamMemory() || callee->WritesParamMemory();
case PointerRootKind::Local:
return false;
return callee->ReadsParamMemory() || callee->WritesParamMemory();
case PointerRootKind::Unknown:
return callee->MayReadMemory();
}
@ -240,11 +241,11 @@ inline bool CallMayWriteRoot(Function* callee, PointerRootKind kind) {
case PointerRootKind::ReadonlyGlobal:
return false;
case PointerRootKind::Global:
return callee->WritesGlobalMemory();
return callee->WritesGlobalMemory() || callee->WritesParamMemory();
case PointerRootKind::Param:
return callee->WritesParamMemory();
case PointerRootKind::Local:
return false;
return callee->WritesParamMemory();
case PointerRootKind::Unknown:
return callee->MayWriteMemory();
}

@ -44,6 +44,7 @@ void RunIRPassPipeline(Module& module) {
if (run_cfg_inline) {
changed |= RunFunctionInlining(module);
}
changed |= RunArithmeticSimplify(module);
changed |= RunConstProp(module);
changed |= RunConstFold(module);
changed |= RunGVN(module);
@ -61,6 +62,7 @@ void RunIRPassPipeline(Module& module) {
changed |= RunLoopStrengthReduction(module);
changed |= RunLoopFission(module);
changed |= RunLoopUnroll(module);
changed |= RunArithmeticSimplify(module);
changed |= RunConstProp(module);
changed |= RunConstFold(module);
changed |= RunGVN(module);

@ -764,6 +764,14 @@ bool EmitSignedRemByConstant(const MachineFunction& function, const MachineOpera
return true;
}
void EmitPreparedModMul(const char* dst, const char* lhs, const char* rhs,
const char* modulo_reg, std::ostream& os) {
os << " smull x12, " << lhs << ", " << rhs << "\n";
os << " sdiv x17, x12, " << modulo_reg << "\n";
os << " msub x12, x17, " << modulo_reg << ", x12\n";
os << " mov " << dst << ", w12\n";
}
std::string MaterializeAddressBaseReg(const MachineFunction& function,
const AddressExpr& address, int scratch_index,
std::ostream& os) {
@ -1755,10 +1763,107 @@ void EmitFunction(const MachineFunction& function, std::ostream& os) {
const auto rhs = MaterializeGprUse(function, inst.GetOperands()[2], ValueType::I32, 11, os);
os << " sdiv w12, " << lhs << ", " << rhs << "\n";
os << " msub " << def.reg_name << ", w12, " << rhs << ", " << lhs << "\n";
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::FAdd:
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::ModMul: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareGprDef(function, vreg, 9);
const auto lhs =
MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os);
const auto rhs =
MaterializeGprUse(function, inst.GetOperands()[2], ValueType::I32, 11, os);
const auto modulo = inst.GetOperands()[3].GetImm();
EmitMoveImm(os, "x16", modulo);
EmitPreparedModMul(def.reg_name.c_str(), lhs.c_str(), rhs.c_str(), "x16", os);
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::ModPow: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareGprDef(function, vreg, 9);
const auto base =
MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os);
const auto exp =
MaterializeGprUse(function, inst.GetOperands()[2], ValueType::I32, 11, os);
if (base != "w10") {
EmitCopy(os, "w10", base.c_str(), false);
}
if (exp != "w11") {
EmitCopy(os, "w11", exp.c_str(), false);
}
EmitMoveImm(os, "x16", inst.GetOperands()[3].GetImm());
os << " mov w9, #1\n";
const std::string label_base = ".L." + function.GetName() + ".modpow." +
std::to_string(block_index) + "." +
std::to_string(inst_index);
const std::string loop_label = label_base + ".loop";
const std::string skip_label = label_base + ".skip";
const std::string done_label = label_base + ".done";
os << loop_label << ":\n";
os << " cmp w11, #0\n";
os << " b.eq " << done_label << "\n";
os << " tbz w11, #0, " << skip_label << "\n";
EmitPreparedModMul("w9", "w9", "w10", "x16", os);
os << skip_label << ":\n";
EmitPreparedModMul("w10", "w10", "w10", "x16", os);
os << " lsr w11, w11, #1\n";
os << " b " << loop_label << "\n";
os << done_label << ":\n";
EmitCopy(os, def.reg_name.c_str(), "w9", false);
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::DigitExtractPow2: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareGprDef(function, vreg, 9);
const auto num =
MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os);
const auto pos =
MaterializeGprUse(function, inst.GetOperands()[2], ValueType::I32, 11, os);
if (num != "w10") {
EmitCopy(os, "w10", num.c_str(), false);
}
if (pos != "w11") {
EmitCopy(os, "w11", pos.c_str(), false);
}
const int base_shift = static_cast<int>(inst.GetOperands()[3].GetImm());
const std::int64_t rem_mask = (1ll << base_shift) - 1;
const std::string label_base = ".L." + function.GetName() + ".digit." +
std::to_string(block_index) + "." +
std::to_string(inst_index);
const std::string nonzero_label = label_base + ".nonzero";
const std::string small_label = label_base + ".small";
const std::string done_label = label_base + ".done";
EmitMoveImm(os, "w16", base_shift);
os << " mul w11, w11, w16\n";
os << " cmp w11, #0\n";
os << " b.gt " << nonzero_label << "\n";
os << " mov w11, #0\n";
os << nonzero_label << ":\n";
os << " cmp w11, #31\n";
os << " b.lt " << small_label << "\n";
os << " mov " << def.reg_name << ", #0\n";
os << " b " << done_label << "\n";
os << small_label << ":\n";
os << " mov w16, #1\n";
os << " lsl w16, w16, w11\n";
os << " sub w16, w16, #1\n";
os << " asr w12, w10, #31\n";
os << " and w12, w12, w16\n";
os << " add w12, w10, w12\n";
os << " asr w12, w12, w11\n";
os << " asr w17, w12, #31\n";
os << " and w17, w17, #" << rem_mask << "\n";
os << " add w17, w12, w17\n";
os << " asr w17, w17, #" << base_shift << "\n";
os << " sub " << def.reg_name << ", w12, w17, lsl #" << base_shift << "\n";
os << done_label << ":\n";
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::FAdd:
case MachineInstr::Opcode::FSub:
case MachineInstr::Opcode::FMul:
case MachineInstr::Opcode::FDiv: {

@ -87,9 +87,202 @@ int GetIRTypeAlign(const std::shared_ptr<ir::Type>& type) {
return GetValueAlign(LowerType(type));
}
bool ShouldMaterializeAllocaBase(const std::shared_ptr<ir::Type>& type) {
return type && type->IsArray() && type->GetSize() >= 256;
}
bool ShouldMaterializeAllocaBase(const std::shared_ptr<ir::Type>& type) {
return type && type->IsArray() && type->GetSize() >= 256;
}
bool IsConstInt(ir::Value* value, int expected) {
auto* ci = ir::dyncast<ir::ConstantInt>(value);
return ci != nullptr && ci->GetValue() == expected;
}
bool IsPositiveConstInt(ir::Value* value, int* out) {
auto* ci = ir::dyncast<ir::ConstantInt>(value);
if (!ci || ci->GetValue() <= 1) {
return false;
}
if (out) {
*out = ci->GetValue();
}
return true;
}
bool IsDivByTwoOf(ir::Value* value, ir::Value* dividend) {
auto* div = ir::dyncast<ir::BinaryInst>(value);
return div != nullptr && div->GetOpcode() == ir::Opcode::Div &&
div->GetLhs() == dividend && IsConstInt(div->GetRhs(), 2);
}
bool IsRecursiveModMultiplyIdiom(const ir::Function& function, int* modulo) {
if (function.IsExternal() || function.GetReturnType() == nullptr ||
!function.GetReturnType()->IsInt32() || function.GetArguments().size() != 2 ||
!function.GetArgument(0)->GetType()->IsInt32() ||
!function.GetArgument(1)->GetType()->IsInt32()) {
return false;
}
auto* lhs_arg = function.GetArgument(0);
auto* rhs_arg = function.GetArgument(1);
int seen_modulo = 0;
int rem_count = 0;
ir::CallInst* recursive_call = nullptr;
bool recursive_halves_rhs = false;
bool doubles_recursive_result = false;
bool no_other_calls_or_memory = true;
for (const auto& block_ptr : function.GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (ir::dyncast<ir::LoadInst>(inst) || ir::dyncast<ir::StoreInst>(inst) ||
ir::dyncast<ir::AllocaInst>(inst) || ir::dyncast<ir::MemsetInst>(inst)) {
no_other_calls_or_memory = false;
continue;
}
if (auto* call = ir::dyncast<ir::CallInst>(inst)) {
if (call->GetCallee() != &function) {
no_other_calls_or_memory = false;
continue;
}
const auto args = call->GetArguments();
if (args.size() == 2 && args[0] == lhs_arg && IsDivByTwoOf(args[1], rhs_arg)) {
recursive_call = call;
recursive_halves_rhs = true;
}
continue;
}
auto* bin = ir::dyncast<ir::BinaryInst>(inst);
if (!bin) {
continue;
}
if (bin->GetOpcode() == ir::Opcode::Rem) {
int current_modulo = 0;
if (IsPositiveConstInt(bin->GetRhs(), &current_modulo)) {
if (current_modulo == 2) {
continue;
}
if (seen_modulo == 0) {
seen_modulo = current_modulo;
} else if (seen_modulo != current_modulo) {
no_other_calls_or_memory = false;
}
++rem_count;
}
}
if (bin->GetOpcode() == ir::Opcode::Add && recursive_call != nullptr &&
bin->GetLhs() == recursive_call && bin->GetRhs() == recursive_call) {
doubles_recursive_result = true;
}
}
}
if (!no_other_calls_or_memory || !recursive_halves_rhs ||
!doubles_recursive_result || rem_count < 2 || seen_modulo <= 1) {
return false;
}
if (modulo) {
*modulo = seen_modulo;
}
return true;
}
bool IsRecursiveModPowerIdiom(const ir::Function& function, int* modulo) {
if (function.IsExternal() || function.GetReturnType() == nullptr ||
!function.GetReturnType()->IsInt32() || function.GetArguments().size() != 2 ||
!function.GetArgument(0)->GetType()->IsInt32() ||
!function.GetArgument(1)->GetType()->IsInt32()) {
return false;
}
auto* lhs_arg = function.GetArgument(0);
auto* rhs_arg = function.GetArgument(1);
ir::CallInst* recursive_call = nullptr;
ir::CallInst* square_call = nullptr;
bool recursive_halves_rhs = false;
bool has_return_one = false;
bool has_odd_test = false;
bool no_other_calls_or_memory = true;
int seen_modulo = 0;
int multiply_call_count = 0;
for (const auto& block_ptr : function.GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (ir::dyncast<ir::LoadInst>(inst) || ir::dyncast<ir::StoreInst>(inst) ||
ir::dyncast<ir::AllocaInst>(inst) || ir::dyncast<ir::MemsetInst>(inst)) {
no_other_calls_or_memory = false;
continue;
}
if (auto* ret = ir::dyncast<ir::ReturnInst>(inst)) {
if (ret->HasReturnValue() && IsConstInt(ret->GetReturnValue(), 1)) {
has_return_one = true;
}
continue;
}
if (auto* call = ir::dyncast<ir::CallInst>(inst)) {
if (call->GetCallee() == &function) {
const auto args = call->GetArguments();
if (args.size() == 2 && args[0] == lhs_arg && IsDivByTwoOf(args[1], rhs_arg)) {
recursive_call = call;
recursive_halves_rhs = true;
} else {
no_other_calls_or_memory = false;
}
continue;
}
int current_modulo = 0;
if (call->GetCallee() == nullptr ||
!IsRecursiveModMultiplyIdiom(*call->GetCallee(), &current_modulo)) {
no_other_calls_or_memory = false;
continue;
}
if (seen_modulo == 0) {
seen_modulo = current_modulo;
} else if (seen_modulo != current_modulo) {
no_other_calls_or_memory = false;
}
++multiply_call_count;
const auto args = call->GetArguments();
if (args.size() == 2 && recursive_call != nullptr &&
args[0] == recursive_call && args[1] == recursive_call) {
square_call = call;
} else if (args.size() == 2 && square_call != nullptr &&
args[0] == square_call && args[1] == lhs_arg) {
// The odd-exponent path multiplies the squared result by the base.
} else if (args.size() == 2 && square_call != nullptr &&
args[1] == square_call && args[0] == lhs_arg) {
// Accept commuted multiply(cur, base) shapes as well.
} else if (recursive_call != nullptr) {
no_other_calls_or_memory = false;
}
continue;
}
auto* bin = ir::dyncast<ir::BinaryInst>(inst);
if (!bin) {
continue;
}
if (bin->GetOpcode() == ir::Opcode::Rem && bin->GetLhs() == rhs_arg &&
IsConstInt(bin->GetRhs(), 2)) {
has_odd_test = true;
}
}
}
if (!no_other_calls_or_memory || !recursive_halves_rhs || square_call == nullptr ||
!has_return_one || !has_odd_test || multiply_call_count < 2 || seen_modulo <= 1) {
return false;
}
if (modulo != nullptr) {
*modulo = seen_modulo;
}
return true;
}
CondCode LowerIntCond(ir::Opcode opcode) {
switch (opcode) {
@ -673,6 +866,65 @@ class Lowerer {
bool TryEmitMathIdiomCall(ir::CallInst* call, const OperandMap* inline_values,
MachineOperand* result_operand) {
auto* callee = call == nullptr ? nullptr : call->GetCallee();
int modulo = 0;
if (callee != nullptr && call->GetType() != nullptr && call->GetType()->IsInt32() &&
call->GetArguments().size() == 2 &&
call->GetArguments()[0]->GetType()->IsInt32() &&
call->GetArguments()[1]->GetType()->IsInt32() &&
IsRecursiveModMultiplyIdiom(*callee, &modulo)) {
auto lowered = NewVRegValue(ValueType::I32);
current_block_->Append(MachineInstr::Opcode::ModMul,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(call->GetArguments()[0], inline_values),
ResolveScalarOperand(call->GetArguments()[1], inline_values),
MachineOperand::Imm(modulo)});
if (result_operand != nullptr) {
*result_operand = MachineOperand::VReg(lowered.index);
} else {
values_[call] = lowered;
}
return true;
}
if (callee != nullptr && call->GetType() != nullptr && call->GetType()->IsInt32() &&
call->GetArguments().size() == 2 &&
call->GetArguments()[0]->GetType()->IsInt32() &&
call->GetArguments()[1]->GetType()->IsInt32() &&
IsRecursiveModPowerIdiom(*callee, &modulo)) {
auto lowered = NewVRegValue(ValueType::I32);
current_block_->Append(MachineInstr::Opcode::ModPow,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(call->GetArguments()[0], inline_values),
ResolveScalarOperand(call->GetArguments()[1], inline_values),
MachineOperand::Imm(modulo)});
if (result_operand != nullptr) {
*result_operand = MachineOperand::VReg(lowered.index);
} else {
values_[call] = lowered;
}
return true;
}
int digit_base_shift = 0;
if (callee != nullptr && call->GetType() != nullptr && call->GetType()->IsInt32() &&
call->GetArguments().size() == 2 &&
call->GetArguments()[0]->GetType()->IsInt32() &&
call->GetArguments()[1]->GetType()->IsInt32() &&
ir::mathidiom::IsPow2DigitExtractShape(*callee, &digit_base_shift)) {
auto lowered = NewVRegValue(ValueType::I32);
current_block_->Append(MachineInstr::Opcode::DigitExtractPow2,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(call->GetArguments()[0], inline_values),
ResolveScalarOperand(call->GetArguments()[1], inline_values),
MachineOperand::Imm(digit_base_shift)});
if (result_operand != nullptr) {
*result_operand = MachineOperand::VReg(lowered.index);
} else {
values_[call] = lowered;
}
return true;
}
const ir::GlobalValue* sqrt_state = nullptr;
if (callee == nullptr || call->GetType() == nullptr || !call->GetType()->IsFloat() ||
call->GetArguments().size() != 1 ||

@ -41,6 +41,9 @@ std::vector<int> MachineInstr::GetDefs() const {
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::ModMul:
case Opcode::ModPow:
case Opcode::DigitExtractPow2:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
@ -128,6 +131,9 @@ std::vector<int> MachineInstr::GetUses() const {
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::ModMul:
case Opcode::ModPow:
case Opcode::DigitExtractPow2:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:

@ -209,6 +209,9 @@ bool RewriteUses(MachineInstr& inst, const AliasMap& aliases) {
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::Rem:
case MachineInstr::Opcode::ModMul:
case MachineInstr::Opcode::ModPow:
case MachineInstr::Opcode::DigitExtractPow2:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
@ -771,6 +774,9 @@ bool IsSideEffectFree(const MachineInstr& inst) {
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::Rem:
case MachineInstr::Opcode::ModMul:
case MachineInstr::Opcode::ModPow:
case MachineInstr::Opcode::DigitExtractPow2:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:

@ -124,6 +124,9 @@ bool RewriteUses(MachineInstr& inst, const std::unordered_map<int, int>& rename_
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::Rem:
case MachineInstr::Opcode::ModMul:
case MachineInstr::Opcode::ModPow:
case MachineInstr::Opcode::DigitExtractPow2:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:

Loading…
Cancel
Save