Compare commits

...

6 Commits

@ -0,0 +1,114 @@
# 比赛性能优化记录
日期2026-04-27
## 本轮已落地
### 1. FFT模乘/模幂 idiom lowering
目标用例:`fft1`、`fft0`。
已实现:
- 在 MIR 增加 `ModMul`,识别递归 `multiply(a, b)` 的模乘 idiomlower 成 `smull + sdiv + msub`,消除 `multiply` 递归调用。
- 在 MIR 增加 `ModPow`,识别递归 `power(a, b)` 的快速幂 idiomlower 成后端内联循环,消除 `power` 递归调用。
- `fft1` 汇编中 `bl multiply` / `bl power` 数量降为 0仅保留算法本身的 `fft` 递归。
主要位置:
- `include/mir/MIR.h`
- `src/mir/Lowering.cpp`
- `src/mir/AsmPrinter.cpp`
- `src/mir/MIRInstr.cpp`
- `src/mir/passes/Peephole.cpp`
- `src/mir/passes/SpillReduction.cpp`
验证结果:
- `fft1`输出匹配qemu 本地约 `0.42s`
- `fft0`输出匹配qemu 本地约 `0.23s`
### 2. 03_sort2power-of-two digit extraction
目标用例:`03_sort2`。
已实现:
- 识别 `while (i < pos) num = num / 16; return num % 16;` 这类 power-of-two radix digit helper。
- IR 内联器会跳过该 helper避免把小函数展开成大量循环。
- 后端用 `DigitExtractPow2` 直接 lower 成移位、带符号除法修正和取余序列,消除 `bl getNumPos`
- 修复 GVN/CSE 的常量等价键,避免等值常量因对象地址不同而错过跨块消冗余。
主要位置:
- `src/ir/passes/MathIdiomUtils.h`
- `src/ir/passes/Inline.cpp`
- `src/ir/passes/GVN.cpp`
- `src/ir/passes/CSE.cpp`
- `src/mir/Lowering.cpp`
- `src/mir/AsmPrinter.cpp`
验证结果:
- `03_sort2`输出匹配qemu 本地约 `19.56s`
- 对比此前表中 `31.317s`,该项收益明显。
### 3. matmul / 2025-MYO-20标量基础优化
目标用例:`matmul1/2/3`、`2025-MYO-20`。
已实现:
- 新增 IR `ArithmeticSimplify`,把 `% power_of_two == 0` 化成 bit-test例如 `x % 2 == 0` 变为 `(x & 1) == 0`
- 增强 `LoadStoreElim`,允许安全的跨块 load forwarding解决 `if` 前已加载、then 块重复加载的问题。
- 修复 `DominatorTree` 的 immediate dominator 判定方向,恢复跨块 GVN/LICM/LSE 的基础支配关系。
- `matmul2` 的内层核心从重复 load + 重复 mul 变为复用同一个乘积。
主要位置:
- `src/ir/passes/ArithmeticSimplify.cpp`
- `src/ir/passes/LoadStoreElim.cpp`
- `src/ir/analysis/DominatorTree.cpp`
- `src/ir/passes/PassManager.cpp`
验证结果:
- `matmul2`输出匹配qemu 本地约 `7.09s`
- 对比此前表中 `8.407s`,已有收益。
尚未完成:
- 真正的 NEON 向量化、矩阵 loop interchange/blocking 还没有落地。当前 MIR 没有 SIMD value type、NEON 寄存器类、向量 load/store、向量 arithmetic也没有稳定的 loop-nest interchange/blocking 框架。硬塞样例级重写风险过高,不适合作为通用比赛编译器优化。
### 4. gameoflifestencil 前置优化
目标用例:`gameoflife-*`。
已实现:
- 通过支配树修复和跨块 load forwarding让 stencil 里的重复地址计算和重复 load 有更多被 GVN/LSE 消除的机会。
验证结果:
- `gameoflife-oscillator`输出匹配qemu 本地约 `8.82s`
尚未完成:
- 真正的 stencil NEON/行缓存优化还未落地。需要先补 SIMD MIR 和更明确的二维数组滑窗识别,否则容易做成样例特化。
### 5. 65_color
该用例加速比难看但绝对损失很小,本轮未优先处理。后续应只在大头用例收敛后再看。
## 本轮验证
- `cmake --build build -j`:通过。
- 单例 qemu 对比均做了 stdout + exit code 的规范化 diff。
- 未运行全量测试,避免耗时过长。
## 下一步优先级
1. 为 MIR 增加 NEON value type、向量寄存器类、vector load/store 和基础 i32x4/f32x4 arithmetic。
2. 在 IR 层补 loop-nest 识别,先做安全的矩阵 loop interchange再考虑 blocking。
3. 对 `gameoflife` 做通用 stencil matcher先生成 scalar row-cache再接 NEON。
4. 对 `2025-MYO-20` 单独用 `scripts/analyze_case.sh` 保存 IR/ASM与 GCC 汇编对照后决定是否值得做 matmul micro-kernel lowering。

@ -0,0 +1,104 @@
# Lab3 最新测试结果分析
日期2026-04-29
## 数据源
- 我方测试日志:`output/logs/lab3/lab3_20260429_192016/whole.log`
- 我方计时表:`output/logs/lab3/lab3_20260429_192016/timing.tsv`
- GCC baseline`output/baseline/gcc_timing.tsv`
本轮我方结果:
```text
summary: 214 PASS / 0 FAIL / total 214
build elapsed: 0.72401s
validation elapsed: 632.18659s
total elapsed: 632.91658s
```
GCC baseline 结果:
```text
Summary: 214 DONE / 0 SKIP (cached) / 0 FAIL / total 214
Total elapsed : 484.24024s
Timing TSV : output/baseline/gcc_timing.tsv (213 entries)
```
## 总体结论
本轮功能正确性已经通过,`214/214 PASS`。但性能口径需要分开看:
| 口径 | 我方 | GCC baseline | 差值 |
| --- | ---: | ---: | ---: |
| 脚本整轮墙钟时间 | 632.91658s | 484.24024s | +148.67634s |
| 程序运行时间总和 | 485.95009s | 425.55356s | +60.39653s |
程序运行时间口径下,当前总体 speedup 为:
```text
425.55356 / 485.95009 = 0.8757x
```
也就是说,生成代码运行时间目前整体比 GCC baseline 慢约 `60.40s`。脚本整轮慢约 `148.68s`,其中额外约 `88s` 来自我方逐样例编译、汇编、链接、校验等流程开销,不完全等价于生成代码性能。
补充说明:`timing.tsv` 有 214 行,当前 `gcc_timing.tsv` 有 213 行;额外项是 `class_test_case/functional/05_arr_defn4`。严格汇总时按当前 baseline 文件可精确匹配的 213 条计算,上表采用这个口径。
## 最大亏损样例
这些样例是当前最值得优先优化的对象,按“我方运行时间 - GCC 运行时间”排序:
| 样例 | 我方 | GCC | 慢多少 |
| --- | ---: | ---: | ---: |
| `class_test_case/performance/2025-MYO-20` | 54.01749s | 29.75174s | +24.26575s |
| `test_case/h_performance/h-14-01` | 33.94136s | 26.19856s | +7.74280s |
| `test_case/h_performance/h-11-01` | 60.07281s | 52.58051s | +7.49230s |
| `test_case/h_performance/h-1-01` | 25.46834s | 20.48401s | +4.98433s |
| `test_case/h_performance/h-12-01` | 20.04854s | 15.68926s | +4.35928s |
| `test_case/h_performance/matmul3` | 7.04411s | 2.87407s | +4.17004s |
| `test_case/h_performance/matmul1` | 7.02077s | 2.86589s | +4.15488s |
| `test_case/h_performance/matmul2` | 6.92980s | 2.92273s | +4.00707s |
| `test_case/h_performance/gameoflife-gosper` | 10.77375s | 7.53120s | +3.24255s |
| `test_case/h_performance/gameoflife-oscillator` | 9.72381s | 6.73087s | +2.99294s |
主要问题集中在四类:
- `2025-MYO-20` 是最大单点亏损,单独慢约 `24.27s`,应作为第一分析对象。
- `matmul1/2/3` 合计慢约 `12.33s`,说明矩阵类内核还缺少有效的 NEON、地址递推、缓存友好变换或循环分块。
- `gameoflife*` 合计慢约 `11s+`,说明 stencil 型访问还没有做到行缓存、重复 load 消除或向量化。
- `h-14-01`、`h-11-01`、`h-1-01`、`h-12-01` 总体占比较大,需要逐个看 IR 和汇编,判断是中端 load/store 没消掉,还是后端 spill/address 质量差。
## 最大收益样例
这些样例说明当前已有优化确实生效:
| 样例 | 我方 | GCC | 快多少 |
| --- | ---: | ---: | ---: |
| `test_case/h_performance/fft1` | 0.42533s | 6.63117s | -6.20584s |
| `class_test_case/performance/fft0` | 0.20593s | 3.13259s | -2.92666s |
| `test_case/h_performance/fft0` | 0.21674s | 3.12871s | -2.91198s |
| `test_case/h_performance/h-2-03` | 16.49539s | 18.95248s | -2.45709s |
| `test_case/h_performance/03_sort2` | 20.81900s | 22.92280s | -2.10380s |
| `test_case/h_performance/h-2-02` | 13.54233s | 15.50163s | -1.95930s |
| `test_case/h_performance/h-4-03` | 5.81272s | 7.71534s | -1.90262s |
| `test_case/h_performance/h-2-01` | 13.92343s | 15.55799s | -1.63456s |
| `class_test_case/performance/large_loop_array_2` | 11.65712s | 13.08078s | -1.42366s |
| `test_case/h_performance/if-combine3` | 14.04854s | 15.40252s | -1.35398s |
关键判断:
- `fft0/fft1` 已明显超过 GCC说明模乘/模幂 idiom lowering 的方向正确。
- `03_sort2` 已从明显慢项变成快项,说明 power-of-two digit extract、常数除法/取模 lowering 已经有实际收益。
- `h-2-*`、`h-4-*`、`if-combine*` 的收益说明中端 GVN/LSE/LICM 和部分后端 peephole 已经在某些结构上命中。
## 当前优化优先级
1. 优先分析 `2025-MYO-20`。这个样例单点亏损最大,应使用 `scripts/analyze_case.sh` 保存 IR 和汇编先确认瓶颈是循环结构、内存访问、调用、spill 还是地址计算。
2. 继续做矩阵类内核优化。`matmul1/2/3` 的差距很集中,下一步应优先看循环层次、地址递推、寄存器复用和保守 NEON而不是继续做零散 peephole。
3. 针对 `gameoflife*` 做 stencil 优化。重点是行缓存、邻域 load 复用、局部数组 promotion以及可证明安全的短向量化。
4. 对 `h-14-01`、`h-11-01`、`h-1-01`、`h-12-01` 做专项拆解。这些样例总时间大,需要逐个确认是否存在尾递归、循环不变量 load、跨块冗余 load/store、或后端 spill 过多。
5. `65_color``29_long_line` 比例难看,但绝对亏损小。它们不是性能分第一优先级;`29_long_line` 更应该作为编译耗时风险样例关注。
## 结论
当前编译器已经能完整通过最新 Lab3 回归,并且在 `fft`、`03_sort2`、部分 `h-2/h-4/if-combine` 样例上体现出明显优化收益。但从比赛性能角度看,总体仍比 GCC baseline 慢约 `60.40s`,主要差距来自 `2025-MYO-20`、矩阵计算、gameoflife stencil 以及若干大规模 h_performance 样例。下一轮优化应围绕这些大头做专项分析,而不是优先处理低绝对耗时的小比例样例。

@ -8,6 +8,9 @@ void RunMem2Reg(Module& module);
bool RunConstFold(Module& module);
bool RunConstProp(Module& module);
bool RunFunctionInlining(Module& module);
bool RunTailRecursionElim(Module& module);
bool RunInterproceduralConstProp(Module& module);
bool RunArithmeticSimplify(Module& module);
bool RunCSE(Module& module);
bool RunGVN(Module& module);
bool RunLoadStoreElim(Module& module);
@ -19,6 +22,8 @@ bool RunLoopUnswitch(Module& module);
bool RunLoopStrengthReduction(Module& module);
bool RunLoopUnroll(Module& module);
bool RunLoopFission(Module& module);
bool RunLoopRepeatReduction(Module& module);
bool RunIfConversion(Module& module);
void RunIRPassPipeline(Module& module);
} // namespace ir

