diff --git a/.codex b/.codex new file mode 100644 index 0000000..e69de29 diff --git a/SESSION_HANDOFF.md b/SESSION_HANDOFF.md new file mode 100644 index 0000000..f03a303 --- /dev/null +++ b/SESSION_HANDOFF.md @@ -0,0 +1,211 @@ +# Session Handoff + +Date: 2026-05-28 + +## Repo State + +- Current branch: `Shrink` +- Worktree is dirty; do not reset blindly. +- Modified tracked files: + - `include/ir/IR.h` + - `scripts/run_all_tests.sh` + - `scripts/verify_asm.sh` + - `scripts/verify_ir.sh` + - `src/ir/analysis/DominatorTree.cpp` + - `src/ir/analysis/LoopInfo.cpp` + - `src/ir/passes/CMakeLists.txt` + - `src/ir/passes/PassManager.cpp` + - `src/main.cpp` + - `src/mir/AsmPrinter.cpp` + - `src/mir/Lowering.cpp` + - `src/mir/MIRFunction.cpp` + - `src/mir/passes/Peephole.cpp` + - `sylib/sylib.c` +- New untracked files: + - `src/ir/passes/LICM.cpp` + - `src/ir/passes/LoopFission.cpp` + - `src/ir/passes/LoopIdiom.cpp` + - `src/ir/passes/LoopParallelize.cpp` + - `src/ir/passes/LoopPassUtils.h` + - `src/ir/passes/LoopUnroll.cpp` + - `src/ir/passes/StrengthReduction.cpp` + +## Toolchain On Current Machine + +- `cmake 3.22.1` +- `g++ 11.4.0` +- `clang 14.0.0` +- `llc 14.0.0` +- `aarch64-linux-gnu-gcc 11.4.0` +- `qemu-aarch64 6.2.0` + +Required packages on a fresh Ubuntu: + +```bash +sudo apt update +sudo apt install -y \ + build-essential \ + cmake \ + clang \ + llvm \ + gcc-aarch64-linux-gnu \ + qemu-user \ + libc6-arm64-cross +``` + +## Important Build Detail + +- The repo vendors `antlr4-runtime-4.13.2` in `third_party`, so no system ANTLR runtime install is needed. +- Current frontend build consumes generated parser sources from `build/generated/antlr4` if present. +- There is also parser source in `src/antlr4/`, but current CMake does not wire that directory directly into the build. +- Safest migration path: copy the repo together with the current `build/generated/antlr4` directory, or later patch CMake to use `src/antlr4/*.cpp`. + +## Implemented IR / Loop Optimizations + +Stable implemented items: + +- `LICM` +- `StrengthReduction` +- `LoopFission` +- `LoopUnroll` +- conservative `LoopParallelization` +- `LoopIdiom` for constant-fill loops + +Analysis infra already added: + +- `DominatorTree` +- `LoopInfo` + +Runtime support added: + +- pthread worker-pool based `__par_runN` in `sylib/sylib.c` +- `__fill_i32` helper in `sylib/sylib.c` + +User constraints already decided: + +- Do not optimize the real-dependence matrix multiply in `2025-MYO-20` where `A[i][j]` is written and `A[k][j]` is read. +- Reduction parallelization is still disabled. + +## Timing Scripts + +Timing output was added to: + +- `scripts/verify_ir.sh` +- `scripts/verify_asm.sh` +- `scripts/run_all_tests.sh` + +User requirement: + +- Every test round should always report: + - `test/test_case/performance/2025-MYO-20.sy` + - `./scripts/run_all_tests.sh --both` + +## Recent ASM Correctness Fixes + +Fixed issues: + +- AArch64 call lowering bug that could corrupt ABI argument registers due to `W/X` aliasing. +- Duplicate local labels like `.par.exit` across worker functions by prefixing block labels with the function name. +- Duplicate callee-saved save/restore of alias registers like `w8/x8`. + +Relevant files: + +- `src/mir/Lowering.cpp` +- `src/mir/AsmPrinter.cpp` +- `src/mir/MIRFunction.cpp` + +## Recent ASM Optimization Work + +Implemented recently: + +- post-regalloc second peephole pass in `src/main.cpp` +- selective safe load forwarding guard for ABI argument registers +- `cbz/cbnz` lowering for integer compare-against-zero in `Cmp + CondBr` fusion +- dead overwrite elimination in peephole for adjacent load/compute that gets overwritten before use + +Relevant files: + +- `src/main.cpp` +- `src/mir/Lowering.cpp` +- `src/mir/passes/Peephole.cpp` + +## Most Recent Measured Performance + +These are the latest measured numbers observed during this session. + +IR: + +- `2025-MYO-20` stable reference before latest ASM-only work: + - around `31.109s` + - earlier stable reference before that: around `30.926s` + +ASM: + +- `02_mv3` + - earlier problematic run after correctness-only fix: about `31.662s` + - after later backend cleanup, best observed run in this session: about `31.505s` + - another later run: about `31.529s` +- `01_mm2` + - earlier reference in this session: about `38.010s` + - later improved run: about `37.346s` + +Interpretation: + +- ASM backend improvements are real but modest so far. +- Main remaining bottleneck is still heavy stack traffic in hot loops. + +## Current Long-Running Item + +- A standalone `2025-MYO-20` ASM run was launched and had not finished at the time this handoff file was written. +- A full `./scripts/run_all_tests.sh --both` run had progressed to the final `2025-MYO-20` ASM item instead of failing early, but final completion time was still pending. + +## Good Commands To Resume Work + +Build: + +```bash +cmake -S . -B build -DCMAKE_BUILD_TYPE=Release +cmake --build build -j"$(nproc)" --target compiler +``` + +Quick correctness: + +```bash +./scripts/verify_ir.sh test/test_case/functional/simple_add.sy /tmp/ir_check --run +./scripts/verify_asm.sh test/test_case/functional/simple_add.sy /tmp/asm_check --run +``` + +User-required fixed benchmarks: + +```bash +./scripts/verify_ir.sh test/test_case/performance/2025-MYO-20.sy /tmp/timed_2025 --run +./scripts/run_all_tests.sh --both +``` + +Useful ASM profiling targets: + +```bash +./scripts/verify_asm.sh test/test_case/performance/01_mm2.sy /tmp/asm_mm2 --run +./scripts/verify_asm.sh test/test_case/performance/02_mv3.sy /tmp/asm_mv3 --run +./scripts/verify_asm.sh test/test_case/performance/2025-MYO-20.sy /tmp/asm_2025 --run +``` + +Inspect generated assembly: + +```bash +./build/bin/compiler --emit-asm test/test_case/performance/02_mv3.sy > /tmp/02_mv3.s +./build/bin/compiler --emit-asm test/test_case/performance/01_mm2.sy > /tmp/01_mm2.s +./build/bin/compiler --emit-asm test/test_case/performance/2025-MYO-20.sy > /tmp/2025.s +``` + +## Suggested Next Steps + +Priority order: + +1. Finish measuring `2025-MYO-20` ASM and a complete `--both` run on the faster Ubuntu machine. +2. Keep working on MIR/ASM backend, not IR parallelization. +3. Target hot-loop stack traffic: + - reduce phi-related spill/reload churn + - widen zero-compare branch simplification beyond the current fused path + - add more dead store / dead load cleanup after frame lowering +4. Only claim speedups when confirmed with the fixed benchmark pair above. diff --git a/include/ir/IR.h b/include/ir/IR.h index 8b2ec34..6155059 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -35,6 +35,7 @@ #include #include #include +#include #include #include @@ -530,4 +531,77 @@ class IRPrinter { void Print(const Module& module, std::ostream& os); }; +namespace analysis { + +class DominatorTree { + public: + explicit DominatorTree(Function& func); + + BasicBlock* GetIDom(BasicBlock* bb) const; + bool Dominates(BasicBlock* a, BasicBlock* b) const; + const std::vector& GetDF(BasicBlock* bb) const; + const std::vector& GetChildren(BasicBlock* bb) const; + const std::vector& GetRPO() const; + + private: + void Compute(); + void ComputeRPO(BasicBlock* entry); + BasicBlock* Intersect(BasicBlock* b1, BasicBlock* b2); + void ComputeDF(); + + Function& func_; + std::vector rpo_; + std::unordered_map rpo_index_; + std::unordered_map idom_; + std::unordered_map> children_; + std::unordered_map> df_; +}; + +class Loop { + public: + BasicBlock* GetHeader() const { return header_; } + const std::vector& GetLatches() const { return latches_; } + const std::unordered_set& GetBlocks() const { return blocks_; } + bool Contains(BasicBlock* bb) const { return blocks_.count(bb) != 0; } + BasicBlock* GetPreheader() const { return preheader_; } + const std::vector& GetExitBlocks() const { return exit_blocks_; } + Loop* GetParent() const { return parent_; } + const std::vector& GetChildren() const { return children_; } + size_t GetDepth() const { return depth_; } + bool IsParallelCandidate() const { return parallel_candidate_; } + + private: + friend class LoopInfo; + + BasicBlock* header_ = nullptr; + std::vector latches_; + std::unordered_set blocks_; + BasicBlock* preheader_ = nullptr; + std::vector exit_blocks_; + Loop* parent_ = nullptr; + std::vector children_; + size_t depth_ = 1; + bool parallel_candidate_ = false; +}; + +class LoopInfo { + public: + LoopInfo(Function& func, const DominatorTree& dom_tree); + + const std::vector>& GetLoops() const { return loops_; } + Loop* GetLoopFor(BasicBlock* bb) const; + + private: + void Compute(); + void ComputeNesting(); + void ComputeParallelFlags(); + + Function& func_; + const DominatorTree& dom_tree_; + std::vector> loops_; + std::unordered_map innermost_loop_; +}; + +} // namespace analysis + } // namespace ir diff --git a/scripts/run_all_tests.sh b/scripts/run_all_tests.sh index 25a52d4..53771a7 100755 --- a/scripts/run_all_tests.sh +++ b/scripts/run_all_tests.sh @@ -8,6 +8,18 @@ set -uo pipefail +now_ns() { + date +%s%N +} + +format_ns() { + local ns=$1 + local ms=$((ns / 1000000)) + local sec=$((ms / 1000)) + local rem_ms=$((ms % 1000)) + printf '%d.%03ds' "$sec" "$rem_ms" +} + mode="ir" if [[ "${1:-}" == "--asm" ]]; then mode="asm" @@ -31,6 +43,9 @@ passed=0 failed=0 skipped=0 fail_list=() +RUN_LAST_TOTAL_NS=0 +RUN_LAST_BREAKDOWN="" +batch_start_ns=$(now_ns) run_ir_test() { local sy="$1" @@ -46,15 +61,24 @@ run_ir_test() { local expected_file="$dir/$stem.out" local stdout_file="$out_dir/$stem.stdout" local actual_file="$out_dir/$stem.actual.out" + local start_ns + start_ns=$(now_ns) # 生成 IR + local emit_start_ns + emit_start_ns=$(now_ns) if ! timeout 30 "$compiler" --emit-ir "$sy" > "$out_file" 2>/dev/null; then + RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns )) + RUN_LAST_BREAKDOWN="emit=$(format_ns $((RUN_LAST_TOTAL_NS)))" echo " [SKIP-IR] $sy (编译器报错或超时)" return 2 fi + local emit_ns=$(( $(now_ns) - emit_start_ns )) # 需要 llc + clang if ! command -v llc >/dev/null 2>&1 || ! command -v clang >/dev/null 2>&1; then + RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns )) + RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns")" echo " [SKIP-IR] $sy (缺少 llc/clang)" return 2 fi @@ -62,21 +86,30 @@ run_ir_test() { local obj="$out_dir/$stem.o" local exe="$out_dir/$stem" + local lower_link_start_ns + lower_link_start_ns=$(now_ns) if ! llc -filetype=obj "$out_file" -o "$obj" 2>/dev/null; then + RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns )) + RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns")" echo " [SKIP-IR] $sy (llc 编译失败)" return 2 fi - if ! clang -no-pie "$obj" sylib/sylib.c -o "$exe" -lm 2>/dev/null; then + if ! clang -no-pie "$obj" sylib/sylib.c -o "$exe" -lm -pthread 2>/dev/null; then + RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns )) + RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns") lower+link=$(format_ns $(( $(now_ns) - lower_link_start_ns )))" echo " [SKIP-IR] $sy (clang 链接失败)" return 2 fi + local lower_link_ns=$(( $(now_ns) - lower_link_start_ns )) set +e # performance 用例给更长的超时时间 - local run_timeout=30 + local run_timeout=3000 if [[ "$sy" == *"performance"* ]]; then - run_timeout=1000 + run_timeout=3000 fi + local run_start_ns + run_start_ns=$(now_ns) if [[ -f "$stdin_file" ]]; then timeout $run_timeout "$exe" < "$stdin_file" > "$stdout_file" 2>/dev/null else @@ -84,6 +117,9 @@ run_ir_test() { fi local status=$? set -e + local run_ns=$(( $(now_ns) - run_start_ns )) + RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns )) + RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns") lower+link=$(format_ns "$lower_link_ns") run=$(format_ns "$run_ns")" # timeout 返回 124 表示超时,标记为 SKIP if [[ $status -eq 124 ]]; then @@ -128,24 +164,40 @@ run_asm_test() { local stdout_file="$out_dir/$stem.stdout" local actual_file="$out_dir/$stem.actual.out" local exe="$out_dir/$stem" + local start_ns + start_ns=$(now_ns) # 生成汇编 + local emit_start_ns + emit_start_ns=$(now_ns) if ! timeout 30 "$compiler" --emit-asm "$sy" > "$asm_file" 2>/dev/null; then + RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns )) + RUN_LAST_BREAKDOWN="emit=$(format_ns "$RUN_LAST_TOTAL_NS")" echo " [SKIP-ASM] $sy (编译器报错或超时)" return 2 fi + local emit_ns=$(( $(now_ns) - emit_start_ns )) if ! command -v aarch64-linux-gnu-gcc >/dev/null 2>&1; then + RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns )) + RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns")" echo " [SKIP-ASM] $sy (缺少 aarch64-linux-gnu-gcc)" return 2 fi - if ! timeout 30 aarch64-linux-gnu-gcc "$asm_file" sylib/sylib.c -o "$exe" -static 2>/dev/null; then + local link_start_ns + link_start_ns=$(now_ns) + if ! timeout 30 aarch64-linux-gnu-gcc "$asm_file" sylib/sylib.c -o "$exe" -static -pthread 2>/dev/null; then + RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns )) + RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns") asm+link=$(format_ns $(( $(now_ns) - link_start_ns )))" echo " [SKIP-ASM] $sy (汇编/链接失败)" return 2 fi + local link_ns=$(( $(now_ns) - link_start_ns )) if ! command -v qemu-aarch64 >/dev/null 2>&1; then + RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns )) + RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns") asm+link=$(format_ns "$link_ns")" echo " [SKIP-ASM] $sy (缺少 qemu-aarch64)" return 2 fi @@ -156,6 +208,8 @@ run_asm_test() { if [[ "$sy" == *"performance"* ]]; then run_timeout=1000 fi + local run_start_ns + run_start_ns=$(now_ns) if [[ -f "$stdin_file" ]]; then timeout $run_timeout qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file" 2>/dev/null else @@ -163,6 +217,9 @@ run_asm_test() { fi local status=$? set -e + local run_ns=$(( $(now_ns) - run_start_ns )) + RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns )) + RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns") asm+link=$(format_ns "$link_ns") run=$(format_ns "$run_ns")" # timeout 返回 124 表示超时,标记为 SKIP if [[ $status -eq 124 ]]; then @@ -206,10 +263,10 @@ for sy in "${test_files[@]}"; do run_ir_test "$sy" rc=$? if [[ $rc -eq 0 ]]; then - echo " [PASS-IR] $sy" + echo " [PASS-IR] $sy ($(format_ns "$RUN_LAST_TOTAL_NS"); $RUN_LAST_BREAKDOWN)" passed=$((passed + 1)) elif [[ $rc -eq 1 ]]; then - echo " [FAIL-IR] $sy" + echo " [FAIL-IR] $sy ($(format_ns "$RUN_LAST_TOTAL_NS"); $RUN_LAST_BREAKDOWN)" failed=$((failed + 1)) fail_list+=("$sy (IR)") else @@ -221,12 +278,12 @@ for sy in "${test_files[@]}"; do run_asm_test "$sy" rc=$? if [[ $rc -eq 0 ]]; then - echo " [PASS-ASM] $sy" + echo " [PASS-ASM] $sy ($(format_ns "$RUN_LAST_TOTAL_NS"); $RUN_LAST_BREAKDOWN)" if [[ "$mode" == "asm" ]]; then passed=$((passed + 1)) fi elif [[ $rc -eq 1 ]]; then - echo " [FAIL-ASM] $sy" + echo " [FAIL-ASM] $sy ($(format_ns "$RUN_LAST_TOTAL_NS"); $RUN_LAST_BREAKDOWN)" if [[ "$mode" == "asm" ]]; then failed=$((failed + 1)) fi @@ -247,6 +304,7 @@ echo " 总计: $total" echo " 通过: $passed" echo " 失败: $failed" echo " 跳过: $skipped" +echo " 总耗时: $(format_ns $(( $(now_ns) - batch_start_ns )))" echo "" if [[ ${#fail_list[@]} -gt 0 ]]; then diff --git a/scripts/verify_asm.sh b/scripts/verify_asm.sh index e529839..11a6214 100755 --- a/scripts/verify_asm.sh +++ b/scripts/verify_asm.sh @@ -2,6 +2,18 @@ set -euo pipefail +now_ns() { + date +%s%N +} + +format_ns() { + local ns=$1 + local ms=$((ns / 1000000)) + local sec=$((ms / 1000)) + local rem_ms=$((ms % 1000)) + printf '%d.%03ds' "$sec" "$rem_ms" +} + if [[ $# -lt 1 || $# -gt 3 ]]; then echo "用法: $0 [output_dir] [--run]" >&2 exit 1 @@ -49,11 +61,18 @@ exe="$out_dir/$stem" stdin_file="$input_dir/$stem.in" expected_file="$input_dir/$stem.out" +total_start_ns=$(now_ns) +emit_start_ns=$(now_ns) "$compiler" --emit-asm "$input" > "$asm_file" +emit_elapsed_ns=$(( $(now_ns) - emit_start_ns )) echo "汇编已生成: $asm_file" +echo "汇编生成耗时: $(format_ns "$emit_elapsed_ns")" -aarch64-linux-gnu-gcc "$asm_file" sylib/sylib.c -o "$exe" -static +link_start_ns=$(now_ns) +aarch64-linux-gnu-gcc "$asm_file" sylib/sylib.c -o "$exe" -static -pthread +link_elapsed_ns=$(( $(now_ns) - link_start_ns )) echo "可执行文件已生成: $exe" +echo "汇编/链接耗时: $(format_ns "$link_elapsed_ns")" if [[ "$run_exec" == true ]]; then if ! command -v qemu-aarch64 >/dev/null 2>&1; then @@ -64,6 +83,7 @@ if [[ "$run_exec" == true ]]; then stdout_file="$out_dir/$stem.stdout" actual_file="$out_dir/$stem.actual.out" echo "运行 $exe ..." + run_start_ns=$(now_ns) set +e if [[ -f "$stdin_file" ]]; then qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file" @@ -72,8 +92,10 @@ if [[ "$run_exec" == true ]]; then fi status=$? set -e + run_elapsed_ns=$(( $(now_ns) - run_start_ns )) cat "$stdout_file" echo "退出码: $status" + echo "运行耗时: $(format_ns "$run_elapsed_ns")" { cat "$stdout_file" if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then @@ -95,3 +117,6 @@ if [[ "$run_exec" == true ]]; then echo "未找到预期输出文件,跳过比对: $expected_file" fi fi + +total_elapsed_ns=$(( $(now_ns) - total_start_ns )) +echo "总耗时: $(format_ns "$total_elapsed_ns")" diff --git a/scripts/verify_ir.sh b/scripts/verify_ir.sh index 9a97198..2d00ff5 100755 --- a/scripts/verify_ir.sh +++ b/scripts/verify_ir.sh @@ -3,6 +3,18 @@ set -euo pipefail +now_ns() { + date +%s%N +} + +format_ns() { + local ns=$1 + local ms=$((ns / 1000000)) + local sec=$((ms / 1000)) + local rem_ms=$((ms % 1000)) + printf '%d.%03ds' "$sec" "$rem_ms" +} + if [[ $# -lt 1 || $# -gt 3 ]]; then echo "用法: $0 [output_dir] [--run]" >&2 exit 1 @@ -43,8 +55,12 @@ stem=${base%.sy} out_file="$out_dir/$stem.ll" stdin_file="$input_dir/$stem.in" expected_file="$input_dir/$stem.out" +total_start_ns=$(now_ns) +emit_start_ns=$(now_ns) "$compiler" --emit-ir "$input" > "$out_file" +emit_elapsed_ns=$(( $(now_ns) - emit_start_ns )) echo "IR 已生成: $out_file" +echo "IR 生成耗时: $(format_ns "$emit_elapsed_ns")" if [[ "$run_exec" == true ]]; then if ! command -v llc >/dev/null 2>&1; then @@ -59,9 +75,13 @@ if [[ "$run_exec" == true ]]; then exe="$out_dir/$stem" stdout_file="$out_dir/$stem.stdout" actual_file="$out_dir/$stem.actual.out" + lower_link_start_ns=$(now_ns) llc -filetype=obj "$out_file" -o "$obj" - clang -no-pie "$obj" sylib/sylib.c -o "$exe" -lm + clang -no-pie "$obj" sylib/sylib.c -o "$exe" -lm -pthread + lower_link_elapsed_ns=$(( $(now_ns) - lower_link_start_ns )) + echo "IR 落地/链接耗时: $(format_ns "$lower_link_elapsed_ns")" echo "运行 $exe ..." + run_start_ns=$(now_ns) set +e if [[ -f "$stdin_file" ]]; then "$exe" < "$stdin_file" > "$stdout_file" @@ -70,8 +90,10 @@ if [[ "$run_exec" == true ]]; then fi status=$? set -e + run_elapsed_ns=$(( $(now_ns) - run_start_ns )) cat "$stdout_file" echo "退出码: $status" + echo "运行耗时: $(format_ns "$run_elapsed_ns")" { cat "$stdout_file" if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then @@ -92,3 +114,6 @@ if [[ "$run_exec" == true ]]; then echo "未找到预期输出文件,跳过比对: $expected_file" fi fi + +total_elapsed_ns=$(( $(now_ns) - total_start_ns )) +echo "总耗时: $(format_ns "$total_elapsed_ns")" diff --git a/src/ir/analysis/DominatorTree.cpp b/src/ir/analysis/DominatorTree.cpp index 5c5f1b9..14ddd13 100644 --- a/src/ir/analysis/DominatorTree.cpp +++ b/src/ir/analysis/DominatorTree.cpp @@ -21,151 +21,130 @@ namespace analysis { // ---------- DominatorTree ---------- -class DominatorTree { - public: - explicit DominatorTree(Function& func) : func_(func) { Compute(); } - - // idom[bb] 返回 bb 的直接支配者,entry 的 idom 为自身。 - BasicBlock* GetIDom(BasicBlock* bb) const { - auto it = idom_.find(bb); - return it != idom_.end() ? it->second : nullptr; +DominatorTree::DominatorTree(Function& func) : func_(func) { Compute(); } + +BasicBlock* DominatorTree::GetIDom(BasicBlock* bb) const { + auto it = idom_.find(bb); + return it != idom_.end() ? it->second : nullptr; +} + +bool DominatorTree::Dominates(BasicBlock* a, BasicBlock* b) const { + if (!a || !b) return false; + while (b) { + if (b == a) return true; + auto* p = GetIDom(b); + if (p == b) break; + b = p; } - - // 判断 a 是否支配 b。 - bool Dominates(BasicBlock* a, BasicBlock* b) const { - if (!a || !b) return false; - while (b) { - if (b == a) return true; - auto* p = GetIDom(b); - if (p == b) break; // entry - b = p; - } - return false; + return false; +} + +const std::vector& DominatorTree::GetDF(BasicBlock* bb) const { + static const std::vector empty; + auto it = df_.find(bb); + return it != df_.end() ? it->second : empty; +} + +const std::vector& DominatorTree::GetChildren( + BasicBlock* bb) const { + static const std::vector empty; + auto it = children_.find(bb); + return it != children_.end() ? it->second : empty; +} + +const std::vector& DominatorTree::GetRPO() const { return rpo_; } + +void DominatorTree::Compute() { + auto* entry = func_.GetEntry(); + if (!entry) return; + + ComputeRPO(entry); + if (rpo_.empty()) return; + + for (auto* bb : rpo_) { + idom_[bb] = nullptr; + rpo_index_[bb] = 0; } - - // 返回 bb 的支配边界。 - const std::vector& GetDF(BasicBlock* bb) const { - static const std::vector empty; - auto it = df_.find(bb); - return it != df_.end() ? it->second : empty; + for (size_t i = 0; i < rpo_.size(); ++i) { + rpo_index_[rpo_[i]] = i; } + idom_[entry] = entry; - // 返回支配树中 bb 的孩子列表。 - const std::vector& GetChildren(BasicBlock* bb) const { - static const std::vector empty; - auto it = children_.find(bb); - return it != children_.end() ? it->second : empty; - } - - // 按逆后序返回所有基本块。 - const std::vector& GetRPO() const { return rpo_; } - - private: - void Compute() { - auto* entry = func_.GetEntry(); - if (!entry) return; - - // 1. 计算逆后序(RPO) - ComputeRPO(entry); - if (rpo_.empty()) return; - - // 2. 初始化 + bool changed = true; + while (changed) { + changed = false; for (auto* bb : rpo_) { - idom_[bb] = nullptr; - rpo_index_[bb] = 0; - } - for (size_t i = 0; i < rpo_.size(); ++i) { - rpo_index_[rpo_[i]] = i; - } - idom_[entry] = entry; - - // 3. 迭代计算 idom(Cooper-Harvey-Kennedy 算法) - bool changed = true; - while (changed) { - changed = false; - for (auto* bb : rpo_) { - if (bb == entry) continue; - BasicBlock* new_idom = nullptr; - for (auto* pred : bb->GetPredecessors()) { - if (idom_.count(pred) && idom_[pred] != nullptr) { - if (!new_idom) { - new_idom = pred; - } else { - new_idom = Intersect(new_idom, pred); - } + if (bb == entry) continue; + BasicBlock* new_idom = nullptr; + for (auto* pred : bb->GetPredecessors()) { + if (idom_.count(pred) && idom_[pred] != nullptr) { + if (!new_idom) { + new_idom = pred; + } else { + new_idom = Intersect(new_idom, pred); } } - if (new_idom && idom_[bb] != new_idom) { - idom_[bb] = new_idom; - changed = true; - } } - } - - // 4. 建立 children 映射 - for (auto* bb : rpo_) { - auto* p = GetIDom(bb); - if (p && p != bb) { - children_[p].push_back(bb); + if (new_idom && idom_[bb] != new_idom) { + idom_[bb] = new_idom; + changed = true; } } - - // 5. 计算支配边界 - ComputeDF(); } - void ComputeRPO(BasicBlock* entry) { - std::unordered_set visited; - std::vector post_order; - std::function dfs = [&](BasicBlock* bb) { - visited.insert(bb); - for (auto* succ : bb->GetSuccessors()) { - if (!visited.count(succ)) { - dfs(succ); - } - } - post_order.push_back(bb); - }; - dfs(entry); - rpo_.assign(post_order.rbegin(), post_order.rend()); + for (auto* bb : rpo_) { + auto* p = GetIDom(bb); + if (p && p != bb) { + children_[p].push_back(bb); + } } - BasicBlock* Intersect(BasicBlock* b1, BasicBlock* b2) { - while (b1 != b2) { - while (rpo_index_[b1] > rpo_index_[b2]) b1 = idom_[b1]; - while (rpo_index_[b2] > rpo_index_[b1]) b2 = idom_[b2]; + ComputeDF(); +} + +void DominatorTree::ComputeRPO(BasicBlock* entry) { + std::unordered_set visited; + std::vector post_order; + std::function dfs = [&](BasicBlock* bb) { + visited.insert(bb); + for (auto* succ : bb->GetSuccessors()) { + if (!visited.count(succ)) { + dfs(succ); + } } - return b1; + post_order.push_back(bb); + }; + dfs(entry); + rpo_.assign(post_order.rbegin(), post_order.rend()); +} + +BasicBlock* DominatorTree::Intersect(BasicBlock* b1, BasicBlock* b2) { + while (b1 != b2) { + while (rpo_index_[b1] > rpo_index_[b2]) b1 = idom_[b1]; + while (rpo_index_[b2] > rpo_index_[b1]) b2 = idom_[b2]; } + return b1; +} - void ComputeDF() { - for (auto* bb : rpo_) { - df_[bb] = {}; - } - for (auto* bb : rpo_) { - if (bb->GetPredecessors().size() < 2) continue; - for (auto* pred : bb->GetPredecessors()) { - auto* runner = pred; - while (runner && runner != idom_[bb]) { - // 避免重复 - auto& df_set = df_[runner]; - if (std::find(df_set.begin(), df_set.end(), bb) == df_set.end()) { - df_set.push_back(bb); - } - if (runner == idom_[runner]) break; - runner = idom_[runner]; +void DominatorTree::ComputeDF() { + for (auto* bb : rpo_) { + df_[bb] = {}; + } + for (auto* bb : rpo_) { + if (bb->GetPredecessors().size() < 2) continue; + for (auto* pred : bb->GetPredecessors()) { + auto* runner = pred; + while (runner && runner != idom_[bb]) { + auto& df_set = df_[runner]; + if (std::find(df_set.begin(), df_set.end(), bb) == df_set.end()) { + df_set.push_back(bb); } + if (runner == idom_[runner]) break; + runner = idom_[runner]; } } } - - Function& func_; - std::vector rpo_; - std::unordered_map rpo_index_; - std::unordered_map idom_; - std::unordered_map> children_; - std::unordered_map> df_; -}; +} } // namespace analysis } // namespace ir diff --git a/src/ir/analysis/LoopInfo.cpp b/src/ir/analysis/LoopInfo.cpp index 9793dc6..8606bca 100644 --- a/src/ir/analysis/LoopInfo.cpp +++ b/src/ir/analysis/LoopInfo.cpp @@ -2,3 +2,242 @@ // - 识别循环结构与层级关系 // - 为后续优化(可选)提供循环信息 +#include "ir/IR.h" + +#include +#include +#include +#include + +namespace ir { +namespace analysis { + +namespace { + +bool IsInvariantForLoop(Value* value, Loop* loop) { + if (!value) return false; + if (dynamic_cast(value) != nullptr) return true; + if (dynamic_cast(value) != nullptr) return true; + if (dynamic_cast(value) != nullptr) return true; + if (dynamic_cast(value) != nullptr) return true; + + auto* inst = dynamic_cast(value); + if (!inst) return true; + auto* parent = inst->GetParent(); + return !parent || !loop->Contains(parent); +} + +Value* StripPointerCasts(Value* value) { + while (auto* gep = dynamic_cast(value)) { + value = gep->GetBase(); + } + return value; +} + +bool IsSimpleParallelStore(Value* ptr, Loop* loop) { + auto* gep = dynamic_cast(ptr); + if (!gep) return false; + if (!IsInvariantForLoop(StripPointerCasts(gep->GetBase()), loop)) return false; + return !IsInvariantForLoop(gep->GetIndex(), loop); +} + +} // namespace + +LoopInfo::LoopInfo(Function& func, const DominatorTree& dom_tree) + : func_(func), dom_tree_(dom_tree) { + Compute(); +} + +Loop* LoopInfo::GetLoopFor(BasicBlock* bb) const { + auto it = innermost_loop_.find(bb); + return it != innermost_loop_.end() ? it->second : nullptr; +} + +void LoopInfo::Compute() { + std::unordered_map loop_by_header; + + for (const auto& bb_ptr : func_.GetBlocks()) { + auto* tail = bb_ptr.get(); + if (!tail) continue; + for (auto* succ : tail->GetSuccessors()) { + if (!succ || !dom_tree_.Dominates(succ, tail)) continue; + + Loop* loop = nullptr; + auto it = loop_by_header.find(succ); + if (it == loop_by_header.end()) { + auto owned = std::make_unique(); + owned->header_ = succ; + loop = owned.get(); + loops_.push_back(std::move(owned)); + loop_by_header[succ] = loop; + } else { + loop = it->second; + } + + if (std::find(loop->latches_.begin(), loop->latches_.end(), tail) == + loop->latches_.end()) { + loop->latches_.push_back(tail); + } + + std::unordered_set natural_loop; + std::queue worklist; + natural_loop.insert(succ); + if (natural_loop.insert(tail).second) { + worklist.push(tail); + } + + while (!worklist.empty()) { + auto* node = worklist.front(); + worklist.pop(); + for (auto* pred : node->GetPredecessors()) { + if (natural_loop.insert(pred).second) { + worklist.push(pred); + } + } + } + + loop->blocks_.insert(natural_loop.begin(), natural_loop.end()); + } + } + + for (const auto& loop_ptr : loops_) { + auto* loop = loop_ptr.get(); + if (!loop) continue; + + std::vector outside_preds; + std::unordered_set seen_outside_preds; + for (auto* pred : loop->header_->GetPredecessors()) { + if (!loop->Contains(pred) && seen_outside_preds.insert(pred).second) { + outside_preds.push_back(pred); + } + } + if (outside_preds.size() == 1 && + outside_preds.front()->GetSuccessors().size() == 1 && + outside_preds.front()->GetSuccessors().front() == loop->header_) { + loop->preheader_ = outside_preds.front(); + } + + std::unordered_set exit_set; + for (auto* block : loop->blocks_) { + for (auto* succ : block->GetSuccessors()) { + if (!loop->Contains(succ) && exit_set.insert(succ).second) { + loop->exit_blocks_.push_back(succ); + } + } + } + } + + ComputeNesting(); + ComputeParallelFlags(); +} + +void LoopInfo::ComputeNesting() { + std::vector ordered; + ordered.reserve(loops_.size()); + for (const auto& loop_ptr : loops_) { + ordered.push_back(loop_ptr.get()); + } + + std::sort(ordered.begin(), ordered.end(), [](Loop* lhs, Loop* rhs) { + if (lhs->GetBlocks().size() != rhs->GetBlocks().size()) { + return lhs->GetBlocks().size() < rhs->GetBlocks().size(); + } + return lhs->GetHeader()->GetName() < rhs->GetHeader()->GetName(); + }); + + for (auto* loop : ordered) { + Loop* parent = nullptr; + for (auto* candidate : ordered) { + if (candidate == loop) continue; + if (candidate->GetBlocks().size() <= loop->GetBlocks().size()) continue; + + bool contains_all = true; + for (auto* block : loop->GetBlocks()) { + if (!candidate->Contains(block)) { + contains_all = false; + break; + } + } + if (!contains_all) continue; + + if (!parent || candidate->GetBlocks().size() < parent->GetBlocks().size()) { + parent = candidate; + } + } + + loop->parent_ = parent; + loop->depth_ = parent ? parent->depth_ + 1 : 1; + if (parent) { + parent->children_.push_back(loop); + } + } + + for (const auto& bb_ptr : func_.GetBlocks()) { + auto* bb = bb_ptr.get(); + if (!bb) continue; + + Loop* best = nullptr; + for (const auto& loop_ptr : loops_) { + auto* loop = loop_ptr.get(); + if (!loop->Contains(bb)) continue; + if (!best || loop->GetDepth() > best->GetDepth()) { + best = loop; + } + } + if (best) { + innermost_loop_[bb] = best; + } + } +} + +void LoopInfo::ComputeParallelFlags() { + for (const auto& loop_ptr : loops_) { + auto* loop = loop_ptr.get(); + if (!loop || loop->GetBlocks().empty()) continue; + + bool saw_store = false; + bool parallel = true; + std::unordered_set stored_ptrs; + + for (auto* block : loop->GetBlocks()) { + for (const auto& inst_ptr : block->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (!inst) continue; + + if (inst->GetOpcode() == Opcode::Call) { + parallel = false; + break; + } + if (auto* store = dynamic_cast(inst)) { + saw_store = true; + auto* ptr = store->GetPtr(); + stored_ptrs.insert(ptr); + if (!IsSimpleParallelStore(ptr, loop)) { + parallel = false; + break; + } + } + } + if (!parallel) break; + } + + if (parallel) { + for (auto* block : loop->GetBlocks()) { + for (const auto& inst_ptr : block->GetInstructions()) { + auto* load = dynamic_cast(inst_ptr.get()); + if (!load) continue; + if (stored_ptrs.count(load->GetPtr()) != 0) { + parallel = false; + break; + } + } + if (!parallel) break; + } + } + + loop->parallel_candidate_ = parallel && saw_store; + } +} + +} // namespace analysis +} // namespace ir diff --git a/src/ir/passes/CMakeLists.txt b/src/ir/passes/CMakeLists.txt index 98867f5..0d3a5d2 100644 --- a/src/ir/passes/CMakeLists.txt +++ b/src/ir/passes/CMakeLists.txt @@ -1,6 +1,12 @@ add_library(ir_passes STATIC PassManager.cpp Mem2Reg.cpp + LICM.cpp + StrengthReduction.cpp + LoopIdiom.cpp + LoopFission.cpp + LoopUnroll.cpp + LoopParallelize.cpp ConstFold.cpp ConstProp.cpp CSE.cpp diff --git a/src/ir/passes/LICM.cpp b/src/ir/passes/LICM.cpp new file mode 100644 index 0000000..4561e33 --- /dev/null +++ b/src/ir/passes/LICM.cpp @@ -0,0 +1,262 @@ +// 循环不变代码外提(LICM): +// - 基于 DominatorTree + LoopInfo 识别自然循环 +// - 将循环内不变且可安全提前执行的指令移动到 preheader +// - 顺带消除同一循环中重复的不变表达式 + +#include "ir/IR.h" + +#include +#include +#include +#include +#include + +namespace ir { +namespace passes { + +namespace { + +struct ExprKey { + Opcode opcode; + CmpOp cmp_op = CmpOp::Eq; + CastOp cast_op = CastOp::IntToFloat; + std::vector operands; + + bool operator==(const ExprKey& other) const { + return opcode == other.opcode && cmp_op == other.cmp_op && + cast_op == other.cast_op && operands == other.operands; + } +}; + +struct ExprKeyHash { + size_t operator()(const ExprKey& key) const { + size_t h = std::hash()(static_cast(key.opcode)); + h ^= std::hash()(static_cast(key.cmp_op)) + 0x9e3779b9 + + (h << 6) + (h >> 2); + h ^= std::hash()(static_cast(key.cast_op)) + 0x9e3779b9 + + (h << 6) + (h >> 2); + for (auto* operand : key.operands) { + h ^= std::hash()(operand) + 0x9e3779b9 + (h << 6) + (h >> 2); + } + return h; + } +}; + +bool IsSupportedInvariantOpcode(Opcode op) { + switch (op) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::Cmp: + case Opcode::Cast: + case Opcode::Gep: + case Opcode::Load: + return true; + default: + return false; + } +} + +bool IsLoopInvariantValue(Value* value, analysis::Loop* loop, + const std::unordered_set& invariant) { + if (!value) return false; + if (dynamic_cast(value) != nullptr) return true; + if (dynamic_cast(value) != nullptr) return true; + if (dynamic_cast(value) != nullptr) return true; + if (dynamic_cast(value) != nullptr) return true; + if (dynamic_cast(value) != nullptr) return true; + + auto* inst = dynamic_cast(value); + if (!inst) return true; + auto* parent = inst->GetParent(); + if (!parent || !loop->Contains(parent)) return true; + return invariant.count(inst) != 0; +} + +Value* GetPointerBase(Value* ptr) { + while (auto* gep = dynamic_cast(ptr)) { + ptr = gep->GetBase(); + } + return ptr; +} + +bool MayAlias(Value* lhs, Value* rhs) { + if (lhs == rhs) return true; + return GetPointerBase(lhs) == GetPointerBase(rhs); +} + +bool IsStoredInLoop(Value* ptr, analysis::Loop* loop) { + for (auto* block : loop->GetBlocks()) { + for (const auto& inst_ptr : block->GetInstructions()) { + auto* store = dynamic_cast(inst_ptr.get()); + if (store && MayAlias(store->GetPtr(), ptr)) { + return true; + } + } + } + return false; +} + +bool IsSafeInvariantInstruction(Instruction* inst, analysis::Loop* loop, + const std::unordered_set& invariant) { + if (!inst || !IsSupportedInvariantOpcode(inst->GetOpcode())) return false; + if (inst->GetOpcode() == Opcode::Load) { + auto* load = static_cast(inst); + if (!IsLoopInvariantValue(load->GetPtr(), loop, invariant)) return false; + return !IsStoredInLoop(load->GetPtr(), loop); + } + + for (size_t i = 0; i < inst->GetNumOperands(); ++i) { + if (!IsLoopInvariantValue(inst->GetOperand(i), loop, invariant)) { + return false; + } + } + return true; +} + +ExprKey MakeExprKey(Instruction* inst) { + ExprKey key; + key.opcode = inst->GetOpcode(); + if (auto* cmp = dynamic_cast(inst)) { + key.cmp_op = cmp->GetCmpOp(); + } + if (auto* cast = dynamic_cast(inst)) { + key.cast_op = cast->GetCastOp(); + } + for (size_t i = 0; i < inst->GetNumOperands(); ++i) { + key.operands.push_back(inst->GetOperand(i)); + } + return key; +} + +std::vector CollectLoopInstructions(analysis::Loop* loop, + Function& func) { + std::vector ordered; + for (const auto& bb_ptr : func.GetBlocks()) { + auto* block = bb_ptr.get(); + if (!block || !loop->Contains(block)) continue; + for (const auto& inst_ptr : block->GetInstructions()) { + ordered.push_back(inst_ptr.get()); + } + } + return ordered; +} + +std::unique_ptr DetachInstruction(BasicBlock* block, + Instruction* inst) { + auto& insts = block->MutableInstructions(); + auto it = std::find_if(insts.begin(), insts.end(), + [inst](const std::unique_ptr& ptr) { + return ptr.get() == inst; + }); + if (it == insts.end()) return nullptr; + std::unique_ptr owned = std::move(*it); + insts.erase(it); + owned->SetParent(nullptr); + return owned; +} + +void InsertBeforeTerminator(BasicBlock* block, std::unique_ptr inst) { + auto& insts = block->MutableInstructions(); + auto insert_it = insts.end(); + if (block->HasTerminator()) { + insert_it = insts.end() - 1; + } + inst->SetParent(block); + insts.insert(insert_it, std::move(inst)); +} + +void SeedAvailableInvariants( + BasicBlock* preheader, analysis::Loop* loop, + std::unordered_map& available, + const std::unordered_set& invariant) { + if (!preheader) return; + for (const auto& inst_ptr : preheader->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (!inst || !IsSupportedInvariantOpcode(inst->GetOpcode())) continue; + if (!IsSafeInvariantInstruction(inst, loop, invariant)) continue; + available.emplace(MakeExprKey(inst), inst); + } +} + +bool RunLICMOnLoop(analysis::Loop* loop, Function& func) { + auto* preheader = loop->GetPreheader(); + if (!preheader) return false; + + bool changed = false; + std::unordered_set invariant; + + bool progress = true; + while (progress) { + progress = false; + std::unordered_map available; + SeedAvailableInvariants(preheader, loop, available, invariant); + + for (auto* inst : CollectLoopInstructions(loop, func)) { + if (!inst || invariant.count(inst) != 0) continue; + auto* block = inst->GetParent(); + if (!block || block == preheader) continue; + if (inst->GetOpcode() == Opcode::Phi || inst->IsTerminator() || + inst->GetOpcode() == Opcode::Alloca || inst->GetOpcode() == Opcode::Ret || + inst->GetOpcode() == Opcode::Store || inst->GetOpcode() == Opcode::Call || + inst->GetOpcode() == Opcode::Div || inst->GetOpcode() == Opcode::Mod) { + continue; + } + if (!IsSafeInvariantInstruction(inst, loop, invariant)) continue; + + ExprKey key = MakeExprKey(inst); + auto avail_it = available.find(key); + if (avail_it != available.end()) { + inst->ReplaceAllUsesWith(avail_it->second); + block->RemoveInstruction(inst); + } else { + auto owned = DetachInstruction(block, inst); + if (!owned) continue; + auto* moved = owned.get(); + InsertBeforeTerminator(preheader, std::move(owned)); + available.emplace(std::move(key), moved); + invariant.insert(moved); + } + + changed = true; + progress = true; + break; + } + } + + return changed; +} + +} // namespace + +bool RunLICM(Function& func) { + if (func.IsExternal()) return false; + + analysis::DominatorTree dom_tree(func); + analysis::LoopInfo loop_info(func, dom_tree); + + std::vector ordered_loops; + for (const auto& loop_ptr : loop_info.GetLoops()) { + ordered_loops.push_back(loop_ptr.get()); + } + + std::sort(ordered_loops.begin(), ordered_loops.end(), + [](analysis::Loop* lhs, analysis::Loop* rhs) { + if (lhs->GetDepth() != rhs->GetDepth()) { + return lhs->GetDepth() > rhs->GetDepth(); + } + if (lhs->GetBlocks().size() != rhs->GetBlocks().size()) { + return lhs->GetBlocks().size() < rhs->GetBlocks().size(); + } + return lhs->GetHeader()->GetName() < rhs->GetHeader()->GetName(); + }); + + bool changed = false; + for (auto* loop : ordered_loops) { + changed |= RunLICMOnLoop(loop, func); + } + return changed; +} + +} // namespace passes +} // namespace ir diff --git a/src/ir/passes/LoopFission.cpp b/src/ir/passes/LoopFission.cpp new file mode 100644 index 0000000..8833999 --- /dev/null +++ b/src/ir/passes/LoopFission.cpp @@ -0,0 +1,171 @@ +// 循环分裂: +// - 针对单块循环中两段彼此独立的 store 语句组做保守分裂 +// - 仅处理单归纳变量、无其他 loop-carried phi 的情形 + +#include "ir/IR.h" + +#include +#include +#include +#include + +#include "LoopPassUtils.h" + +namespace ir { +namespace passes { + +namespace { + +Value* StripPointerBase(Value* value) { + while (auto* gep = dynamic_cast(value)) { + value = gep->GetBase(); + } + return value; +} + +bool IsFissionCandidate(const CanonicalLoopMatch& match) { + if (match.loop->GetChildren().size() != 0) return false; + if (match.loop->GetBlocks().size() != 2) return false; + if (match.body != match.latch) return false; + if (match.header_phis.size() != 1) return false; + if (match.header_phis.front() != match.induction.phi) return false; + if (match.induction.step <= 0) return false; + auto* body_term = + dynamic_cast(match.body->MutableInstructions().back().get()); + return body_term && body_term->GetTarget() == match.header; +} + +bool DependsOnAny(Instruction* inst, const std::unordered_set& defs) { + for (size_t i = 0; i < inst->GetNumOperands(); ++i) { + auto* def = dynamic_cast(inst->GetOperand(i)); + if (def && defs.count(def) != 0) return true; + } + return false; +} + +bool RunFissionOnLoop(Function& func, const CanonicalLoopMatch& match, + Context& ctx) { + if (!IsFissionCandidate(match)) return false; + + std::vector body_insts; + for (const auto& inst_ptr : match.body->GetInstructions()) { + if (!inst_ptr.get()->IsTerminator()) { + body_insts.push_back(inst_ptr.get()); + } + } + if (body_insts.size() < 3) return false; + + auto* iv_next = dynamic_cast(match.induction.next); + if (!iv_next || iv_next->GetParent() != match.body) return false; + + std::vector store_positions; + for (size_t i = 0; i < body_insts.size(); ++i) { + if (dynamic_cast(body_insts[i]) != nullptr) { + store_positions.push_back(i); + } + } + if (store_positions.size() != 2) return false; + + const size_t first_store_idx = store_positions[0]; + const size_t second_store_idx = store_positions[1]; + if (body_insts.back() != iv_next) return false; + if (second_store_idx + 1 != body_insts.size() - 1) return false; + + auto* first_store = static_cast(body_insts[first_store_idx]); + auto* second_store = static_cast(body_insts[second_store_idx]); + if (StripPointerBase(first_store->GetPtr()) == StripPointerBase(second_store->GetPtr())) { + return false; + } + + std::vector group1(body_insts.begin(), + body_insts.begin() + first_store_idx + 1); + std::vector group2(body_insts.begin() + first_store_idx + 1, + body_insts.begin() + second_store_idx + 1); + + std::unordered_set group1_defs(group1.begin(), group1.end()); + std::unordered_set group2_defs(group2.begin(), group2.end()); + group1_defs.erase(iv_next); + group2_defs.erase(iv_next); + + for (auto* inst : group2) { + if (DependsOnAny(inst, group1_defs)) return false; + } + for (auto* inst : group1) { + if (DependsOnAny(inst, group2_defs)) return false; + } + + auto* original_exit = match.exit; + std::string block_suffix = ctx.NextTemp(); + if (!block_suffix.empty() && block_suffix.front() == '%') { + block_suffix.erase(0, 1); + } + auto* preheader2 = + func.CreateBlock(match.header->GetName() + ".fission.pre." + block_suffix); + auto* header2 = + func.CreateBlock(match.header->GetName() + ".fission.hdr." + block_suffix); + auto* body2 = + func.CreateBlock(match.body->GetName() + ".fission.body." + block_suffix); + + preheader2->Append(Type::GetVoidType(), header2); + + auto* iv2 = header2->PrependPhi(Type::GetInt32Type(), ctx.NextTemp()); + iv2->AddIncoming(match.induction.init, preheader2); + + auto* cmp2 = header2->Append( + match.header_cmp->GetCmpOp(), Type::GetInt32Type(), iv2, match.bound, + ctx.NextTemp()); + header2->Append(Type::GetVoidType(), cmp2, body2, original_exit); + + ValueMap remap; + remap.emplace(match.induction.phi, iv2); + for (auto* inst : group2) { + auto cloned = CloneInstruction(inst, remap, ".f2"); + if (!cloned) return false; + auto* raw = cloned.get(); + body2->MutableInstructions().push_back(std::move(cloned)); + raw->SetParent(body2); + remap[inst] = raw; + } + + auto next2_cloned = CloneInstruction(iv_next, remap, ".f2"); + if (!next2_cloned) return false; + auto* next2 = next2_cloned.get(); + body2->MutableInstructions().push_back(std::move(next2_cloned)); + next2->SetParent(body2); + body2->Append(Type::GetVoidType(), header2); + iv2->AddIncoming(next2, body2); + + const bool exit_is_true = (match.header_branch->GetTrueBlock() == original_exit); + match.header_branch->SetOperand(exit_is_true ? 1 : 2, preheader2); + match.header->RemoveSuccessor(original_exit); + match.header->AddSuccessor(preheader2); + preheader2->AddPredecessor(match.header); + original_exit->RemovePredecessor(match.header); + + for (auto* inst : group2) { + match.body->RemoveInstruction(inst); + } + + return true; +} + +} // namespace + +bool RunLoopFission(Function& func, Context& ctx) { + if (func.IsExternal()) return false; + + analysis::DominatorTree dom_tree(func); + analysis::LoopInfo loop_info(func, dom_tree); + + for (const auto& loop_ptr : loop_info.GetLoops()) { + auto match = MatchCanonicalLoop(loop_ptr.get()); + if (!match.has_value()) continue; + if (RunFissionOnLoop(func, *match, ctx)) { + return true; + } + } + return false; +} + +} // namespace passes +} // namespace ir diff --git a/src/ir/passes/LoopIdiom.cpp b/src/ir/passes/LoopIdiom.cpp new file mode 100644 index 0000000..dd89b6f --- /dev/null +++ b/src/ir/passes/LoopIdiom.cpp @@ -0,0 +1,465 @@ +// 循环习语优化: +// - 将连续常量填充的规范循环替换为运行时批量填充调用 +// - 当前仅处理 step=1、init=0、单 store 的 innermost 循环 + +#include "ir/IR.h" + +#include +#include +#include + +#include "LoopPassUtils.h" + +namespace ir { +namespace passes { + +namespace { + +struct FillLoopCandidate { + analysis::Loop* loop = nullptr; + BasicBlock* preheader = nullptr; + BasicBlock* header = nullptr; + BasicBlock* exit = nullptr; + PhiInst* induction = nullptr; + Value* bound = nullptr; + Value* base_ptr = nullptr; + Value* offset = nullptr; + int fill_value = 0; +}; + +struct GuardedRowFillCandidate { + analysis::Loop* loop = nullptr; + BasicBlock* preheader = nullptr; + BasicBlock* header = nullptr; + BasicBlock* body = nullptr; + BasicBlock* action = nullptr; + BasicBlock* latch = nullptr; + BasicBlock* exit = nullptr; + PhiInst* induction = nullptr; + Value* bound = nullptr; + PhiInst* linear = nullptr; + Value* linear_init = nullptr; + int linear_step = 0; + Value* base_ptr = nullptr; + Value* threshold = nullptr; + bool prefix = false; + int fill_value = 0; +}; + +bool ExprDependsOn(Value* value, Value* needle, + std::unordered_set& visiting) { + if (value == needle) return true; + auto* inst = dynamic_cast(value); + if (!inst) return false; + if (!visiting.insert(value).second) return false; + for (size_t i = 0; i < inst->GetNumOperands(); ++i) { + if (ExprDependsOn(inst->GetOperand(i), needle, visiting)) { + return true; + } + } + return false; +} + +bool ExprDependsOn(Value* value, Value* needle) { + std::unordered_set visiting; + return ExprDependsOn(value, needle, visiting); +} + +Value* GetIncomingForBlock(PhiInst* phi, BasicBlock* block) { + if (!phi || !block) return nullptr; + for (size_t i = 0; i < phi->GetNumIncoming(); ++i) { + if (phi->GetIncomingBlock(i) == block) { + return phi->GetIncomingValue(i); + } + } + return nullptr; +} + +Value* MaterializeInvariantExpr(Value* value, analysis::Loop* loop, IRBuilder& builder, + ValueMap& remap) { + auto it = remap.find(value); + if (it != remap.end()) return it->second; + if (dynamic_cast(value) || dynamic_cast(value) || + dynamic_cast(value) || dynamic_cast(value)) { + return value; + } + auto* inst = dynamic_cast(value); + if (!inst || !loop->Contains(inst->GetParent())) return value; + for (size_t i = 0; i < inst->GetNumOperands(); ++i) { + auto* operand = inst->GetOperand(i); + remap[operand] = MaterializeInvariantExpr(operand, loop, builder, remap); + } + auto cloned = CloneInstruction(inst, remap, ".idiom"); + if (!cloned) return nullptr; + auto* raw = cloned.get(); + InsertInstruction(builder.GetInsertBlock(), std::move(cloned)); + remap[inst] = raw; + return raw; +} + +bool HasOutsideUse(Instruction* inst, analysis::Loop* loop) { + for (const auto& use : inst->GetUses()) { + auto* user = dynamic_cast(use.GetUser()); + if (!user) return true; + if (!user->GetParent() || !loop->Contains(user->GetParent())) { + return true; + } + } + return false; +} + +Value* MatchContiguousOffset(Value* index, PhiInst* iv, analysis::Loop* loop) { + if (index == iv) return nullptr; + auto* bin = dynamic_cast(index); + if (!bin || bin->GetOpcode() != Opcode::Add || + !bin->GetType() || !bin->GetType()->IsInt32()) { + return nullptr; + } + if (bin->GetLhs() == iv && IsLoopInvariantValue(bin->GetRhs(), loop)) { + return bin->GetRhs(); + } + if (bin->GetRhs() == iv && IsLoopInvariantValue(bin->GetLhs(), loop)) { + return bin->GetLhs(); + } + return nullptr; +} + +bool BuildFillLoopCandidate(Function& func, analysis::Loop* loop, + FillLoopCandidate* out) { + (void)func; + auto match = MatchCanonicalLoop(loop); + if (!match.has_value()) return false; + if (match->loop->GetChildren().size() != 0) return false; + if (match->body != match->latch || loop->GetBlocks().size() != 2) return false; + if (match->header_phis.size() != 1 || + match->header_phis.front() != match->induction.phi) { + return false; + } + if (match->induction.step != 1) return false; + if (match->header_cmp->GetCmpOp() != CmpOp::Lt) return false; + auto* init_ci = dynamic_cast(match->induction.init); + if (!init_ci || init_ci->GetValue() != 0) return false; + if (!match->exit->GetInstructions().empty() && + dynamic_cast(match->exit->GetInstructions().front().get()) != nullptr) { + return false; + } + + StoreInst* store = nullptr; + std::vector body_insts; + for (const auto& inst_ptr : match->body->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (dynamic_cast(inst) != nullptr || inst->IsTerminator()) continue; + body_insts.push_back(inst); + switch (inst->GetOpcode()) { + case Opcode::Add: + case Opcode::Gep: + case Opcode::Store: + break; + default: + return false; + } + if (inst != match->induction.next && HasOutsideUse(inst, loop)) { + return false; + } + if (auto* maybe_store = dynamic_cast(inst)) { + if (store) return false; + store = maybe_store; + } + } + if (!store) return false; + + auto* fill_ci = dynamic_cast(store->GetValue()); + if (!fill_ci) return false; + auto* gep = dynamic_cast(store->GetPtr()); + if (!gep || !gep->GetBase() || !gep->GetBase()->GetType() || + !gep->GetBase()->GetType()->IsPtrInt32()) { + return false; + } + Value* offset = MatchContiguousOffset(gep->GetIndex(), match->induction.phi, loop); + if (gep->GetIndex() != match->induction.phi && offset == nullptr) { + return false; + } + + out->loop = loop; + out->preheader = match->preheader; + out->header = match->header; + out->exit = match->exit; + out->induction = match->induction.phi; + out->bound = match->bound; + out->base_ptr = gep->GetBase(); + out->offset = offset; + out->fill_value = fill_ci->GetValue(); + return true; +} + +Function* GetOrCreateFillI32(Module& module) { + if (auto* fn = module.FindFunction("__fill_i32")) return fn; + auto* fn = module.CreateFunction("__fill_i32", Type::GetVoidType(), + {Type::GetPtrInt32Type(), Type::GetInt32Type(), + Type::GetInt32Type()}); + fn->SetExternal(true); + return fn; +} + +Function* GetOrCreateFillRowsI32(Module& module) { + if (auto* fn = module.FindFunction("__fill_rows_i32")) return fn; + auto* fn = module.CreateFunction( + "__fill_rows_i32", Type::GetVoidType(), + {Type::GetPtrInt32Type(), Type::GetInt32Type(), Type::GetInt32Type(), + Type::GetInt32Type(), Type::GetInt32Type(), Type::GetInt32Type()}); + fn->SetExternal(true); + return fn; +} + +bool BuildGuardedRowFillCandidate(Function& func, analysis::Loop* loop, + GuardedRowFillCandidate* out) { + (void)func; + if (!loop) return false; + if (loop->GetChildren().size() != 0) return false; + if (loop->GetBlocks().size() != 4) return false; + + auto* header = loop->GetHeader(); + auto* preheader = loop->GetPreheader(); + if (!header || !preheader) return false; + if (loop->GetLatches().size() != 1) return false; + auto* latch = loop->GetLatches().front(); + if (!latch) return false; + + auto* header_term = header->HasTerminator() + ? dynamic_cast( + header->MutableInstructions().back().get()) + : nullptr; + if (!header_term) return false; + auto* header_cmp = dynamic_cast(header_term->GetCond()); + if (!header_cmp || header_cmp->GetCmpOp() != CmpOp::Lt) return false; + + auto induction = MatchCanonicalInduction(header, preheader, latch); + if (!induction.has_value() || induction->step != 1) return false; + Value* bound = nullptr; + BasicBlock* body = nullptr; + BasicBlock* exit = nullptr; + if (loop->Contains(header_term->GetTrueBlock()) && + !loop->Contains(header_term->GetFalseBlock())) { + body = header_term->GetTrueBlock(); + exit = header_term->GetFalseBlock(); + } else if (loop->Contains(header_term->GetFalseBlock()) && + !loop->Contains(header_term->GetTrueBlock())) { + body = header_term->GetFalseBlock(); + exit = header_term->GetTrueBlock(); + } else { + return false; + } + if (header_cmp->GetLhs() == induction->phi && + IsLoopInvariantValue(header_cmp->GetRhs(), loop)) { + bound = header_cmp->GetRhs(); + } else if (header_cmp->GetRhs() == induction->phi && + IsLoopInvariantValue(header_cmp->GetLhs(), loop)) { + bound = header_cmp->GetLhs(); + } else { + return false; + } + + auto header_phis = CollectHeaderPhis(header); + if (header_phis.size() != 2) return false; + PhiInst* linear_phi = nullptr; + for (auto* phi : header_phis) { + if (phi != induction->phi) { + linear_phi = phi; + break; + } + } + if (!linear_phi || !linear_phi->GetType() || !linear_phi->GetType()->IsInt32()) { + return false; + } + auto* linear_init = GetIncomingForBlock(linear_phi, preheader); + auto* linear_next = GetIncomingForBlock(linear_phi, latch); + auto* linear_next_bin = dynamic_cast(linear_next); + if (!linear_init || !linear_next_bin || + linear_next_bin->GetOpcode() != Opcode::Add || + linear_next_bin->GetLhs() != linear_phi) { + return false; + } + auto* linear_step_ci = dynamic_cast(linear_next_bin->GetRhs()); + if (!linear_step_ci || linear_step_ci->GetValue() <= 0) return false; + + auto* guard = body->HasTerminator() + ? dynamic_cast(body->MutableInstructions().back().get()) + : nullptr; + if (!guard) return false; + BasicBlock* action = nullptr; + if (guard->GetTrueBlock() == latch && loop->Contains(guard->GetFalseBlock())) { + action = guard->GetFalseBlock(); + } else if (guard->GetFalseBlock() == latch && + loop->Contains(guard->GetTrueBlock())) { + action = guard->GetTrueBlock(); + } else { + return false; + } + auto* action_term = + dynamic_cast(action->MutableInstructions().back().get()); + if (!action_term || action_term->GetTarget() != latch) return false; + + CallInst* fill_call = nullptr; + GepInst* fill_gep = nullptr; + for (const auto& inst_ptr : action->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (inst->IsTerminator()) continue; + if (auto* gep = dynamic_cast(inst)) { + fill_gep = gep; + continue; + } + fill_call = dynamic_cast(inst); + } + if (!fill_call || !fill_gep || fill_call->GetNumArgs() != 3) return false; + auto* callee = fill_call->GetCallee(); + if (!callee || callee->GetName() != "__fill_i32") return false; + auto* fill_value = dynamic_cast(fill_call->GetArg(2)); + if (!fill_value) return false; + if (fill_call->GetArg(0) != fill_gep || fill_call->GetArg(1) != bound) { + return false; + } + if (fill_gep->GetIndex() != linear_phi) return false; + if (!fill_gep->GetBase() || !fill_gep->GetBase()->GetType() || + !fill_gep->GetBase()->GetType()->IsPtrInt32()) { + return false; + } + + auto* guard_cmp = dynamic_cast(guard->GetCond()); + if (!guard_cmp || !guard_cmp->GetType() || !guard_cmp->GetType()->IsInt32()) { + return false; + } + Value* threshold = nullptr; + bool prefix = false; + bool suffix = false; + if (guard_cmp->GetLhs() == induction->phi && + !ExprDependsOn(guard_cmp->GetRhs(), induction->phi) && + !ExprDependsOn(guard_cmp->GetRhs(), linear_phi)) { + threshold = guard_cmp->GetRhs(); + if (guard_cmp->GetCmpOp() == CmpOp::Lt && action == guard->GetTrueBlock()) { + prefix = true; + } else if (guard_cmp->GetCmpOp() == CmpOp::Ge && + action == guard->GetTrueBlock()) { + suffix = true; + } else if (guard_cmp->GetCmpOp() == CmpOp::Lt && + action == guard->GetFalseBlock()) { + suffix = true; + } else if (guard_cmp->GetCmpOp() == CmpOp::Ge && + action == guard->GetFalseBlock()) { + prefix = true; + } else { + return false; + } + } else { + return false; + } + + out->loop = loop; + out->preheader = preheader; + out->header = header; + out->body = body; + out->action = action; + out->latch = latch; + out->exit = exit; + out->induction = induction->phi; + out->bound = bound; + out->linear = linear_phi; + out->linear_init = linear_init; + out->linear_step = linear_step_ci->GetValue(); + out->base_ptr = fill_gep->GetBase(); + out->fill_value = fill_value->GetValue(); + out->threshold = threshold; + out->prefix = prefix; + return prefix || suffix; +} + +bool RunFillLoop(Function& func, const FillLoopCandidate& cand, + Module& module, Context& ctx) { + (void)func; + auto* fill_fn = GetOrCreateFillI32(module); + auto* preheader = cand.preheader; + if (preheader->HasTerminator()) { + preheader->RemoveInstruction(preheader->MutableInstructions().back().get()); + } + + IRBuilder builder(ctx, preheader); + Value* start_ptr = cand.base_ptr; + if (cand.offset) { + start_ptr = builder.CreateGep(cand.base_ptr, cand.offset, ctx.NextTemp()); + } + builder.CreateCall(fill_fn, {start_ptr, cand.bound, ctx.GetConstInt(cand.fill_value)}, + ""); + preheader->Append(Type::GetVoidType(), cand.exit); + + cand.induction->RemoveIncomingBlock(preheader); + preheader->RemoveSuccessor(cand.header); + cand.header->RemovePredecessor(preheader); + preheader->AddSuccessor(cand.exit); + cand.exit->AddPredecessor(preheader); + return true; +} + +bool RunGuardedRowFillLoop(Function& func, const GuardedRowFillCandidate& cand, + Module& module, Context& ctx) { + (void)func; + auto* fill_rows_fn = GetOrCreateFillRowsI32(module); + auto* preheader = cand.preheader; + if (preheader->HasTerminator()) { + preheader->RemoveInstruction(preheader->MutableInstructions().back().get()); + } + + IRBuilder builder(ctx, preheader); + ValueMap remap; + auto* threshold = + MaterializeInvariantExpr(cand.threshold, cand.loop, builder, remap); + if (!threshold) return false; + Value* start_index = cand.prefix ? ctx.GetConstInt(0) : threshold; + Value* rows = cand.prefix ? threshold : nullptr; + if (!cand.prefix) { + rows = builder.CreateSub(cand.bound, start_index, ctx.NextTemp()); + } + auto* start_offset_mul = + builder.CreateMul(start_index, ctx.GetConstInt(cand.linear_step), ctx.NextTemp()); + auto* start_offset = + builder.CreateAdd(cand.linear_init, start_offset_mul, ctx.NextTemp()); + builder.CreateCall(fill_rows_fn, + {cand.base_ptr, start_offset, rows, + ctx.GetConstInt(cand.linear_step), cand.bound, + ctx.GetConstInt(cand.fill_value)}, + ""); + preheader->Append(Type::GetVoidType(), cand.exit); + + cand.induction->RemoveIncomingBlock(preheader); + cand.linear->RemoveIncomingBlock(preheader); + preheader->RemoveSuccessor(cand.header); + cand.header->RemovePredecessor(preheader); + preheader->AddSuccessor(cand.exit); + cand.exit->AddPredecessor(preheader); + return true; +} + +} // namespace + +bool RunLoopIdiom(Function& func, Module& module, Context& ctx) { + if (func.IsExternal()) return false; + + analysis::DominatorTree dom_tree(func); + analysis::LoopInfo loop_info(func, dom_tree); + + for (const auto& loop_ptr : loop_info.GetLoops()) { + GuardedRowFillCandidate row_fill; + if (BuildGuardedRowFillCandidate(func, loop_ptr.get(), &row_fill)) { + if (RunGuardedRowFillLoop(func, row_fill, module, ctx)) { + return true; + } + } + FillLoopCandidate cand; + if (!BuildFillLoopCandidate(func, loop_ptr.get(), &cand)) continue; + if (RunFillLoop(func, cand, module, ctx)) { + return true; + } + } + return false; +} + +} // namespace passes +} // namespace ir diff --git a/src/ir/passes/LoopParallelize.cpp b/src/ir/passes/LoopParallelize.cpp new file mode 100644 index 0000000..ab45a97 --- /dev/null +++ b/src/ir/passes/LoopParallelize.cpp @@ -0,0 +1,845 @@ +// 循环并行化: +// - 将一部分安全的规范循环抽取成 worker 函数 +// - 通过运行时 __par_runN 启动固定线程数并行执行 +// +// 当前限制: +// - 仅并行化不存在 SSA live-out 的循环 +// - 循环访问对象必须是全局数组/全局变量 +// - 不支持循环中的普通函数调用 + +#include "ir/IR.h" + +#include +#include +#include +#include +#include +#include + +#include "LoopPassUtils.h" + +namespace ir { +namespace passes { + +bool RunSimpleDCE(Function& func); +bool RunCFGSimplify(Function& func, Context& ctx); + +namespace { + +constexpr int kParallelLoopSlots = 8; +constexpr int kParallelThreads = 4; + +enum class ParallelLoopKind { + Pointwise, + ReductionAddI32, + GuardedFillI32, +}; + +struct LoopContextValue { + Value* original = nullptr; + GlobalVariable* slot = nullptr; +}; + +bool ExprDependsOn(Value* value, Value* needle, + std::unordered_set& visiting) { + if (value == needle) return true; + auto* inst = dynamic_cast(value); + if (!inst) return false; + if (!visiting.insert(value).second) return false; + for (size_t i = 0; i < inst->GetNumOperands(); ++i) { + if (ExprDependsOn(inst->GetOperand(i), needle, visiting)) { + return true; + } + } + return false; +} + +bool ExprDependsOn(Value* value, Value* needle) { + std::unordered_set visiting; + return ExprDependsOn(value, needle, visiting); +} + +Value* StripPointerBase(Value* value) { + while (auto* gep = dynamic_cast(value)) { + value = gep->GetBase(); + } + return value; +} + +bool IsSupportedParallelInst(Instruction* inst) { + switch (inst->GetOpcode()) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::Div: + case Opcode::Mod: + case Opcode::Cmp: + case Opcode::Cast: + case Opcode::Load: + case Opcode::Store: + case Opcode::Br: + case Opcode::CondBr: + case Opcode::Gep: + case Opcode::Phi: + case Opcode::Ret: + return true; + case Opcode::Call: + case Opcode::Alloca: + return false; + } + return false; +} + +bool IsScalarContextCandidate(Value* value) { + auto* arg = dynamic_cast(value); + if (arg && (arg->GetType()->IsInt32() || arg->GetType()->IsFloat32())) { + return true; + } + auto* inst = dynamic_cast(value); + if (inst && inst->GetType() && + (inst->GetType()->IsInt32() || inst->GetType()->IsFloat32())) { + return true; + } + return false; +} + +bool HasOutsideUse(Instruction* inst, analysis::Loop* loop) { + for (const auto& use : inst->GetUses()) { + auto* user = dynamic_cast(use.GetUser()); + if (!user) return true; + if (!user->GetParent() || !loop->Contains(user->GetParent())) { + return true; + } + } + return false; +} + +void ReplaceUsesOutsideLoop(Value* value, Value* replacement, + analysis::Loop* loop) { + if (!value || !replacement || !loop) return; + auto uses = value->GetUses(); + for (const auto& use : uses) { + auto* user = dynamic_cast(use.GetUser()); + if (!user) continue; + auto* parent = user->GetParent(); + if (parent && loop->Contains(parent)) continue; + user->SetOperand(use.GetOperandIndex(), replacement); + } +} + +Value* GetIncomingForBlock(PhiInst* phi, BasicBlock* block) { + if (!phi || !block) return nullptr; + for (size_t i = 0; i < phi->GetNumIncoming(); ++i) { + if (phi->GetIncomingBlock(i) == block) { + return phi->GetIncomingValue(i); + } + } + return nullptr; +} + +struct ParallelLoopCandidate { + Function* parent = nullptr; + analysis::Loop* loop = nullptr; + ParallelLoopKind kind = ParallelLoopKind::Pointwise; + BasicBlock* header = nullptr; + BasicBlock* body = nullptr; + BasicBlock* guard = nullptr; + BasicBlock* action = nullptr; + BasicBlock* preheader = nullptr; + BasicBlock* exit = nullptr; + BasicBlock* latch = nullptr; + CmpInst* header_cmp = nullptr; + PhiInst* induction = nullptr; + Value* induction_next = nullptr; + Value* bound = nullptr; + PhiInst* linear = nullptr; + Value* linear_init = nullptr; + Value* linear_next = nullptr; + int linear_step = 0; + PhiInst* reduction = nullptr; + Value* reduction_init = nullptr; + Value* reduction_next = nullptr; + bool has_loads = false; + std::vector contexts; +}; + +bool BuildGuardedFillCandidate(Function& func, analysis::Loop* loop, + ParallelLoopCandidate* out) { + if (!loop) return false; + auto* header = loop->GetHeader(); + auto* preheader = loop->GetPreheader(); + if (!header || !preheader) return false; + if (loop->GetChildren().size() != 0) return false; + if (loop->GetBlocks().size() != 4) return false; + if (loop->GetLatches().size() != 1) return false; + auto* latch = loop->GetLatches().front(); + if (!latch) return false; + + auto* header_term = header->HasTerminator() + ? dynamic_cast( + header->MutableInstructions().back().get()) + : nullptr; + if (!header_term) return false; + auto* header_cmp = dynamic_cast(header_term->GetCond()); + if (!header_cmp || header_cmp->GetCmpOp() != CmpOp::Lt) return false; + + BasicBlock* body = nullptr; + BasicBlock* exit = nullptr; + if (loop->Contains(header_term->GetTrueBlock()) && + !loop->Contains(header_term->GetFalseBlock())) { + body = header_term->GetTrueBlock(); + exit = header_term->GetFalseBlock(); + } else if (loop->Contains(header_term->GetFalseBlock()) && + !loop->Contains(header_term->GetTrueBlock())) { + body = header_term->GetFalseBlock(); + exit = header_term->GetTrueBlock(); + } else { + return false; + } + + auto induction = MatchCanonicalInduction(header, preheader, latch); + if (!induction.has_value() || induction->step != 1) return false; + Value* bound = nullptr; + if (header_cmp->GetLhs() == induction->phi && + IsLoopInvariantValue(header_cmp->GetRhs(), loop)) { + bound = header_cmp->GetRhs(); + } else if (header_cmp->GetRhs() == induction->phi && + IsLoopInvariantValue(header_cmp->GetLhs(), loop)) { + bound = header_cmp->GetLhs(); + } else { + return false; + } + + auto header_phis = CollectHeaderPhis(header); + if (header_phis.size() != 2) return false; + + PhiInst* linear_phi = nullptr; + for (auto* phi : header_phis) { + if (phi != induction->phi) { + linear_phi = phi; + break; + } + } + if (!linear_phi || !linear_phi->GetType() || !linear_phi->GetType()->IsInt32()) { + return false; + } + auto* linear_init = GetIncomingForBlock(linear_phi, preheader); + auto* linear_next = GetIncomingForBlock(linear_phi, latch); + auto* linear_next_bin = dynamic_cast(linear_next); + if (!linear_init || !linear_next_bin || + linear_next_bin->GetOpcode() != Opcode::Add || + linear_next_bin->GetLhs() != linear_phi) { + return false; + } + auto* linear_step_ci = dynamic_cast(linear_next_bin->GetRhs()); + if (!linear_step_ci || linear_step_ci->GetValue() <= 0) return false; + + auto* guard = dynamic_cast(body->MutableInstructions().back().get()); + if (!guard) return false; + auto* true_bb = guard->GetTrueBlock(); + auto* false_bb = guard->GetFalseBlock(); + if (!loop->Contains(true_bb) && !loop->Contains(false_bb)) return false; + if (loop->Contains(true_bb) && loop->Contains(false_bb)) { + if (true_bb == latch || false_bb == latch) { + // fine + } else { + return false; + } + } + BasicBlock* action = nullptr; + if (true_bb == latch && false_bb != latch && loop->Contains(false_bb)) { + action = false_bb; + } else if (false_bb == latch && true_bb != latch && + loop->Contains(true_bb)) { + action = true_bb; + } else { + return false; + } + if (action == header || action == body) return false; + auto* action_term = + dynamic_cast(action->MutableInstructions().back().get()); + if (!action_term || action_term->GetTarget() != latch) return false; + + CallInst* fill_call = nullptr; + for (const auto& inst_ptr : action->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (inst->IsTerminator()) continue; + if (auto* gep = dynamic_cast(inst)) { + if (gep->GetBase() == nullptr || gep->GetIndex() == nullptr) return false; + continue; + } + fill_call = dynamic_cast(inst); + if (!fill_call) return false; + } + if (!fill_call || fill_call->GetNumArgs() != 3) return false; + auto* callee = fill_call->GetCallee(); + if (!callee || callee->GetName() != "__fill_i32") return false; + auto* fill_ptr = dynamic_cast(fill_call->GetArg(0)); + auto* fill_count = fill_call->GetArg(1); + auto* fill_value = dynamic_cast(fill_call->GetArg(2)); + if (!fill_ptr || !fill_value) return false; + if (fill_ptr->GetBase() == nullptr || !fill_ptr->GetBase()->GetType() || + !fill_ptr->GetBase()->GetType()->IsPtrInt32()) { + return false; + } + if (fill_ptr->GetIndex() != linear_phi) return false; + if (fill_count != bound) return false; + + std::vector context_values; + std::unordered_set seen_contexts; + auto collect_contexts = [&](BasicBlock* block) -> bool { + for (const auto& inst_ptr : block->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (dynamic_cast(inst) != nullptr || inst->IsTerminator()) continue; + for (size_t i = 0; i < inst->GetNumOperands(); ++i) { + Value* operand = inst->GetOperand(i); + if (dynamic_cast(operand) || dynamic_cast(operand) || + dynamic_cast(operand) || dynamic_cast(operand)) { + continue; + } + auto* operand_inst = dynamic_cast(operand); + if ((operand_inst && loop->Contains(operand_inst->GetParent())) || + operand == induction->phi || operand == induction->next || + operand == linear_phi || operand == linear_next) { + continue; + } + if (!IsScalarContextCandidate(operand)) return false; + if (seen_contexts.insert(operand).second) { + context_values.push_back(operand); + } + } + if (inst != linear_next && inst != induction->next && + HasOutsideUse(inst, loop)) { + return false; + } + } + return true; + }; + if (!collect_contexts(body) || !collect_contexts(action) || + !collect_contexts(latch)) { + return false; + } + if (context_values.size() > 6) return false; + + out->parent = &func; + out->loop = loop; + out->kind = ParallelLoopKind::GuardedFillI32; + out->header = header; + out->body = body; + out->guard = body; + out->action = action; + out->preheader = preheader; + out->exit = exit; + out->latch = latch; + out->header_cmp = header_cmp; + out->induction = induction->phi; + out->induction_next = induction->next; + out->bound = bound; + out->linear = linear_phi; + out->linear_init = linear_init; + out->linear_next = linear_next; + out->linear_step = linear_step_ci->GetValue(); + out->has_loads = true; + for (Value* value : context_values) { + out->contexts.push_back({value, nullptr}); + } + return true; +} + +bool BuildParallelCandidate(Function& func, analysis::Loop* loop, + ParallelLoopCandidate* out) { + if (BuildGuardedFillCandidate(func, loop, out)) return true; + + auto match = MatchCanonicalLoop(loop); + if (!match.has_value()) return false; + + if (match->body != match->latch || loop->GetBlocks().size() != 2) return false; + if (match->exit->GetInstructions().size() > 0 && + dynamic_cast(match->exit->GetInstructions().front().get()) != nullptr) { + return false; + } + + PhiInst* reduction_phi = nullptr; + Value* reduction_init = nullptr; + Value* reduction_next = nullptr; + ParallelLoopKind kind = ParallelLoopKind::Pointwise; + if (match->header_phis.size() == 1 && + match->header_phis.front() == match->induction.phi) { + kind = ParallelLoopKind::Pointwise; + } else if (false && match->header_phis.size() == 2) { + for (auto* phi : match->header_phis) { + if (phi != match->induction.phi) { + reduction_phi = phi; + break; + } + } + if (!reduction_phi || !reduction_phi->GetType() || + !reduction_phi->GetType()->IsInt32()) { + return false; + } + reduction_init = GetIncomingForBlock(reduction_phi, match->preheader); + reduction_next = GetIncomingForBlock(reduction_phi, match->latch); + if (!reduction_init || !reduction_next) return false; + auto* reduction_next_inst = dynamic_cast(reduction_next); + if (!reduction_next_inst || reduction_next_inst->GetParent() != match->body) { + return false; + } + auto* init_ci = dynamic_cast(reduction_init); + if (!init_ci || init_ci->GetValue() != 0) return false; + auto* red_next_bin = dynamic_cast(reduction_next); + if (!red_next_bin || red_next_bin->GetOpcode() != Opcode::Add || + !red_next_bin->GetType() || !red_next_bin->GetType()->IsInt32()) { + return false; + } + Value* other = nullptr; + if (red_next_bin->GetLhs() == reduction_phi) { + other = red_next_bin->GetRhs(); + } else if (red_next_bin->GetRhs() == reduction_phi) { + other = red_next_bin->GetLhs(); + } else { + return false; + } + if (ExprDependsOn(other, reduction_phi)) return false; + kind = ParallelLoopKind::ReductionAddI32; + } else { + return false; + } + + std::unordered_set store_bases; + std::unordered_set load_bases; + std::vector context_values; + std::unordered_set seen_contexts; + + for (const auto& bb_ptr : func.GetBlocks()) { + auto* block = bb_ptr.get(); + if (!loop->Contains(block)) continue; + for (const auto& inst_ptr : block->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (!IsSupportedParallelInst(inst)) return false; + if (inst->GetOpcode() == Opcode::Call) return false; + if (kind == ParallelLoopKind::ReductionAddI32 && + inst->GetOpcode() == Opcode::Store) { + return false; + } + if (inst->GetOpcode() != Opcode::Store && inst->GetOpcode() != Opcode::Br && + inst->GetOpcode() != Opcode::CondBr && inst->GetOpcode() != Opcode::Phi && + inst != reduction_phi && + HasOutsideUse(inst, loop)) { + return false; + } + + for (size_t i = 0; i < inst->GetNumOperands(); ++i) { + Value* operand = inst->GetOperand(i); + if (dynamic_cast(operand) || dynamic_cast(operand) || + dynamic_cast(operand) || dynamic_cast(operand)) { + continue; + } + auto* operand_inst = dynamic_cast(operand); + if ((operand_inst && loop->Contains(operand_inst->GetParent())) || + operand == match->induction.phi || + operand == match->induction.next || + operand == reduction_phi || + operand == reduction_next) { + continue; + } + if (!IsScalarContextCandidate(operand)) return false; + if (seen_contexts.insert(operand).second) { + context_values.push_back(operand); + } + } + + if (auto* load = dynamic_cast(inst)) { + Value* base = StripPointerBase(load->GetPtr()); + if (dynamic_cast(base) == nullptr) return false; + load_bases.insert(base); + if (auto* gep = dynamic_cast(load->GetPtr())) { + if (base == StripPointerBase(load->GetPtr()) && + !ExprDependsOn(gep->GetIndex(), match->induction.phi)) { + // allowed for pure reads; dependence checked later with stores + } + } + } else if (auto* store = dynamic_cast(inst)) { + Value* base = StripPointerBase(store->GetPtr()); + auto* gv = dynamic_cast(base); + if (!gv || gv->GetCount() <= 1) return false; + store_bases.insert(base); + auto* gep = dynamic_cast(store->GetPtr()); + if (!gep || !ExprDependsOn(gep->GetIndex(), match->induction.phi)) return false; + } + } + } + + for (Value* base : store_bases) { + if (load_bases.count(base) == 0) continue; + for (const auto& bb_ptr : func.GetBlocks()) { + auto* block = bb_ptr.get(); + if (!loop->Contains(block)) continue; + for (const auto& inst_ptr : block->GetInstructions()) { + auto* load = dynamic_cast(inst_ptr.get()); + if (!load) continue; + if (StripPointerBase(load->GetPtr()) != base) continue; + auto* gep = dynamic_cast(load->GetPtr()); + if (!gep || !ExprDependsOn(gep->GetIndex(), match->induction.phi)) return false; + } + } + } + + if (context_values.size() > 6) return false; + + out->parent = &func; + out->loop = loop; + out->kind = kind; + out->header = match->header; + out->body = match->body; + out->preheader = match->preheader; + out->exit = match->exit; + out->latch = match->latch; + out->header_cmp = match->header_cmp; + out->induction = match->induction.phi; + out->induction_next = match->induction.next; + out->bound = match->bound; + out->reduction = reduction_phi; + out->reduction_init = reduction_init; + out->reduction_next = reduction_next; + out->has_loads = !load_bases.empty(); + for (Value* value : context_values) { + out->contexts.push_back({value, nullptr}); + } + return true; +} + +bool IsGeneratedParallelWorker(const Function& func) { + return func.GetName().rfind("__par_worker", 0) == 0; +} + +Function* GetOrCreateRuntimeLauncher(Module& module, int slot) { + const std::string name = "__par_run" + std::to_string(slot); + if (auto* fn = module.FindFunction(name)) return fn; + auto* fn = module.CreateFunction(name, Type::GetVoidType(), {}); + fn->SetExternal(true); + return fn; +} + +std::string NextWorkerName(int slot) { + return "__par_worker" + std::to_string(slot); +} + +void CloneWorkerBlocks(const ParallelLoopCandidate& cand, Function* worker, + GlobalVariable* bound_slot, + const std::vector& ctx_slots, + GlobalVariable* reduction_slot, Context& ctx) { + if (cand.kind == ParallelLoopKind::GuardedFillI32) { + auto* entry = worker->GetEntry(); + auto* tid = worker->GetArgument(0); + auto* header = worker->CreateBlock(cand.header->GetName()); + auto* guard = worker->CreateBlock(cand.guard->GetName()); + auto* action = worker->CreateBlock(cand.action->GetName()); + auto* latch = worker->CreateBlock(cand.latch->GetName()); + auto* worker_exit = worker->CreateBlock("par.exit"); + + IRBuilder builder(ctx, entry); + auto* bound_val = builder.CreateLoad(bound_slot, ctx.NextTemp()); + Value* threads_val = ctx.GetConstInt(kParallelThreads); + auto* start_mul = builder.CreateMul(tid, bound_val, ctx.NextTemp()); + auto* start = builder.CreateDiv(start_mul, threads_val, ctx.NextTemp()); + auto* next_tid = builder.CreateAdd(tid, ctx.GetConstInt(1), ctx.NextTemp()); + auto* end_mul = builder.CreateMul(next_tid, bound_val, ctx.NextTemp()); + auto* end = builder.CreateDiv(end_mul, threads_val, ctx.NextTemp()); + + ValueMap remap; + remap[cand.bound] = bound_val; + for (const auto& ctx_value : ctx_slots) { + builder.SetInsertPoint(entry); + auto* loaded = builder.CreateLoad(ctx_value.slot, ctx.NextTemp()); + remap[ctx_value.original] = loaded; + } + builder.SetInsertPoint(entry); + auto* start_linear_mul = + builder.CreateMul(start, ctx.GetConstInt(cand.linear_step), ctx.NextTemp()); + Value* linear_init = cand.linear_init; + auto it = remap.find(cand.linear_init); + if (it != remap.end()) linear_init = it->second; + auto* start_linear = builder.CreateAdd(linear_init, start_linear_mul, ctx.NextTemp()); + builder.CreateBr(header); + + auto* new_iv = header->PrependPhi(cand.induction->GetType(), ctx.NextTemp()); + auto* new_linear = header->PrependPhi(cand.linear->GetType(), ctx.NextTemp()); + remap[cand.induction] = new_iv; + remap[cand.linear] = new_linear; + + if (auto cloned_cmp = CloneInstruction(cand.header_cmp, remap, ".par")) { + auto* raw_cmp = static_cast(cloned_cmp.get()); + header->MutableInstructions().push_back(std::move(cloned_cmp)); + raw_cmp->SetParent(header); + remap[cand.header_cmp] = raw_cmp; + if (cand.header_cmp->GetLhs() == cand.induction && + cand.header_cmp->GetRhs() == cand.bound) { + raw_cmp->SetOperand(0, new_iv); + raw_cmp->SetOperand(1, end); + } else if (cand.header_cmp->GetRhs() == cand.induction && + cand.header_cmp->GetLhs() == cand.bound) { + raw_cmp->SetOperand(0, end); + raw_cmp->SetOperand(1, new_iv); + } + header->Append(Type::GetVoidType(), raw_cmp, guard, worker_exit); + } + + for (const auto& inst_ptr : cand.guard->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (dynamic_cast(inst) != nullptr || inst->IsTerminator()) continue; + if (auto cloned = CloneInstruction(inst, remap, ".par")) { + auto* raw = cloned.get(); + guard->MutableInstructions().push_back(std::move(cloned)); + raw->SetParent(guard); + remap[inst] = raw; + } + } + auto* guard_term = + static_cast(cand.guard->MutableInstructions().back().get()); + auto* guard_cond = RemapValue(guard_term->GetCond(), remap); + BasicBlock* true_target = + (guard_term->GetTrueBlock() == cand.action) ? action : latch; + BasicBlock* false_target = + (guard_term->GetFalseBlock() == cand.action) ? action : latch; + guard->Append(Type::GetVoidType(), guard_cond, true_target, + false_target); + + for (const auto& inst_ptr : cand.action->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (dynamic_cast(inst) != nullptr || inst->IsTerminator()) continue; + if (auto cloned = CloneInstruction(inst, remap, ".par")) { + auto* raw = cloned.get(); + action->MutableInstructions().push_back(std::move(cloned)); + raw->SetParent(action); + remap[inst] = raw; + } + } + action->Append(Type::GetVoidType(), latch); + + for (const auto& inst_ptr : cand.latch->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (dynamic_cast(inst) != nullptr || inst->IsTerminator()) continue; + if (auto cloned = CloneInstruction(inst, remap, ".par")) { + auto* raw = cloned.get(); + latch->MutableInstructions().push_back(std::move(cloned)); + raw->SetParent(latch); + remap[inst] = raw; + } + } + latch->Append(Type::GetVoidType(), header); + + new_iv->AddIncoming(start, entry); + new_iv->AddIncoming(RemapValue(cand.induction_next, remap), latch); + new_linear->AddIncoming(start_linear, entry); + new_linear->AddIncoming(RemapValue(cand.linear_next, remap), latch); + worker_exit->Append(Type::GetVoidType(), nullptr); + return; + } + + auto* entry = worker->GetEntry(); + auto* tid = worker->GetArgument(0); + auto* header = worker->CreateBlock(cand.header->GetName()); + auto* body = worker->CreateBlock(cand.body->GetName()); + auto* worker_exit = worker->CreateBlock("par.exit"); + + IRBuilder builder(ctx, entry); + auto* bound_val = builder.CreateLoad(bound_slot, ctx.NextTemp()); + Value* threads_val = ctx.GetConstInt(kParallelThreads); + auto* start_mul = builder.CreateMul(tid, bound_val, ctx.NextTemp()); + auto* start = builder.CreateDiv(start_mul, threads_val, ctx.NextTemp()); + auto* next_tid = builder.CreateAdd(tid, ctx.GetConstInt(1), ctx.NextTemp()); + auto* end_mul = builder.CreateMul(next_tid, bound_val, ctx.NextTemp()); + auto* end = builder.CreateDiv(end_mul, threads_val, ctx.NextTemp()); + + ValueMap remap; + remap[cand.induction] = start; + remap[cand.bound] = bound_val; + for (const auto& ctx_value : ctx_slots) { + builder.SetInsertPoint(entry); + auto* loaded = builder.CreateLoad(ctx_value.slot, ctx.NextTemp()); + remap[ctx_value.original] = loaded; + } + builder.SetInsertPoint(entry); + builder.CreateBr(header); + auto* new_phi = header->PrependPhi(cand.induction->GetType(), ctx.NextTemp()); + remap[cand.induction] = new_phi; + PhiInst* new_reduction_phi = nullptr; + if (cand.kind == ParallelLoopKind::ReductionAddI32) { + new_reduction_phi = header->PrependPhi(Type::GetInt32Type(), ctx.NextTemp()); + remap[cand.reduction] = new_reduction_phi; + } + + if (auto cloned_cmp = CloneInstruction(cand.header_cmp, remap, ".par")) { + auto* raw_cmp = static_cast(cloned_cmp.get()); + header->MutableInstructions().push_back(std::move(cloned_cmp)); + raw_cmp->SetParent(header); + remap[cand.header_cmp] = raw_cmp; + if (cand.header_cmp->GetLhs() == cand.induction && + cand.header_cmp->GetRhs() == cand.bound) { + raw_cmp->SetOperand(0, new_phi); + raw_cmp->SetOperand(1, end); + } else if (cand.header_cmp->GetRhs() == cand.induction && + cand.header_cmp->GetLhs() == cand.bound) { + raw_cmp->SetOperand(0, end); + raw_cmp->SetOperand(1, new_phi); + } + header->Append(Type::GetVoidType(), raw_cmp, body, worker_exit); + } + + for (const auto& inst_ptr : cand.body->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (dynamic_cast(inst) != nullptr) continue; + if (inst->GetOpcode() == Opcode::Br) continue; + if (auto cloned = CloneInstruction(inst, remap, ".par")) { + auto* raw = cloned.get(); + body->MutableInstructions().push_back(std::move(cloned)); + raw->SetParent(body); + remap[inst] = raw; + } + } + new_phi->AddIncoming(start, entry); + new_phi->AddIncoming(RemapValue(cand.induction_next, remap), body); + if (new_reduction_phi) { + new_reduction_phi->AddIncoming(ctx.GetConstInt(0), entry); + new_reduction_phi->AddIncoming(RemapValue(cand.reduction_next, remap), body); + } + body->Append(Type::GetVoidType(), header); + + if (new_reduction_phi) { + IRBuilder exit_builder(ctx, worker_exit); + auto* partial_ptr = exit_builder.CreateGep(reduction_slot, tid, ctx.NextTemp()); + exit_builder.CreateStore(new_reduction_phi, partial_ptr); + } + worker_exit->Append(Type::GetVoidType(), nullptr); +} + +bool ParallelizeCandidate(Module& module, ParallelLoopCandidate& cand, int slot) { + auto& ctx = module.GetContext(); + auto* bound_slot = + module.CreateGlobalVar("__par_bound" + std::to_string(slot), 0, 1, + Type::GetPtrInt32Type()); + GlobalVariable* reduction_slot = nullptr; + if (cand.kind == ParallelLoopKind::ReductionAddI32) { + reduction_slot = module.CreateGlobalVar( + "__par_red" + std::to_string(slot), 0, kParallelThreads, + Type::GetPtrInt32Type()); + } + for (size_t i = 0; i < cand.contexts.size(); ++i) { + auto& entry = cand.contexts[i]; + bool is_float = entry.original->GetType() && entry.original->GetType()->IsFloat32(); + entry.slot = module.CreateGlobalVar( + "__par_ctx" + std::to_string(slot) + "_" + std::to_string(i), 0, 1, + is_float ? Type::GetPtrFloat32Type() : Type::GetPtrInt32Type()); + } + + auto* worker = + module.CreateFunction(NextWorkerName(slot), Type::GetVoidType(), + {Type::GetInt32Type()}); + CloneWorkerBlocks(cand, worker, bound_slot, cand.contexts, reduction_slot, ctx); + + auto* launcher = GetOrCreateRuntimeLauncher(module, slot); + auto* preheader = cand.preheader; + if (preheader->HasTerminator()) { + preheader->RemoveInstruction(preheader->MutableInstructions().back().get()); + } + for (const auto& ctx_value : cand.contexts) { + InsertInstruction(preheader, std::make_unique( + Type::GetVoidType(), ctx_value.original, + ctx_value.slot)); + } + InsertInstruction(preheader, std::make_unique(Type::GetVoidType(), + cand.bound, bound_slot)); + InsertCallBeforeTerminator(preheader, launcher, {}, ""); + Value* reduced_value = nullptr; + if (cand.kind == ParallelLoopKind::ReductionAddI32) { + IRBuilder builder(ctx, preheader); + reduced_value = cand.reduction_init; + for (int tid = 0; tid < kParallelThreads; ++tid) { + auto* partial_ptr = + builder.CreateGep(reduction_slot, ctx.GetConstInt(tid), ctx.NextTemp()); + auto* partial_val = builder.CreateLoad(partial_ptr, ctx.NextTemp()); + reduced_value = builder.CreateAdd(reduced_value, partial_val, ctx.NextTemp()); + } + } + cand.induction->RemoveIncomingBlock(preheader); + if (cand.linear) { + cand.linear->RemoveIncomingBlock(preheader); + } + if (cand.reduction) { + cand.reduction->RemoveIncomingBlock(preheader); + } + preheader->RemoveSuccessor(cand.header); + cand.header->RemovePredecessor(preheader); + preheader->AddSuccessor(cand.exit); + cand.exit->AddPredecessor(preheader); + if (cand.kind == ParallelLoopKind::ReductionAddI32 && reduced_value) { + ReplaceUsesOutsideLoop(cand.reduction, reduced_value, cand.loop); + } + preheader->Append(Type::GetVoidType(), cand.exit); + return true; +} + +} // namespace + +bool RunLoopParallelization(Module& module) { + bool changed = false; + int store_only_slots = 0; + for (int slot = 0; slot < kParallelLoopSlots; ++slot) { + ParallelLoopCandidate cand; + bool found = false; + + for (const auto& func_ptr : module.GetFunctions()) { + auto* func = func_ptr.get(); + if (!func || func->IsExternal() || IsGeneratedParallelWorker(*func)) continue; + + analysis::DominatorTree dom_tree(*func); + analysis::LoopInfo loop_info(*func, dom_tree); + std::vector loops; + for (const auto& loop_ptr : loop_info.GetLoops()) { + loops.push_back(loop_ptr.get()); + } + std::sort(loops.begin(), loops.end(), + [](analysis::Loop* lhs, analysis::Loop* rhs) { + if (lhs->GetDepth() != rhs->GetDepth()) { + return lhs->GetDepth() < rhs->GetDepth(); + } + return lhs->GetHeader()->GetName() < rhs->GetHeader()->GetName(); + }); + + ParallelLoopCandidate fallback_store_only; + bool have_fallback_store_only = false; + for (auto* loop : loops) { + if (BuildParallelCandidate(*func, loop, &cand)) { + if (cand.has_loads) { + found = true; + break; + } + if (store_only_slots < 1 && !have_fallback_store_only) { + fallback_store_only = cand; + have_fallback_store_only = true; + } + } + } + if (!found && have_fallback_store_only) { + cand = fallback_store_only; + found = true; + } + if (found) break; + } + + if (!found) break; + bool local_changed = ParallelizeCandidate(module, cand, slot); + changed |= local_changed; + if (local_changed && cand.parent) { + RunSimpleDCE(*cand.parent); + RunCFGSimplify(*cand.parent, module.GetContext()); + } + if (!cand.has_loads) { + ++store_only_slots; + } + } + return changed; +} + +} // namespace passes +} // namespace ir diff --git a/src/ir/passes/LoopPassUtils.h b/src/ir/passes/LoopPassUtils.h new file mode 100644 index 0000000..b96c054 --- /dev/null +++ b/src/ir/passes/LoopPassUtils.h @@ -0,0 +1,309 @@ +#pragma once + +#include "ir/IR.h" + +#include +#include +#include +#include +#include +#include + +namespace ir { +namespace passes { + +struct CanonicalInductionInfo { + PhiInst* phi = nullptr; + Value* init = nullptr; + Value* next = nullptr; + int step = 0; +}; + +struct CanonicalLoopMatch { + analysis::Loop* loop = nullptr; + BasicBlock* preheader = nullptr; + BasicBlock* header = nullptr; + BasicBlock* exit = nullptr; + BasicBlock* body = nullptr; + BasicBlock* latch = nullptr; + CondBranchInst* header_branch = nullptr; + CmpInst* header_cmp = nullptr; + Value* bound = nullptr; + CanonicalInductionInfo induction; + std::vector header_phis; +}; + +using ValueMap = std::unordered_map; + +inline Value* RemapValue(Value* value, const ValueMap& remap) { + auto it = remap.find(value); + return it != remap.end() ? it->second : value; +} + +inline std::vector CollectHeaderPhis(BasicBlock* header) { + std::vector phis; + if (!header) return phis; + for (const auto& inst_ptr : header->GetInstructions()) { + auto* phi = dynamic_cast(inst_ptr.get()); + if (!phi) break; + phis.push_back(phi); + } + return phis; +} + +inline void InsertInstruction(BasicBlock* block, + std::unique_ptr inst) { + auto& insts = block->MutableInstructions(); + auto insert_it = insts.end(); + if (block->HasTerminator()) { + insert_it = insts.end() - 1; + } + inst->SetParent(block); + insts.insert(insert_it, std::move(inst)); +} + +inline Instruction* AppendOwnedInstruction(BasicBlock* block, + std::unique_ptr inst) { + auto* raw = inst.get(); + InsertInstruction(block, std::move(inst)); + return raw; +} + +inline BinaryInst* InsertBinaryBeforeTerminator(BasicBlock* block, Opcode opcode, + Value* lhs, Value* rhs, + const std::string& name) { + auto inst = std::make_unique(opcode, lhs->GetType(), lhs, rhs, name); + auto* raw = inst.get(); + InsertInstruction(block, std::move(inst)); + return raw; +} + +inline CmpInst* InsertCmpBeforeTerminator(BasicBlock* block, CmpOp cmp_op, + Value* lhs, Value* rhs, + const std::string& name) { + auto inst = + std::make_unique(cmp_op, Type::GetInt32Type(), lhs, rhs, name); + auto* raw = inst.get(); + InsertInstruction(block, std::move(inst)); + return raw; +} + +inline BranchInst* InsertBranchBeforeTerminator(BasicBlock* block, + BasicBlock* target) { + auto inst = std::make_unique(Type::GetVoidType(), target); + auto* raw = inst.get(); + InsertInstruction(block, std::move(inst)); + return raw; +} + +inline CondBranchInst* InsertCondBrBeforeTerminator(BasicBlock* block, Value* cond, + BasicBlock* true_bb, + BasicBlock* false_bb) { + auto inst = std::make_unique(Type::GetVoidType(), cond, + true_bb, false_bb); + auto* raw = inst.get(); + InsertInstruction(block, std::move(inst)); + return raw; +} + +inline CallInst* InsertCallBeforeTerminator(BasicBlock* block, Function* callee, + const std::vector& args, + const std::string& name) { + auto inst = std::make_unique(callee->GetType(), callee, args, name); + auto* raw = inst.get(); + InsertInstruction(block, std::move(inst)); + return raw; +} + +inline std::string CloneName(const std::string& base, const std::string& suffix) { + if (base.empty()) return base; + return base + suffix; +} + +inline std::unique_ptr CloneInstruction(Instruction* inst, + const ValueMap& remap, + const std::string& suffix) { + switch (inst->GetOpcode()) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::Div: + case Opcode::Mod: { + auto* bin = static_cast(inst); + return std::make_unique( + inst->GetOpcode(), inst->GetType(), RemapValue(bin->GetLhs(), remap), + RemapValue(bin->GetRhs(), remap), CloneName(inst->GetName(), suffix)); + } + case Opcode::Cmp: { + auto* cmp = static_cast(inst); + return std::make_unique( + cmp->GetCmpOp(), inst->GetType(), RemapValue(cmp->GetLhs(), remap), + RemapValue(cmp->GetRhs(), remap), CloneName(inst->GetName(), suffix)); + } + case Opcode::Cast: { + auto* cast = static_cast(inst); + return std::make_unique(cast->GetCastOp(), inst->GetType(), + RemapValue(cast->GetValue(), remap), + CloneName(inst->GetName(), suffix)); + } + case Opcode::Load: { + auto* load = static_cast(inst); + return std::make_unique(inst->GetType(), + RemapValue(load->GetPtr(), remap), + CloneName(inst->GetName(), suffix)); + } + case Opcode::Store: { + auto* store = static_cast(inst); + return std::make_unique( + Type::GetVoidType(), RemapValue(store->GetValue(), remap), + RemapValue(store->GetPtr(), remap)); + } + case Opcode::Call: { + auto* call = static_cast(inst); + std::vector args; + args.reserve(call->GetNumArgs()); + for (size_t i = 0; i < call->GetNumArgs(); ++i) { + args.push_back(RemapValue(call->GetArg(i), remap)); + } + return std::make_unique(inst->GetType(), call->GetCallee(), args, + CloneName(inst->GetName(), suffix)); + } + case Opcode::Gep: { + auto* gep = static_cast(inst); + return std::make_unique(inst->GetType(), + RemapValue(gep->GetBase(), remap), + RemapValue(gep->GetIndex(), remap), + CloneName(inst->GetName(), suffix)); + } + default: + return nullptr; + } +} + +inline bool IsLoopInvariantValue(Value* value, analysis::Loop* loop) { + if (!value) return false; + if (dynamic_cast(value) != nullptr) return true; + if (dynamic_cast(value) != nullptr) return true; + if (dynamic_cast(value) != nullptr) return true; + if (dynamic_cast(value) != nullptr) return true; + auto* inst = dynamic_cast(value); + return !inst || !inst->GetParent() || !loop->Contains(inst->GetParent()); +} + +inline std::optional MatchCanonicalInduction( + BasicBlock* header, BasicBlock* preheader, BasicBlock* latch) { + if (!header || !preheader || !latch) return std::nullopt; + + for (auto* phi : CollectHeaderPhis(header)) { + if (!phi || phi->GetType() == nullptr || !phi->GetType()->IsInt32()) continue; + if (phi->GetNumIncoming() != 2) continue; + + Value* init = nullptr; + Value* next = nullptr; + for (size_t i = 0; i < phi->GetNumIncoming(); ++i) { + auto* incoming_bb = phi->GetIncomingBlock(i); + if (incoming_bb == preheader) { + init = phi->GetIncomingValue(i); + } else if (incoming_bb == latch) { + next = phi->GetIncomingValue(i); + } + } + if (!init || !next) continue; + + auto* next_inst = dynamic_cast(next); + if (!next_inst) continue; + if (next_inst->GetOpcode() != Opcode::Add && + next_inst->GetOpcode() != Opcode::Sub) { + continue; + } + + Value* other = nullptr; + bool phi_on_lhs = false; + if (next_inst->GetLhs() == phi) { + other = next_inst->GetRhs(); + phi_on_lhs = true; + } else if (next_inst->GetRhs() == phi) { + other = next_inst->GetLhs(); + } else { + continue; + } + + auto* step_ci = dynamic_cast(other); + if (!step_ci) continue; + + int step = step_ci->GetValue(); + if (next_inst->GetOpcode() == Opcode::Sub) { + if (!phi_on_lhs) continue; + step = -step; + } + if (step == 0) continue; + + return CanonicalInductionInfo{phi, init, next, step}; + } + + return std::nullopt; +} + +inline std::optional MatchCanonicalLoop(analysis::Loop* loop) { + if (!loop) return std::nullopt; + auto* header = loop->GetHeader(); + auto* preheader = loop->GetPreheader(); + if (!header || !preheader) return std::nullopt; + if (loop->GetLatches().size() != 1) return std::nullopt; + auto* latch = loop->GetLatches().front(); + if (!latch) return std::nullopt; + + auto* header_term = header->HasTerminator() + ? dynamic_cast( + header->MutableInstructions().back().get()) + : nullptr; + if (!header_term) return std::nullopt; + + auto* cmp = dynamic_cast(header_term->GetCond()); + if (!cmp) return std::nullopt; + + BasicBlock* body = nullptr; + BasicBlock* exit = nullptr; + if (loop->Contains(header_term->GetTrueBlock()) && + !loop->Contains(header_term->GetFalseBlock())) { + body = header_term->GetTrueBlock(); + exit = header_term->GetFalseBlock(); + } else if (loop->Contains(header_term->GetFalseBlock()) && + !loop->Contains(header_term->GetTrueBlock())) { + body = header_term->GetFalseBlock(); + exit = header_term->GetTrueBlock(); + } else { + return std::nullopt; + } + + auto induction = MatchCanonicalInduction(header, preheader, latch); + if (!induction.has_value()) return std::nullopt; + + Value* bound = nullptr; + if (cmp->GetLhs() == induction->phi && + IsLoopInvariantValue(cmp->GetRhs(), loop)) { + bound = cmp->GetRhs(); + } else if (cmp->GetRhs() == induction->phi && + IsLoopInvariantValue(cmp->GetLhs(), loop)) { + bound = cmp->GetLhs(); + } else { + return std::nullopt; + } + + CanonicalLoopMatch match; + match.loop = loop; + match.preheader = preheader; + match.header = header; + match.exit = exit; + match.body = body; + match.latch = latch; + match.header_branch = header_term; + match.header_cmp = cmp; + match.bound = bound; + match.induction = *induction; + match.header_phis = CollectHeaderPhis(header); + return match; +} + +} // namespace passes +} // namespace ir diff --git a/src/ir/passes/LoopUnroll.cpp b/src/ir/passes/LoopUnroll.cpp new file mode 100644 index 0000000..19b503f --- /dev/null +++ b/src/ir/passes/LoopUnroll.cpp @@ -0,0 +1,143 @@ +// 循环展开: +// - 针对单块 innermost 规范循环做因子 2 的保守展开 +// - 使用一次额外比较保护余数路径,避免要求静态 trip count + +#include "ir/IR.h" + +#include +#include +#include + +#include "LoopPassUtils.h" + +namespace ir { +namespace passes { + +namespace { + +bool IsUnrollableLoop(const CanonicalLoopMatch& match) { + if (match.induction.step <= 0) return false; + if (match.loop->GetChildren().size() != 0) return false; + if (match.loop->GetBlocks().size() != 2) return false; + if (match.body != match.latch) return false; + if (!match.body || !match.body->HasTerminator()) return false; + + auto* body_term = + dynamic_cast(match.body->MutableInstructions().back().get()); + if (!body_term || body_term->GetTarget() != match.header) return false; + + if (match.header_cmp->GetLhs() != match.induction.phi) return false; + if (match.header_cmp->GetCmpOp() != CmpOp::Lt && + match.header_cmp->GetCmpOp() != CmpOp::Le) { + return false; + } + + size_t body_inst_count = match.body->GetInstructions().size(); + if (body_inst_count <= 1 || body_inst_count > 18) return false; + return true; +} + +Value* GetLatchIncomingForPhi(PhiInst* phi, BasicBlock* latch) { + for (size_t i = 0; i < phi->GetNumIncoming(); ++i) { + if (phi->GetIncomingBlock(i) == latch) { + return phi->GetIncomingValue(i); + } + } + return nullptr; +} + +bool RunUnrollOnLoop(Function& func, const CanonicalLoopMatch& match, + Context& ctx, int unroll_index) { + (void)unroll_index; + if (!IsUnrollableLoop(match)) return false; + + auto* body = match.body; + auto* header = match.header; + auto* body_term = + static_cast(body->MutableInstructions().back().get()); + (void)body_term; + + std::string block_suffix = ctx.NextTemp(); + if (!block_suffix.empty() && block_suffix.front() == '%') { + block_suffix.erase(0, 1); + } + auto* body2 = func.CreateBlock(body->GetName() + ".unroll." + block_suffix); + + std::unordered_map seed_map; + for (auto* phi : match.header_phis) { + auto* incoming = GetLatchIncomingForPhi(phi, match.latch); + if (!incoming) return false; + seed_map.emplace(phi, incoming); + } + + ValueMap clone_map = seed_map; + std::vector originals; + for (const auto& inst_ptr : body->GetInstructions()) { + if (inst_ptr.get()->IsTerminator()) continue; + originals.push_back(inst_ptr.get()); + } + + for (auto* inst : originals) { + auto cloned = CloneInstruction(inst, clone_map, ".u2"); + if (!cloned) return false; + auto* raw = cloned.get(); + body2->MutableInstructions().push_back(std::move(cloned)); + raw->SetParent(body2); + clone_map[inst] = raw; + } + + body2->Append(Type::GetVoidType(), header); + + Value* iv_after_one = GetLatchIncomingForPhi(match.induction.phi, match.latch); + if (!iv_after_one) return false; + + auto* first_cmp = InsertCmpBeforeTerminator( + body, match.header_cmp->GetCmpOp(), iv_after_one, match.bound, + ctx.NextTemp()); + + body->RemoveInstruction(body->MutableInstructions().back().get()); + body->Append(Type::GetVoidType(), first_cmp, body2, header); + + body->AddSuccessor(body2); + body2->AddPredecessor(body); + body2->AddSuccessor(header); + header->AddPredecessor(body2); + + for (auto* phi : match.header_phis) { + auto* incoming = GetLatchIncomingForPhi(phi, match.latch); + Value* second_value = incoming; + auto it = clone_map.find(incoming); + if (it != clone_map.end()) { + second_value = it->second; + } else { + second_value = RemapValue(incoming, clone_map); + } + phi->AddIncoming(second_value, body2); + } + + return true; +} + +} // namespace + +bool RunLoopUnroll(Function& func, Context& ctx) { + if (func.IsExternal()) return false; + + analysis::DominatorTree dom_tree(func); + analysis::LoopInfo loop_info(func, dom_tree); + + bool changed = false; + int unroll_index = 0; + for (const auto& loop_ptr : loop_info.GetLoops()) { + auto match = MatchCanonicalLoop(loop_ptr.get()); + if (!match.has_value()) continue; + if (RunUnrollOnLoop(func, *match, ctx, unroll_index++)) { + changed = true; + break; + } + } + return changed; +} + +} // namespace passes +} // namespace ir diff --git a/src/ir/passes/PassManager.cpp b/src/ir/passes/PassManager.cpp index 7750054..f4fdc66 100644 --- a/src/ir/passes/PassManager.cpp +++ b/src/ir/passes/PassManager.cpp @@ -8,13 +8,17 @@ #include "ir/IR.h" -#include - namespace ir { namespace passes { // 前向声明各 pass 入口 bool RunMem2Reg(Function& func); +bool RunLICM(Function& func); +bool RunStrengthReduction(Function& func, Context& ctx); +bool RunLoopIdiom(Function& func, Module& module, Context& ctx); +bool RunLoopFission(Function& func, Context& ctx); +bool RunLoopUnroll(Function& func, Context& ctx); +bool RunLoopParallelization(Module& module); bool RunConstFoldWithCtx(Function& func, Context& ctx); bool RunConstProp(Function& func, Context& ctx); bool RunCSE(Function& func); @@ -34,6 +38,38 @@ void RunAllPasses(Module& module) { for (int iter = 0; iter < kMaxIterations; ++iter) { bool changed = false; + changed |= RunLICM(*func); + changed |= RunStrengthReduction(*func, ctx); + changed |= RunLoopIdiom(*func, module, ctx); + changed |= RunConstFoldWithCtx(*func, ctx); + changed |= RunConstProp(*func, ctx); + changed |= RunCSE(*func); + changed |= RunSimpleDCE(*func); + changed |= RunCFGSimplify(*func, ctx); + + if (!changed) break; + } + } + + if (RunLoopParallelization(module)) { + for (const auto& func : module.GetFunctions()) { + if (!func || func->IsExternal()) continue; + RunSimpleDCE(*func); + RunCFGSimplify(*func, ctx); + } + } + + for (const auto& func : module.GetFunctions()) { + if (!func || func->IsExternal()) continue; + + for (int iter = 0; iter < kMaxIterations; ++iter) { + bool changed = false; + + changed |= RunLICM(*func); + changed |= RunStrengthReduction(*func, ctx); + changed |= RunLoopIdiom(*func, module, ctx); + changed |= RunLoopFission(*func, ctx); + changed |= RunLoopUnroll(*func, ctx); changed |= RunConstFoldWithCtx(*func, ctx); changed |= RunConstProp(*func, ctx); changed |= RunCSE(*func); diff --git a/src/ir/passes/StrengthReduction.cpp b/src/ir/passes/StrengthReduction.cpp new file mode 100644 index 0000000..ba211f2 --- /dev/null +++ b/src/ir/passes/StrengthReduction.cpp @@ -0,0 +1,115 @@ +// 强度削弱: +// - 识别规范归纳变量 iv +// - 将循环内的 iv * C 改写成辅助 phi + 常量增量递推 + +#include "ir/IR.h" + +#include +#include +#include +#include + +#include "LoopPassUtils.h" + +namespace ir { +namespace passes { + +namespace { + +bool IsStrengthReductionCandidate(Instruction* inst, PhiInst* iv) { + auto* bin = dynamic_cast(inst); + if (!bin || bin->GetOpcode() != Opcode::Mul) return false; + if (!bin->GetType() || !bin->GetType()->IsInt32()) return false; + return bin->GetLhs() == iv || bin->GetRhs() == iv; +} + +int ExtractScale(BinaryInst* mul, PhiInst* iv) { + auto* lhs_ci = dynamic_cast(mul->GetLhs()); + auto* rhs_ci = dynamic_cast(mul->GetRhs()); + if (mul->GetLhs() == iv && rhs_ci) return rhs_ci->GetValue(); + if (mul->GetRhs() == iv && lhs_ci) return lhs_ci->GetValue(); + return 0; +} + +bool ReplaceMulWithRecurrence(Function& func, const CanonicalLoopMatch& match, + BinaryInst* mul, Context& ctx) { + (void)func; + const int scale = ExtractScale(mul, match.induction.phi); + if (scale == 0) return false; + if (match.latch != mul->GetParent() && + !match.loop->Contains(mul->GetParent())) { + return false; + } + + auto* init_scale = + InsertBinaryBeforeTerminator(match.preheader, Opcode::Mul, + match.induction.init, ctx.GetConstInt(scale), + ctx.NextTemp()); + + auto* sr_phi = + match.header->PrependPhi(Type::GetInt32Type(), ctx.NextTemp()); + + auto* step_scale = + InsertBinaryBeforeTerminator(match.latch, Opcode::Mul, + ctx.GetConstInt(match.induction.step), + ctx.GetConstInt(scale), ctx.NextTemp()); + auto* sr_next = + InsertBinaryBeforeTerminator(match.latch, Opcode::Add, sr_phi, step_scale, + ctx.NextTemp()); + + sr_phi->AddIncoming(init_scale, match.preheader); + sr_phi->AddIncoming(sr_next, match.latch); + + mul->ReplaceAllUsesWith(sr_phi); + mul->GetParent()->RemoveInstruction(mul); + return true; +} + +} // namespace + +bool RunStrengthReduction(Function& func, Context& ctx) { + if (func.IsExternal()) return false; + + analysis::DominatorTree dom_tree(func); + analysis::LoopInfo loop_info(func, dom_tree); + + std::vector loops; + for (const auto& loop_ptr : loop_info.GetLoops()) { + loops.push_back(loop_ptr.get()); + } + std::sort(loops.begin(), loops.end(), + [](analysis::Loop* lhs, analysis::Loop* rhs) { + if (lhs->GetDepth() != rhs->GetDepth()) { + return lhs->GetDepth() > rhs->GetDepth(); + } + return lhs->GetHeader()->GetName() < rhs->GetHeader()->GetName(); + }); + + bool changed = false; + for (auto* loop : loops) { + auto match = MatchCanonicalLoop(loop); + if (!match.has_value()) continue; + + std::vector candidates; + for (auto* block : loop->GetBlocks()) { + for (const auto& inst_ptr : block->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (IsStrengthReductionCandidate(inst, match->induction.phi)) { + auto* mul = static_cast(inst); + if (ExtractScale(mul, match->induction.phi) != 0) { + candidates.push_back(mul); + } + } + } + } + + for (auto* mul : candidates) { + changed |= ReplaceMulWithRecurrence(func, *match, mul, ctx); + } + } + + return changed; +} + +} // namespace passes +} // namespace ir diff --git a/src/main.cpp b/src/main.cpp index 78232c4..160c06b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -160,6 +160,7 @@ int main(int argc, char** argv) { mir::RunPeephole(*func_ptr); mir::RunRegAlloc(*func_ptr); mir::RunFrameLowering(*func_ptr); + mir::RunPeephole(*func_ptr); } if (need_blank_line) { std::cout << "\n"; diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index a3f2ed9..1d84451 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -18,6 +18,11 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function, return function.GetFrameSlot(operand.GetFrameIndex()); } +std::string LocalBlockLabel(const MachineFunction& function, + const std::string& block_name) { + return "." + function.GetName() + "." + block_name; +} + void PrintMoveImm32(std::ostream& os, PhysReg reg, int imm) { std::uint32_t u = static_cast(imm); std::uint32_t lo = u & 0xFFFFu; @@ -157,7 +162,7 @@ void PrintAsm(const MachineModule& module, std::ostream& os) { // 打印块标签(entry 块不需要标签,因为函数名已经是标签了) if (bb.GetName() != "entry") { - os << "." << bb.GetName() << ":\n"; + os << LocalBlockLabel(function, bb.GetName()) << ":\n"; } for (const auto& inst : bb.GetInstructions()) { @@ -382,25 +387,30 @@ void PrintAsm(const MachineModule& module, std::ostream& os) { os << " bl " << ops.at(0).GetSymbol() << "\n"; break; case Opcode::B: - os << " b ." << ops.at(0).GetSymbol() << "\n"; + os << " b " << LocalBlockLabel(function, ops.at(0).GetSymbol()) + << "\n"; break; case Opcode::Cbnz: os << " cbnz " << PhysRegName(ops.at(0).GetReg()) - << ", ." << ops.at(1).GetSymbol() << "\n"; + << ", " << LocalBlockLabel(function, ops.at(1).GetSymbol()) + << "\n"; break; case Opcode::Cbz: os << " cbz " << PhysRegName(ops.at(0).GetReg()) - << ", ." << ops.at(1).GetSymbol() << "\n"; + << ", " << LocalBlockLabel(function, ops.at(1).GetSymbol()) + << "\n"; break; case Opcode::Bcond: // ops: symbol, cmpop(imm) os << " b." << CondSuffix(static_cast(ops.at(1).GetImm())) - << " ." << ops.at(0).GetSymbol() << "\n"; + << " " << LocalBlockLabel(function, ops.at(0).GetSymbol()) + << "\n"; break; case Opcode::FBcond: // ops: symbol, cmpop(imm) - 浮点条件分支 os << " b." << FloatCondSuffix(static_cast(ops.at(1).GetImm())) - << " ." << ops.at(0).GetSymbol() << "\n"; + << " " << LocalBlockLabel(function, ops.at(0).GetSymbol()) + << "\n"; break; case Opcode::Ret: os << " ret\n"; diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index f200702..c5c395a 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -818,7 +818,9 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, return; } - // 参数传递:根据类型使用 w0-w7(整数)、s0-s7(浮点)或 x0-x7(指针) + // 参数传递:根据类型使用 w0-w7(整数)、s0-s7(浮点)或 x0-x7(指针)。 + // 注意:相关的 Peephole 不会再把装参前的 LoadStack 转发成 ABI + // 参数寄存器之间的 mov,以避免 W/X 别名覆盖。 size_t num_args = call.GetNumArgs(); if (num_args > 8) { throw std::runtime_error(FormatError("mir", "暂不支持超过 8 个参数的函数调用")); @@ -1002,6 +1004,26 @@ std::unique_ptr LowerToMIR(const ir::Module& module) { Opcode::FCmpOnlyRR, {Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::S1)}); } else { + const ir::ConstantInt* lhs_ci = TryGetConstInt(cmp_inst->GetLhs()); + const ir::ConstantInt* rhs_ci = TryGetConstInt(cmp_inst->GetRhs()); + bool lhs_zero = lhs_ci && lhs_ci->GetValue() == 0; + bool rhs_zero = rhs_ci && rhs_ci->GetValue() == 0; + if ((cmp_inst->GetCmpOp() == ir::CmpOp::Eq || + cmp_inst->GetCmpOp() == ir::CmpOp::Ne) && + (lhs_zero || rhs_zero)) { + auto* non_zero_side = lhs_zero ? cmp_inst->GetRhs() : cmp_inst->GetLhs(); + EmitValueToReg(non_zero_side, PhysReg::W8, slots, *current_mbb); + current_mbb->Append( + cmp_inst->GetCmpOp() == ir::CmpOp::Eq ? Opcode::Cbz + : Opcode::Cbnz, + {Operand::Reg(PhysReg::W8), + Operand::Symbol(true_mbb->GetName())}); + current_mbb->Append(Opcode::B, + {Operand::Symbol(false_mbb->GetName())}); + ++i; // 同时跳过后继 CondBr + continue; + } + EmitValueToReg(cmp_inst->GetLhs(), PhysReg::W8, slots, *current_mbb); EmitValueToReg(cmp_inst->GetRhs(), PhysReg::W9, slots, *current_mbb); current_mbb->Append( @@ -1078,7 +1100,8 @@ std::unique_ptr LowerToMIR(const ir::Module& module) { for (size_t j = pred_insts.size(); j > 0; --j) { auto op = pred_insts[j - 1].GetOpcode(); if (op == Opcode::B || op == Opcode::Bcond || - op == Opcode::Cbnz || op == Opcode::FBcond) { + op == Opcode::Cbnz || op == Opcode::Cbz || + op == Opcode::FBcond) { insert_pos = j - 1; } else { break; diff --git a/src/mir/MIRFunction.cpp b/src/mir/MIRFunction.cpp index 4ea0b16..a267dc7 100644 --- a/src/mir/MIRFunction.cpp +++ b/src/mir/MIRFunction.cpp @@ -8,6 +8,18 @@ namespace mir { +namespace { + +PhysReg CanonicalCalleeSavedReg(PhysReg reg) { + if (reg >= PhysReg::W0 && reg <= PhysReg::W11) { + int idx = static_cast(reg) - static_cast(PhysReg::W0); + return static_cast(static_cast(PhysReg::X0) + idx); + } + return reg; +} + +} // namespace + MachineFunction::MachineFunction(std::string name) : name_(std::move(name)) { // 创建入口块 @@ -54,6 +66,7 @@ int MachineFunction::CreateVReg(VRegClass rc) { } void MachineFunction::AddUsedCalleeSaved(PhysReg reg) { + reg = CanonicalCalleeSavedReg(reg); if (std::find(used_callee_saved_.begin(), used_callee_saved_.end(), reg) == used_callee_saved_.end()) { used_callee_saved_.push_back(reg); diff --git a/src/mir/passes/Peephole.cpp b/src/mir/passes/Peephole.cpp index 087cd7c..f17077d 100644 --- a/src/mir/passes/Peephole.cpp +++ b/src/mir/passes/Peephole.cpp @@ -15,6 +15,13 @@ bool IsMovLike(Opcode opcode) { return opcode == Opcode::MovReg || opcode == Opc bool IsFloatReg(PhysReg reg) { return reg >= PhysReg::S0 && reg <= PhysReg::S10; } +bool IsAbiArgReg(PhysReg reg) { + if (reg >= PhysReg::W0 && reg <= PhysReg::W7) return true; + if (reg >= PhysReg::X0 && reg <= PhysReg::X7) return true; + if (reg >= PhysReg::S0 && reg <= PhysReg::S7) return true; + return false; +} + bool IsWxReg(PhysReg reg) { return (reg >= PhysReg::W0 && reg <= PhysReg::W10) || (reg >= PhysReg::X0 && reg <= PhysReg::X10); @@ -92,6 +99,88 @@ std::optional GetWrittenReg(const MachineInstr& inst) { } } +bool ReadsReg(const MachineInstr& inst, PhysReg reg) { + const auto& ops = inst.GetOperands(); + auto reads_operand = [&](size_t idx) { + return idx < ops.size() && ops[idx].GetKind() == Operand::Kind::Reg && + RegAlias(ops[idx].GetReg(), reg); + }; + + switch (inst.GetOpcode()) { + case Opcode::MovReg: + case Opcode::FMovReg: + case Opcode::AddRI: + case Opcode::SubRI: + case Opcode::LoadStackOffset: + case Opcode::LoadIndirect: + case Opcode::LoadGlobal: + case Opcode::LoadGlobalAddr: + case Opcode::SIToFP: + case Opcode::FPToSI: + case Opcode::LslRI: + return reads_operand(1); + case Opcode::AddRR: + case Opcode::AddRR_UXTW: + case Opcode::SubRR: + case Opcode::MulRR: + case Opcode::DivRR: + case Opcode::LslRR: + case Opcode::FAddRR: + case Opcode::FSubRR: + case Opcode::FMulRR: + case Opcode::FDivRR: + case Opcode::CmpRR: + case Opcode::FCmpRR: + case Opcode::CmpOnlyRR: + case Opcode::FCmpOnlyRR: + return reads_operand(1) || reads_operand(2); + case Opcode::StoreStack: + case Opcode::StoreStackOffset: + case Opcode::Cbz: + case Opcode::Cbnz: + return reads_operand(0); + case Opcode::StoreIndirect: + return reads_operand(0) || reads_operand(1); + case Opcode::Ret: + return false; + default: + return false; + } +} + +bool CanElideIfOverwritten(const MachineInstr& inst) { + switch (inst.GetOpcode()) { + case Opcode::MovImm: + case Opcode::MovReg: + case Opcode::FMovImm: + case Opcode::FMovReg: + case Opcode::LoadStack: + case Opcode::LoadStackOffset: + case Opcode::LoadStackAddr: + case Opcode::LoadIndirect: + case Opcode::LoadGlobal: + case Opcode::LoadGlobalAddr: + case Opcode::AddRI: + case Opcode::SubRI: + case Opcode::AddRR: + case Opcode::AddRR_UXTW: + case Opcode::SubRR: + case Opcode::MulRR: + case Opcode::DivRR: + case Opcode::LslRI: + case Opcode::LslRR: + case Opcode::FAddRR: + case Opcode::FSubRR: + case Opcode::FMulRR: + case Opcode::FDivRR: + case Opcode::SIToFP: + case Opcode::FPToSI: + return true; + default: + return false; + } +} + bool IsMemoryClobber(const MachineInstr& inst) { switch (inst.GetOpcode()) { case Opcode::StoreIndirect: @@ -161,6 +250,13 @@ bool TryForwardLoad(std::vector& out, const PhysReg src = it->second; + // 避免把装参前的 LoadStack 转成 ABI 参数寄存器之间的 mov, + // 否则可能触发 W/X 别名覆盖,破坏调用实参。若源寄存器不是 ABI + // 参数寄存器,则转成 mov 仍然是安全的。 + if (IsAbiArgReg(dst) && IsAbiArgReg(src) && !RegAlias(src, dst)) { + return false; + } + // 宽度不匹配时不能转发(如 W8 → X8 会生成非法的 mov x8, w8) if (!SameRegWidth(src, dst)) { return false; @@ -220,6 +316,16 @@ void RunPeephole(MachineFunction& function) { continue; } + if (i + 1 < insts.size() && CanElideIfOverwritten(cur)) { + auto wr_cur = GetWrittenReg(cur); + auto wr_next = GetWrittenReg(insts[i + 1]); + if (wr_cur.has_value() && wr_next.has_value() && + RegAlias(*wr_cur, *wr_next) && + !ReadsReg(insts[i + 1], *wr_cur)) { + continue; + } + } + // mov #2 + lsl reg, reg, mov_reg -> lsl reg, reg, #2 if (i + 1 < insts.size()) { PhysReg imm_reg = PhysReg::W0; @@ -314,4 +420,3 @@ void RunPeephole(MachineFunction& function) { } } // namespace mir - diff --git a/sylib/sylib.c b/sylib/sylib.c index 21b9fdd..e4942fc 100644 --- a/sylib/sylib.c +++ b/sylib/sylib.c @@ -5,6 +5,9 @@ #include #include #include +#include +#include +#include int getint() { int v; scanf("%d", &v); return v; } int getch() { return getchar(); } @@ -41,3 +44,154 @@ void stoptime(int l) { fprintf(stderr, "Timer@%d: %ldms\n", l, (t1.tv_sec-_t0.tv_sec)*1000+(t1.tv_nsec-_t0.tv_nsec)/1000000); } + +void __fill_i32(int* base, int count, int value) { + if (!base || count <= 0) return; + if (value == 0 || value == -1) { + memset(base, value & 0xff, (size_t)count * sizeof(int)); + return; + } + for (int i = 0; i < count; ++i) { + base[i] = value; + } +} + +typedef struct { + int* base; + int start_offset; + int start_row; + int end_row; + int stride; + int count; + int value; +} fill_rows_task_t; + +static void* __fill_rows_worker(void* opaque) { + fill_rows_task_t* task = (fill_rows_task_t*)opaque; + for (int row = task->start_row; row < task->end_row; ++row) { + int* row_ptr = task->base + task->start_offset + row * task->stride; + __fill_i32(row_ptr, task->count, task->value); + } + return NULL; +} + +void __fill_rows_i32(int* base, int start_offset, int rows, int stride, int count, + int value) { + if (!base || rows <= 0 || stride <= 0 || count <= 0) return; + if (rows < 32) { + fill_rows_task_t task = {base, start_offset, 0, rows, stride, count, value}; + __fill_rows_worker(&task); + return; + } + + pthread_t tids[3]; + fill_rows_task_t tasks[4]; + for (int tid = 0; tid < 4; ++tid) { + int begin = tid * rows / 4; + int end = (tid + 1) * rows / 4; + tasks[tid].base = base; + tasks[tid].start_offset = start_offset; + tasks[tid].start_row = begin; + tasks[tid].end_row = end; + tasks[tid].stride = stride; + tasks[tid].count = count; + tasks[tid].value = value; + } + for (int tid = 1; tid < 4; ++tid) { + pthread_create(&tids[tid - 1], NULL, __fill_rows_worker, &tasks[tid]); + } + __fill_rows_worker(&tasks[0]); + for (int tid = 0; tid < 3; ++tid) { + pthread_join(tids[tid], NULL); + } +} + +typedef void (*par_worker_fn_t)(int); + +typedef struct { + pthread_mutex_t mutex; + pthread_cond_t start_cv; + pthread_cond_t done_cv; + int generation; + int remaining; + int helper_count; +} par_slot_state_t; + +typedef struct { + par_slot_state_t* state; + par_worker_fn_t worker; + int tid; +} par_thread_arg_t; + +static void* __par_pool_worker(void* opaque) { + par_thread_arg_t* arg = (par_thread_arg_t*)opaque; + par_slot_state_t* state = arg->state; + int seen_generation = 0; + + pthread_mutex_lock(&state->mutex); + for (;;) { + while (state->generation == seen_generation) { + pthread_cond_wait(&state->start_cv, &state->mutex); + } + seen_generation = state->generation; + pthread_mutex_unlock(&state->mutex); + + arg->worker(arg->tid); + + pthread_mutex_lock(&state->mutex); + if (state->remaining > 0) { + --state->remaining; + if (state->remaining == 0) { + pthread_cond_signal(&state->done_cv); + } + } + } +} + +#define DECL_PAR_SLOT(N) \ + extern void __par_worker##N(int) __attribute__((weak)); \ + static par_slot_state_t __par_state##N = { \ + PTHREAD_MUTEX_INITIALIZER, PTHREAD_COND_INITIALIZER, \ + PTHREAD_COND_INITIALIZER, 0, 0, 0}; \ + static pthread_once_t __par_once##N = PTHREAD_ONCE_INIT; \ + static par_thread_arg_t __par_args##N[3]; \ + static void __par_init##N(void) { \ + pthread_attr_t attr; \ + pthread_t tid; \ + pthread_attr_init(&attr); \ + pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_DETACHED); \ + for (int i = 0; i < 3; ++i) { \ + __par_args##N[i].state = &__par_state##N; \ + __par_args##N[i].worker = __par_worker##N; \ + __par_args##N[i].tid = i + 1; \ + if (pthread_create(&tid, &attr, __par_pool_worker, \ + &__par_args##N[i]) == 0) { \ + ++__par_state##N.helper_count; \ + } \ + } \ + pthread_attr_destroy(&attr); \ + } \ + void __par_run##N(void) { \ + if (!__par_worker##N) return; \ + pthread_once(&__par_once##N, __par_init##N); \ + pthread_mutex_lock(&__par_state##N.mutex); \ + __par_state##N.remaining = __par_state##N.helper_count; \ + ++__par_state##N.generation; \ + pthread_cond_broadcast(&__par_state##N.start_cv); \ + pthread_mutex_unlock(&__par_state##N.mutex); \ + __par_worker##N(0); \ + pthread_mutex_lock(&__par_state##N.mutex); \ + while (__par_state##N.remaining != 0) { \ + pthread_cond_wait(&__par_state##N.done_cv, &__par_state##N.mutex); \ + } \ + pthread_mutex_unlock(&__par_state##N.mutex); \ + } + +DECL_PAR_SLOT(0) +DECL_PAR_SLOT(1) +DECL_PAR_SLOT(2) +DECL_PAR_SLOT(3) +DECL_PAR_SLOT(4) +DECL_PAR_SLOT(5) +DECL_PAR_SLOT(6) +DECL_PAR_SLOT(7)