@ -110,20 +110,25 @@ class MachineInstr {
Lea,
Add,
Sub,
Mul,
Div,
Rem,
And,
Or,
Mul,
Div,
Rem,
ModMul,
ModPow,
DigitExtractPow2,
BitTestMask,
And,
Or,
Xor,
Shl,
AShr,
LShr,
FAdd,
FSub,
FMul,
FDiv,
FNeg,
FAdd,
FSub,
FMul,
FDiv,
FSqrt,
FNeg,
ICmp,
FCmp,
ZExt,

@ -0,0 +1,324 @@
#!/usr/bin/env bash
# analyze_case.sh — 单个 .sy 测试用例的全流程编译 + IR/汇编保存脚本
# 用于深度分析单个样例与 GCC 基线之间的差距。
#
# 用法:
# analyze_case.sh <input.sy> [output_dir]
#
# 输出目录(默认 output/analyze/<stem>_<timestamp>)中包含:
# <stem>.ll — 我方编译器输出的 LLVM IR
# <stem>.s — 我方编译器输出的 AArch64 汇编
# <stem>.elf — 我方编译链接后的可执行文件
# <stem>.gcc.s — GCC -O2 输出的 AArch64 汇编
# <stem>.gcc.elf — GCC -O2 链接后的可执行文件
# <stem>.our.time — 我方程序运行耗时(秒)
# <stem>.gcc.time — GCC 程序运行耗时(秒)
# <stem>.our.out — 我方程序实际输出
# <stem>.gcc.out — GCC 程序实际输出
# <stem>.diff — 输出 diff若有差异
# report.txt — 汇总报告IR 行数、汇编行数、耗时、加速比)
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
CYAN='\033[0;36m'
BOLD='\033[1m'
NC='\033[0m'
# ---------- 参数解析 ----------
if [[ $# -lt 1 || $# -gt 2 ]]; then
printf 'usage: %s <input.sy> [output_dir]\n' "$0" >&2
exit 1
fi
INPUT="$1"
if [[ ! -f "$INPUT" ]]; then
printf 'input file not found: %s\n' "$INPUT" >&2
exit 1
fi
BASE="$(basename "$INPUT")"
STEM="${BASE%.sy}"
INPUT_DIR="$(dirname "$(realpath "$INPUT")")"
TIMESTAMP="$(date +%Y%m%d_%H%M%S)"
# 与 run_baseline.sh 一致的路径键:去掉 test/ 前缀和 .sy 后缀
REL="$(realpath --relative-to="$REPO_ROOT" "$INPUT" 2>/dev/null || echo "$INPUT")"
CASE_KEY="${REL#test/}"
CASE_KEY="${CASE_KEY%.sy}"
if [[ $# -ge 2 ]]; then
OUT_DIR="$2"
else
OUT_DIR="$REPO_ROOT/output/analyze/${STEM}_${TIMESTAMP}"
fi
mkdir -p "$OUT_DIR"
REPORT="$OUT_DIR/report.txt"
: > "$REPORT"
rpt() {
printf '%s\n' "$*" | tee -a "$REPORT"
}
rpt_color() {
local color="$1"; shift
printf '%b%s%b\n' "$color" "$*" "$NC"
printf '%s\n' "$*" >> "$REPORT"
}
rpt "============================================================"
rpt " analyze_case report"
rpt " case : $STEM"
rpt " source : $INPUT"
rpt " output : $OUT_DIR"
rpt " date : $(date)"
rpt "============================================================"
rpt ""
# ---------- 查找编译器 ----------
COMPILER=""
for candidate in \
"$REPO_ROOT/build_lab3/bin/compiler" \
"$REPO_ROOT/build_lab2/bin/compiler" \
"$REPO_ROOT/build/bin/compiler"; do
if [[ -x "$candidate" ]]; then
COMPILER="$candidate"
break
fi
done
if [[ -z "$COMPILER" ]]; then
rpt_color "$RED" "ERROR: compiler not found. Build first:"
rpt " cmake -S $REPO_ROOT -B $REPO_ROOT/build_lab3 && cmake --build $REPO_ROOT/build_lab3 -j"
exit 1
fi
rpt "compiler : $COMPILER"
# ---------- 工具检查 ----------
for tool in aarch64-linux-gnu-gcc qemu-aarch64; do
if ! command -v "$tool" >/dev/null 2>&1; then
rpt_color "$RED" "ERROR: required tool not found: $tool"
exit 1
fi
done
STDIN_FILE="$INPUT_DIR/$STEM.in"
EXPECTED_FILE="$INPUT_DIR/$STEM.out"
# ---------- 1. 生成 IR ----------
rpt ""
rpt "--- [1/5] Generating LLVM IR ---"
IR_FILE="$OUT_DIR/$STEM.ll"
if "$COMPILER" --emit-ir "$INPUT" > "$IR_FILE" 2>"$OUT_DIR/$STEM.ir.err"; then
IR_LINES=$(wc -l < "$IR_FILE")
rpt_color "$GREEN" "IR generated: $IR_FILE ($IR_LINES lines)"
else
rpt_color "$RED" "ERROR: IR generation failed"
cat "$OUT_DIR/$STEM.ir.err" >&2
exit 1
fi
# ---------- 2. 生成我方汇编并链接 ----------
rpt ""
rpt "--- [2/5] Generating our ASM & linking ---"
OUR_ASM="$OUT_DIR/$STEM.s"
OUR_ELF="$OUT_DIR/$STEM.elf"
if "$COMPILER" --emit-asm "$INPUT" > "$OUR_ASM" 2>"$OUT_DIR/$STEM.asm.err"; then
OUR_ASM_LINES=$(wc -l < "$OUR_ASM")
rpt_color "$GREEN" "ASM generated: $OUR_ASM ($OUR_ASM_LINES lines)"
else
rpt_color "$RED" "ERROR: ASM generation failed"
cat "$OUT_DIR/$STEM.asm.err" >&2
exit 1
fi
if aarch64-linux-gnu-gcc "$OUR_ASM" "$REPO_ROOT/sylib/sylib.c" -O2 \
-I "$REPO_ROOT/sylib" -lm -o "$OUR_ELF" 2>"$OUT_DIR/$STEM.link.err"; then
rpt_color "$GREEN" "Linked: $OUR_ELF"
else
rpt_color "$RED" "ERROR: link failed"
cat "$OUT_DIR/$STEM.link.err" >&2
exit 1
fi
# ---------- 3. GCC -O2 基线(从预计算数据读取)----------
rpt ""
rpt "--- [3/5] GCC -O2 baseline (reading from pre-computed data) ---"
BASELINE_DATA_DIR="$REPO_ROOT/output/baseline"
BASELINE_TSV_PATH="$BASELINE_DATA_DIR/gcc_timing.tsv"
GCC_ASM="$OUT_DIR/$STEM.gcc.s"
GCC_OUT="$OUT_DIR/$STEM.gcc.out"
GCC_OK=false
GCC_ASM_LINES=0
GCC_ELAPSED_RAW="" # 秒,无 s 后缀
if [[ -f "$BASELINE_TSV_PATH" ]]; then
GCC_ELAPSED_RAW=$(awk -F'\t' -v s="$CASE_KEY" '$1==s{v=$2} END{if(v!="") print v}' \
"$BASELINE_TSV_PATH" 2>/dev/null || true)
if [[ -n "$GCC_ELAPSED_RAW" ]]; then
GCC_OK=true
rpt_color "$GREEN" "baseline timing: ${GCC_ELAPSED_RAW}s"
else
rpt_color "$YELLOW" "WARNING: no baseline entry for '$CASE_KEY'"
rpt " Run: scripts/run_baseline.sh"
fi
# 复制汇编文件(路径镜像结构)
local_gcc_asm="$BASELINE_DATA_DIR/${CASE_KEY}.gcc.s"
if [[ -f "$local_gcc_asm" ]]; then
cp "$local_gcc_asm" "$GCC_ASM"
GCC_ASM_LINES=$(wc -l < "$GCC_ASM")
rpt "GCC ASM: $GCC_ASM ($GCC_ASM_LINES lines)"
else
rpt_color "$YELLOW" "GCC ASM not found in baseline dir: $local_gcc_asm"
fi
# 复制输出文件供步骤5 diff
local_gcc_out="$BASELINE_DATA_DIR/${CASE_KEY}.gcc.out"
if [[ -f "$local_gcc_out" ]]; then
cp "$local_gcc_out" "$GCC_OUT"
rpt "GCC output: $GCC_OUT"
fi
else
rpt_color "$YELLOW" "WARNING: baseline data not found: $BASELINE_TSV_PATH"
rpt " Run: scripts/run_baseline.sh"
rpt " to pre-compute GCC -O2 baseline for all test cases."
fi
# ---------- 4. 运行并计时(仅我方编译器)----------
rpt ""
rpt "--- [4/5] Running & timing (our compiler) ---"
run_and_time() {
local label="$1"
local exe="$2"
local out_file="$3"
local timeout_sec="${4:-60}"
local stdout_file="$out_file.raw"
local status=0
local _t0 _t1 _ns
_t0=$(date +%s%N)
set +e
if [[ -f "$STDIN_FILE" ]]; then
timeout "$timeout_sec" \
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" \
< "$STDIN_FILE" > "$stdout_file" 2>/dev/null
else
timeout "$timeout_sec" \
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" \
> "$stdout_file" 2>/dev/null
fi
status=$?
_t1=$(date +%s%N)
_ns=$((_t1 - _t0))
set -e
# 将 stdout + exit_code 合并为 .out与 verify_asm.sh 格式一致)
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$out_file"
rm -f "$stdout_file"
local elapsed
if [[ $status -eq 124 ]]; then
elapsed="timeout"
rpt_color "$YELLOW" "$label: TIMEOUT (>${timeout_sec}s)" >&2
else
elapsed=$(awk "BEGIN{printf \"%.5f\", $_ns / 1000000000}")
if [[ $status -ne 0 ]]; then
rpt_color "$YELLOW" "$label: exit $status elapsed=${elapsed}s" >&2
else
rpt_color "$GREEN" "$label: OK elapsed=${elapsed}s" >&2
fi
fi
echo "$elapsed"
}
OUR_OUT="$OUT_DIR/$STEM.our.out"
TIMEOUT_SEC=60
[[ "$INPUT" == *"/performance/"* || "$INPUT" == *"/h_performance/"* ]] && TIMEOUT_SEC=300
OUR_ELAPSED=$(run_and_time "our compiler" "$OUR_ELF" "$OUR_OUT" "$TIMEOUT_SEC")
# GCC 耗时直接读取基线数据,不重新运行
GCC_ELAPSED="N/A"
if [[ "$GCC_OK" == true && -n "$GCC_ELAPSED_RAW" ]]; then
GCC_ELAPSED="${GCC_ELAPSED_RAW}s"
rpt_color "$GREEN" "gcc -O2: ${GCC_ELAPSED} (from pre-computed baseline)"
fi
# ---------- 5. 输出对比 ----------
rpt ""
rpt "--- [5/5] Output comparison ---"
normalize_out() {
awk '{ sub(/\r$/, ""); print }' "$1"
}
if [[ -f "$EXPECTED_FILE" ]]; then
DIFF_FILE="$OUT_DIR/$STEM.diff"
if diff <(normalize_out "$EXPECTED_FILE") <(normalize_out "$OUR_OUT") > "$DIFF_FILE" 2>&1; then
rpt_color "$GREEN" "our output: MATCH expected"
rm -f "$DIFF_FILE"
else
rpt_color "$RED" "our output: MISMATCH — diff saved to $DIFF_FILE"
fi
if [[ "$GCC_OK" == true && -f "$GCC_OUT" ]]; then
GCC_DIFF_FILE="$OUT_DIR/$STEM.gcc.diff"
if diff <(normalize_out "$EXPECTED_FILE") <(normalize_out "$GCC_OUT") > "$GCC_DIFF_FILE" 2>&1; then
rpt_color "$GREEN" "gcc output: MATCH expected"
rm -f "$GCC_DIFF_FILE"
else
rpt_color "$YELLOW" "gcc output: MISMATCH — diff saved to $GCC_DIFF_FILE"
fi
fi
else
rpt_color "$YELLOW" "no expected output file found, skipping diff"
fi
# ---------- 汇总报告 ----------
rpt ""
rpt "============================================================"
rpt_color "$BOLD" " Summary"
rpt "============================================================"
rpt "$(printf '%-20s %s' 'IR lines:' "$IR_LINES")"
rpt "$(printf '%-20s %s' 'Our ASM lines:' "$OUR_ASM_LINES")"
if [[ "$GCC_OK" == true && $GCC_ASM_LINES -gt 0 ]]; then
rpt "$(printf '%-20s %s' 'GCC ASM lines:' "$GCC_ASM_LINES")"
rpt "$(printf '%-20s %s' 'ASM ratio (ours/gcc):' \
"$(awk "BEGIN{if($GCC_ASM_LINES>0) printf \"%.2f\", $OUR_ASM_LINES/$GCC_ASM_LINES; else print \"N/A\"}")")"
fi
rpt "$(printf '%-20s %s' 'Our time:' "$OUR_ELAPSED")"
rpt "$(printf '%-20s %s' 'GCC time:' "$GCC_ELAPSED")"
if [[ "$GCC_ELAPSED" != "N/A" && "$GCC_ELAPSED" != "timeout" && "$OUR_ELAPSED" != "timeout" ]]; then
OUR_S="${OUR_ELAPSED%s}"
GCC_S="${GCC_ELAPSED%s}"
SPEEDUP=$(awk "BEGIN{if($OUR_S>0) printf \"%.5f\", $GCC_S/$OUR_S; else print \"inf\"}")
rpt "$(printf '%-20s %sx' 'Speedup (gcc/ours):' "$SPEEDUP")"
fi
rpt ""
rpt "Output directory: $OUT_DIR"
rpt "============================================================"
printf '\n%bReport saved to: %s%b\n' "$CYAN" "$REPORT" "$NC"

@ -0,0 +1,170 @@
#!/usr/bin/env bash
# clean_outputs.sh — 清理编译输出与日志垃圾文件
#
# 用法:
# clean_outputs.sh [选项]
#
# 选项:
# --logs 清理 output/logs/ 下的运行日志(保留 last_run.txt / last_failed.txt
# --analyze 清理 output/analyze/ 下的单用例分析结果
# --build 清理 build_lab*/ 构建目录
# --test-result 清理 test/test_result/ 下的测试产物
# --all 清理以上全部
# --dry-run 只打印将要删除的内容,不实际删除
# --yes 跳过确认提示,直接删除(配合 --logs / --all 等使用)
#
# 不带任何选项时交互式选择。
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
CYAN='\033[0;36m'
NC='\033[0m'
DO_LOGS=false
DO_ANALYZE=false
DO_BUILD=false
DO_TEST_RESULT=false
DRY_RUN=false
AUTO_YES=false
if [[ $# -eq 0 ]]; then
# 交互模式
printf '%bclean_outputs.sh — interactive mode%b\n' "$CYAN" "$NC"
printf 'Select what to clean (space-separated numbers, e.g. "1 3"):\n'
printf ' 1) output/logs/ — run logs\n'
printf ' 2) output/analyze/ — single-case analysis results\n'
printf ' 3) build_lab*/ — CMake build directories\n'
printf ' 4) test/test_result/ — test artifacts\n'
printf ' 0) cancel\n'
read -r -p 'choice: ' choices
for c in $choices; do
case "$c" in
1) DO_LOGS=true ;;
2) DO_ANALYZE=true ;;
3) DO_BUILD=true ;;
4) DO_TEST_RESULT=true ;;
0) printf 'cancelled.\n'; exit 0 ;;
*) printf '%bunknown option: %s (ignored)%b\n' "$YELLOW" "$c" "$NC" ;;
esac
done
fi
while [[ $# -gt 0 ]]; do
case "$1" in
--logs) DO_LOGS=true ;;
--analyze) DO_ANALYZE=true ;;
--build) DO_BUILD=true ;;
--test-result) DO_TEST_RESULT=true ;;
--all) DO_LOGS=true; DO_ANALYZE=true; DO_BUILD=true; DO_TEST_RESULT=true ;;
--dry-run) DRY_RUN=true ;;
--yes|-y) AUTO_YES=true ;;
*)
printf '%bunknown option: %s%b\n' "$YELLOW" "$1" "$NC" >&2
;;
esac
shift
done
if [[ "$DO_LOGS" == false && "$DO_ANALYZE" == false && \
"$DO_BUILD" == false && "$DO_TEST_RESULT" == false ]]; then
printf 'nothing selected. use --help or run without arguments for interactive mode.\n' >&2
exit 0
fi
# ---------- 收集要删除的路径 ----------
declare -a TARGETS=()
if [[ "$DO_LOGS" == true ]]; then
LOG_ROOT="$REPO_ROOT/output/logs"
if [[ -d "$LOG_ROOT" ]]; then
# 删除所有子目录(即每次的 run dir保留 last_run.txt / last_failed.txt
while IFS= read -r -d '' d; do
TARGETS+=("$d")
done < <(find "$LOG_ROOT" -mindepth 2 -maxdepth 2 -type d -print0 2>/dev/null)
fi
fi
if [[ "$DO_ANALYZE" == true ]]; then
ANALYZE_ROOT="$REPO_ROOT/output/analyze"
if [[ -d "$ANALYZE_ROOT" ]]; then
while IFS= read -r -d '' d; do
TARGETS+=("$d")
done < <(find "$ANALYZE_ROOT" -mindepth 1 -maxdepth 1 -print0 2>/dev/null)
fi
fi
if [[ "$DO_BUILD" == true ]]; then
while IFS= read -r -d '' d; do
TARGETS+=("$d")
done < <(find "$REPO_ROOT" -maxdepth 1 -type d -name 'build_lab*' -print0 2>/dev/null)
fi
if [[ "$DO_TEST_RESULT" == true ]]; then
TR_ROOT="$REPO_ROOT/test/test_result"
if [[ -d "$TR_ROOT" ]]; then
TARGETS+=("$TR_ROOT")
fi
fi
if [[ ${#TARGETS[@]} -eq 0 ]]; then
printf '%bNothing to clean — target directories are already empty or do not exist.%b\n' "$GREEN" "$NC"
exit 0
fi
# ---------- 打印列表 ----------
printf '\n%bThe following will be %s:%b\n' "$YELLOW" \
"$([[ "$DRY_RUN" == true ]] && echo "listed (dry-run)" || echo "DELETED")" "$NC"
TOTAL_SIZE=0
for t in "${TARGETS[@]}"; do
SIZE=$(du -sh "$t" 2>/dev/null | cut -f1 || echo "?")
printf ' [%s] %s\n' "$SIZE" "$t"
done
printf '\n'
if [[ "$DRY_RUN" == true ]]; then
printf '%bDry-run mode: nothing deleted.%b\n' "$CYAN" "$NC"
exit 0
fi
# ---------- 确认 ----------
if [[ "$AUTO_YES" == false ]]; then
read -r -p "Proceed with deletion? [y/N] " confirm
case "$confirm" in
[yY][eE][sS]|[yY]) ;;
*)
printf 'cancelled.\n'
exit 0
;;
esac
fi
# ---------- 删除 ----------
DELETED=0
ERRORS=0
for t in "${TARGETS[@]}"; do
if rm -rf "$t" 2>/dev/null; then
printf '%b deleted: %s%b\n' "$GREEN" "$t" "$NC"
DELETED=$((DELETED + 1))
else
printf '%b ERROR deleting: %s%b\n' "$RED" "$t" "$NC"
ERRORS=$((ERRORS + 1))
fi
done
printf '\n'
if [[ $ERRORS -eq 0 ]]; then
printf '%bDone. %d item(s) deleted.%b\n' "$GREEN" "$DELETED" "$NC"
else
printf '%bDone. %d deleted, %d errors.%b\n' "$YELLOW" "$DELETED" "$ERRORS" "$NC"
exit 1
fi

@ -19,6 +19,7 @@ FALLBACK_TO_FULL=false
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
CYAN='\033[0;36m'
NC='\033[0m'
TEST_DIRS=()
@ -104,8 +105,8 @@ now_ns() {
format_duration_ns() {
local ns="$1"
local sec=$((ns / 1000000000))
local ms=$(((ns % 1000000000) / 1000000))
printf '%d.%03ds' "$sec" "$ms"
local us10=$(((ns % 1000000000) / 10000))
printf '%d.%05ds' "$sec" "$us10"
}
is_transient_io_failure() {
@ -116,9 +117,34 @@ is_transient_io_failure() {
"$log_file"
}
# ---------- baseline 读取 & timing ----------
# 共享基线数据(由 run_baseline.sh 生成)
BASELINE_TSV="$REPO_ROOT/output/baseline/gcc_timing.tsv"
# 本次运行的我方计时 TSVstem<TAB>our_ns<TAB>gcc_s
TIMING_TSV="$RUN_DIR/timing.tsv"
# 从共享 TSV 查找某 stem 的 GCC 基线耗时(秒),找不到返回 N/A
lookup_gcc_s() {
local stem="$1"
local val="N/A"
if [[ -f "$BASELINE_TSV" ]]; then
val=$(awk -F'\t' -v s="$stem" '$1==s{v=$2} END{if(v!="") print v; else print "N/A"}' "$BASELINE_TSV")
fi
echo "$val"
}
record_timing() {
local stem="$1"
local our_ns="$2"
local gcc_s="${3:-N/A}"
printf '%s\t%s\t%s\n' "$stem" "$our_ns" "$gcc_s" >> "$TIMING_TSV"
}
test_one() {
local sy_file="$1"
local rel="$2"
local timing_out="${3:-}"
local safe_name="${rel//\//_}"
local case_key="${safe_name%.sy}"
local tmp_dir="$RUN_DIR/.tmp/$case_key"
@ -132,7 +158,10 @@ test_one() {
cleanup_tmp_dir "$tmp_dir"
mkdir -p "$tmp_dir"
if "$VERIFY_SCRIPT" "$sy_file" "$tmp_dir" --run > "$case_log" 2>&1; then
local verify_args=("$sy_file" "$tmp_dir" --run)
[[ -n "$timing_out" ]] && verify_args+=(--timing-out "$timing_out")
if "$VERIFY_SCRIPT" "${verify_args[@]}" > "$case_log" 2>&1; then
cleanup_tmp_dir "$tmp_dir"
return 0
fi
@ -156,19 +185,35 @@ run_case() {
local sy_file="$1"
local rel
local case_start_ns
local case_end_ns
local case_elapsed_ns
rel="$(realpath --relative-to="$REPO_ROOT" "$sy_file")"
case_start_ns=$(now_ns)
if test_one "$sy_file" "$rel"; then
case_end_ns=$(now_ns)
case_elapsed_ns=$((case_end_ns - case_start_ns))
log_color "$GREEN" "PASS $rel [$(format_duration_ns "$case_elapsed_ns")]"
local base stem case_key
base="$(basename "$sy_file")"
stem="${base%.sy}"
# 与 run_baseline.sh 保持一致:去掉 test/ 前缀和 .sy 后缀
case_key="${rel#test/}"
case_key="${case_key%.sy}"
local timing_file
timing_file="$(mktemp)"
if test_one "$sy_file" "$rel" "$timing_file"; then
local compile_ns=0 run_ns=0
if [[ -f "$timing_file" ]]; then
compile_ns=$(grep '^compile_ns=' "$timing_file" | cut -d= -f2 || echo 0)
run_ns=$(grep '^run_ns=' "$timing_file" | cut -d= -f2 || echo 0)
fi
rm -f "$timing_file"
log_color "$GREEN" "PASS $rel [compile=$(format_duration_ns "$compile_ns") run=$(format_duration_ns "$run_ns")]"
PASS=$((PASS + 1))
local gcc_s
gcc_s=$(lookup_gcc_s "$case_key")
record_timing "$case_key" "$run_ns" "$gcc_s"
else
case_end_ns=$(now_ns)
case_elapsed_ns=$((case_end_ns - case_start_ns))
rm -f "$timing_file"
local case_elapsed_ns=$(( $(now_ns) - case_start_ns ))
log_color "$RED" "FAIL $rel [$(format_duration_ns "$case_elapsed_ns")]"
FAIL=$((FAIL + 1))
FAIL_LIST+=("$rel")
@ -176,6 +221,7 @@ run_case() {
}
TOTAL_START_NS=$(now_ns)
: > "$TIMING_TSV"
if [[ "$FAILED_ONLY" == true ]]; then
if [[ -f "$LAST_FAILED_FILE" ]]; then
@ -209,6 +255,11 @@ fi
if [[ "$FALLBACK_TO_FULL" == true ]]; then
log_color "$YELLOW" "No cached failed cases found, fallback to full suite."
fi
if [[ -f "$BASELINE_TSV" ]]; then
log_plain "Baseline TSV: $BASELINE_TSV (speedup ratios will be computed)"
else
log_color "$CYAN" "Tip: run scripts/run_baseline.sh first to enable GCC -O2 speedup analysis."
fi
if [[ ! -f "$VERIFY_SCRIPT" ]]; then
log_color "$RED" "missing verify script: $VERIFY_SCRIPT"
@ -280,6 +331,93 @@ log_plain "summary: ${PASS} PASS / ${FAIL} FAIL / total $((PASS + FAIL))"
log_plain "build elapsed: $(format_duration_ns "$BUILD_ELAPSED_NS")"
log_plain "validation elapsed: $(format_duration_ns "$VALIDATION_ELAPSED_NS")"
log_plain "total elapsed: $(format_duration_ns "$TOTAL_ELAPSED_NS")"
# ---------- 计时与加速比分析 ----------
if [[ -s "$TIMING_TSV" ]]; then
log_plain ""
log_plain "==> Timing & Speedup Analysis"
# 检查本次结果中是否有任何 GCC 基线数据
HAS_BASELINE=false
if grep -qv $'\tN/A$' "$TIMING_TSV" 2>/dev/null; then
HAS_BASELINE=true
fi
if [[ "$HAS_BASELINE" == true ]]; then
# 将 TSV 展开为含计算值的临时文件case_key, our_s, gcc_s, speedup
_tmp_timing="$RUN_DIR/timing_computed.tsv"
while IFS=$'\t' read -r case_key our_ns gcc_s; do
our_s=$(awk "BEGIN{printf \"%.5f\", $our_ns / 1000000000}")
if [[ "$gcc_s" == "N/A" ]]; then
speedup="N/A"
else
speedup=$(awk "BEGIN{if($our_s>0) printf \"%.5f\", $gcc_s/$our_s; else print \"inf\"}")
fi
printf '%s\t%s\t%s\t%s\n' "$case_key" "$our_s" "$gcc_s" "$speedup"
done < "$TIMING_TSV" > "$_tmp_timing"
# 排序1加速比升序N/A 排最后)
log_plain ""
log_plain "--- [Sort 1] Speedup ratio ascending (worst speedup first) ---"
log_plain "$(printf '%-40s %10s %10s %10s' 'case' 'our(s)' 'gcc(s)' 'speedup')"
log_plain "$(printf '%0.s-' {1..76})"
{
grep -v $'\tN/A$' "$_tmp_timing" | sort -t$'\t' -k4 -n || true
grep $'\tN/A$' "$_tmp_timing" | sort -t$'\t' -k1 || true
} | while IFS=$'\t' read -r case_key our_s gcc_s speedup; do
disp="${case_key##*/}"
if [[ "$speedup" == "N/A" ]]; then
log_plain "$(printf '%-40s %10s %10s %10s' "$disp" "${our_s}s" "N/A" "N/A")"
else
log_plain "$(printf '%-40s %10s %10s %9sx' "$disp" "${our_s}s" "${gcc_s}s" "$speedup")"
fi
done
# 排序2我方总用时降序
log_plain ""
log_plain "--- [Sort 2] Our elapsed time descending (slowest first) ---"
log_plain "$(printf '%-40s %10s %10s %10s' 'case' 'our(s)' 'gcc(s)' 'speedup')"
log_plain "$(printf '%0.s-' {1..76})"
sort -t$'\t' -k2 -rn "$_tmp_timing" | \
while IFS=$'\t' read -r case_key our_s gcc_s speedup; do
disp="${case_key##*/}"
if [[ "$speedup" == "N/A" ]]; then
log_plain "$(printf '%-40s %10s %10s %10s' "$disp" "${our_s}s" "N/A" "N/A")"
else
log_plain "$(printf '%-40s %10s %10s %9sx' "$disp" "${our_s}s" "${gcc_s}s" "$speedup")"
fi
done
rm -f "$_tmp_timing"
else
# 无基线:只输出总用时降序
log_plain ""
log_plain "--- [Sort] Our elapsed time descending (slowest first) ---"
log_plain "$(printf '%-40s %10s' 'case' 'our(s)')"
log_plain "$(printf '%0.s-' {1..54})"
while IFS=$'\t' read -r case_key our_ns _; do
our_s=$(awk "BEGIN{printf \"%.5f\", $our_ns / 1000000000}")
printf '%s\t%s\n' "$case_key" "$our_s"
done < "$TIMING_TSV" | \
sort -t$'\t' -k2 -rn | \
while IFS=$'\t' read -r case_key our_s; do
disp="${case_key##*/}"
log_plain "$(printf '%-40s %10ss' "$disp" "$our_s")"
done
log_plain ""
log_color "$CYAN" "Tip: run scripts/run_baseline.sh to compute GCC -O2 baseline for speedup analysis."
fi
log_plain ""
log_plain "timing data saved to: $TIMING_TSV"
fi
# ---------- 失败用例列表 ----------
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
log_plain "failed cases:"
for f in "${FAIL_LIST[@]}"; do

@ -0,0 +1,326 @@
#!/usr/bin/env bash
# run_baseline.sh — 批量编译 GCC -O2 基线并保存汇编、输出与运行时间
#
# 数据统一保存在 output/baseline/
# gcc_timing.tsv — stem<TAB>gcc_elapsed_s (所有脚本的共享数据源)
# <stem>.gcc.s — GCC -O2 AArch64 汇编(供 analyze_case.sh 对比)
# <stem>.gcc.out — GCC 程序实际输出 stdout+exit_code供 analyze_case.sh 对比)
#
# 用法:
# run_baseline.sh [--update] [test_dir|file ...]
#
# --update 重新计算所有条目(默认跳过 gcc_timing.tsv 中已有的 stem
#
# 若不指定测试目录/文件,自动扫描 test/test_case 和 test/class_test_case
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
BASELINE_DIR="$REPO_ROOT/output/baseline"
TIMING_TSV="$BASELINE_DIR/gcc_timing.tsv"
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
CYAN='\033[0;36m'
NC='\033[0m'
UPDATE=false
TEST_DIRS=()
TEST_FILES=()
while [[ $# -gt 0 ]]; do
case "$1" in
--update) UPDATE=true ;;
*)
if [[ -f "$1" ]]; then
TEST_FILES+=("$1")
else
TEST_DIRS+=("$1")
fi
;;
esac
shift
done
# ---------- 工具检查 ----------
for tool in aarch64-linux-gnu-gcc qemu-aarch64; do
if ! command -v "$tool" >/dev/null 2>&1; then
printf '%bERROR: required tool not found: %s%b\n' "$RED" "$tool" "$NC" >&2
exit 1
fi
done
if [[ ! -x /usr/bin/time ]]; then
printf '%bERROR: /usr/bin/time not found%b\n' "$RED" "$NC" >&2
exit 1
fi
mkdir -p "$BASELINE_DIR"
# 是否已存在某 stem 的基线数据(直接查 TSV 文件,避免关联数组兼容性问题)
stem_is_cached() {
local key="$1"
[[ -f "$TIMING_TSV" ]] && grep -qF "${key} " "$TIMING_TSV" 2>/dev/null
}
stem_cached_time() {
local key="$1"
awk -F'\t' -v s="$key" '$1==s{print $2; exit}' "$TIMING_TSV" 2>/dev/null || true
}
# ---------- 测试用例发现 ----------
discover_default_test_dirs() {
local roots=(
"$REPO_ROOT/test/test_case"
"$REPO_ROOT/test/class_test_case"
)
local root
for root in "${roots[@]}"; do
[[ -d "$root" ]] || continue
find "$root" -mindepth 1 -maxdepth 1 -type d -print0
done | sort -z
}
if [[ ${#TEST_DIRS[@]} -eq 0 && ${#TEST_FILES[@]} -eq 0 ]]; then
while IFS= read -r -d '' d; do
TEST_DIRS+=("$d")
done < <(discover_default_test_dirs)
fi
# ---------- 计时工具 ----------
now_ns() { date +%s%N; }
format_duration_ns() {
local ns="$1"
printf '%d.%05ds' "$((ns / 1000000000))" "$(((ns % 1000000000) / 10000))"
}
# ---------- 处理单个用例 ----------
PASS=0
SKIP=0
FAIL=0
process_case() {
local sy_file="$1"
local base stem input_dir stdin_file
base="$(basename "$sy_file")"
stem="${base%.sy}"
input_dir="$(dirname "$sy_file")"
stdin_file="$input_dir/$stem.in"
local rel
rel="$(realpath --relative-to="$REPO_ROOT" "$sy_file")"
# 路径键:去掉 test/ 前缀和 .sy 后缀,保留完整目录结构
# 例test/class_test_case/h_functional/11_BST.sy → class_test_case/h_functional/11_BST
local case_key
case_key="${rel#test/}"
case_key="${case_key%.sy}"
local case_start_ns
case_start_ns=$(now_ns)
# 已有数据且不强制更新 → 跳过
if [[ "$UPDATE" == false ]] && stem_is_cached "$case_key"; then
printf '%b SKIP %s (cached: %ss)%b\n' \
"$CYAN" "$rel" "$(stem_cached_time "$case_key")" "$NC"
SKIP=$((SKIP + 1))
return 0
fi
# 输出目录镜像源路径结构
local case_out_dir
case_out_dir="$BASELINE_DIR/$(dirname "$case_key")"
mkdir -p "$case_out_dir"
local gcc_elf gcc_asm gcc_out gcc_err
gcc_elf="$case_out_dir/$stem.gcc.elf"
gcc_asm="$case_out_dir/$stem.gcc.s"
gcc_out="$case_out_dir/$stem.gcc.out"
gcc_err="$case_out_dir/$stem.gcc.err"
# 预处理:把 "const int NAME = EXPR;" 转为 "#define NAME ((int)(EXPR))"
# 同时处理多声明符const int A=1, B=2; → #define A ((int)(1))\n#define B ((int)(2))
# 原因SysY const int 是编译期常量C 模式下不能用于全局数组维度,#define 可以
local tmp_sy
tmp_sy="$(mktemp /tmp/sysy_XXXXXX.c)"
python3 - "$sy_file" "$tmp_sy" << 'PYEOF'
import re, sys
pat = re.compile(
r'^(\s*)const\s+int\s+((?:[A-Za-z_]\w*\s*=\s*[^,;]+)(?:,\s*[A-Za-z_]\w*\s*=\s*[^,;]+)*)\s*;',
re.MULTILINE
)
def replace(m):
indent = m.group(1)
decls = re.split(r',\s*(?=[A-Za-z_])', m.group(2))
lines = []
for d in decls:
name, _, val = d.partition('=')
lines.append(f'{indent}#define {name.strip()} ((int)({val.strip()}))')
return '\n'.join(lines)
with open(sys.argv[1]) as f:
src = f.read()
with open(sys.argv[2], 'w') as f:
f.write(pat.sub(replace, src))
PYEOF
# 步骤1编译链接C 模式,用于运行计时)
# -x c允许 delete/new/class 等作为标识符
# -include sylib.h强制注入 SysY 运行时声明(.sy 无 #include
# 无名称修饰,直接链接同为 C 编译的 sylib.o
if ! aarch64-linux-gnu-gcc -O2 \
-x c -include "$REPO_ROOT/sylib/sylib.h" \
-I "$REPO_ROOT/sylib" \
"$tmp_sy" -x none "$SYLIB_OBJ" \
-lm -o "$gcc_elf" > "$gcc_err" 2>&1; then
rm -f "$tmp_sy"
printf '%b FAIL %s (GCC compile error — see %s)%b\n' \
"$RED" "$rel" "$gcc_err" "$NC"
FAIL=$((FAIL + 1))
return 0
fi
# 步骤2生成汇编单独 -S仅针对 .sy 文件本身)
aarch64-linux-gnu-gcc -O2 \
-x c -include "$REPO_ROOT/sylib/sylib.h" \
-I "$REPO_ROOT/sylib" \
"$tmp_sy" -S -o "$gcc_asm" 2>/dev/null || true
rm -f "$tmp_sy"
# 步骤3运行并计时手动 ns 计时,精度 5 位小数)
local stdout_file="$case_out_dir/$stem.gcc.stdout"
local status=0
local timeout_sec=60
[[ "$sy_file" == *"/performance/"* || "$sy_file" == *"/h_performance/"* ]] && timeout_sec=300
local run_start_ns run_end_ns run_elapsed_ns
run_start_ns=$(now_ns)
set +e
if [[ -f "$stdin_file" ]]; then
timeout "$timeout_sec" \
qemu-aarch64 -L /usr/aarch64-linux-gnu "$gcc_elf" \
< "$stdin_file" > "$stdout_file" 2>/dev/null
else
timeout "$timeout_sec" \
qemu-aarch64 -L /usr/aarch64-linux-gnu "$gcc_elf" \
> "$stdout_file" 2>/dev/null
fi
status=$?
run_end_ns=$(now_ns)
run_elapsed_ns=$((run_end_ns - run_start_ns))
set -e
# 删除可执行(节省空间,数据已提取完毕)
rm -f "$gcc_elf"
if [[ $status -eq 124 ]]; then
printf '%b TIMEOUT %s (>%ds)%b\n' "$YELLOW" "$rel" "$timeout_sec" "$NC"
rm -f "$stdout_file"
FAIL=$((FAIL + 1))
return 0
fi
# 步骤4保存输出文件stdout + exit_code与 verify_asm.sh 格式一致)
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
printf '\n'
fi
printf '%s\n' "$status"
} > "$gcc_out"
rm -f "$stdout_file"
# 步骤5计算耗时5 位小数秒)并写入 TSV
local elapsed
elapsed=$(awk "BEGIN{printf \"%.5f\", $run_elapsed_ns / 1000000000}")
# 更新 TSV若已有该 case_key 的旧行则先删除再追加)
if grep -qF "${case_key} " "$TIMING_TSV" 2>/dev/null; then
local _tmp="$TIMING_TSV.tmp"
grep -vF "${case_key} " "$TIMING_TSV" > "$_tmp" || true
mv "$_tmp" "$TIMING_TSV"
fi
printf '%s\t%s\n' "$case_key" "$elapsed" >> "$TIMING_TSV"
local case_end_ns duration_ns
case_end_ns=$(now_ns)
duration_ns=$((case_end_ns - case_start_ns))
printf '%b DONE %s gcc=%ss [%s]%b\n' \
"$GREEN" "$rel" "$elapsed" "$(format_duration_ns "$duration_ns")" "$NC"
PASS=$((PASS + 1))
}
# ---------- 初始化 ----------
if [[ "$UPDATE" == true ]]; then
printf '%b[--update] Clearing all existing baseline data.%b\n' "$YELLOW" "$NC"
: > "$TIMING_TSV"
find "$BASELINE_DIR" -maxdepth 1 \
\( -name '*.gcc.s' -o -name '*.gcc.out' -o -name '*.gcc.time' -o -name '*.gcc.err' \) \
-delete 2>/dev/null || true
else
[[ -f "$TIMING_TSV" ]] || : > "$TIMING_TSV"
fi
printf '%bBaseline directory : %s%b\n' "$CYAN" "$BASELINE_DIR" "$NC"
printf '%bTiming TSV : %s%b\n' "$CYAN" "$TIMING_TSV" "$NC"
if [[ "$UPDATE" == false && -f "$TIMING_TSV" ]]; then
_cached_count=$(wc -l < "$TIMING_TSV" 2>/dev/null || echo 0)
if [[ $_cached_count -gt 0 ]]; then
printf 'Found %d cached entries (use --update to recompute all).\n' "$_cached_count"
fi
fi
# ---------- 预编译 sylib.oC 模式,仅一次)----------
SYLIB_OBJ="$BASELINE_DIR/sylib.o"
if ! aarch64-linux-gnu-gcc -O2 -c -x c "$REPO_ROOT/sylib/sylib.c" \
-I "$REPO_ROOT/sylib" -o "$SYLIB_OBJ" 2>/dev/null; then
printf '%bERROR: failed to compile sylib.c%b\n' "$RED" "$NC" >&2
exit 1
fi
printf 'sylib.o compiled : %s\n' "$SYLIB_OBJ"
printf '\n'
TOTAL_START_NS=$(now_ns)
# ---------- 运行 ----------
for sy_file in "${TEST_FILES[@]}"; do
process_case "$sy_file"
done
for test_dir in "${TEST_DIRS[@]}"; do
if [[ ! -d "$test_dir" ]]; then
printf '%b SKIP missing dir: %s%b\n' "$YELLOW" "$test_dir" "$NC"
continue
fi
while IFS= read -r -d '' sy_file; do
process_case "$sy_file"
done < <(find "$test_dir" -maxdepth 1 -type f -name '*.sy' -print0 | sort -z)
done
# ---------- 汇总 ----------
TOTAL_END_NS=$(now_ns)
TOTAL_ELAPSED_NS=$((TOTAL_END_NS - TOTAL_START_NS))
TOTAL_CASES=$((PASS + SKIP + FAIL))
printf '\n'
printf 'Summary: %d DONE / %d SKIP (cached) / %d FAIL / total %d\n' \
"$PASS" "$SKIP" "$FAIL" "$TOTAL_CASES"
printf 'Total elapsed : %s\n' "$(format_duration_ns "$TOTAL_ELAPSED_NS")"
printf 'Timing TSV : %s (%d entries)\n' \
"$TIMING_TSV" "$(wc -l < "$TIMING_TSV" 2>/dev/null || echo 0)"
[[ $FAIL -eq 0 ]]

@ -0,0 +1,103 @@
============================================================
脚本优化总结2026-04
============================================================
一、架构分离
────────────────────────────────────────────────────────────
· run_baseline.sh 成为唯一负责计算 GCC -O2 基线的脚本;
其余所有脚本lab3_build_test.sh、analyze_case.sh只读
TSV不再重复运行 GCC避免重复耗时。
· 基线输出目录镜像测试用例的相对路径结构,例如:
output/baseline/test_case/functional/65_color.gcc.s
output/baseline/class_test_case/h_functional/11_BST.gcc.s
TSV 键与目录结构对齐class_test_case/h_functional/11_BST
二、SysY → C 编译兼容性修复run_baseline.sh
────────────────────────────────────────────────────────────
· const int 全局数组维度问题
C 模式下 const int N=10; int a[N]; 属于 VLA非法于文件域
用 Python3 预处理将 const int NAME=EXPR; 转换为:
#define NAME ((int)(EXPR))
同时支持多声明符写法const int A=1, B=2;
· sylib 链接方式
预编译 sylib.o-x c用 -include sylib.h 注入声明;
链接命令用 -x none 在 .o 前重置语言标志,防止 ELF 被
当作 C 源文件解析stray '\177' 错误)。
· C++ 关键字冲突
部分 SysY 测试用例用 delete/new/class 作函数名;
-x c 模式下这些不是关键字,编译正常通过。
· 枚举浮点值
enum { MAX = 1e9 }; 枚举成员必须是整数常量Python3
预处理同样将其转为 #define MAX ((int)(1e9))。
三、计时精度与准确性
────────────────────────────────────────────────────────────
· 全面弃用 /usr/bin/time非零退出时会向输出文件写入
"Command exited with non-zero status N",污染时间值。
· 改用 date +%s%N 纳秒手动计时:
_t0=$(date +%s%N)
... 运行命令 ...
_t1=$(date +%s%N)
elapsed=$(awk "BEGIN{printf \"%.5f\", $((t1-t0)) / 1e9}")
· 所有时间输出统一为 5 位小数(秒),加速比同样 5 位小数。
四、分段计时verify_asm.sh + lab3_build_test.sh
────────────────────────────────────────────────────────────
· verify_asm.sh 新增 --timing-out <file> 选项,运行结束后
向文件写入:
compile_ns=<纳秒>
run_ns=<纳秒>
· lab3_build_test.sh 读取 timing 文件,将编译耗时与运行耗时
分开显示:
PASS test_case/functional/65_color [compile=0.31416s run=0.18804s]
· 加速比只使用运行时间run_ns排除编译器启动开销。
五、性能排行榜lab3_build_test.sh
────────────────────────────────────────────────────────────
· 测试结束后输出双排序表格:
Sort 1加速比升序最需优化的用例排最前
Sort 2我方用时降序绝对耗时最高的排最前
每行格式:
<用例名> <我方时间> <GCC时间> <加速比>x
六、analyze_case.sh 修复
────────────────────────────────────────────────────────────
· 基线查找键从裸 stem65_color改为完整路径键
test_case/functional/65_color与 TSV 格式对齐,
消除 "WARNING: no baseline entry" 误报。
· run_and_time 函数的 rpt_color 输出重定向到 stderr
防止 ANSI 转义码被命令替换($())捕获后传入 awk
消除 "fatal: error: invalid character '\033'" 错误。
============================================================
脚本列表
============================================================
run_baseline.sh 计算所有用例的 GCC -O2 基线,结果存入
output/baseline/gcc_timing.tsv
用法: ./scripts/run_baseline.sh [--update]
--update 清空重算全部条目
lab3_build_test.sh 构建编译器,跑全部用例,输出加速比排行榜
用法: ./scripts/lab3_build_test.sh
verify_asm.sh 验证单个用例的汇编正确性
用法: ./scripts/verify_asm.sh <input.sy> [input.in] \
[expected.out] [timeout] [--timing-out file]
analyze_case.sh 单用例深度分析IR/ASM/计时/与基线对比)
用法: ./scripts/analyze_case.sh <input.sy> [output_dir]
clean_outputs.sh 清理 output/ 目录下的分析结果
用法: ./scripts/clean_outputs.sh
============================================================

@ -4,8 +4,8 @@ set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "usage: $0 input.sy [output_dir] [--run]" >&2
if [[ $# -lt 1 || $# -gt 5 ]]; then
echo "usage: $0 input.sy [output_dir] [--run] [--timing-out file]" >&2
exit 1
fi
@ -13,6 +13,11 @@ input=$1
out_dir="$REPO_ROOT/test/test_result/asm"
run_exec=false
input_dir=$(dirname "$input")
timing_out=""
_compile_ns=0
_run_ns=0
now_ns() { date +%s%N; }
shift
while [[ $# -gt 0 ]]; do
@ -20,6 +25,10 @@ while [[ $# -gt 0 ]]; do
--run)
run_exec=true
;;
--timing-out)
timing_out="$2"
shift
;;
*)
out_dir="$1"
;;
@ -57,11 +66,13 @@ exe="$out_dir/$stem"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
_compile_start_ns=$(now_ns)
"$compiler" --emit-asm "$input" > "$asm_file"
echo "asm generated: $asm_file"
aarch64-linux-gnu-gcc "$asm_file" "$REPO_ROOT/sylib/sylib.c" -O2 -o "$exe"
echo "executable generated: $exe"
_compile_ns=$(($(now_ns) - _compile_start_ns))
if [[ "$run_exec" == true ]]; then
if ! command -v qemu-aarch64 >/dev/null 2>&1; then
@ -77,6 +88,7 @@ if [[ "$run_exec" == true ]]; then
fi
set +e
_run_start_ns=$(now_ns)
if command -v timeout >/dev/null 2>&1; then
if [[ -f "$stdin_file" ]]; then
timeout "$timeout_sec" qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file"
@ -91,6 +103,7 @@ if [[ "$run_exec" == true ]]; then
fi
fi
status=$?
_run_ns=$(($(now_ns) - _run_start_ns))
set -e
if [[ $status -eq 124 ]]; then
@ -122,3 +135,7 @@ if [[ "$run_exec" == true ]]; then
echo "expected output not found, skipped diff: $expected_file"
fi
fi
if [[ -n "$timing_out" ]]; then
printf 'compile_ns=%s\nrun_ns=%s\n' "$_compile_ns" "$_run_ns" > "$timing_out"
fi

@ -82,27 +82,27 @@ void DominatorTree::Recalculate() {
}
}
std::vector<std::size_t> dom_depth(num_blocks, 0);
for (std::size_t i = 0; i < num_blocks; ++i) {
for (std::size_t candidate = 0; candidate < num_blocks; ++candidate) {
if (dominates_[i][candidate]) {
++dom_depth[i];
}
}
}
for (std::size_t i = 1; i < num_blocks; ++i) {
auto* block = reverse_post_order_[i];
BasicBlock* idom = nullptr;
std::size_t best_depth = 0;
for (std::size_t candidate = 0; candidate < num_blocks; ++candidate) {
if (candidate == i || !dominates_[i][candidate]) {
continue;
}
auto* candidate_block = reverse_post_order_[candidate];
bool immediate = true;
for (std::size_t other = 0; other < num_blocks; ++other) {
if (other == i || other == candidate || !dominates_[i][other]) {
continue;
}
if (Dominates(reverse_post_order_[other], candidate_block)) {
immediate = false;
break;
}
}
if (immediate) {
if (idom == nullptr || dom_depth[candidate] > best_depth) {
idom = candidate_block;
break;
best_depth = dom_depth[candidate];
}
}
immediate_dominator_.emplace(block, idom);

@ -0,0 +1,137 @@
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "LoopPassUtils.h"
#include <algorithm>
#include <cstdint>
#include <memory>
#include <vector>
namespace ir {
namespace {
bool IsPowerOfTwoPositive(int value) {
return value > 0 && (value & (value - 1)) == 0;
}
std::size_t FindInstructionIndex(BasicBlock* block, Instruction* inst) {
if (!block || !inst) {
return 0;
}
auto& instructions = block->GetInstructions();
for (std::size_t i = 0; i < instructions.size(); ++i) {
if (instructions[i].get() == inst) {
return i;
}
}
return instructions.size();
}
bool IsZero(Value* value) {
if (auto* ci = dyncast<ConstantInt>(value)) {
return ci->GetValue() == 0;
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return !cb->GetValue();
}
return false;
}
Value* OtherCompareOperand(BinaryInst* cmp, Value* value) {
if (!cmp || cmp->GetNumOperands() != 2) {
return nullptr;
}
if (cmp->GetLhs() == value) {
return cmp->GetRhs();
}
if (cmp->GetRhs() == value) {
return cmp->GetLhs();
}
return nullptr;
}
bool SimplifyPowerOfTwoRemTests(Function& function) {
bool changed = false;
std::vector<Instruction*> dead_rems;
for (const auto& block_ptr : function.GetBlocks()) {
auto* block = block_ptr.get();
if (!block) {
continue;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* rem = dyncast<BinaryInst>(inst_ptr.get());
if (!rem || rem->GetOpcode() != Opcode::Rem) {
continue;
}
auto* divisor = dyncast<ConstantInt>(rem->GetRhs());
if (!divisor || !IsPowerOfTwoPositive(divisor->GetValue())) {
continue;
}
const int mask_value = divisor->GetValue() - 1;
if (mask_value == 0) {
rem->ReplaceAllUsesWith(looputils::ConstInt(0));
dead_rems.push_back(rem);
changed = true;
continue;
}
std::vector<BinaryInst*> compare_uses;
bool all_uses_are_zero_tests = !rem->GetUses().empty();
for (const auto& use : rem->GetUses()) {
auto* cmp = dyncast<BinaryInst>(dynamic_cast<Value*>(use.GetUser()));
if (!cmp || (cmp->GetOpcode() != Opcode::ICmpEQ &&
cmp->GetOpcode() != Opcode::ICmpNE) ||
!IsZero(OtherCompareOperand(cmp, rem))) {
all_uses_are_zero_tests = false;
break;
}
compare_uses.push_back(cmp);
}
if (!all_uses_are_zero_tests || compare_uses.empty()) {
continue;
}
const auto insert_index = FindInstructionIndex(block, rem) + 1;
auto* masked = block->Insert<BinaryInst>(
insert_index, Opcode::And, Type::GetInt32Type(), rem->GetLhs(),
looputils::ConstInt(mask_value), nullptr,
looputils::NextSyntheticName(function, "pow2.mask."));
for (auto* cmp : compare_uses) {
if (cmp->GetLhs() == rem) {
cmp->SetOperand(0, masked);
}
if (cmp->GetRhs() == rem) {
cmp->SetOperand(1, masked);
}
}
dead_rems.push_back(rem);
changed = true;
}
}
for (auto* rem : dead_rems) {
if (rem->GetUses().empty() && rem->GetParent()) {
rem->GetParent()->EraseInstruction(rem);
}
}
return changed;
}
} // namespace
bool RunArithmeticSimplify(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (!function || function->IsExternal()) {
continue;
}
changed |= SimplifyPowerOfTwoRemTests(*function);
}
return changed;
}
} // namespace ir

@ -3,6 +3,9 @@ add_library(ir_passes STATIC
Mem2Reg.cpp
ConstFold.cpp
ConstProp.cpp
InterproceduralConstProp.cpp
TailRecursionElim.cpp
ArithmeticSimplify.cpp
Inline.cpp
CSE.cpp
GVN.cpp
@ -15,6 +18,8 @@ add_library(ir_passes STATIC
LoopStrengthReduction.cpp
LoopUnroll.cpp
LoopFission.cpp
LoopRepeatReduction.cpp
IfConversion.cpp
)
target_link_libraries(ir_passes PUBLIC

@ -14,7 +14,15 @@ namespace {
struct ExprKey {
Opcode opcode = Opcode::Add;
std::vector<std::uintptr_t> operands;
struct OperandKey {
int kind = 0;
std::intptr_t value = 0;
bool operator==(const OperandKey& rhs) const {
return kind == rhs.kind && value == rhs.value;
}
};
std::vector<OperandKey> operands;
bool operator==(const ExprKey& rhs) const {
return opcode == rhs.opcode && operands == rhs.operands;
@ -25,12 +33,26 @@ struct ExprKeyHash {
std::size_t operator()(const ExprKey& key) const {
std::size_t h = static_cast<std::size_t>(key.opcode);
for (auto operand : key.operands) {
h ^= std::hash<std::uintptr_t>{}(operand) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<int>{}(operand.kind) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<std::intptr_t>{}(operand.value) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
ExprKey::OperandKey BuildOperandKey(Value* value) {
if (auto* ci = dyncast<ConstantInt>(value)) {
return {1, ci->GetValue()};
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return {2, cb->GetValue() ? 1 : 0};
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
return {3, static_cast<std::intptr_t>(passutils::FloatBits(cf->GetValue()))};
}
return {0, reinterpret_cast<std::intptr_t>(value)};
}
bool IsSupportedCSEInstruction(Instruction* inst) {
if (!inst || inst->IsVoid()) {
return false;
@ -81,11 +103,12 @@ ExprKey BuildExprKey(Instruction* inst) {
key.opcode = inst->GetOpcode();
key.operands.reserve(inst->GetNumOperands());
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
key.operands.push_back(
reinterpret_cast<std::uintptr_t>(inst->GetOperand(i)));
key.operands.push_back(BuildOperandKey(inst->GetOperand(i)));
}
if (inst->GetNumOperands() == 2 && passutils::IsCommutativeOpcode(inst->GetOpcode()) &&
key.operands[1] < key.operands[0]) {
(key.operands[1].kind < key.operands[0].kind ||
(key.operands[1].kind == key.operands[0].kind &&
key.operands[1].value < key.operands[0].value))) {
std::swap(key.operands[0], key.operands[1]);
}
return key;

@ -17,7 +17,15 @@ struct ExprKey {
Opcode opcode = Opcode::Add;
std::uintptr_t result_type = 0;
std::uintptr_t aux_type = 0;
std::vector<std::uintptr_t> operands;
struct OperandKey {
int kind = 0;
std::intptr_t value = 0;
bool operator==(const OperandKey& rhs) const {
return kind == rhs.kind && value == rhs.value;
}
};
std::vector<OperandKey> operands;
bool operator==(const ExprKey& rhs) const {
return opcode == rhs.opcode && result_type == rhs.result_type &&
@ -33,12 +41,26 @@ struct ExprKeyHash {
h ^= std::hash<std::uintptr_t>{}(key.aux_type) + 0x9e3779b9 + (h << 6) +
(h >> 2);
for (auto operand : key.operands) {
h ^= std::hash<std::uintptr_t>{}(operand) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<int>{}(operand.kind) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<std::intptr_t>{}(operand.value) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
ExprKey::OperandKey BuildOperandKey(Value* value) {
if (auto* ci = dyncast<ConstantInt>(value)) {
return {1, ci->GetValue()};
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return {2, cb->GetValue() ? 1 : 0};
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
return {3, static_cast<std::intptr_t>(passutils::FloatBits(cf->GetValue()))};
}
return {0, reinterpret_cast<std::intptr_t>(value)};
}
struct ScopedExpr {
ExprKey key;
Value* previous = nullptr;
@ -103,12 +125,13 @@ ExprKey BuildExprKey(Instruction* inst) {
}
key.operands.reserve(inst->GetNumOperands());
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
key.operands.push_back(
reinterpret_cast<std::uintptr_t>(inst->GetOperand(i)));
key.operands.push_back(BuildOperandKey(inst->GetOperand(i)));
}
if (inst->GetNumOperands() == 2 &&
passutils::IsCommutativeOpcode(inst->GetOpcode()) &&
key.operands[1] < key.operands[0]) {
(key.operands[1].kind < key.operands[0].kind ||
(key.operands[1].kind == key.operands[0].kind &&
key.operands[1].value < key.operands[0].value))) {
std::swap(key.operands[0], key.operands[1]);
}
return key;

@ -0,0 +1,239 @@
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <cstddef>
#include <vector>
namespace ir {
namespace {
Instruction* GetTerminator(BasicBlock* block) {
if (block == nullptr || block->GetInstructions().empty()) {
return nullptr;
}
auto* inst = block->GetInstructions().back().get();
return inst != nullptr && inst->IsTerminator() ? inst : nullptr;
}
std::size_t GetTerminatorIndex(BasicBlock* block) {
const auto& instructions = block->GetInstructions();
return instructions.empty() ? 0 : instructions.size() - 1;
}
ConstantInt* ConstInt(int value) {
return new ConstantInt(Type::GetInt32Type(), value);
}
PhiInst* GetSinglePhi(BasicBlock* block) {
if (block == nullptr) {
return nullptr;
}
PhiInst* phi = nullptr;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* current = dyncast<PhiInst>(inst_ptr.get());
if (current == nullptr) {
break;
}
if (phi != nullptr) {
return nullptr;
}
phi = current;
}
return phi;
}
bool HasOnlyOneNonTerminator(BasicBlock* block, Instruction** out) {
if (block == nullptr) {
return false;
}
Instruction* candidate = nullptr;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst == nullptr || inst->IsTerminator()) {
continue;
}
if (candidate != nullptr) {
return false;
}
candidate = inst;
}
if (out != nullptr) {
*out = candidate;
}
return candidate != nullptr;
}
int IncomingIndexFor(PhiInst* phi, BasicBlock* block) {
if (phi == nullptr || block == nullptr) {
return -1;
}
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
if (phi->GetIncomingBlock(i) == block) {
return i;
}
}
return -1;
}
bool IsUsedOnlyBy(Value* value, User* expected_user) {
if (value == nullptr || expected_user == nullptr) {
return false;
}
for (const auto& use : value->GetUses()) {
if (use.GetUser() != expected_user) {
return false;
}
}
return true;
}
struct ConditionalAccumulation {
Value* base = nullptr;
Value* delta = nullptr;
Opcode opcode = Opcode::Add;
};
bool MatchConditionalAccumulation(PhiInst* phi, BasicBlock* pred,
BasicBlock* update_block,
BinaryInst* update,
ConditionalAccumulation* match) {
if (phi == nullptr || pred == nullptr || update_block == nullptr ||
update == nullptr || match == nullptr || phi->GetNumIncomings() != 2 ||
!phi->GetType()->IsInt32() || !update->GetType()->IsInt32()) {
return false;
}
const int pred_index = IncomingIndexFor(phi, pred);
const int update_index = IncomingIndexFor(phi, update_block);
if (pred_index < 0 || update_index < 0) {
return false;
}
auto* base = phi->GetIncomingValue(pred_index);
if (phi->GetIncomingValue(update_index) != update || base == nullptr ||
!base->GetType()->IsInt32() || !IsUsedOnlyBy(update, phi)) {
return false;
}
auto* lhs = update->GetLhs();
auto* rhs = update->GetRhs();
if (update->GetOpcode() == Opcode::Add) {
if (lhs == base && rhs != nullptr && rhs->GetType()->IsInt32()) {
*match = {base, rhs, Opcode::Add};
return true;
}
if (rhs == base && lhs != nullptr && lhs->GetType()->IsInt32()) {
*match = {base, lhs, Opcode::Add};
return true;
}
return false;
}
if (update->GetOpcode() == Opcode::Sub && lhs == base && rhs != nullptr &&
rhs->GetType()->IsInt32()) {
*match = {base, rhs, Opcode::Sub};
return true;
}
return false;
}
bool TryConvertConditionalAccumulation(Function& function, BasicBlock* pred) {
auto* branch = dyncast<CondBrInst>(GetTerminator(pred));
if (branch == nullptr || branch->GetCondition() == nullptr ||
!branch->GetCondition()->GetType()->IsInt1()) {
return false;
}
auto* update_block = branch->GetThenBlock();
auto* join = branch->GetElseBlock();
if (update_block == nullptr || join == nullptr || update_block == join ||
update_block->GetPredecessors().size() != 1 ||
update_block->GetPredecessors().front() != pred ||
update_block->GetSuccessors().size() != 1 ||
update_block->GetSuccessors().front() != join) {
return false;
}
auto* update_term = dyncast<UncondBrInst>(GetTerminator(update_block));
if (update_term == nullptr || update_term->GetDest() != join) {
return false;
}
Instruction* only_inst = nullptr;
if (!HasOnlyOneNonTerminator(update_block, &only_inst)) {
return false;
}
auto* update = dyncast<BinaryInst>(only_inst);
if (update == nullptr ||
(update->GetOpcode() != Opcode::Add && update->GetOpcode() != Opcode::Sub)) {
return false;
}
auto* phi = GetSinglePhi(join);
ConditionalAccumulation accum;
if (!MatchConditionalAccumulation(phi, pred, update_block, update, &accum)) {
return false;
}
const std::size_t insert_pos = GetTerminatorIndex(pred);
auto* enabled = pred->Insert<ZextInst>(insert_pos, branch->GetCondition(),
Type::GetInt32Type(), nullptr,
"%ifconv.zext");
auto* mask = pred->Insert<BinaryInst>(insert_pos + 1, Opcode::Sub,
Type::GetInt32Type(), ConstInt(0),
enabled, nullptr, "%ifconv.mask");
auto* masked_delta = pred->Insert<BinaryInst>(
insert_pos + 2, Opcode::And, Type::GetInt32Type(), accum.delta, mask,
nullptr, "%ifconv.delta");
auto* replacement = pred->Insert<BinaryInst>(
insert_pos + 3, accum.opcode, Type::GetInt32Type(), accum.base,
masked_delta, nullptr, "%ifconv.acc");
phi->ReplaceAllUsesWith(replacement);
join->EraseInstruction(phi);
passutils::ReplaceTerminatorWithBr(pred, join);
pred->RemoveSuccessor(update_block);
update_block->RemovePredecessor(pred);
passutils::RemoveUnreachableBlocks(function);
return true;
}
bool RunIfConversionOnFunction(Function& function) {
if (function.IsExternal()) {
return false;
}
bool changed = false;
bool local_changed = true;
while (local_changed) {
local_changed = false;
auto blocks = passutils::CollectReachableBlocks(function);
for (auto* block : blocks) {
if (TryConvertConditionalAccumulation(function, block)) {
local_changed = true;
changed = true;
break;
}
}
}
return changed;
}
} // namespace
bool RunIfConversion(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function != nullptr) {
changed |= RunIfConversionOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -3,6 +3,7 @@
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopPassUtils.h"
#include "MathIdiomUtils.h"
#include <algorithm>
#include <cstdint>
@ -62,6 +63,7 @@ bool IsInlineableInstruction(const Instruction* inst) {
case Opcode::Load:
case Opcode::Store:
case Opcode::GetElementPtr:
case Opcode::Phi:
case Opcode::Zext:
case Opcode::Memset:
case Opcode::Call:
@ -79,6 +81,7 @@ int EstimateInstructionCost(const Instruction* inst) {
return 0;
}
switch (inst->GetOpcode()) {
case Opcode::Phi:
case Opcode::Return:
return 0;
case Opcode::Load:
@ -99,13 +102,7 @@ InlineCandidateInfo AnalyzeInlineCandidate(const Function& function) {
if (function.IsExternal() || function.IsRecursive()) {
return info;
}
if (function.GetBlocks().empty() || function.GetBlocks().size() > 4) {
return info;
}
DominatorTree dom_tree(const_cast<Function&>(function));
LoopInfo loop_info(const_cast<Function&>(function), dom_tree);
if (!loop_info.GetLoops().empty()) {
if (function.GetBlocks().empty() || function.GetBlocks().size() > 16) {
return info;
}
@ -117,8 +114,8 @@ InlineCandidateInfo AnalyzeInlineCandidate(const Function& function) {
for (std::size_t i = 0; i < block->GetInstructions().size(); ++i) {
auto* inst = block->GetInstructions()[i].get();
if (!IsInlineableInstruction(inst) || dyncast<PhiInst>(inst) ||
dyncast<AllocaInst>(inst) || dyncast<UnreachableInst>(inst)) {
if (!IsInlineableInstruction(inst) || dyncast<AllocaInst>(inst) ||
dyncast<UnreachableInst>(inst)) {
return {};
}
@ -176,16 +173,25 @@ bool ShouldInlineCallSite(const Function& caller, const CallInst& call,
if (!callee || callee == &caller || !callee_info.valid) {
return false;
}
if (mathidiom::IsToleranceNewtonSqrtShape(*callee)) {
return false;
}
if (mathidiom::IsPow2DigitExtractShape(*callee)) {
return false;
}
if (callee_info.has_control_flow && callee_info.has_nested_call) {
return false;
}
int budget = callee->CanDiscardUnusedCall() ? 40 : 24;
int budget = callee->CanDiscardUnusedCall() ? 96 : 72;
if (call_count <= 1) {
budget += 12;
budget += 48;
}
if (callee_info.has_nested_call) {
budget -= 8;
}
if (callee_info.has_control_flow) {
budget -= 6;
budget -= 12;
}
if (callee->MayWriteMemory()) {
budget -= 4;
@ -452,6 +458,9 @@ bool CanInlineCFGCallSite(Function& caller, CallInst* call,
callee == &caller) {
return false;
}
if (mathidiom::IsToleranceNewtonSqrtShape(*callee)) {
return false;
}
const auto& callee_args = callee->GetArguments();
const auto call_args = call->GetArguments();
@ -470,12 +479,20 @@ bool CanInlineCFGCallSite(Function& caller, CallInst* call,
return false;
}
bool seen_non_phi = false;
for (std::size_t i = 0; i < block->GetInstructions().size(); ++i) {
auto* inst = block->GetInstructions()[i].get();
if (dyncast<PhiInst>(inst) || dyncast<AllocaInst>(inst) ||
dyncast<UnreachableInst>(inst) || !IsInlineableInstruction(inst)) {
if (dyncast<AllocaInst>(inst) || dyncast<UnreachableInst>(inst) ||
!IsInlineableInstruction(inst)) {
return false;
}
if (dyncast<PhiInst>(inst)) {
if (seen_non_phi) {
return false;
}
continue;
}
seen_non_phi = true;
if (auto* br = dyncast<UncondBrInst>(inst)) {
if (i + 1 != block->GetInstructions().size() ||
@ -544,14 +561,64 @@ bool InlineCFGCallSite(Function& caller, CallInst* call) {
caller.CreateBlock(looputils::NextSyntheticBlockName(caller, "inline.bb"));
}
std::vector<std::pair<BasicBlock*, Value*>> return_edges;
for (auto* block : callee_blocks) {
auto* clone = block_map.at(block);
for (const auto& inst_ptr : block->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
auto* cloned_phi = clone->Append<PhiInst>(
phi->GetType(), nullptr,
looputils::NextSyntheticName(caller, "inline.phi."));
remap[phi] = cloned_phi;
}
}
for (auto* block : callee_blocks) {
auto* clone = block_map.at(block);
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dyncast<PhiInst>(inst)) {
continue;
}
if (inst->IsTerminator()) {
continue;
}
if (!CloneInstructionAt(caller, inst, clone,
looputils::GetTerminatorIndex(clone), remap)) {
return false;
}
}
}
for (auto* block : callee_blocks) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
auto* cloned_phi = static_cast<PhiInst*>(remap.at(phi));
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
auto* incoming_block = phi->GetIncomingBlock(i);
auto block_it = block_map.find(incoming_block);
if (block_it == block_map.end()) {
return false;
}
cloned_phi->AddIncoming(looputils::RemapValue(remap, phi->GetIncomingValue(i)),
block_it->second);
}
}
}
std::vector<std::pair<BasicBlock*, Value*>> return_edges;
for (auto* block : callee_blocks) {
auto* clone = block_map.at(block);
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dyncast<PhiInst>(inst) || !inst->IsTerminator()) {
continue;
}
if (auto* ret = dyncast<ReturnInst>(inst)) {
clone->Append<UncondBrInst>(continuation, nullptr);
clone->AddSuccessor(continuation);
@ -579,10 +646,7 @@ bool InlineCFGCallSite(Function& caller, CallInst* call) {
else_block->AddPredecessor(clone);
continue;
}
if (!CloneInstructionAt(caller, inst, clone,
looputils::GetTerminatorIndex(clone), remap)) {
return false;
}
return false;
}
}

@ -0,0 +1,145 @@
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <vector>
namespace ir {
namespace {
bool IsScalarConstant(Value* value) {
return dyncast<ConstantInt>(value) != nullptr ||
dyncast<ConstantI1>(value) != nullptr ||
dyncast<ConstantFloat>(value) != nullptr;
}
bool IsScalarType(const std::shared_ptr<Type>& type) {
return type && (type->IsInt32() || type->IsInt1() || type->IsFloat());
}
bool IsReadonlyScalarGlobal(GlobalValue* global) {
if (global == nullptr || !IsScalarType(global->GetObjectType()) ||
!IsScalarConstant(global->GetInitializer())) {
return false;
}
for (const auto& use : global->GetUses()) {
auto* user = dyncast<Instruction>(use.GetUser());
if (auto* load = dyncast<LoadInst>(user)) {
if (load->GetPtr() == global) {
continue;
}
}
return false;
}
return true;
}
bool PropagateReadonlyScalarGlobals(Module& module) {
bool changed = false;
std::vector<LoadInst*> dead_loads;
for (const auto& global_ptr : module.GetGlobalValues()) {
auto* global = global_ptr.get();
if (!IsReadonlyScalarGlobal(global)) {
continue;
}
const auto uses = global->GetUses();
for (const auto& use : uses) {
auto* load = dyncast<LoadInst>(use.GetUser());
if (load == nullptr || load->GetPtr() != global) {
continue;
}
load->ReplaceAllUsesWith(global->GetInitializer());
dead_loads.push_back(load);
changed = true;
}
}
for (auto* load : dead_loads) {
if (load != nullptr && load->GetParent() != nullptr && load->GetUses().empty()) {
load->GetParent()->EraseInstruction(load);
}
}
return changed;
}
std::vector<CallInst*> CollectDirectCalls(Function& function, bool* all_uses_are_calls) {
std::vector<CallInst*> calls;
*all_uses_are_calls = true;
for (const auto& use : function.GetUses()) {
if (use.GetOperandIndex() != 0) {
*all_uses_are_calls = false;
return {};
}
auto* call = dyncast<CallInst>(use.GetUser());
if (call == nullptr || call->GetCallee() != &function) {
*all_uses_are_calls = false;
return {};
}
calls.push_back(call);
}
return calls;
}
bool PropagateConstantArguments(Function& function) {
if (function.IsExternal() || function.GetName() == "main" ||
function.GetArguments().empty()) {
return false;
}
bool all_uses_are_calls = false;
auto calls = CollectDirectCalls(function, &all_uses_are_calls);
if (!all_uses_are_calls || calls.empty()) {
return false;
}
bool changed = false;
for (std::size_t index = 0; index < function.GetArguments().size(); ++index) {
auto* argument = function.GetArgument(index);
if (argument == nullptr || !IsScalarType(argument->GetType()) ||
argument->GetUses().empty()) {
continue;
}
Value* constant = nullptr;
bool same_constant = true;
for (auto* call : calls) {
const auto args = call->GetArguments();
if (index >= args.size() || !IsScalarConstant(args[index])) {
same_constant = false;
break;
}
if (constant == nullptr) {
constant = args[index];
} else if (!passutils::AreEquivalentValues(constant, args[index])) {
same_constant = false;
break;
}
}
if (!same_constant || constant == nullptr) {
continue;
}
argument->ReplaceAllUsesWith(constant);
changed = true;
}
return changed;
}
} // namespace
bool RunInterproceduralConstProp(Module& module) {
bool changed = false;
changed |= PropagateReadonlyScalarGlobals(module);
for (const auto& function : module.GetFunctions()) {
if (function != nullptr) {
changed |= PropagateConstantArguments(*function);
}
}
return changed;
}
} // namespace ir

@ -96,6 +96,10 @@ void SimulateInstruction(const memutils::EscapeSummary& escapes, Instruction* in
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(load->GetPtr(), &escapes, key)) {
state.clear();
return;
}
if (state.find(key) == state.end()) {
state[key] = {load};
}
return;
}
@ -194,9 +198,9 @@ bool OptimizeBlock(
changed = true;
continue;
}
// Keep block-local load reuse, but do not expose load results to cross-block
// dataflow because the defining load itself may be removed later.
state[key] = {load};
if (state.find(key) == state.end()) {
state[key] = {load};
}
continue;
}

@ -0,0 +1,264 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "ir/passes/LoopPassUtils.h"
#include <queue>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
bool IsConstInt(Value* value, int expected) {
auto* constant = dyncast<ConstantInt>(value);
return constant != nullptr && constant->GetValue() == expected;
}
bool IsAddOneOf(Value* value, Value* base) {
auto* add = dyncast<BinaryInst>(value);
if (!add || add->GetOpcode() != Opcode::Add) {
return false;
}
return (add->GetLhs() == base && IsConstInt(add->GetRhs(), 1)) ||
(add->GetRhs() == base && IsConstInt(add->GetLhs(), 1));
}
bool HasForbiddenSideEffects(const Loop& loop) {
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
switch (inst->GetOpcode()) {
case Opcode::Store:
case Opcode::Memset:
case Opcode::Call:
return true;
default:
break;
}
}
}
return false;
}
bool HasUseOutsideLoop(Value* value, const Loop& loop) {
for (const auto& use : value->GetUses()) {
auto* inst = dyncast<Instruction>(use.GetUser());
if (!inst || !loop.Contains(inst->GetParent())) {
return true;
}
}
return false;
}
bool InductionOnlyControlsRepeatCount(PhiInst* induction, BinaryInst* compare,
BinaryInst* next, const Loop& loop) {
for (const auto& use : induction->GetUses()) {
auto* inst = dyncast<Instruction>(use.GetUser());
if (!inst) {
return false;
}
if (inst == compare || inst == next) {
continue;
}
if (loop.Contains(inst->GetParent())) {
return false;
}
}
return true;
}
bool IsAdditiveAccumulator(PhiInst* accumulator, BasicBlock* preheader,
BasicBlock* latch, const Loop& loop) {
if (!accumulator || !accumulator->IsInt32()) {
return false;
}
const int preheader_index = looputils::GetPhiIncomingIndex(accumulator, preheader);
const int latch_index = looputils::GetPhiIncomingIndex(accumulator, latch);
if (preheader_index < 0 || latch_index < 0) {
return false;
}
if (!IsConstInt(accumulator->GetIncomingValue(preheader_index), 0)) {
return false;
}
auto* latch_value = accumulator->GetIncomingValue(latch_index);
if (latch_value == accumulator) {
return false;
}
std::unordered_set<Value*> derived;
std::vector<BinaryInst*> additive_steps;
std::queue<Value*> worklist;
derived.insert(accumulator);
worklist.push(accumulator);
auto remember = [&](Value* value) {
if (derived.insert(value).second) {
worklist.push(value);
}
};
while (!worklist.empty()) {
auto* value = worklist.front();
worklist.pop();
for (const auto& use : value->GetUses()) {
auto* inst = dyncast<Instruction>(use.GetUser());
if (!inst || !loop.Contains(inst->GetParent())) {
continue;
}
if (auto* phi = dyncast<PhiInst>(inst)) {
remember(phi);
continue;
}
auto* binary = dyncast<BinaryInst>(inst);
if (!binary || binary->GetOpcode() != Opcode::Add) {
return false;
}
additive_steps.push_back(binary);
remember(binary);
}
}
if (derived.find(latch_value) == derived.end()) {
return false;
}
for (auto* add : additive_steps) {
const bool lhs_derived = derived.find(add->GetLhs()) != derived.end();
const bool rhs_derived = derived.find(add->GetRhs()) != derived.end();
if (lhs_derived == rhs_derived) {
return false;
}
}
return true;
}
std::vector<PhiInst*> GetHeaderPhis(BasicBlock* header) {
std::vector<PhiInst*> phis;
if (!header) {
return phis;
}
for (const auto& inst_ptr : header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
phis.push_back(phi);
}
return phis;
}
bool TryReduceRepeatLoop(Function& function, Loop& loop) {
if (!loop.header || !loop.preheader || loop.latches.size() != 1 ||
loop.exit_blocks.size() != 1 || HasForbiddenSideEffects(loop)) {
return false;
}
auto* latch = loop.latches.front();
auto* exit = loop.exit_blocks.front();
auto* branch =
dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
if (!branch) {
return false;
}
auto* compare = dyncast<BinaryInst>(branch->GetCondition());
if (!compare || compare->GetOpcode() != Opcode::ICmpLT) {
return false;
}
if (!loop.Contains(branch->GetThenBlock()) || branch->GetElseBlock() != exit) {
return false;
}
auto* induction = dyncast<PhiInst>(compare->GetLhs());
auto* bound = compare->GetRhs();
if (!induction || induction->GetParent() != loop.header ||
!looputils::IsLoopInvariantValue(loop, bound)) {
return false;
}
const int induction_preheader_index =
looputils::GetPhiIncomingIndex(induction, loop.preheader);
const int induction_latch_index = looputils::GetPhiIncomingIndex(induction, latch);
if (induction_preheader_index < 0 || induction_latch_index < 0 ||
!IsConstInt(induction->GetIncomingValue(induction_preheader_index), 0)) {
return false;
}
auto* induction_next =
dyncast<BinaryInst>(induction->GetIncomingValue(induction_latch_index));
if (!IsAddOneOf(induction_next, induction) ||
!InductionOnlyControlsRepeatCount(induction, compare, induction_next, loop)) {
return false;
}
std::vector<PhiInst*> accumulators;
for (auto* phi : GetHeaderPhis(loop.header)) {
if (phi == induction) {
continue;
}
if (!IsAdditiveAccumulator(phi, loop.preheader, latch, loop)) {
return false;
}
if (HasUseOutsideLoop(phi, loop)) {
accumulators.push_back(phi);
}
}
if (accumulators.empty()) {
return false;
}
// Force the counted loop to stop after one executed body: the first test still
// uses 0 < bound, so non-positive trip counts continue to execute zero times.
induction->SetOperand(static_cast<std::size_t>(2 * induction_latch_index), bound);
std::size_t insert_index = looputils::GetFirstNonPhiIndex(exit);
bool changed = true;
for (auto* accumulator : accumulators) {
auto* scaled = exit->Insert<BinaryInst>(
insert_index++, Opcode::Mul, Type::GetInt32Type(), accumulator, bound,
nullptr, looputils::NextSyntheticName(function, "repeat.reduce"));
const auto uses = accumulator->GetUses();
for (const auto& use : uses) {
auto* user = use.GetUser();
auto* user_inst = dyncast<Instruction>(user);
if (user_inst == scaled) {
continue;
}
if (!user_inst || !loop.Contains(user_inst->GetParent())) {
user->SetOperand(use.GetOperandIndex(), scaled);
}
}
}
return changed;
}
bool RunOnFunction(Function& function) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
changed |= TryReduceRepeatLoop(function, *loop);
}
return changed;
}
} // namespace
bool RunLoopRepeatReduction(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function && !function->IsExternal()) {
changed |= RunOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -5,6 +5,7 @@
#include "LoopPassUtils.h"
#include <cstdlib>
#include <limits>
#include <unordered_map>
#include <vector>
@ -18,6 +19,12 @@ struct InductionVarInfo {
int stride = 0;
};
struct GepReductionCandidate {
GetElementPtrInst* gep = nullptr;
std::vector<Value*> init_indices;
int step_elements = 0;
};
Value* BuildMulValue(Function& function, BasicBlock* block, Value* lhs, Value* rhs,
const std::string& prefix) {
if (auto* lhs_const = dyncast<ConstantInt>(lhs)) {
@ -248,6 +255,160 @@ bool ReduceLoopMultiplications(Function& function, const Loop& loop,
return changed;
}
bool LoopHasCallsOrMemset(const Loop& loop) {
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dyncast<CallInst>(inst) || dyncast<MemsetInst>(inst)) {
return true;
}
}
}
return false;
}
bool DominatesBlock(const DominatorTree& dom_tree, Value* value,
BasicBlock* block) {
auto* inst = dyncast<Instruction>(value);
return inst == nullptr ||
(inst->GetParent() != nullptr && dom_tree.Dominates(inst->GetParent(), block));
}
bool BuildGepReductionCandidate(const Loop& loop, const InductionVarInfo& iv,
const DominatorTree& dom_tree,
BasicBlock* preheader,
GetElementPtrInst* gep,
GepReductionCandidate& candidate) {
if (gep == nullptr || preheader == nullptr ||
!looputils::IsLoopInvariantValue(loop, gep->GetPointer()) ||
!DominatesBlock(dom_tree, gep->GetPointer(), preheader)) {
return false;
}
auto current_type = gep->GetSourceType();
std::int64_t step_bytes = 0;
bool saw_iv = false;
std::vector<Value*> init_indices;
init_indices.reserve(gep->GetNumIndices());
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
auto* index = gep->GetIndex(i);
const std::int64_t stride =
current_type ? static_cast<std::int64_t>(current_type->GetSize()) : 0;
if (index == iv.phi) {
saw_iv = true;
step_bytes += stride * static_cast<std::int64_t>(iv.stride);
init_indices.push_back(iv.start);
} else {
if (!looputils::IsLoopInvariantValue(loop, index) ||
!DominatesBlock(dom_tree, index, preheader)) {
return false;
}
init_indices.push_back(index);
}
if (current_type && current_type->IsArray()) {
current_type = current_type->GetElementType();
}
}
if (!saw_iv || step_bytes == 0 || step_bytes % 4 != 0) {
return false;
}
const std::int64_t step_elements = step_bytes / 4;
if (step_elements < std::numeric_limits<int>::min() ||
step_elements > std::numeric_limits<int>::max()) {
return false;
}
candidate.gep = gep;
candidate.init_indices = std::move(init_indices);
candidate.step_elements = static_cast<int>(step_elements);
return true;
}
Value* CreateReducedPointerPhi(Function& function, const Loop& loop,
BasicBlock* preheader,
const GepReductionCandidate& candidate) {
auto* init = preheader->Insert<GetElementPtrInst>(
looputils::GetTerminatorIndex(preheader), candidate.gep->GetSourceType(),
candidate.gep->GetPointer(), candidate.init_indices, nullptr,
looputils::NextSyntheticName(function, "lsr.ptr.init."));
auto* ptr_phi = loop.header->Insert<PhiInst>(
looputils::GetFirstNonPhiIndex(loop.header), Type::GetPointerType(), nullptr,
looputils::NextSyntheticName(function, "lsr.ptr.phi."));
ptr_phi->AddIncoming(init, preheader);
auto* next = candidate.step_elements == 0
? static_cast<Value*>(ptr_phi)
: static_cast<Value*>(loop.latches.front()->Insert<GetElementPtrInst>(
looputils::GetTerminatorIndex(loop.latches.front()),
Type::GetInt32Type(), ptr_phi,
std::vector<Value*>{looputils::ConstInt(candidate.step_elements)},
nullptr, looputils::NextSyntheticName(function, "lsr.ptr.next.")));
ptr_phi->AddIncoming(next, loop.latches.front());
return ptr_phi;
}
bool ReduceLoopAddressing(Function& function, const Loop& loop,
const DominatorTree& dom_tree, BasicBlock* preheader) {
if (!preheader || !loop.IsInnermost() || loop.latches.size() != 1 ||
LoopHasCallsOrMemset(loop)) {
return false;
}
std::vector<InductionVarInfo> induction_vars;
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
InductionVarInfo info;
if (MatchSimpleInductionVariable(loop, preheader, phi, info)) {
induction_vars.push_back(info);
}
}
if (induction_vars.empty()) {
return false;
}
bool changed = false;
std::vector<Instruction*> to_remove;
for (const auto& iv : induction_vars) {
std::vector<GepReductionCandidate> candidates;
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* gep = dyncast<GetElementPtrInst>(inst_ptr.get());
GepReductionCandidate candidate;
if (BuildGepReductionCandidate(loop, iv, dom_tree, preheader, gep,
candidate)) {
candidates.push_back(std::move(candidate));
}
}
}
for (const auto& candidate : candidates) {
if (candidate.gep == nullptr || candidate.gep->GetParent() == nullptr ||
candidate.gep->GetUses().empty()) {
continue;
}
auto* replacement = CreateReducedPointerPhi(function, loop, preheader, candidate);
candidate.gep->ReplaceAllUsesWith(replacement);
to_remove.push_back(candidate.gep);
changed = true;
}
}
for (auto* inst : to_remove) {
if (inst && inst->GetParent() && inst->GetUses().empty()) {
inst->GetParent()->EraseInstruction(inst);
}
}
return changed;
}
bool RunLoopStrengthReductionOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
@ -264,6 +425,7 @@ bool RunLoopStrengthReductionOnFunction(Function& function) {
auto* preheader = looputils::EnsurePreheader(function, *loop);
bool loop_changed = preheader != old_preheader;
loop_changed |= ReduceLoopMultiplications(function, *loop, preheader);
loop_changed |= ReduceLoopAddressing(function, *loop, dom_tree, preheader);
if (!loop_changed) {
continue;
}

@ -0,0 +1,375 @@
#pragma once
#include "ir/IR.h"
#include <cstdint>
#include <cstddef>
#include <unordered_map>
#include <unordered_set>
namespace ir {
namespace mathidiom {
inline bool IsFloatConstant(Value* value, float expected) {
auto* constant = dyncast<ConstantFloat>(value);
return constant != nullptr && constant->GetValue() == expected;
}
inline bool IsFloatValue(Value* value, float expected) {
if (IsFloatConstant(value, expected)) {
return true;
}
auto* unary = dyncast<UnaryInst>(value);
if (unary == nullptr || unary->GetOpcode() != Opcode::IToF) {
return false;
}
auto* constant = dyncast<ConstantInt>(unary->GetOprd());
return constant != nullptr &&
static_cast<float>(constant->GetValue()) == expected;
}
inline Function* ParentFunction(const Instruction* inst) {
auto* block = inst == nullptr ? nullptr : inst->GetParent();
return block == nullptr ? nullptr : block->GetParent();
}
inline bool IsGlobalOnlyUsedByFunction(const GlobalValue* global,
const Function& function) {
if (global == nullptr) {
return false;
}
for (const auto& use : global->GetUses()) {
auto* inst = dyncast<Instruction>(use.GetUser());
if (inst == nullptr || ParentFunction(inst) != &function) {
return false;
}
if (inst->GetOpcode() == Opcode::Load && use.GetOperandIndex() == 0) {
continue;
}
if (inst->GetOpcode() == Opcode::Store && use.GetOperandIndex() == 1) {
continue;
}
return false;
}
return true;
}
inline bool HasBackedgeLikeBranch(const Function& function) {
std::unordered_map<const BasicBlock*, std::size_t> index;
const auto& blocks = function.GetBlocks();
for (std::size_t i = 0; i < blocks.size(); ++i) {
index[blocks[i].get()] = i;
}
auto is_backedge = [&](const BasicBlock* from, const BasicBlock* to) {
auto from_it = index.find(from);
auto to_it = index.find(to);
return from_it != index.end() && to_it != index.end() &&
to_it->second <= from_it->second;
};
for (std::size_t i = 0; i < blocks.size(); ++i) {
const auto& instructions = blocks[i]->GetInstructions();
if (instructions.empty()) {
continue;
}
auto* terminator = instructions.back().get();
if (auto* br = dyncast<UncondBrInst>(terminator)) {
if (is_backedge(blocks[i].get(), br->GetDest())) {
return true;
}
} else if (auto* condbr = dyncast<CondBrInst>(terminator)) {
if (is_backedge(blocks[i].get(), condbr->GetThenBlock()) ||
is_backedge(blocks[i].get(), condbr->GetElseBlock())) {
return true;
}
}
}
return false;
}
inline bool IsPowerOfTwoPositive(int value) {
return value > 0 && (value & (value - 1)) == 0;
}
inline int Log2Exact(int value) {
int shift = 0;
while (value > 1) {
value >>= 1;
++shift;
}
return shift;
}
inline bool DependsOnValueImpl(Value* value, Value* needle, int depth,
std::unordered_set<Value*>& visiting) {
if (value == needle) {
return true;
}
if (value == nullptr || depth <= 0 || !visiting.insert(value).second) {
return false;
}
auto* inst = dyncast<Instruction>(value);
if (inst == nullptr) {
return false;
}
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
if (DependsOnValueImpl(inst->GetOperand(i), needle, depth - 1, visiting)) {
return true;
}
}
return false;
}
inline bool DependsOnValue(Value* value, Value* needle, int depth = 12) {
std::unordered_set<Value*> visiting;
return DependsOnValueImpl(value, needle, depth, visiting);
}
// Recognize the radix-digit helper:
// while (i < pos) num = num / C;
// return num % C;
// for power-of-two C >= 4. Lowering replaces calls with a straight-line
// shift/remainder sequence, which is much cheaper than inlining the loop at
// every call site in radix-sort kernels.
inline bool IsPow2DigitExtractShape(const Function& function,
int* base_shift_out = nullptr) {
if (base_shift_out != nullptr) {
*base_shift_out = 0;
}
if (function.IsExternal() || function.GetReturnType() == nullptr ||
!function.GetReturnType()->IsInt32() || function.GetArguments().size() != 2 ||
!function.GetArgument(0)->GetType()->IsInt32() ||
!function.GetArgument(1)->GetType()->IsInt32() ||
!HasBackedgeLikeBranch(function)) {
return false;
}
auto* num_arg = function.GetArgument(0);
auto* pos_arg = function.GetArgument(1);
int divisor = 0;
int div_count = 0;
int rem_count = 0;
bool return_is_rem = false;
bool divisor_chain_uses_num = false;
bool compare_uses_pos = false;
for (const auto& block : function.GetBlocks()) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dyncast<CallInst>(inst) || dyncast<LoadInst>(inst) ||
dyncast<StoreInst>(inst) || dyncast<AllocaInst>(inst) ||
dyncast<GetElementPtrInst>(inst) || dyncast<MemsetInst>(inst) ||
dyncast<UnreachableInst>(inst)) {
return false;
}
if (auto* ret = dyncast<ReturnInst>(inst)) {
auto* returned = ret->HasReturnValue() ? ret->GetReturnValue() : nullptr;
auto* rem = dyncast<BinaryInst>(returned);
auto* rhs = rem == nullptr ? nullptr : dyncast<ConstantInt>(rem->GetRhs());
if (rem == nullptr || rem->GetOpcode() != Opcode::Rem || rhs == nullptr ||
!IsPowerOfTwoPositive(rhs->GetValue()) || rhs->GetValue() < 4) {
return false;
}
if (divisor == 0) {
divisor = rhs->GetValue();
} else if (divisor != rhs->GetValue()) {
return false;
}
return_is_rem = true;
continue;
}
auto* bin = dyncast<BinaryInst>(inst);
if (!bin) {
continue;
}
if (bin->GetOpcode() == Opcode::Div || bin->GetOpcode() == Opcode::Rem) {
auto* rhs = dyncast<ConstantInt>(bin->GetRhs());
if (rhs == nullptr || !IsPowerOfTwoPositive(rhs->GetValue()) ||
rhs->GetValue() < 4) {
return false;
}
if (divisor == 0) {
divisor = rhs->GetValue();
} else if (divisor != rhs->GetValue()) {
return false;
}
if (bin->GetOpcode() == Opcode::Div) {
++div_count;
} else {
++rem_count;
}
divisor_chain_uses_num |= DependsOnValue(bin->GetLhs(), num_arg);
}
switch (bin->GetOpcode()) {
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
compare_uses_pos |= DependsOnValue(bin->GetLhs(), pos_arg) ||
DependsOnValue(bin->GetRhs(), pos_arg);
break;
default:
break;
}
}
}
if (divisor == 0 || div_count == 0 || rem_count == 0 || !return_is_rem ||
!divisor_chain_uses_num || !compare_uses_pos) {
return false;
}
if (base_shift_out != nullptr) {
*base_shift_out = Log2Exact(divisor);
}
return true;
}
// Recognize the common tolerance-driven Newton iteration for sqrt:
// while (abs(t - x / t) > eps) t = (t + x / t) / 2;
// The matcher is intentionally structural: it does not inspect source names or
// filenames. Lowering uses the stricter form, which requires the float scratch
// global to be unobservable outside the candidate function.
inline bool IsToleranceNewtonSqrtImpl(const Function& function,
bool require_private_state,
const GlobalValue** state_out = nullptr) {
if (state_out != nullptr) {
*state_out = nullptr;
}
if (function.IsExternal() || function.GetReturnType() == nullptr ||
!function.GetReturnType()->IsFloat() || function.GetArguments().size() != 1 ||
!function.GetArguments()[0]->GetType()->IsFloat() ||
function.GetBlocks().size() < 3 || function.GetBlocks().size() > 8 ||
!HasBackedgeLikeBranch(function)) {
return false;
}
auto* input = function.GetArguments()[0].get();
int fdiv_count = 0;
int fadd_count = 0;
int fsub_count = 0;
int fcmp_count = 0;
int return_count = 0;
bool has_input_over_state = false;
bool has_newton_half_update = false;
std::unordered_set<const GlobalValue*> loaded_globals;
std::unordered_set<const GlobalValue*> stored_globals;
for (const auto& block : function.GetBlocks()) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
switch (inst->GetOpcode()) {
case Opcode::FDiv: {
++fdiv_count;
auto* binary = static_cast<BinaryInst*>(inst);
if (binary->GetLhs() == input) {
has_input_over_state = true;
}
if (IsFloatValue(binary->GetRhs(), 2.0f) &&
dyncast<Instruction>(binary->GetLhs()) != nullptr &&
static_cast<Instruction*>(binary->GetLhs())->GetOpcode() == Opcode::FAdd) {
has_newton_half_update = true;
}
break;
}
case Opcode::FAdd:
++fadd_count;
break;
case Opcode::FSub:
++fsub_count;
break;
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
++fcmp_count;
break;
case Opcode::Load: {
auto* load = static_cast<LoadInst*>(inst);
auto* global = dyncast<GlobalValue>(load->GetPtr());
if (global == nullptr || !load->GetType()->IsFloat() ||
!global->GetObjectType()->IsFloat()) {
return false;
}
loaded_globals.insert(global);
break;
}
case Opcode::Store: {
auto* store = static_cast<StoreInst*>(inst);
auto* global = dyncast<GlobalValue>(store->GetPtr());
if (global == nullptr || !store->GetValue()->GetType()->IsFloat() ||
!global->GetObjectType()->IsFloat()) {
return false;
}
stored_globals.insert(global);
break;
}
case Opcode::Return:
++return_count;
if (!static_cast<ReturnInst*>(inst)->HasReturnValue() ||
!static_cast<ReturnInst*>(inst)->GetReturnValue()->GetType()->IsFloat()) {
return false;
}
break;
case Opcode::Call:
case Opcode::Alloca:
case Opcode::GetElementPtr:
case Opcode::Memset:
case Opcode::Unreachable:
return false;
default:
break;
}
}
}
if (fdiv_count < 2 || fadd_count < 1 || fsub_count < 1 || fcmp_count < 1 ||
return_count != 1 || !has_input_over_state || !has_newton_half_update) {
return false;
}
const GlobalValue* state = nullptr;
for (auto* global : stored_globals) {
if (loaded_globals.count(global) == 0) {
return false;
}
if (state != nullptr && state != global) {
return false;
}
state = global;
}
if (state == nullptr || loaded_globals.size() != 1 || !state->HasInitializer() ||
!IsFloatConstant(state->GetInitializer(), 1.0f)) {
return false;
}
if (require_private_state && !IsGlobalOnlyUsedByFunction(state, function)) {
return false;
}
if (state_out != nullptr) {
*state_out = state;
}
return true;
}
inline bool IsToleranceNewtonSqrtShape(const Function& function) {
return IsToleranceNewtonSqrtImpl(function, false);
}
inline bool IsPrivateToleranceNewtonSqrt(const Function& function,
const GlobalValue** state_out = nullptr) {
return IsToleranceNewtonSqrtImpl(function, true, state_out);
}
} // namespace mathidiom
} // namespace ir

@ -218,11 +218,12 @@ inline bool CallMayReadRoot(Function* callee, PointerRootKind kind) {
case PointerRootKind::ReadonlyGlobal:
return callee->ReadsGlobalMemory();
case PointerRootKind::Global:
return callee->ReadsGlobalMemory() || callee->WritesGlobalMemory();
return callee->ReadsGlobalMemory() || callee->WritesGlobalMemory() ||
callee->ReadsParamMemory() || callee->WritesParamMemory();
case PointerRootKind::Param:
return callee->ReadsParamMemory() || callee->WritesParamMemory();
case PointerRootKind::Local:
return false;
return callee->ReadsParamMemory() || callee->WritesParamMemory();
case PointerRootKind::Unknown:
return callee->MayReadMemory();
}
@ -240,11 +241,11 @@ inline bool CallMayWriteRoot(Function* callee, PointerRootKind kind) {
case PointerRootKind::ReadonlyGlobal:
return false;
case PointerRootKind::Global:
return callee->WritesGlobalMemory();
return callee->WritesGlobalMemory() || callee->WritesParamMemory();
case PointerRootKind::Param:
return callee->WritesParamMemory();
case PointerRootKind::Local:
return false;
return callee->WritesParamMemory();
case PointerRootKind::Unknown:
return callee->MayWriteMemory();
}

@ -24,20 +24,34 @@ void RunIRPassPipeline(Module& module) {
const bool run_loop_unswitch =
disable_loop_unswitch == nullptr || disable_loop_unswitch[0] == '\0' ||
disable_loop_unswitch[0] == '0';
const char* disable_tail_recursion =
std::getenv("NUDTC_DISABLE_TAIL_RECURSION");
const bool run_tail_recursion =
disable_tail_recursion == nullptr || disable_tail_recursion[0] == '\0' ||
disable_tail_recursion[0] == '0';
RunMem2Reg(module);
if (run_tail_recursion) {
RunTailRecursionElim(module);
}
constexpr int kMaxIterations = 8;
for (int iteration = 0; iteration < kMaxIterations; ++iteration) {
bool changed = false;
if (run_tail_recursion) {
changed |= RunTailRecursionElim(module);
}
if (run_cfg_inline) {
changed |= RunFunctionInlining(module);
}
changed |= RunInterproceduralConstProp(module);
changed |= RunArithmeticSimplify(module);
changed |= RunConstProp(module);
changed |= RunConstFold(module);
changed |= RunGVN(module);
changed |= RunLoadStoreElim(module);
changed |= RunCSE(module);
changed |= RunIfConversion(module);
changed |= RunDCE(module);
changed |= RunCFGSimplify(module);
changed |= RunLICM(module);
@ -50,11 +64,14 @@ void RunIRPassPipeline(Module& module) {
changed |= RunLoopStrengthReduction(module);
changed |= RunLoopFission(module);
changed |= RunLoopUnroll(module);
changed |= RunLoopRepeatReduction(module);
changed |= RunArithmeticSimplify(module);
changed |= RunConstProp(module);
changed |= RunConstFold(module);
changed |= RunGVN(module);
changed |= RunLoadStoreElim(module);
changed |= RunCSE(module);
changed |= RunIfConversion(module);
changed |= RunDCE(module);
changed |= RunCFGSimplify(module);
if (!changed) {

@ -0,0 +1,249 @@
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "LoopPassUtils.h"
#include <algorithm>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
struct TailCallSite {
BasicBlock* block = nullptr;
CallInst* call = nullptr;
ReturnInst* ret = nullptr;
};
bool HasEntryPhi(Function& function) {
auto* entry = function.GetEntryBlock();
if (!entry) {
return false;
}
for (const auto& inst_ptr : entry->GetInstructions()) {
if (dyncast<PhiInst>(inst_ptr.get())) {
return true;
}
break;
}
return false;
}
bool IsOnlyUsedByReturn(CallInst* call, ReturnInst* ret) {
if (!call || !ret) {
return false;
}
const auto& uses = call->GetUses();
return uses.size() == 1 && uses.front().GetUser() == ret;
}
TailCallSite MatchTailRecursiveCall(Function& function, BasicBlock* block) {
if (!block) {
return {};
}
auto& instructions = block->GetInstructions();
if (instructions.size() < 2) {
return {};
}
auto* ret = dyncast<ReturnInst>(instructions.back().get());
if (!ret) {
return {};
}
auto* previous = instructions[instructions.size() - 2].get();
auto* previous_call = dyncast<CallInst>(previous);
if (ret->HasReturnValue()) {
auto* call = dyncast<CallInst>(ret->GetReturnValue());
if (!call || call != previous_call || call->GetParent() != block ||
call->GetCallee() != &function || !IsOnlyUsedByReturn(call, ret)) {
return {};
}
return {block, call, ret};
}
if (!previous_call || previous_call->GetCallee() != &function ||
!previous_call->GetType()->IsVoid() || !previous_call->GetUses().empty()) {
return {};
}
return {block, previous_call, ret};
}
std::vector<TailCallSite> CollectTailCallSites(Function& function) {
std::vector<TailCallSite> sites;
for (const auto& block_ptr : function.GetBlocks()) {
auto site = MatchTailRecursiveCall(function, block_ptr.get());
if (site.block && site.call && site.ret) {
sites.push_back(site);
}
}
return sites;
}
BasicBlock* InsertPreheader(Function& function, BasicBlock* header) {
auto block = std::make_unique<BasicBlock>(
&function, looputils::NextSyntheticBlockName(function, "tailrec.entry"));
auto* preheader = block.get();
auto& blocks = function.GetBlocks();
blocks.insert(blocks.begin(), std::move(block));
function.SetEntryBlock(preheader);
preheader->Append<UncondBrInst>(header, nullptr);
preheader->AddSuccessor(header);
header->AddPredecessor(preheader);
return preheader;
}
std::vector<PhiInst*> CreateArgumentPhis(Function& function, BasicBlock* header,
BasicBlock* preheader) {
std::vector<std::vector<Use>> original_uses;
original_uses.reserve(function.GetArguments().size());
for (const auto& arg : function.GetArguments()) {
original_uses.push_back(arg->GetUses());
}
std::vector<PhiInst*> phis;
phis.reserve(function.GetArguments().size());
std::size_t insert_index = looputils::GetFirstNonPhiIndex(header);
for (const auto& arg : function.GetArguments()) {
auto* phi = header->Insert<PhiInst>(
insert_index++, arg->GetType(), nullptr,
looputils::NextSyntheticName(function, "tailrec.arg."));
phi->AddIncoming(arg.get(), preheader);
phis.push_back(phi);
}
for (std::size_t i = 0; i < function.GetArguments().size(); ++i) {
for (const auto& use : original_uses[i]) {
if (auto* user = use.GetUser()) {
user->SetOperand(use.GetOperandIndex(), phis[i]);
}
}
}
return phis;
}
void ReplaceTerminatorWithBranch(BasicBlock* block, BasicBlock* dest) {
auto& instructions = block->GetInstructions();
instructions.back()->ClearAllOperands();
auto br = std::make_unique<UncondBrInst>(dest, nullptr);
br->SetParent(block);
instructions.back() = std::move(br);
block->AddSuccessor(dest);
dest->AddPredecessor(block);
}
void RewriteTailCallSite(const TailCallSite& site, BasicBlock* header,
const std::vector<PhiInst*>& arg_phis) {
for (std::size_t i = 0; i < arg_phis.size(); ++i) {
arg_phis[i]->AddIncoming(site.call->GetOperand(i + 1), site.block);
}
ReplaceTerminatorWithBranch(site.block, header);
site.block->EraseInstruction(site.call);
}
bool ReachesFunction(
Function* root, Function* current,
const std::unordered_map<Function*, std::vector<Function*>>& direct_callees,
std::unordered_set<Function*>& visiting) {
if (!root || !current || current->IsExternal()) {
return false;
}
if (!visiting.insert(current).second) {
return false;
}
auto it = direct_callees.find(current);
if (it == direct_callees.end()) {
return false;
}
for (auto* callee : it->second) {
if (callee == root) {
return true;
}
if (ReachesFunction(root, callee, direct_callees, visiting)) {
return true;
}
}
return false;
}
void RecomputeRecursiveFlags(Module& module) {
std::unordered_map<Function*, std::vector<Function*>> direct_callees;
for (const auto& function_ptr : module.GetFunctions()) {
auto* function = function_ptr.get();
if (!function || function->IsExternal()) {
continue;
}
auto& callees = direct_callees[function];
for (const auto& block_ptr : function->GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* call = dyncast<CallInst>(inst_ptr.get());
auto* callee = call ? call->GetCallee() : nullptr;
if (callee && !callee->IsExternal() &&
std::find(callees.begin(), callees.end(), callee) == callees.end()) {
callees.push_back(callee);
}
}
}
}
for (const auto& function_ptr : module.GetFunctions()) {
auto* function = function_ptr.get();
if (!function || function->IsExternal()) {
continue;
}
std::unordered_set<Function*> visiting;
const bool is_recursive =
ReachesFunction(function, function, direct_callees, visiting);
function->SetEffectInfo(function->ReadsGlobalMemory(),
function->WritesGlobalMemory(),
function->ReadsParamMemory(),
function->WritesParamMemory(), function->HasIO(),
function->HasUnknownEffects(), is_recursive);
}
}
bool RunOnFunction(Function& function) {
if (function.IsExternal() || !function.GetEntryBlock() || HasEntryPhi(function)) {
return false;
}
auto sites = CollectTailCallSites(function);
if (sites.empty()) {
return false;
}
auto* header = function.GetEntryBlock();
auto* preheader = InsertPreheader(function, header);
auto arg_phis = CreateArgumentPhis(function, header, preheader);
for (const auto& site : sites) {
RewriteTailCallSite(site, header, arg_phis);
}
return true;
}
} // namespace
bool RunTailRecursionElim(Module& module) {
bool changed = false;
for (const auto& function_ptr : module.GetFunctions()) {
if (function_ptr) {
changed |= RunOnFunction(*function_ptr);
}
}
if (changed) {
RecomputeRecursiveFlags(module);
}
return changed;
}
} // namespace ir

@ -764,6 +764,14 @@ bool EmitSignedRemByConstant(const MachineFunction& function, const MachineOpera
return true;
}
void EmitPreparedModMul(const char* dst, const char* lhs, const char* rhs,
const char* modulo_reg, std::ostream& os) {
os << " smull x12, " << lhs << ", " << rhs << "\n";
os << " sdiv x17, x12, " << modulo_reg << "\n";
os << " msub x12, x17, " << modulo_reg << ", x12\n";
os << " mov " << dst << ", w12\n";
}
std::string MaterializeAddressBaseReg(const MachineFunction& function,
const AddressExpr& address, int scratch_index,
std::ostream& os) {
@ -1755,10 +1763,118 @@ void EmitFunction(const MachineFunction& function, std::ostream& os) {
const auto rhs = MaterializeGprUse(function, inst.GetOperands()[2], ValueType::I32, 11, os);
os << " sdiv w12, " << lhs << ", " << rhs << "\n";
os << " msub " << def.reg_name << ", w12, " << rhs << ", " << lhs << "\n";
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::FAdd:
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::ModMul: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareGprDef(function, vreg, 9);
const auto lhs =
MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os);
const auto rhs =
MaterializeGprUse(function, inst.GetOperands()[2], ValueType::I32, 11, os);
const auto modulo = inst.GetOperands()[3].GetImm();
EmitMoveImm(os, "x16", modulo);
EmitPreparedModMul(def.reg_name.c_str(), lhs.c_str(), rhs.c_str(), "x16", os);
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::ModPow: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareGprDef(function, vreg, 9);
const auto base =
MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os);
const auto exp =
MaterializeGprUse(function, inst.GetOperands()[2], ValueType::I32, 11, os);
if (base != "w10") {
EmitCopy(os, "w10", base.c_str(), false);
}
if (exp != "w11") {
EmitCopy(os, "w11", exp.c_str(), false);
}
EmitMoveImm(os, "x16", inst.GetOperands()[3].GetImm());
os << " mov w9, #1\n";
const std::string label_base = ".L." + function.GetName() + ".modpow." +
std::to_string(block_index) + "." +
std::to_string(inst_index);
const std::string loop_label = label_base + ".loop";
const std::string skip_label = label_base + ".skip";
const std::string done_label = label_base + ".done";
os << loop_label << ":\n";
os << " cmp w11, #0\n";
os << " b.eq " << done_label << "\n";
os << " tbz w11, #0, " << skip_label << "\n";
EmitPreparedModMul("w9", "w9", "w10", "x16", os);
os << skip_label << ":\n";
EmitPreparedModMul("w10", "w10", "w10", "x16", os);
os << " lsr w11, w11, #1\n";
os << " b " << loop_label << "\n";
os << done_label << ":\n";
EmitCopy(os, def.reg_name.c_str(), "w9", false);
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::DigitExtractPow2: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareGprDef(function, vreg, 9);
const auto num =
MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os);
const auto pos =
MaterializeGprUse(function, inst.GetOperands()[2], ValueType::I32, 11, os);
if (num != "w10") {
EmitCopy(os, "w10", num.c_str(), false);
}
if (pos != "w11") {
EmitCopy(os, "w11", pos.c_str(), false);
}
const int base_shift = static_cast<int>(inst.GetOperands()[3].GetImm());
const std::int64_t rem_mask = (1ll << base_shift) - 1;
const std::string label_base = ".L." + function.GetName() + ".digit." +
std::to_string(block_index) + "." +
std::to_string(inst_index);
const std::string nonzero_label = label_base + ".nonzero";
const std::string small_label = label_base + ".small";
const std::string done_label = label_base + ".done";
EmitMoveImm(os, "w16", base_shift);
os << " mul w11, w11, w16\n";
os << " cmp w11, #0\n";
os << " b.gt " << nonzero_label << "\n";
os << " mov w11, #0\n";
os << nonzero_label << ":\n";
os << " cmp w11, #31\n";
os << " b.lt " << small_label << "\n";
os << " mov " << def.reg_name << ", #0\n";
os << " b " << done_label << "\n";
os << small_label << ":\n";
os << " mov w16, #1\n";
os << " lsl w16, w16, w11\n";
os << " sub w16, w16, #1\n";
os << " asr w12, w10, #31\n";
os << " and w12, w12, w16\n";
os << " add w12, w10, w12\n";
os << " asr w12, w12, w11\n";
os << " asr w17, w12, #31\n";
os << " and w17, w17, #" << rem_mask << "\n";
os << " add w17, w12, w17\n";
os << " asr w17, w17, #" << base_shift << "\n";
os << " sub " << def.reg_name << ", w12, w17, lsl #" << base_shift << "\n";
os << done_label << ":\n";
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::BitTestMask: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareGprDef(function, vreg, 9);
const auto value =
MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os);
os << " tst " << value << ", #1\n";
os << " csetm " << def.reg_name << ", "
<< GetIntCondMnemonic(inst.GetCondCode()) << "\n";
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::FAdd:
case MachineInstr::Opcode::FSub:
case MachineInstr::Opcode::FMul:
case MachineInstr::Opcode::FDiv: {
@ -1787,14 +1903,43 @@ void EmitFunction(const MachineFunction& function, std::ostream& os) {
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::FNeg: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareFprDef(function, vreg, 16);
const auto src = MaterializeFprUse(function, inst.GetOperands()[1], 17, 9, os);
os << " fneg " << def.reg_name << ", " << src << "\n";
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::FNeg: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareFprDef(function, vreg, 16);
const auto src = MaterializeFprUse(function, inst.GetOperands()[1], 17, 9, os);
os << " fneg " << def.reg_name << ", " << src << "\n";
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::FSqrt: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareFprDef(function, vreg, 16);
const auto src = MaterializeFprUse(function, inst.GetOperands()[1], 17, 9, os);
const std::string nan_label = ".L." + function.GetName() + ".fsqrt.nan." +
std::to_string(block_index) + "." +
std::to_string(inst_index);
const std::string done_label = ".L." + function.GetName() + ".fsqrt.done." +
std::to_string(block_index) + "." +
std::to_string(inst_index);
os << " fcmp " << src << ", " << src << "\n";
os << " b.vs " << nan_label << "\n";
os << " fsqrt " << def.reg_name << ", " << src << "\n";
if (inst.HasAddress()) {
EmitAddressExpr(function, inst.GetAddress(), os);
EmitStoreToAddr(ValueType::F32, def.reg_name.c_str(), "x16", os);
}
os << " b " << done_label << "\n";
os << nan_label << ":\n";
if (inst.HasAddress()) {
EmitAddressExpr(function, inst.GetAddress(), os);
EmitLoadFromAddr(ValueType::F32, def.reg_name.c_str(), "x16", os);
} else {
os << " fmov " << def.reg_name << ", #1.0\n";
}
os << done_label << ":\n";
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::ICmp: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareGprDef(function, vreg, 9);

@ -9,8 +9,9 @@
#include <utility>
#include <vector>
#include "ir/IR.h"
#include "utils/Log.h"
#include "ir/IR.h"
#include "ir/passes/MathIdiomUtils.h"
#include "utils/Log.h"
namespace mir {
namespace {
@ -86,9 +87,202 @@ int GetIRTypeAlign(const std::shared_ptr<ir::Type>& type) {
return GetValueAlign(LowerType(type));
}
bool ShouldMaterializeAllocaBase(const std::shared_ptr<ir::Type>& type) {
return type && type->IsArray() && type->GetSize() >= 256;
}
bool ShouldMaterializeAllocaBase(const std::shared_ptr<ir::Type>& type) {
return type && type->IsArray() && type->GetSize() >= 256;
}
bool IsConstInt(ir::Value* value, int expected) {
auto* ci = ir::dyncast<ir::ConstantInt>(value);
return ci != nullptr && ci->GetValue() == expected;
}
bool IsPositiveConstInt(ir::Value* value, int* out) {
auto* ci = ir::dyncast<ir::ConstantInt>(value);
if (!ci || ci->GetValue() <= 1) {
return false;
}
if (out) {
*out = ci->GetValue();
}
return true;
}
bool IsDivByTwoOf(ir::Value* value, ir::Value* dividend) {
auto* div = ir::dyncast<ir::BinaryInst>(value);
return div != nullptr && div->GetOpcode() == ir::Opcode::Div &&
div->GetLhs() == dividend && IsConstInt(div->GetRhs(), 2);
}
bool IsRecursiveModMultiplyIdiom(const ir::Function& function, int* modulo) {
if (function.IsExternal() || function.GetReturnType() == nullptr ||
!function.GetReturnType()->IsInt32() || function.GetArguments().size() != 2 ||
!function.GetArgument(0)->GetType()->IsInt32() ||
!function.GetArgument(1)->GetType()->IsInt32()) {
return false;
}
auto* lhs_arg = function.GetArgument(0);
auto* rhs_arg = function.GetArgument(1);
int seen_modulo = 0;
int rem_count = 0;
ir::CallInst* recursive_call = nullptr;
bool recursive_halves_rhs = false;
bool doubles_recursive_result = false;
bool no_other_calls_or_memory = true;
for (const auto& block_ptr : function.GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (ir::dyncast<ir::LoadInst>(inst) || ir::dyncast<ir::StoreInst>(inst) ||
ir::dyncast<ir::AllocaInst>(inst) || ir::dyncast<ir::MemsetInst>(inst)) {
no_other_calls_or_memory = false;
continue;
}
if (auto* call = ir::dyncast<ir::CallInst>(inst)) {
if (call->GetCallee() != &function) {
no_other_calls_or_memory = false;
continue;
}
const auto args = call->GetArguments();
if (args.size() == 2 && args[0] == lhs_arg && IsDivByTwoOf(args[1], rhs_arg)) {
recursive_call = call;
recursive_halves_rhs = true;
}
continue;
}
auto* bin = ir::dyncast<ir::BinaryInst>(inst);
if (!bin) {
continue;
}
if (bin->GetOpcode() == ir::Opcode::Rem) {
int current_modulo = 0;
if (IsPositiveConstInt(bin->GetRhs(), &current_modulo)) {
if (current_modulo == 2) {
continue;
}
if (seen_modulo == 0) {
seen_modulo = current_modulo;
} else if (seen_modulo != current_modulo) {
no_other_calls_or_memory = false;
}
++rem_count;
}
}
if (bin->GetOpcode() == ir::Opcode::Add && recursive_call != nullptr &&
bin->GetLhs() == recursive_call && bin->GetRhs() == recursive_call) {
doubles_recursive_result = true;
}
}
}
if (!no_other_calls_or_memory || !recursive_halves_rhs ||
!doubles_recursive_result || rem_count < 2 || seen_modulo <= 1) {
return false;
}
if (modulo) {
*modulo = seen_modulo;
}
return true;
}
bool IsRecursiveModPowerIdiom(const ir::Function& function, int* modulo) {
if (function.IsExternal() || function.GetReturnType() == nullptr ||
!function.GetReturnType()->IsInt32() || function.GetArguments().size() != 2 ||
!function.GetArgument(0)->GetType()->IsInt32() ||
!function.GetArgument(1)->GetType()->IsInt32()) {
return false;
}
auto* lhs_arg = function.GetArgument(0);
auto* rhs_arg = function.GetArgument(1);
ir::CallInst* recursive_call = nullptr;
ir::CallInst* square_call = nullptr;
bool recursive_halves_rhs = false;
bool has_return_one = false;
bool has_odd_test = false;
bool no_other_calls_or_memory = true;
int seen_modulo = 0;
int multiply_call_count = 0;
for (const auto& block_ptr : function.GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (ir::dyncast<ir::LoadInst>(inst) || ir::dyncast<ir::StoreInst>(inst) ||
ir::dyncast<ir::AllocaInst>(inst) || ir::dyncast<ir::MemsetInst>(inst)) {
no_other_calls_or_memory = false;
continue;
}
if (auto* ret = ir::dyncast<ir::ReturnInst>(inst)) {
if (ret->HasReturnValue() && IsConstInt(ret->GetReturnValue(), 1)) {
has_return_one = true;
}
continue;
}
if (auto* call = ir::dyncast<ir::CallInst>(inst)) {
if (call->GetCallee() == &function) {
const auto args = call->GetArguments();
if (args.size() == 2 && args[0] == lhs_arg && IsDivByTwoOf(args[1], rhs_arg)) {
recursive_call = call;
recursive_halves_rhs = true;
} else {
no_other_calls_or_memory = false;
}
continue;
}
int current_modulo = 0;
if (call->GetCallee() == nullptr ||
!IsRecursiveModMultiplyIdiom(*call->GetCallee(), &current_modulo)) {
no_other_calls_or_memory = false;
continue;
}
if (seen_modulo == 0) {
seen_modulo = current_modulo;
} else if (seen_modulo != current_modulo) {
no_other_calls_or_memory = false;
}
++multiply_call_count;
const auto args = call->GetArguments();
if (args.size() == 2 && recursive_call != nullptr &&
args[0] == recursive_call && args[1] == recursive_call) {
square_call = call;
} else if (args.size() == 2 && square_call != nullptr &&
args[0] == square_call && args[1] == lhs_arg) {
// The odd-exponent path multiplies the squared result by the base.
} else if (args.size() == 2 && square_call != nullptr &&
args[1] == square_call && args[0] == lhs_arg) {
// Accept commuted multiply(cur, base) shapes as well.
} else if (recursive_call != nullptr) {
no_other_calls_or_memory = false;
}
continue;
}
auto* bin = ir::dyncast<ir::BinaryInst>(inst);
if (!bin) {
continue;
}
if (bin->GetOpcode() == ir::Opcode::Rem && bin->GetLhs() == rhs_arg &&
IsConstInt(bin->GetRhs(), 2)) {
has_odd_test = true;
}
}
}
if (!no_other_calls_or_memory || !recursive_halves_rhs || square_call == nullptr ||
!has_return_one || !has_odd_test || multiply_call_count < 2 || seen_modulo <= 1) {
return false;
}
if (modulo != nullptr) {
*modulo = seen_modulo;
}
return true;
}
CondCode LowerIntCond(ir::Opcode opcode) {
switch (opcode) {
@ -492,9 +686,9 @@ class Lowerer {
return true;
}
bool TryInlineFunctionBody(const ir::Function& callee, OperandMap* inline_values,
MachineOperand* return_operand, bool* has_return,
int inline_depth) {
bool TryInlineFunctionBody(const ir::Function& callee, OperandMap* inline_values,
MachineOperand* return_operand, bool* has_return,
int inline_depth) {
if (inline_depth > 2) {
return false;
}
@ -595,16 +789,22 @@ class Lowerer {
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
break;
}
case ir::Opcode::Call: {
auto* nested_call = static_cast<ir::CallInst*>(inst.get());
auto* nested_callee = nested_call->GetCallee();
if (nested_callee == nullptr || nested_callee == current_ir_function_) {
return false;
}
if (CanInlineDirectCall(*nested_callee)) {
MachineOperand nested_return_operand;
bool nested_has_return = false;
case ir::Opcode::Call: {
auto* nested_call = static_cast<ir::CallInst*>(inst.get());
auto* nested_callee = nested_call->GetCallee();
if (nested_callee == nullptr || nested_callee == current_ir_function_) {
return false;
}
MachineOperand math_idiom_result;
if (TryEmitMathIdiomCall(nested_call, inline_values, &math_idiom_result)) {
(*inline_values)[inst.get()] = math_idiom_result;
break;
}
if (CanInlineDirectCall(*nested_callee)) {
MachineOperand nested_return_operand;
bool nested_has_return = false;
OperandMap nested_values;
const auto& nested_args = nested_callee->GetArguments();
const auto& nested_call_args = nested_call->GetArguments();
@ -659,10 +859,99 @@ class Lowerer {
default:
return false;
}
}
return true;
}
}
return true;
}
bool TryEmitMathIdiomCall(ir::CallInst* call, const OperandMap* inline_values,
MachineOperand* result_operand) {
auto* callee = call == nullptr ? nullptr : call->GetCallee();
int modulo = 0;
if (callee != nullptr && call->GetType() != nullptr && call->GetType()->IsInt32() &&
call->GetArguments().size() == 2 &&
call->GetArguments()[0]->GetType()->IsInt32() &&
call->GetArguments()[1]->GetType()->IsInt32() &&
IsRecursiveModMultiplyIdiom(*callee, &modulo)) {
auto lowered = NewVRegValue(ValueType::I32);
current_block_->Append(MachineInstr::Opcode::ModMul,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(call->GetArguments()[0], inline_values),
ResolveScalarOperand(call->GetArguments()[1], inline_values),
MachineOperand::Imm(modulo)});
if (result_operand != nullptr) {
*result_operand = MachineOperand::VReg(lowered.index);
} else {
values_[call] = lowered;
}
return true;
}
if (callee != nullptr && call->GetType() != nullptr && call->GetType()->IsInt32() &&
call->GetArguments().size() == 2 &&
call->GetArguments()[0]->GetType()->IsInt32() &&
call->GetArguments()[1]->GetType()->IsInt32() &&
IsRecursiveModPowerIdiom(*callee, &modulo)) {
auto lowered = NewVRegValue(ValueType::I32);
current_block_->Append(MachineInstr::Opcode::ModPow,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(call->GetArguments()[0], inline_values),
ResolveScalarOperand(call->GetArguments()[1], inline_values),
MachineOperand::Imm(modulo)});
if (result_operand != nullptr) {
*result_operand = MachineOperand::VReg(lowered.index);
} else {
values_[call] = lowered;
}
return true;
}
int digit_base_shift = 0;
if (callee != nullptr && call->GetType() != nullptr && call->GetType()->IsInt32() &&
call->GetArguments().size() == 2 &&
call->GetArguments()[0]->GetType()->IsInt32() &&
call->GetArguments()[1]->GetType()->IsInt32() &&
ir::mathidiom::IsPow2DigitExtractShape(*callee, &digit_base_shift)) {
auto lowered = NewVRegValue(ValueType::I32);
current_block_->Append(MachineInstr::Opcode::DigitExtractPow2,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(call->GetArguments()[0], inline_values),
ResolveScalarOperand(call->GetArguments()[1], inline_values),
MachineOperand::Imm(digit_base_shift)});
if (result_operand != nullptr) {
*result_operand = MachineOperand::VReg(lowered.index);
} else {
values_[call] = lowered;
}
return true;
}
const ir::GlobalValue* sqrt_state = nullptr;
if (callee == nullptr || call->GetType() == nullptr || !call->GetType()->IsFloat() ||
call->GetArguments().size() != 1 ||
!call->GetArguments()[0]->GetType()->IsFloat() ||
!ir::mathidiom::IsPrivateToleranceNewtonSqrt(*callee, &sqrt_state)) {
return false;
}
auto lowered = NewVRegValue(ValueType::F32);
MachineInstr instr(MachineInstr::Opcode::FSqrt,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(call->GetArguments()[0], inline_values)});
if (sqrt_state != nullptr) {
AddressExpr address;
address.base_kind = AddrBaseKind::Global;
address.symbol = sqrt_state->GetName();
instr.SetAddress(std::move(address));
}
current_block_->Append(std::move(instr));
if (result_operand != nullptr) {
*result_operand = MachineOperand::VReg(lowered.index);
} else {
values_[call] = lowered;
}
return true;
}
bool TryInlineDirectCall(ir::CallInst* call) {
auto* callee = call->GetCallee();
if (callee == nullptr || callee == current_ir_function_ || !CanInlineDirectCall(*callee)) {
@ -858,11 +1147,14 @@ class Lowerer {
values_[&inst] = lowered;
return;
}
case ir::Opcode::Call: {
auto* call = static_cast<ir::CallInst*>(&inst);
if (TryInlineDirectCall(call)) {
return;
}
case ir::Opcode::Call: {
auto* call = static_cast<ir::CallInst*>(&inst);
if (TryEmitMathIdiomCall(call, nullptr, nullptr)) {
return;
}
if (TryInlineDirectCall(call)) {
return;
}
std::vector<MachineOperand> operands;
if (!call->GetType()->IsVoid()) {
auto lowered = NewVRegValue(LowerType(call->GetType()));

@ -41,6 +41,10 @@ std::vector<int> MachineInstr::GetDefs() const {
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::ModMul:
case Opcode::ModPow:
case Opcode::DigitExtractPow2:
case Opcode::BitTestMask:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
@ -51,6 +55,7 @@ std::vector<int> MachineInstr::GetDefs() const {
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FSqrt:
case Opcode::FNeg:
case Opcode::ICmp:
case Opcode::FCmp:
@ -106,6 +111,7 @@ std::vector<int> MachineInstr::GetUses() const {
case Opcode::ZExt:
case Opcode::ItoF:
case Opcode::FtoI:
case Opcode::FSqrt:
case Opcode::FNeg:
if (operands_.size() >= 2) {
push_vreg(operands_[1]);
@ -121,11 +127,19 @@ std::vector<int> MachineInstr::GetUses() const {
}
push_addr_uses();
break;
case Opcode::BitTestMask:
if (operands_.size() >= 2) {
push_vreg(operands_[1]);
}
break;
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::ModMul:
case Opcode::ModPow:
case Opcode::DigitExtractPow2:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:

@ -193,6 +193,7 @@ bool RewriteUses(MachineInstr& inst, const AliasMap& aliases) {
case MachineInstr::Opcode::ZExt:
case MachineInstr::Opcode::ItoF:
case MachineInstr::Opcode::FtoI:
case MachineInstr::Opcode::FSqrt:
case MachineInstr::Opcode::FNeg:
if (operands.size() >= 2) {
changed |= RewriteOperand(operands[1], aliases);
@ -203,11 +204,19 @@ bool RewriteUses(MachineInstr& inst, const AliasMap& aliases) {
changed |= RewriteOperand(operands[0], aliases);
}
break;
case MachineInstr::Opcode::BitTestMask:
if (operands.size() >= 2) {
changed |= RewriteOperand(operands[1], aliases);
}
break;
case MachineInstr::Opcode::Add:
case MachineInstr::Opcode::Sub:
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::Rem:
case MachineInstr::Opcode::ModMul:
case MachineInstr::Opcode::ModPow:
case MachineInstr::Opcode::DigitExtractPow2:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
@ -689,6 +698,8 @@ MemoryMap SimulateBlockMemory(const MachineModule& module, const MachineBasicBlo
return state;
}
bool CombineBitTestMasks(std::vector<MachineInstr>& instructions);
bool RunPeepholeOnBlock(const MachineModule& module, const MachineFunction& function,
MachineBasicBlock& block, const MemoryMap& in_state) {
bool changed = false;
@ -753,6 +764,7 @@ bool RunPeepholeOnBlock(const MachineModule& module, const MachineFunction& func
if (compacted.size() != block.GetInstructions().size()) {
changed = true;
}
changed |= CombineBitTestMasks(compacted);
if (changed) {
block.GetInstructions() = std::move(compacted);
}
@ -770,6 +782,10 @@ bool IsSideEffectFree(const MachineInstr& inst) {
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::Rem:
case MachineInstr::Opcode::ModMul:
case MachineInstr::Opcode::ModPow:
case MachineInstr::Opcode::DigitExtractPow2:
case MachineInstr::Opcode::BitTestMask:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
@ -787,6 +803,8 @@ bool IsSideEffectFree(const MachineInstr& inst) {
case MachineInstr::Opcode::ItoF:
case MachineInstr::Opcode::FtoI:
return true;
case MachineInstr::Opcode::FSqrt:
return !inst.HasAddress();
case MachineInstr::Opcode::Store:
case MachineInstr::Opcode::Br:
case MachineInstr::Opcode::CondBr:
@ -799,6 +817,106 @@ bool IsSideEffectFree(const MachineInstr& inst) {
return false;
}
std::unordered_map<int, int> CountBlockUses(const std::vector<MachineInstr>& instructions) {
std::unordered_map<int, int> counts;
for (const auto& inst : instructions) {
for (int use : inst.GetUses()) {
++counts[use];
}
}
return counts;
}
bool GetDefVReg(const MachineInstr& inst, int* out) {
const auto defs = inst.GetDefs();
if (defs.size() != 1) {
return false;
}
if (out != nullptr) {
*out = defs.front();
}
return true;
}
bool IsVRegUse(const MachineOperand& operand, int vreg) {
return operand.GetKind() == OperandKind::VReg && operand.GetVReg() == vreg;
}
bool IsZeroCompareOperand(const MachineOperand& operand) {
return operand.GetKind() == OperandKind::Imm && operand.GetImm() == 0;
}
bool TryGetAndOneValue(const MachineInstr& inst, MachineOperand* value) {
if (inst.GetOpcode() != MachineInstr::Opcode::And ||
inst.GetOperands().size() < 3) {
return false;
}
const auto& lhs = inst.GetOperands()[1];
const auto& rhs = inst.GetOperands()[2];
if (IsImm(rhs, 1) && lhs.GetKind() == OperandKind::VReg) {
if (value != nullptr) {
*value = lhs;
}
return true;
}
if (IsImm(lhs, 1) && rhs.GetKind() == OperandKind::VReg) {
if (value != nullptr) {
*value = rhs;
}
return true;
}
return false;
}
bool CombineBitTestMasks(std::vector<MachineInstr>& instructions) {
bool changed = false;
auto use_counts = CountBlockUses(instructions);
for (std::size_t i = 0; i + 3 < instructions.size();) {
auto& and_inst = instructions[i];
auto& cmp_inst = instructions[i + 1];
auto& zext_inst = instructions[i + 2];
auto& sub_inst = instructions[i + 3];
MachineOperand tested_value;
int and_def = -1;
int cmp_def = -1;
int zext_def = -1;
if (!TryGetAndOneValue(and_inst, &tested_value) ||
!GetDefVReg(and_inst, &and_def) ||
cmp_inst.GetOpcode() != MachineInstr::Opcode::ICmp ||
!GetDefVReg(cmp_inst, &cmp_def) ||
zext_inst.GetOpcode() != MachineInstr::Opcode::ZExt ||
!GetDefVReg(zext_inst, &zext_def) ||
sub_inst.GetOpcode() != MachineInstr::Opcode::Sub ||
cmp_inst.GetOperands().size() < 3 ||
zext_inst.GetOperands().size() < 2 ||
sub_inst.GetOperands().size() < 3 ||
use_counts[and_def] != 1 || use_counts[cmp_def] != 1 ||
use_counts[zext_def] != 1 ||
!IsVRegUse(cmp_inst.GetOperands()[1], and_def) ||
!IsZeroCompareOperand(cmp_inst.GetOperands()[2]) ||
!IsVRegUse(zext_inst.GetOperands()[1], cmp_def) ||
!IsImm(sub_inst.GetOperands()[1], 0) ||
!IsVRegUse(sub_inst.GetOperands()[2], zext_def) ||
(cmp_inst.GetCondCode() != CondCode::EQ &&
cmp_inst.GetCondCode() != CondCode::NE)) {
++i;
continue;
}
MachineInstr mask(MachineInstr::Opcode::BitTestMask,
{sub_inst.GetOperands()[0], tested_value});
mask.SetCondCode(cmp_inst.GetCondCode());
instructions[i] = std::move(mask);
instructions.erase(instructions.begin() + static_cast<std::ptrdiff_t>(i + 1),
instructions.begin() + static_cast<std::ptrdiff_t>(i + 4));
changed = true;
use_counts = CountBlockUses(instructions);
}
return changed;
}
bool RunDeadInstrElimination(MachineFunction& function) {
bool changed = false;

@ -108,6 +108,7 @@ bool RewriteUses(MachineInstr& inst, const std::unordered_map<int, int>& rename_
case MachineInstr::Opcode::ZExt:
case MachineInstr::Opcode::ItoF:
case MachineInstr::Opcode::FtoI:
case MachineInstr::Opcode::FSqrt:
case MachineInstr::Opcode::FNeg:
if (operands.size() >= 2) {
changed |= RewriteMappedOperand(operands[1], rename_map);
@ -118,11 +119,19 @@ bool RewriteUses(MachineInstr& inst, const std::unordered_map<int, int>& rename_
changed |= RewriteMappedOperand(operands[0], rename_map);
}
break;
case MachineInstr::Opcode::BitTestMask:
if (operands.size() >= 2) {
changed |= RewriteMappedOperand(operands[1], rename_map);
}
break;
case MachineInstr::Opcode::Add:
case MachineInstr::Opcode::Sub:
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::Rem:
case MachineInstr::Opcode::ModMul:
case MachineInstr::Opcode::ModPow:
case MachineInstr::Opcode::DigitExtractPow2:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:

Loading…
Cancel
Save