lab6: add loop optimizations, parallel runtime, and asm backend fixes

Shrink 19 hours ago
parent 9e8984d740
commit 4df492feb9

@ -0,0 +1,211 @@
# Session Handoff
Date: 2026-05-28
## Repo State
- Current branch: `Shrink`
- Worktree is dirty; do not reset blindly.
- Modified tracked files:
- `include/ir/IR.h`
- `scripts/run_all_tests.sh`
- `scripts/verify_asm.sh`
- `scripts/verify_ir.sh`
- `src/ir/analysis/DominatorTree.cpp`
- `src/ir/analysis/LoopInfo.cpp`
- `src/ir/passes/CMakeLists.txt`
- `src/ir/passes/PassManager.cpp`
- `src/main.cpp`
- `src/mir/AsmPrinter.cpp`
- `src/mir/Lowering.cpp`
- `src/mir/MIRFunction.cpp`
- `src/mir/passes/Peephole.cpp`
- `sylib/sylib.c`
- New untracked files:
- `src/ir/passes/LICM.cpp`
- `src/ir/passes/LoopFission.cpp`
- `src/ir/passes/LoopIdiom.cpp`
- `src/ir/passes/LoopParallelize.cpp`
- `src/ir/passes/LoopPassUtils.h`
- `src/ir/passes/LoopUnroll.cpp`
- `src/ir/passes/StrengthReduction.cpp`
## Toolchain On Current Machine
- `cmake 3.22.1`
- `g++ 11.4.0`
- `clang 14.0.0`
- `llc 14.0.0`
- `aarch64-linux-gnu-gcc 11.4.0`
- `qemu-aarch64 6.2.0`
Required packages on a fresh Ubuntu:
```bash
sudo apt update
sudo apt install -y \
build-essential \
cmake \
clang \
llvm \
gcc-aarch64-linux-gnu \
qemu-user \
libc6-arm64-cross
```
## Important Build Detail
- The repo vendors `antlr4-runtime-4.13.2` in `third_party`, so no system ANTLR runtime install is needed.
- Current frontend build consumes generated parser sources from `build/generated/antlr4` if present.
- There is also parser source in `src/antlr4/`, but current CMake does not wire that directory directly into the build.
- Safest migration path: copy the repo together with the current `build/generated/antlr4` directory, or later patch CMake to use `src/antlr4/*.cpp`.
## Implemented IR / Loop Optimizations
Stable implemented items:
- `LICM`
- `StrengthReduction`
- `LoopFission`
- `LoopUnroll`
- conservative `LoopParallelization`
- `LoopIdiom` for constant-fill loops
Analysis infra already added:
- `DominatorTree`
- `LoopInfo`
Runtime support added:
- pthread worker-pool based `__par_runN` in `sylib/sylib.c`
- `__fill_i32` helper in `sylib/sylib.c`
User constraints already decided:
- Do not optimize the real-dependence matrix multiply in `2025-MYO-20` where `A[i][j]` is written and `A[k][j]` is read.
- Reduction parallelization is still disabled.
## Timing Scripts
Timing output was added to:
- `scripts/verify_ir.sh`
- `scripts/verify_asm.sh`
- `scripts/run_all_tests.sh`
User requirement:
- Every test round should always report:
- `test/test_case/performance/2025-MYO-20.sy`
- `./scripts/run_all_tests.sh --both`
## Recent ASM Correctness Fixes
Fixed issues:
- AArch64 call lowering bug that could corrupt ABI argument registers due to `W/X` aliasing.
- Duplicate local labels like `.par.exit` across worker functions by prefixing block labels with the function name.
- Duplicate callee-saved save/restore of alias registers like `w8/x8`.
Relevant files:
- `src/mir/Lowering.cpp`
- `src/mir/AsmPrinter.cpp`
- `src/mir/MIRFunction.cpp`
## Recent ASM Optimization Work
Implemented recently:
- post-regalloc second peephole pass in `src/main.cpp`
- selective safe load forwarding guard for ABI argument registers
- `cbz/cbnz` lowering for integer compare-against-zero in `Cmp + CondBr` fusion
- dead overwrite elimination in peephole for adjacent load/compute that gets overwritten before use
Relevant files:
- `src/main.cpp`
- `src/mir/Lowering.cpp`
- `src/mir/passes/Peephole.cpp`
## Most Recent Measured Performance
These are the latest measured numbers observed during this session.
IR:
- `2025-MYO-20` stable reference before latest ASM-only work:
- around `31.109s`
- earlier stable reference before that: around `30.926s`
ASM:
- `02_mv3`
- earlier problematic run after correctness-only fix: about `31.662s`
- after later backend cleanup, best observed run in this session: about `31.505s`
- another later run: about `31.529s`
- `01_mm2`
- earlier reference in this session: about `38.010s`
- later improved run: about `37.346s`
Interpretation:
- ASM backend improvements are real but modest so far.
- Main remaining bottleneck is still heavy stack traffic in hot loops.
## Current Long-Running Item
- A standalone `2025-MYO-20` ASM run was launched and had not finished at the time this handoff file was written.
- A full `./scripts/run_all_tests.sh --both` run had progressed to the final `2025-MYO-20` ASM item instead of failing early, but final completion time was still pending.
## Good Commands To Resume Work
Build:
```bash
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build -j"$(nproc)" --target compiler
```
Quick correctness:
```bash
./scripts/verify_ir.sh test/test_case/functional/simple_add.sy /tmp/ir_check --run
./scripts/verify_asm.sh test/test_case/functional/simple_add.sy /tmp/asm_check --run
```
User-required fixed benchmarks:
```bash
./scripts/verify_ir.sh test/test_case/performance/2025-MYO-20.sy /tmp/timed_2025 --run
./scripts/run_all_tests.sh --both
```
Useful ASM profiling targets:
```bash
./scripts/verify_asm.sh test/test_case/performance/01_mm2.sy /tmp/asm_mm2 --run
./scripts/verify_asm.sh test/test_case/performance/02_mv3.sy /tmp/asm_mv3 --run
./scripts/verify_asm.sh test/test_case/performance/2025-MYO-20.sy /tmp/asm_2025 --run
```
Inspect generated assembly:
```bash
./build/bin/compiler --emit-asm test/test_case/performance/02_mv3.sy > /tmp/02_mv3.s
./build/bin/compiler --emit-asm test/test_case/performance/01_mm2.sy > /tmp/01_mm2.s
./build/bin/compiler --emit-asm test/test_case/performance/2025-MYO-20.sy > /tmp/2025.s
```
## Suggested Next Steps
Priority order:
1. Finish measuring `2025-MYO-20` ASM and a complete `--both` run on the faster Ubuntu machine.
2. Keep working on MIR/ASM backend, not IR parallelization.
3. Target hot-loop stack traffic:
- reduce phi-related spill/reload churn
- widen zero-compare branch simplification beyond the current fused path
- add more dead store / dead load cleanup after frame lowering
4. Only claim speedups when confirmed with the fixed benchmark pair above.

@ -35,6 +35,7 @@
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
@ -530,4 +531,77 @@ class IRPrinter {
void Print(const Module& module, std::ostream& os);
};
namespace analysis {
class DominatorTree {
public:
explicit DominatorTree(Function& func);
BasicBlock* GetIDom(BasicBlock* bb) const;
bool Dominates(BasicBlock* a, BasicBlock* b) const;
const std::vector<BasicBlock*>& GetDF(BasicBlock* bb) const;
const std::vector<BasicBlock*>& GetChildren(BasicBlock* bb) const;
const std::vector<BasicBlock*>& GetRPO() const;
private:
void Compute();
void ComputeRPO(BasicBlock* entry);
BasicBlock* Intersect(BasicBlock* b1, BasicBlock* b2);
void ComputeDF();
Function& func_;
std::vector<BasicBlock*> rpo_;
std::unordered_map<BasicBlock*, size_t> rpo_index_;
std::unordered_map<BasicBlock*, BasicBlock*> idom_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> children_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> df_;
};
class Loop {
public:
BasicBlock* GetHeader() const { return header_; }
const std::vector<BasicBlock*>& GetLatches() const { return latches_; }
const std::unordered_set<BasicBlock*>& GetBlocks() const { return blocks_; }
bool Contains(BasicBlock* bb) const { return blocks_.count(bb) != 0; }
BasicBlock* GetPreheader() const { return preheader_; }
const std::vector<BasicBlock*>& GetExitBlocks() const { return exit_blocks_; }
Loop* GetParent() const { return parent_; }
const std::vector<Loop*>& GetChildren() const { return children_; }
size_t GetDepth() const { return depth_; }
bool IsParallelCandidate() const { return parallel_candidate_; }
private:
friend class LoopInfo;
BasicBlock* header_ = nullptr;
std::vector<BasicBlock*> latches_;
std::unordered_set<BasicBlock*> blocks_;
BasicBlock* preheader_ = nullptr;
std::vector<BasicBlock*> exit_blocks_;
Loop* parent_ = nullptr;
std::vector<Loop*> children_;
size_t depth_ = 1;
bool parallel_candidate_ = false;
};
class LoopInfo {
public:
LoopInfo(Function& func, const DominatorTree& dom_tree);
const std::vector<std::unique_ptr<Loop>>& GetLoops() const { return loops_; }
Loop* GetLoopFor(BasicBlock* bb) const;
private:
void Compute();
void ComputeNesting();
void ComputeParallelFlags();
Function& func_;
const DominatorTree& dom_tree_;
std::vector<std::unique_ptr<Loop>> loops_;
std::unordered_map<BasicBlock*, Loop*> innermost_loop_;
};
} // namespace analysis
} // namespace ir

@ -8,6 +8,18 @@
set -uo pipefail
now_ns() {
date +%s%N
}
format_ns() {
local ns=$1
local ms=$((ns / 1000000))
local sec=$((ms / 1000))
local rem_ms=$((ms % 1000))
printf '%d.%03ds' "$sec" "$rem_ms"
}
mode="ir"
if [[ "${1:-}" == "--asm" ]]; then
mode="asm"
@ -31,6 +43,9 @@ passed=0
failed=0
skipped=0
fail_list=()
RUN_LAST_TOTAL_NS=0
RUN_LAST_BREAKDOWN=""
batch_start_ns=$(now_ns)
run_ir_test() {
local sy="$1"
@ -46,15 +61,24 @@ run_ir_test() {
local expected_file="$dir/$stem.out"
local stdout_file="$out_dir/$stem.stdout"
local actual_file="$out_dir/$stem.actual.out"
local start_ns
start_ns=$(now_ns)
# 生成 IR
local emit_start_ns
emit_start_ns=$(now_ns)
if ! timeout 30 "$compiler" --emit-ir "$sy" > "$out_file" 2>/dev/null; then
RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns ))
RUN_LAST_BREAKDOWN="emit=$(format_ns $((RUN_LAST_TOTAL_NS)))"
echo " [SKIP-IR] $sy (编译器报错或超时)"
return 2
fi
local emit_ns=$(( $(now_ns) - emit_start_ns ))
# 需要 llc + clang
if ! command -v llc >/dev/null 2>&1 || ! command -v clang >/dev/null 2>&1; then
RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns ))
RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns")"
echo " [SKIP-IR] $sy (缺少 llc/clang)"
return 2
fi
@ -62,21 +86,30 @@ run_ir_test() {
local obj="$out_dir/$stem.o"
local exe="$out_dir/$stem"
local lower_link_start_ns
lower_link_start_ns=$(now_ns)
if ! llc -filetype=obj "$out_file" -o "$obj" 2>/dev/null; then
RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns ))
RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns")"
echo " [SKIP-IR] $sy (llc 编译失败)"
return 2
fi
if ! clang -no-pie "$obj" sylib/sylib.c -o "$exe" -lm 2>/dev/null; then
if ! clang -no-pie "$obj" sylib/sylib.c -o "$exe" -lm -pthread 2>/dev/null; then
RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns ))
RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns") lower+link=$(format_ns $(( $(now_ns) - lower_link_start_ns )))"
echo " [SKIP-IR] $sy (clang 链接失败)"
return 2
fi
local lower_link_ns=$(( $(now_ns) - lower_link_start_ns ))
set +e
# performance 用例给更长的超时时间
local run_timeout=30
local run_timeout=3000
if [[ "$sy" == *"performance"* ]]; then
run_timeout=1000
run_timeout=3000
fi
local run_start_ns
run_start_ns=$(now_ns)
if [[ -f "$stdin_file" ]]; then
timeout $run_timeout "$exe" < "$stdin_file" > "$stdout_file" 2>/dev/null
else
@ -84,6 +117,9 @@ run_ir_test() {
fi
local status=$?
set -e
local run_ns=$(( $(now_ns) - run_start_ns ))
RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns ))
RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns") lower+link=$(format_ns "$lower_link_ns") run=$(format_ns "$run_ns")"
# timeout 返回 124 表示超时,标记为 SKIP
if [[ $status -eq 124 ]]; then
@ -128,24 +164,40 @@ run_asm_test() {
local stdout_file="$out_dir/$stem.stdout"
local actual_file="$out_dir/$stem.actual.out"
local exe="$out_dir/$stem"
local start_ns
start_ns=$(now_ns)
# 生成汇编
local emit_start_ns
emit_start_ns=$(now_ns)
if ! timeout 30 "$compiler" --emit-asm "$sy" > "$asm_file" 2>/dev/null; then
RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns ))
RUN_LAST_BREAKDOWN="emit=$(format_ns "$RUN_LAST_TOTAL_NS")"
echo " [SKIP-ASM] $sy (编译器报错或超时)"
return 2
fi
local emit_ns=$(( $(now_ns) - emit_start_ns ))
if ! command -v aarch64-linux-gnu-gcc >/dev/null 2>&1; then
RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns ))
RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns")"
echo " [SKIP-ASM] $sy (缺少 aarch64-linux-gnu-gcc)"
return 2
fi
if ! timeout 30 aarch64-linux-gnu-gcc "$asm_file" sylib/sylib.c -o "$exe" -static 2>/dev/null; then
local link_start_ns
link_start_ns=$(now_ns)
if ! timeout 30 aarch64-linux-gnu-gcc "$asm_file" sylib/sylib.c -o "$exe" -static -pthread 2>/dev/null; then
RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns ))
RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns") asm+link=$(format_ns $(( $(now_ns) - link_start_ns )))"
echo " [SKIP-ASM] $sy (汇编/链接失败)"
return 2
fi
local link_ns=$(( $(now_ns) - link_start_ns ))
if ! command -v qemu-aarch64 >/dev/null 2>&1; then
RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns ))
RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns") asm+link=$(format_ns "$link_ns")"
echo " [SKIP-ASM] $sy (缺少 qemu-aarch64)"
return 2
fi
@ -156,6 +208,8 @@ run_asm_test() {
if [[ "$sy" == *"performance"* ]]; then
run_timeout=1000
fi
local run_start_ns
run_start_ns=$(now_ns)
if [[ -f "$stdin_file" ]]; then
timeout $run_timeout qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file" 2>/dev/null
else
@ -163,6 +217,9 @@ run_asm_test() {
fi
local status=$?
set -e
local run_ns=$(( $(now_ns) - run_start_ns ))
RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns ))
RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns") asm+link=$(format_ns "$link_ns") run=$(format_ns "$run_ns")"
# timeout 返回 124 表示超时,标记为 SKIP
if [[ $status -eq 124 ]]; then
@ -206,10 +263,10 @@ for sy in "${test_files[@]}"; do
run_ir_test "$sy"
rc=$?
if [[ $rc -eq 0 ]]; then
echo " [PASS-IR] $sy"
echo " [PASS-IR] $sy ($(format_ns "$RUN_LAST_TOTAL_NS"); $RUN_LAST_BREAKDOWN)"
passed=$((passed + 1))
elif [[ $rc -eq 1 ]]; then
echo " [FAIL-IR] $sy"
echo " [FAIL-IR] $sy ($(format_ns "$RUN_LAST_TOTAL_NS"); $RUN_LAST_BREAKDOWN)"
failed=$((failed + 1))
fail_list+=("$sy (IR)")
else
@ -221,12 +278,12 @@ for sy in "${test_files[@]}"; do
run_asm_test "$sy"
rc=$?
if [[ $rc -eq 0 ]]; then
echo " [PASS-ASM] $sy"
echo " [PASS-ASM] $sy ($(format_ns "$RUN_LAST_TOTAL_NS"); $RUN_LAST_BREAKDOWN)"
if [[ "$mode" == "asm" ]]; then
passed=$((passed + 1))
fi
elif [[ $rc -eq 1 ]]; then
echo " [FAIL-ASM] $sy"
echo " [FAIL-ASM] $sy ($(format_ns "$RUN_LAST_TOTAL_NS"); $RUN_LAST_BREAKDOWN)"
if [[ "$mode" == "asm" ]]; then
failed=$((failed + 1))
fi
@ -247,6 +304,7 @@ echo " 总计: $total"
echo " 通过: $passed"
echo " 失败: $failed"
echo " 跳过: $skipped"
echo " 总耗时: $(format_ns $(( $(now_ns) - batch_start_ns )))"
echo ""
if [[ ${#fail_list[@]} -gt 0 ]]; then

@ -2,6 +2,18 @@
set -euo pipefail
now_ns() {
date +%s%N
}
format_ns() {
local ns=$1
local ms=$((ns / 1000000))
local sec=$((ms / 1000))
local rem_ms=$((ms % 1000))
printf '%d.%03ds' "$sec" "$rem_ms"
}
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
exit 1
@ -49,11 +61,18 @@ exe="$out_dir/$stem"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
total_start_ns=$(now_ns)
emit_start_ns=$(now_ns)
"$compiler" --emit-asm "$input" > "$asm_file"
emit_elapsed_ns=$(( $(now_ns) - emit_start_ns ))
echo "汇编已生成: $asm_file"
echo "汇编生成耗时: $(format_ns "$emit_elapsed_ns")"
aarch64-linux-gnu-gcc "$asm_file" sylib/sylib.c -o "$exe" -static
link_start_ns=$(now_ns)
aarch64-linux-gnu-gcc "$asm_file" sylib/sylib.c -o "$exe" -static -pthread
link_elapsed_ns=$(( $(now_ns) - link_start_ns ))
echo "可执行文件已生成: $exe"
echo "汇编/链接耗时: $(format_ns "$link_elapsed_ns")"
if [[ "$run_exec" == true ]]; then
if ! command -v qemu-aarch64 >/dev/null 2>&1; then
@ -64,6 +83,7 @@ if [[ "$run_exec" == true ]]; then
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
echo "运行 $exe ..."
run_start_ns=$(now_ns)
set +e
if [[ -f "$stdin_file" ]]; then
qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" < "$stdin_file" > "$stdout_file"
@ -72,8 +92,10 @@ if [[ "$run_exec" == true ]]; then
fi
status=$?
set -e
run_elapsed_ns=$(( $(now_ns) - run_start_ns ))
cat "$stdout_file"
echo "退出码: $status"
echo "运行耗时: $(format_ns "$run_elapsed_ns")"
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
@ -95,3 +117,6 @@ if [[ "$run_exec" == true ]]; then
echo "未找到预期输出文件,跳过比对: $expected_file"
fi
fi
total_elapsed_ns=$(( $(now_ns) - total_start_ns ))
echo "总耗时: $(format_ns "$total_elapsed_ns")"

@ -3,6 +3,18 @@
set -euo pipefail
now_ns() {
date +%s%N
}
format_ns() {
local ns=$1
local ms=$((ns / 1000000))
local sec=$((ms / 1000))
local rem_ms=$((ms % 1000))
printf '%d.%03ds' "$sec" "$rem_ms"
}
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
exit 1
@ -43,8 +55,12 @@ stem=${base%.sy}
out_file="$out_dir/$stem.ll"
stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out"
total_start_ns=$(now_ns)
emit_start_ns=$(now_ns)
"$compiler" --emit-ir "$input" > "$out_file"
emit_elapsed_ns=$(( $(now_ns) - emit_start_ns ))
echo "IR 已生成: $out_file"
echo "IR 生成耗时: $(format_ns "$emit_elapsed_ns")"
if [[ "$run_exec" == true ]]; then
if ! command -v llc >/dev/null 2>&1; then
@ -59,9 +75,13 @@ if [[ "$run_exec" == true ]]; then
exe="$out_dir/$stem"
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
lower_link_start_ns=$(now_ns)
llc -filetype=obj "$out_file" -o "$obj"
clang -no-pie "$obj" sylib/sylib.c -o "$exe" -lm
clang -no-pie "$obj" sylib/sylib.c -o "$exe" -lm -pthread
lower_link_elapsed_ns=$(( $(now_ns) - lower_link_start_ns ))
echo "IR 落地/链接耗时: $(format_ns "$lower_link_elapsed_ns")"
echo "运行 $exe ..."
run_start_ns=$(now_ns)
set +e
if [[ -f "$stdin_file" ]]; then
"$exe" < "$stdin_file" > "$stdout_file"
@ -70,8 +90,10 @@ if [[ "$run_exec" == true ]]; then
fi
status=$?
set -e
run_elapsed_ns=$(( $(now_ns) - run_start_ns ))
cat "$stdout_file"
echo "退出码: $status"
echo "运行耗时: $(format_ns "$run_elapsed_ns")"
{
cat "$stdout_file"
if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then
@ -92,3 +114,6 @@ if [[ "$run_exec" == true ]]; then
echo "未找到预期输出文件,跳过比对: $expected_file"
fi
fi
total_elapsed_ns=$(( $(now_ns) - total_start_ns ))
echo "总耗时: $(format_ns "$total_elapsed_ns")"

@ -21,151 +21,130 @@ namespace analysis {
// ---------- DominatorTree ----------
class DominatorTree {
public:
explicit DominatorTree(Function& func) : func_(func) { Compute(); }
// idom[bb] 返回 bb 的直接支配者entry 的 idom 为自身。
BasicBlock* GetIDom(BasicBlock* bb) const {
auto it = idom_.find(bb);
return it != idom_.end() ? it->second : nullptr;
DominatorTree::DominatorTree(Function& func) : func_(func) { Compute(); }
BasicBlock* DominatorTree::GetIDom(BasicBlock* bb) const {
auto it = idom_.find(bb);
return it != idom_.end() ? it->second : nullptr;
}
bool DominatorTree::Dominates(BasicBlock* a, BasicBlock* b) const {
if (!a || !b) return false;
while (b) {
if (b == a) return true;
auto* p = GetIDom(b);
if (p == b) break;
b = p;
}
// 判断 a 是否支配 b。
bool Dominates(BasicBlock* a, BasicBlock* b) const {
if (!a || !b) return false;
while (b) {
if (b == a) return true;
auto* p = GetIDom(b);
if (p == b) break; // entry
b = p;
}
return false;
return false;
}
const std::vector<BasicBlock*>& DominatorTree::GetDF(BasicBlock* bb) const {
static const std::vector<BasicBlock*> empty;
auto it = df_.find(bb);
return it != df_.end() ? it->second : empty;
}
const std::vector<BasicBlock*>& DominatorTree::GetChildren(
BasicBlock* bb) const {
static const std::vector<BasicBlock*> empty;
auto it = children_.find(bb);
return it != children_.end() ? it->second : empty;
}
const std::vector<BasicBlock*>& DominatorTree::GetRPO() const { return rpo_; }
void DominatorTree::Compute() {
auto* entry = func_.GetEntry();
if (!entry) return;
ComputeRPO(entry);
if (rpo_.empty()) return;
for (auto* bb : rpo_) {
idom_[bb] = nullptr;
rpo_index_[bb] = 0;
}
// 返回 bb 的支配边界。
const std::vector<BasicBlock*>& GetDF(BasicBlock* bb) const {
static const std::vector<BasicBlock*> empty;
auto it = df_.find(bb);
return it != df_.end() ? it->second : empty;
for (size_t i = 0; i < rpo_.size(); ++i) {
rpo_index_[rpo_[i]] = i;
}
idom_[entry] = entry;
// 返回支配树中 bb 的孩子列表。
const std::vector<BasicBlock*>& GetChildren(BasicBlock* bb) const {
static const std::vector<BasicBlock*> empty;
auto it = children_.find(bb);
return it != children_.end() ? it->second : empty;
}
// 按逆后序返回所有基本块。
const std::vector<BasicBlock*>& GetRPO() const { return rpo_; }
private:
void Compute() {
auto* entry = func_.GetEntry();
if (!entry) return;
// 1. 计算逆后序RPO
ComputeRPO(entry);
if (rpo_.empty()) return;
// 2. 初始化
bool changed = true;
while (changed) {
changed = false;
for (auto* bb : rpo_) {
idom_[bb] = nullptr;
rpo_index_[bb] = 0;
}
for (size_t i = 0; i < rpo_.size(); ++i) {
rpo_index_[rpo_[i]] = i;
}
idom_[entry] = entry;
// 3. 迭代计算 idomCooper-Harvey-Kennedy 算法)
bool changed = true;
while (changed) {
changed = false;
for (auto* bb : rpo_) {
if (bb == entry) continue;
BasicBlock* new_idom = nullptr;
for (auto* pred : bb->GetPredecessors()) {
if (idom_.count(pred) && idom_[pred] != nullptr) {
if (!new_idom) {
new_idom = pred;
} else {
new_idom = Intersect(new_idom, pred);
}
if (bb == entry) continue;
BasicBlock* new_idom = nullptr;
for (auto* pred : bb->GetPredecessors()) {
if (idom_.count(pred) && idom_[pred] != nullptr) {
if (!new_idom) {
new_idom = pred;
} else {
new_idom = Intersect(new_idom, pred);
}
}
if (new_idom && idom_[bb] != new_idom) {
idom_[bb] = new_idom;
changed = true;
}
}
}
// 4. 建立 children 映射
for (auto* bb : rpo_) {
auto* p = GetIDom(bb);
if (p && p != bb) {
children_[p].push_back(bb);
if (new_idom && idom_[bb] != new_idom) {
idom_[bb] = new_idom;
changed = true;
}
}
// 5. 计算支配边界
ComputeDF();
}
void ComputeRPO(BasicBlock* entry) {
std::unordered_set<BasicBlock*> visited;
std::vector<BasicBlock*> post_order;
std::function<void(BasicBlock*)> dfs = [&](BasicBlock* bb) {
visited.insert(bb);
for (auto* succ : bb->GetSuccessors()) {
if (!visited.count(succ)) {
dfs(succ);
}
}
post_order.push_back(bb);
};
dfs(entry);
rpo_.assign(post_order.rbegin(), post_order.rend());
for (auto* bb : rpo_) {
auto* p = GetIDom(bb);
if (p && p != bb) {
children_[p].push_back(bb);
}
}
BasicBlock* Intersect(BasicBlock* b1, BasicBlock* b2) {
while (b1 != b2) {
while (rpo_index_[b1] > rpo_index_[b2]) b1 = idom_[b1];
while (rpo_index_[b2] > rpo_index_[b1]) b2 = idom_[b2];
ComputeDF();
}
void DominatorTree::ComputeRPO(BasicBlock* entry) {
std::unordered_set<BasicBlock*> visited;
std::vector<BasicBlock*> post_order;
std::function<void(BasicBlock*)> dfs = [&](BasicBlock* bb) {
visited.insert(bb);
for (auto* succ : bb->GetSuccessors()) {
if (!visited.count(succ)) {
dfs(succ);
}
}
return b1;
post_order.push_back(bb);
};
dfs(entry);
rpo_.assign(post_order.rbegin(), post_order.rend());
}
BasicBlock* DominatorTree::Intersect(BasicBlock* b1, BasicBlock* b2) {
while (b1 != b2) {
while (rpo_index_[b1] > rpo_index_[b2]) b1 = idom_[b1];
while (rpo_index_[b2] > rpo_index_[b1]) b2 = idom_[b2];
}
return b1;
}
void ComputeDF() {
for (auto* bb : rpo_) {
df_[bb] = {};
}
for (auto* bb : rpo_) {
if (bb->GetPredecessors().size() < 2) continue;
for (auto* pred : bb->GetPredecessors()) {
auto* runner = pred;
while (runner && runner != idom_[bb]) {
// 避免重复
auto& df_set = df_[runner];
if (std::find(df_set.begin(), df_set.end(), bb) == df_set.end()) {
df_set.push_back(bb);
}
if (runner == idom_[runner]) break;
runner = idom_[runner];
void DominatorTree::ComputeDF() {
for (auto* bb : rpo_) {
df_[bb] = {};
}
for (auto* bb : rpo_) {
if (bb->GetPredecessors().size() < 2) continue;
for (auto* pred : bb->GetPredecessors()) {
auto* runner = pred;
while (runner && runner != idom_[bb]) {
auto& df_set = df_[runner];
if (std::find(df_set.begin(), df_set.end(), bb) == df_set.end()) {
df_set.push_back(bb);
}
if (runner == idom_[runner]) break;
runner = idom_[runner];
}
}
}
Function& func_;
std::vector<BasicBlock*> rpo_;
std::unordered_map<BasicBlock*, size_t> rpo_index_;
std::unordered_map<BasicBlock*, BasicBlock*> idom_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> children_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> df_;
};
}
} // namespace analysis
} // namespace ir

@ -2,3 +2,242 @@
// - 识别循环结构与层级关系
// - 为后续优化(可选)提供循环信息
#include "ir/IR.h"
#include <algorithm>
#include <queue>
#include <unordered_set>
#include <vector>
namespace ir {
namespace analysis {
namespace {
bool IsInvariantForLoop(Value* value, Loop* loop) {
if (!value) return false;
if (dynamic_cast<ConstantValue*>(value) != nullptr) return true;
if (dynamic_cast<Argument*>(value) != nullptr) return true;
if (dynamic_cast<GlobalVariable*>(value) != nullptr) return true;
if (dynamic_cast<Function*>(value) != nullptr) return true;
auto* inst = dynamic_cast<Instruction*>(value);
if (!inst) return true;
auto* parent = inst->GetParent();
return !parent || !loop->Contains(parent);
}
Value* StripPointerCasts(Value* value) {
while (auto* gep = dynamic_cast<GepInst*>(value)) {
value = gep->GetBase();
}
return value;
}
bool IsSimpleParallelStore(Value* ptr, Loop* loop) {
auto* gep = dynamic_cast<GepInst*>(ptr);
if (!gep) return false;
if (!IsInvariantForLoop(StripPointerCasts(gep->GetBase()), loop)) return false;
return !IsInvariantForLoop(gep->GetIndex(), loop);
}
} // namespace
LoopInfo::LoopInfo(Function& func, const DominatorTree& dom_tree)
: func_(func), dom_tree_(dom_tree) {
Compute();
}
Loop* LoopInfo::GetLoopFor(BasicBlock* bb) const {
auto it = innermost_loop_.find(bb);
return it != innermost_loop_.end() ? it->second : nullptr;
}
void LoopInfo::Compute() {
std::unordered_map<BasicBlock*, Loop*> loop_by_header;
for (const auto& bb_ptr : func_.GetBlocks()) {
auto* tail = bb_ptr.get();
if (!tail) continue;
for (auto* succ : tail->GetSuccessors()) {
if (!succ || !dom_tree_.Dominates(succ, tail)) continue;
Loop* loop = nullptr;
auto it = loop_by_header.find(succ);
if (it == loop_by_header.end()) {
auto owned = std::make_unique<Loop>();
owned->header_ = succ;
loop = owned.get();
loops_.push_back(std::move(owned));
loop_by_header[succ] = loop;
} else {
loop = it->second;
}
if (std::find(loop->latches_.begin(), loop->latches_.end(), tail) ==
loop->latches_.end()) {
loop->latches_.push_back(tail);
}
std::unordered_set<BasicBlock*> natural_loop;
std::queue<BasicBlock*> worklist;
natural_loop.insert(succ);
if (natural_loop.insert(tail).second) {
worklist.push(tail);
}
while (!worklist.empty()) {
auto* node = worklist.front();
worklist.pop();
for (auto* pred : node->GetPredecessors()) {
if (natural_loop.insert(pred).second) {
worklist.push(pred);
}
}
}
loop->blocks_.insert(natural_loop.begin(), natural_loop.end());
}
}
for (const auto& loop_ptr : loops_) {
auto* loop = loop_ptr.get();
if (!loop) continue;
std::vector<BasicBlock*> outside_preds;
std::unordered_set<BasicBlock*> seen_outside_preds;
for (auto* pred : loop->header_->GetPredecessors()) {
if (!loop->Contains(pred) && seen_outside_preds.insert(pred).second) {
outside_preds.push_back(pred);
}
}
if (outside_preds.size() == 1 &&
outside_preds.front()->GetSuccessors().size() == 1 &&
outside_preds.front()->GetSuccessors().front() == loop->header_) {
loop->preheader_ = outside_preds.front();
}
std::unordered_set<BasicBlock*> exit_set;
for (auto* block : loop->blocks_) {
for (auto* succ : block->GetSuccessors()) {
if (!loop->Contains(succ) && exit_set.insert(succ).second) {
loop->exit_blocks_.push_back(succ);
}
}
}
}
ComputeNesting();
ComputeParallelFlags();
}
void LoopInfo::ComputeNesting() {
std::vector<Loop*> ordered;
ordered.reserve(loops_.size());
for (const auto& loop_ptr : loops_) {
ordered.push_back(loop_ptr.get());
}
std::sort(ordered.begin(), ordered.end(), [](Loop* lhs, Loop* rhs) {
if (lhs->GetBlocks().size() != rhs->GetBlocks().size()) {
return lhs->GetBlocks().size() < rhs->GetBlocks().size();
}
return lhs->GetHeader()->GetName() < rhs->GetHeader()->GetName();
});
for (auto* loop : ordered) {
Loop* parent = nullptr;
for (auto* candidate : ordered) {
if (candidate == loop) continue;
if (candidate->GetBlocks().size() <= loop->GetBlocks().size()) continue;
bool contains_all = true;
for (auto* block : loop->GetBlocks()) {
if (!candidate->Contains(block)) {
contains_all = false;
break;
}
}
if (!contains_all) continue;
if (!parent || candidate->GetBlocks().size() < parent->GetBlocks().size()) {
parent = candidate;
}
}
loop->parent_ = parent;
loop->depth_ = parent ? parent->depth_ + 1 : 1;
if (parent) {
parent->children_.push_back(loop);
}
}
for (const auto& bb_ptr : func_.GetBlocks()) {
auto* bb = bb_ptr.get();
if (!bb) continue;
Loop* best = nullptr;
for (const auto& loop_ptr : loops_) {
auto* loop = loop_ptr.get();
if (!loop->Contains(bb)) continue;
if (!best || loop->GetDepth() > best->GetDepth()) {
best = loop;
}
}
if (best) {
innermost_loop_[bb] = best;
}
}
}
void LoopInfo::ComputeParallelFlags() {
for (const auto& loop_ptr : loops_) {
auto* loop = loop_ptr.get();
if (!loop || loop->GetBlocks().empty()) continue;
bool saw_store = false;
bool parallel = true;
std::unordered_set<Value*> stored_ptrs;
for (auto* block : loop->GetBlocks()) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!inst) continue;
if (inst->GetOpcode() == Opcode::Call) {
parallel = false;
break;
}
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
saw_store = true;
auto* ptr = store->GetPtr();
stored_ptrs.insert(ptr);
if (!IsSimpleParallelStore(ptr, loop)) {
parallel = false;
break;
}
}
}
if (!parallel) break;
}
if (parallel) {
for (auto* block : loop->GetBlocks()) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* load = dynamic_cast<LoadInst*>(inst_ptr.get());
if (!load) continue;
if (stored_ptrs.count(load->GetPtr()) != 0) {
parallel = false;
break;
}
}
if (!parallel) break;
}
}
loop->parallel_candidate_ = parallel && saw_store;
}
}
} // namespace analysis
} // namespace ir

@ -1,6 +1,12 @@
add_library(ir_passes STATIC
PassManager.cpp
Mem2Reg.cpp
LICM.cpp
StrengthReduction.cpp
LoopIdiom.cpp
LoopFission.cpp
LoopUnroll.cpp
LoopParallelize.cpp
ConstFold.cpp
ConstProp.cpp
CSE.cpp

@ -0,0 +1,262 @@
// 循环不变代码外提LICM
// - 基于 DominatorTree + LoopInfo 识别自然循环
// - 将循环内不变且可安全提前执行的指令移动到 preheader
// - 顺带消除同一循环中重复的不变表达式
#include "ir/IR.h"
#include <algorithm>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace passes {
namespace {
struct ExprKey {
Opcode opcode;
CmpOp cmp_op = CmpOp::Eq;
CastOp cast_op = CastOp::IntToFloat;
std::vector<Value*> operands;
bool operator==(const ExprKey& other) const {
return opcode == other.opcode && cmp_op == other.cmp_op &&
cast_op == other.cast_op && operands == other.operands;
}
};
struct ExprKeyHash {
size_t operator()(const ExprKey& key) const {
size_t h = std::hash<int>()(static_cast<int>(key.opcode));
h ^= std::hash<int>()(static_cast<int>(key.cmp_op)) + 0x9e3779b9 +
(h << 6) + (h >> 2);
h ^= std::hash<int>()(static_cast<int>(key.cast_op)) + 0x9e3779b9 +
(h << 6) + (h >> 2);
for (auto* operand : key.operands) {
h ^= std::hash<void*>()(operand) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
bool IsSupportedInvariantOpcode(Opcode op) {
switch (op) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Cmp:
case Opcode::Cast:
case Opcode::Gep:
case Opcode::Load:
return true;
default:
return false;
}
}
bool IsLoopInvariantValue(Value* value, analysis::Loop* loop,
const std::unordered_set<Instruction*>& invariant) {
if (!value) return false;
if (dynamic_cast<ConstantValue*>(value) != nullptr) return true;
if (dynamic_cast<Argument*>(value) != nullptr) return true;
if (dynamic_cast<GlobalVariable*>(value) != nullptr) return true;
if (dynamic_cast<Function*>(value) != nullptr) return true;
if (dynamic_cast<BasicBlock*>(value) != nullptr) return true;
auto* inst = dynamic_cast<Instruction*>(value);
if (!inst) return true;
auto* parent = inst->GetParent();
if (!parent || !loop->Contains(parent)) return true;
return invariant.count(inst) != 0;
}
Value* GetPointerBase(Value* ptr) {
while (auto* gep = dynamic_cast<GepInst*>(ptr)) {
ptr = gep->GetBase();
}
return ptr;
}
bool MayAlias(Value* lhs, Value* rhs) {
if (lhs == rhs) return true;
return GetPointerBase(lhs) == GetPointerBase(rhs);
}
bool IsStoredInLoop(Value* ptr, analysis::Loop* loop) {
for (auto* block : loop->GetBlocks()) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* store = dynamic_cast<StoreInst*>(inst_ptr.get());
if (store && MayAlias(store->GetPtr(), ptr)) {
return true;
}
}
}
return false;
}
bool IsSafeInvariantInstruction(Instruction* inst, analysis::Loop* loop,
const std::unordered_set<Instruction*>& invariant) {
if (!inst || !IsSupportedInvariantOpcode(inst->GetOpcode())) return false;
if (inst->GetOpcode() == Opcode::Load) {
auto* load = static_cast<LoadInst*>(inst);
if (!IsLoopInvariantValue(load->GetPtr(), loop, invariant)) return false;
return !IsStoredInLoop(load->GetPtr(), loop);
}
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
if (!IsLoopInvariantValue(inst->GetOperand(i), loop, invariant)) {
return false;
}
}
return true;
}
ExprKey MakeExprKey(Instruction* inst) {
ExprKey key;
key.opcode = inst->GetOpcode();
if (auto* cmp = dynamic_cast<CmpInst*>(inst)) {
key.cmp_op = cmp->GetCmpOp();
}
if (auto* cast = dynamic_cast<CastInst*>(inst)) {
key.cast_op = cast->GetCastOp();
}
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
key.operands.push_back(inst->GetOperand(i));
}
return key;
}
std::vector<Instruction*> CollectLoopInstructions(analysis::Loop* loop,
Function& func) {
std::vector<Instruction*> ordered;
for (const auto& bb_ptr : func.GetBlocks()) {
auto* block = bb_ptr.get();
if (!block || !loop->Contains(block)) continue;
for (const auto& inst_ptr : block->GetInstructions()) {
ordered.push_back(inst_ptr.get());
}
}
return ordered;
}
std::unique_ptr<Instruction> DetachInstruction(BasicBlock* block,
Instruction* inst) {
auto& insts = block->MutableInstructions();
auto it = std::find_if(insts.begin(), insts.end(),
[inst](const std::unique_ptr<Instruction>& ptr) {
return ptr.get() == inst;
});
if (it == insts.end()) return nullptr;
std::unique_ptr<Instruction> owned = std::move(*it);
insts.erase(it);
owned->SetParent(nullptr);
return owned;
}
void InsertBeforeTerminator(BasicBlock* block, std::unique_ptr<Instruction> inst) {
auto& insts = block->MutableInstructions();
auto insert_it = insts.end();
if (block->HasTerminator()) {
insert_it = insts.end() - 1;
}
inst->SetParent(block);
insts.insert(insert_it, std::move(inst));
}
void SeedAvailableInvariants(
BasicBlock* preheader, analysis::Loop* loop,
std::unordered_map<ExprKey, Instruction*, ExprKeyHash>& available,
const std::unordered_set<Instruction*>& invariant) {
if (!preheader) return;
for (const auto& inst_ptr : preheader->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!inst || !IsSupportedInvariantOpcode(inst->GetOpcode())) continue;
if (!IsSafeInvariantInstruction(inst, loop, invariant)) continue;
available.emplace(MakeExprKey(inst), inst);
}
}
bool RunLICMOnLoop(analysis::Loop* loop, Function& func) {
auto* preheader = loop->GetPreheader();
if (!preheader) return false;
bool changed = false;
std::unordered_set<Instruction*> invariant;
bool progress = true;
while (progress) {
progress = false;
std::unordered_map<ExprKey, Instruction*, ExprKeyHash> available;
SeedAvailableInvariants(preheader, loop, available, invariant);
for (auto* inst : CollectLoopInstructions(loop, func)) {
if (!inst || invariant.count(inst) != 0) continue;
auto* block = inst->GetParent();
if (!block || block == preheader) continue;
if (inst->GetOpcode() == Opcode::Phi || inst->IsTerminator() ||
inst->GetOpcode() == Opcode::Alloca || inst->GetOpcode() == Opcode::Ret ||
inst->GetOpcode() == Opcode::Store || inst->GetOpcode() == Opcode::Call ||
inst->GetOpcode() == Opcode::Div || inst->GetOpcode() == Opcode::Mod) {
continue;
}
if (!IsSafeInvariantInstruction(inst, loop, invariant)) continue;
ExprKey key = MakeExprKey(inst);
auto avail_it = available.find(key);
if (avail_it != available.end()) {
inst->ReplaceAllUsesWith(avail_it->second);
block->RemoveInstruction(inst);
} else {
auto owned = DetachInstruction(block, inst);
if (!owned) continue;
auto* moved = owned.get();
InsertBeforeTerminator(preheader, std::move(owned));
available.emplace(std::move(key), moved);
invariant.insert(moved);
}
changed = true;
progress = true;
break;
}
}
return changed;
}
} // namespace
bool RunLICM(Function& func) {
if (func.IsExternal()) return false;
analysis::DominatorTree dom_tree(func);
analysis::LoopInfo loop_info(func, dom_tree);
std::vector<analysis::Loop*> ordered_loops;
for (const auto& loop_ptr : loop_info.GetLoops()) {
ordered_loops.push_back(loop_ptr.get());
}
std::sort(ordered_loops.begin(), ordered_loops.end(),
[](analysis::Loop* lhs, analysis::Loop* rhs) {
if (lhs->GetDepth() != rhs->GetDepth()) {
return lhs->GetDepth() > rhs->GetDepth();
}
if (lhs->GetBlocks().size() != rhs->GetBlocks().size()) {
return lhs->GetBlocks().size() < rhs->GetBlocks().size();
}
return lhs->GetHeader()->GetName() < rhs->GetHeader()->GetName();
});
bool changed = false;
for (auto* loop : ordered_loops) {
changed |= RunLICMOnLoop(loop, func);
}
return changed;
}
} // namespace passes
} // namespace ir

@ -0,0 +1,171 @@
// 循环分裂:
// - 针对单块循环中两段彼此独立的 store 语句组做保守分裂
// - 仅处理单归纳变量、无其他 loop-carried phi 的情形
#include "ir/IR.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <vector>
#include "LoopPassUtils.h"
namespace ir {
namespace passes {
namespace {
Value* StripPointerBase(Value* value) {
while (auto* gep = dynamic_cast<GepInst*>(value)) {
value = gep->GetBase();
}
return value;
}
bool IsFissionCandidate(const CanonicalLoopMatch& match) {
if (match.loop->GetChildren().size() != 0) return false;
if (match.loop->GetBlocks().size() != 2) return false;
if (match.body != match.latch) return false;
if (match.header_phis.size() != 1) return false;
if (match.header_phis.front() != match.induction.phi) return false;
if (match.induction.step <= 0) return false;
auto* body_term =
dynamic_cast<BranchInst*>(match.body->MutableInstructions().back().get());
return body_term && body_term->GetTarget() == match.header;
}
bool DependsOnAny(Instruction* inst, const std::unordered_set<Instruction*>& defs) {
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
auto* def = dynamic_cast<Instruction*>(inst->GetOperand(i));
if (def && defs.count(def) != 0) return true;
}
return false;
}
bool RunFissionOnLoop(Function& func, const CanonicalLoopMatch& match,
Context& ctx) {
if (!IsFissionCandidate(match)) return false;
std::vector<Instruction*> body_insts;
for (const auto& inst_ptr : match.body->GetInstructions()) {
if (!inst_ptr.get()->IsTerminator()) {
body_insts.push_back(inst_ptr.get());
}
}
if (body_insts.size() < 3) return false;
auto* iv_next = dynamic_cast<Instruction*>(match.induction.next);
if (!iv_next || iv_next->GetParent() != match.body) return false;
std::vector<size_t> store_positions;
for (size_t i = 0; i < body_insts.size(); ++i) {
if (dynamic_cast<StoreInst*>(body_insts[i]) != nullptr) {
store_positions.push_back(i);
}
}
if (store_positions.size() != 2) return false;
const size_t first_store_idx = store_positions[0];
const size_t second_store_idx = store_positions[1];
if (body_insts.back() != iv_next) return false;
if (second_store_idx + 1 != body_insts.size() - 1) return false;
auto* first_store = static_cast<StoreInst*>(body_insts[first_store_idx]);
auto* second_store = static_cast<StoreInst*>(body_insts[second_store_idx]);
if (StripPointerBase(first_store->GetPtr()) == StripPointerBase(second_store->GetPtr())) {
return false;
}
std::vector<Instruction*> group1(body_insts.begin(),
body_insts.begin() + first_store_idx + 1);
std::vector<Instruction*> group2(body_insts.begin() + first_store_idx + 1,
body_insts.begin() + second_store_idx + 1);
std::unordered_set<Instruction*> group1_defs(group1.begin(), group1.end());
std::unordered_set<Instruction*> group2_defs(group2.begin(), group2.end());
group1_defs.erase(iv_next);
group2_defs.erase(iv_next);
for (auto* inst : group2) {
if (DependsOnAny(inst, group1_defs)) return false;
}
for (auto* inst : group1) {
if (DependsOnAny(inst, group2_defs)) return false;
}
auto* original_exit = match.exit;
std::string block_suffix = ctx.NextTemp();
if (!block_suffix.empty() && block_suffix.front() == '%') {
block_suffix.erase(0, 1);
}
auto* preheader2 =
func.CreateBlock(match.header->GetName() + ".fission.pre." + block_suffix);
auto* header2 =
func.CreateBlock(match.header->GetName() + ".fission.hdr." + block_suffix);
auto* body2 =
func.CreateBlock(match.body->GetName() + ".fission.body." + block_suffix);
preheader2->Append<BranchInst>(Type::GetVoidType(), header2);
auto* iv2 = header2->PrependPhi(Type::GetInt32Type(), ctx.NextTemp());
iv2->AddIncoming(match.induction.init, preheader2);
auto* cmp2 = header2->Append<CmpInst>(
match.header_cmp->GetCmpOp(), Type::GetInt32Type(), iv2, match.bound,
ctx.NextTemp());
header2->Append<CondBranchInst>(Type::GetVoidType(), cmp2, body2, original_exit);
ValueMap remap;
remap.emplace(match.induction.phi, iv2);
for (auto* inst : group2) {
auto cloned = CloneInstruction(inst, remap, ".f2");
if (!cloned) return false;
auto* raw = cloned.get();
body2->MutableInstructions().push_back(std::move(cloned));
raw->SetParent(body2);
remap[inst] = raw;
}
auto next2_cloned = CloneInstruction(iv_next, remap, ".f2");
if (!next2_cloned) return false;
auto* next2 = next2_cloned.get();
body2->MutableInstructions().push_back(std::move(next2_cloned));
next2->SetParent(body2);
body2->Append<BranchInst>(Type::GetVoidType(), header2);
iv2->AddIncoming(next2, body2);
const bool exit_is_true = (match.header_branch->GetTrueBlock() == original_exit);
match.header_branch->SetOperand(exit_is_true ? 1 : 2, preheader2);
match.header->RemoveSuccessor(original_exit);
match.header->AddSuccessor(preheader2);
preheader2->AddPredecessor(match.header);
original_exit->RemovePredecessor(match.header);
for (auto* inst : group2) {
match.body->RemoveInstruction(inst);
}
return true;
}
} // namespace
bool RunLoopFission(Function& func, Context& ctx) {
if (func.IsExternal()) return false;
analysis::DominatorTree dom_tree(func);
analysis::LoopInfo loop_info(func, dom_tree);
for (const auto& loop_ptr : loop_info.GetLoops()) {
auto match = MatchCanonicalLoop(loop_ptr.get());
if (!match.has_value()) continue;
if (RunFissionOnLoop(func, *match, ctx)) {
return true;
}
}
return false;
}
} // namespace passes
} // namespace ir

@ -0,0 +1,465 @@
// 循环习语优化:
// - 将连续常量填充的规范循环替换为运行时批量填充调用
// - 当前仅处理 step=1、init=0、单 store 的 innermost 循环
#include "ir/IR.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "LoopPassUtils.h"
namespace ir {
namespace passes {
namespace {
struct FillLoopCandidate {
analysis::Loop* loop = nullptr;
BasicBlock* preheader = nullptr;
BasicBlock* header = nullptr;
BasicBlock* exit = nullptr;
PhiInst* induction = nullptr;
Value* bound = nullptr;
Value* base_ptr = nullptr;
Value* offset = nullptr;
int fill_value = 0;
};
struct GuardedRowFillCandidate {
analysis::Loop* loop = nullptr;
BasicBlock* preheader = nullptr;
BasicBlock* header = nullptr;
BasicBlock* body = nullptr;
BasicBlock* action = nullptr;
BasicBlock* latch = nullptr;
BasicBlock* exit = nullptr;
PhiInst* induction = nullptr;
Value* bound = nullptr;
PhiInst* linear = nullptr;
Value* linear_init = nullptr;
int linear_step = 0;
Value* base_ptr = nullptr;
Value* threshold = nullptr;
bool prefix = false;
int fill_value = 0;
};
bool ExprDependsOn(Value* value, Value* needle,
std::unordered_set<Value*>& visiting) {
if (value == needle) return true;
auto* inst = dynamic_cast<Instruction*>(value);
if (!inst) return false;
if (!visiting.insert(value).second) return false;
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
if (ExprDependsOn(inst->GetOperand(i), needle, visiting)) {
return true;
}
}
return false;
}
bool ExprDependsOn(Value* value, Value* needle) {
std::unordered_set<Value*> visiting;
return ExprDependsOn(value, needle, visiting);
}
Value* GetIncomingForBlock(PhiInst* phi, BasicBlock* block) {
if (!phi || !block) return nullptr;
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
if (phi->GetIncomingBlock(i) == block) {
return phi->GetIncomingValue(i);
}
}
return nullptr;
}
Value* MaterializeInvariantExpr(Value* value, analysis::Loop* loop, IRBuilder& builder,
ValueMap& remap) {
auto it = remap.find(value);
if (it != remap.end()) return it->second;
if (dynamic_cast<ConstantValue*>(value) || dynamic_cast<Argument*>(value) ||
dynamic_cast<GlobalVariable*>(value) || dynamic_cast<Function*>(value)) {
return value;
}
auto* inst = dynamic_cast<Instruction*>(value);
if (!inst || !loop->Contains(inst->GetParent())) return value;
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
auto* operand = inst->GetOperand(i);
remap[operand] = MaterializeInvariantExpr(operand, loop, builder, remap);
}
auto cloned = CloneInstruction(inst, remap, ".idiom");
if (!cloned) return nullptr;
auto* raw = cloned.get();
InsertInstruction(builder.GetInsertBlock(), std::move(cloned));
remap[inst] = raw;
return raw;
}
bool HasOutsideUse(Instruction* inst, analysis::Loop* loop) {
for (const auto& use : inst->GetUses()) {
auto* user = dynamic_cast<Instruction*>(use.GetUser());
if (!user) return true;
if (!user->GetParent() || !loop->Contains(user->GetParent())) {
return true;
}
}
return false;
}
Value* MatchContiguousOffset(Value* index, PhiInst* iv, analysis::Loop* loop) {
if (index == iv) return nullptr;
auto* bin = dynamic_cast<BinaryInst*>(index);
if (!bin || bin->GetOpcode() != Opcode::Add ||
!bin->GetType() || !bin->GetType()->IsInt32()) {
return nullptr;
}
if (bin->GetLhs() == iv && IsLoopInvariantValue(bin->GetRhs(), loop)) {
return bin->GetRhs();
}
if (bin->GetRhs() == iv && IsLoopInvariantValue(bin->GetLhs(), loop)) {
return bin->GetLhs();
}
return nullptr;
}
bool BuildFillLoopCandidate(Function& func, analysis::Loop* loop,
FillLoopCandidate* out) {
(void)func;
auto match = MatchCanonicalLoop(loop);
if (!match.has_value()) return false;
if (match->loop->GetChildren().size() != 0) return false;
if (match->body != match->latch || loop->GetBlocks().size() != 2) return false;
if (match->header_phis.size() != 1 ||
match->header_phis.front() != match->induction.phi) {
return false;
}
if (match->induction.step != 1) return false;
if (match->header_cmp->GetCmpOp() != CmpOp::Lt) return false;
auto* init_ci = dynamic_cast<ConstantInt*>(match->induction.init);
if (!init_ci || init_ci->GetValue() != 0) return false;
if (!match->exit->GetInstructions().empty() &&
dynamic_cast<PhiInst*>(match->exit->GetInstructions().front().get()) != nullptr) {
return false;
}
StoreInst* store = nullptr;
std::vector<Instruction*> body_insts;
for (const auto& inst_ptr : match->body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dynamic_cast<PhiInst*>(inst) != nullptr || inst->IsTerminator()) continue;
body_insts.push_back(inst);
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Gep:
case Opcode::Store:
break;
default:
return false;
}
if (inst != match->induction.next && HasOutsideUse(inst, loop)) {
return false;
}
if (auto* maybe_store = dynamic_cast<StoreInst*>(inst)) {
if (store) return false;
store = maybe_store;
}
}
if (!store) return false;
auto* fill_ci = dynamic_cast<ConstantInt*>(store->GetValue());
if (!fill_ci) return false;
auto* gep = dynamic_cast<GepInst*>(store->GetPtr());
if (!gep || !gep->GetBase() || !gep->GetBase()->GetType() ||
!gep->GetBase()->GetType()->IsPtrInt32()) {
return false;
}
Value* offset = MatchContiguousOffset(gep->GetIndex(), match->induction.phi, loop);
if (gep->GetIndex() != match->induction.phi && offset == nullptr) {
return false;
}
out->loop = loop;
out->preheader = match->preheader;
out->header = match->header;
out->exit = match->exit;
out->induction = match->induction.phi;
out->bound = match->bound;
out->base_ptr = gep->GetBase();
out->offset = offset;
out->fill_value = fill_ci->GetValue();
return true;
}
Function* GetOrCreateFillI32(Module& module) {
if (auto* fn = module.FindFunction("__fill_i32")) return fn;
auto* fn = module.CreateFunction("__fill_i32", Type::GetVoidType(),
{Type::GetPtrInt32Type(), Type::GetInt32Type(),
Type::GetInt32Type()});
fn->SetExternal(true);
return fn;
}
Function* GetOrCreateFillRowsI32(Module& module) {
if (auto* fn = module.FindFunction("__fill_rows_i32")) return fn;
auto* fn = module.CreateFunction(
"__fill_rows_i32", Type::GetVoidType(),
{Type::GetPtrInt32Type(), Type::GetInt32Type(), Type::GetInt32Type(),
Type::GetInt32Type(), Type::GetInt32Type(), Type::GetInt32Type()});
fn->SetExternal(true);
return fn;
}
bool BuildGuardedRowFillCandidate(Function& func, analysis::Loop* loop,
GuardedRowFillCandidate* out) {
(void)func;
if (!loop) return false;
if (loop->GetChildren().size() != 0) return false;
if (loop->GetBlocks().size() != 4) return false;
auto* header = loop->GetHeader();
auto* preheader = loop->GetPreheader();
if (!header || !preheader) return false;
if (loop->GetLatches().size() != 1) return false;
auto* latch = loop->GetLatches().front();
if (!latch) return false;
auto* header_term = header->HasTerminator()
? dynamic_cast<CondBranchInst*>(
header->MutableInstructions().back().get())
: nullptr;
if (!header_term) return false;
auto* header_cmp = dynamic_cast<CmpInst*>(header_term->GetCond());
if (!header_cmp || header_cmp->GetCmpOp() != CmpOp::Lt) return false;
auto induction = MatchCanonicalInduction(header, preheader, latch);
if (!induction.has_value() || induction->step != 1) return false;
Value* bound = nullptr;
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
if (loop->Contains(header_term->GetTrueBlock()) &&
!loop->Contains(header_term->GetFalseBlock())) {
body = header_term->GetTrueBlock();
exit = header_term->GetFalseBlock();
} else if (loop->Contains(header_term->GetFalseBlock()) &&
!loop->Contains(header_term->GetTrueBlock())) {
body = header_term->GetFalseBlock();
exit = header_term->GetTrueBlock();
} else {
return false;
}
if (header_cmp->GetLhs() == induction->phi &&
IsLoopInvariantValue(header_cmp->GetRhs(), loop)) {
bound = header_cmp->GetRhs();
} else if (header_cmp->GetRhs() == induction->phi &&
IsLoopInvariantValue(header_cmp->GetLhs(), loop)) {
bound = header_cmp->GetLhs();
} else {
return false;
}
auto header_phis = CollectHeaderPhis(header);
if (header_phis.size() != 2) return false;
PhiInst* linear_phi = nullptr;
for (auto* phi : header_phis) {
if (phi != induction->phi) {
linear_phi = phi;
break;
}
}
if (!linear_phi || !linear_phi->GetType() || !linear_phi->GetType()->IsInt32()) {
return false;
}
auto* linear_init = GetIncomingForBlock(linear_phi, preheader);
auto* linear_next = GetIncomingForBlock(linear_phi, latch);
auto* linear_next_bin = dynamic_cast<BinaryInst*>(linear_next);
if (!linear_init || !linear_next_bin ||
linear_next_bin->GetOpcode() != Opcode::Add ||
linear_next_bin->GetLhs() != linear_phi) {
return false;
}
auto* linear_step_ci = dynamic_cast<ConstantInt*>(linear_next_bin->GetRhs());
if (!linear_step_ci || linear_step_ci->GetValue() <= 0) return false;
auto* guard = body->HasTerminator()
? dynamic_cast<CondBranchInst*>(body->MutableInstructions().back().get())
: nullptr;
if (!guard) return false;
BasicBlock* action = nullptr;
if (guard->GetTrueBlock() == latch && loop->Contains(guard->GetFalseBlock())) {
action = guard->GetFalseBlock();
} else if (guard->GetFalseBlock() == latch &&
loop->Contains(guard->GetTrueBlock())) {
action = guard->GetTrueBlock();
} else {
return false;
}
auto* action_term =
dynamic_cast<BranchInst*>(action->MutableInstructions().back().get());
if (!action_term || action_term->GetTarget() != latch) return false;
CallInst* fill_call = nullptr;
GepInst* fill_gep = nullptr;
for (const auto& inst_ptr : action->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator()) continue;
if (auto* gep = dynamic_cast<GepInst*>(inst)) {
fill_gep = gep;
continue;
}
fill_call = dynamic_cast<CallInst*>(inst);
}
if (!fill_call || !fill_gep || fill_call->GetNumArgs() != 3) return false;
auto* callee = fill_call->GetCallee();
if (!callee || callee->GetName() != "__fill_i32") return false;
auto* fill_value = dynamic_cast<ConstantInt*>(fill_call->GetArg(2));
if (!fill_value) return false;
if (fill_call->GetArg(0) != fill_gep || fill_call->GetArg(1) != bound) {
return false;
}
if (fill_gep->GetIndex() != linear_phi) return false;
if (!fill_gep->GetBase() || !fill_gep->GetBase()->GetType() ||
!fill_gep->GetBase()->GetType()->IsPtrInt32()) {
return false;
}
auto* guard_cmp = dynamic_cast<CmpInst*>(guard->GetCond());
if (!guard_cmp || !guard_cmp->GetType() || !guard_cmp->GetType()->IsInt32()) {
return false;
}
Value* threshold = nullptr;
bool prefix = false;
bool suffix = false;
if (guard_cmp->GetLhs() == induction->phi &&
!ExprDependsOn(guard_cmp->GetRhs(), induction->phi) &&
!ExprDependsOn(guard_cmp->GetRhs(), linear_phi)) {
threshold = guard_cmp->GetRhs();
if (guard_cmp->GetCmpOp() == CmpOp::Lt && action == guard->GetTrueBlock()) {
prefix = true;
} else if (guard_cmp->GetCmpOp() == CmpOp::Ge &&
action == guard->GetTrueBlock()) {
suffix = true;
} else if (guard_cmp->GetCmpOp() == CmpOp::Lt &&
action == guard->GetFalseBlock()) {
suffix = true;
} else if (guard_cmp->GetCmpOp() == CmpOp::Ge &&
action == guard->GetFalseBlock()) {
prefix = true;
} else {
return false;
}
} else {
return false;
}
out->loop = loop;
out->preheader = preheader;
out->header = header;
out->body = body;
out->action = action;
out->latch = latch;
out->exit = exit;
out->induction = induction->phi;
out->bound = bound;
out->linear = linear_phi;
out->linear_init = linear_init;
out->linear_step = linear_step_ci->GetValue();
out->base_ptr = fill_gep->GetBase();
out->fill_value = fill_value->GetValue();
out->threshold = threshold;
out->prefix = prefix;
return prefix || suffix;
}
bool RunFillLoop(Function& func, const FillLoopCandidate& cand,
Module& module, Context& ctx) {
(void)func;
auto* fill_fn = GetOrCreateFillI32(module);
auto* preheader = cand.preheader;
if (preheader->HasTerminator()) {
preheader->RemoveInstruction(preheader->MutableInstructions().back().get());
}
IRBuilder builder(ctx, preheader);
Value* start_ptr = cand.base_ptr;
if (cand.offset) {
start_ptr = builder.CreateGep(cand.base_ptr, cand.offset, ctx.NextTemp());
}
builder.CreateCall(fill_fn, {start_ptr, cand.bound, ctx.GetConstInt(cand.fill_value)},
"");
preheader->Append<BranchInst>(Type::GetVoidType(), cand.exit);
cand.induction->RemoveIncomingBlock(preheader);
preheader->RemoveSuccessor(cand.header);
cand.header->RemovePredecessor(preheader);
preheader->AddSuccessor(cand.exit);
cand.exit->AddPredecessor(preheader);
return true;
}
bool RunGuardedRowFillLoop(Function& func, const GuardedRowFillCandidate& cand,
Module& module, Context& ctx) {
(void)func;
auto* fill_rows_fn = GetOrCreateFillRowsI32(module);
auto* preheader = cand.preheader;
if (preheader->HasTerminator()) {
preheader->RemoveInstruction(preheader->MutableInstructions().back().get());
}
IRBuilder builder(ctx, preheader);
ValueMap remap;
auto* threshold =
MaterializeInvariantExpr(cand.threshold, cand.loop, builder, remap);
if (!threshold) return false;
Value* start_index = cand.prefix ? ctx.GetConstInt(0) : threshold;
Value* rows = cand.prefix ? threshold : nullptr;
if (!cand.prefix) {
rows = builder.CreateSub(cand.bound, start_index, ctx.NextTemp());
}
auto* start_offset_mul =
builder.CreateMul(start_index, ctx.GetConstInt(cand.linear_step), ctx.NextTemp());
auto* start_offset =
builder.CreateAdd(cand.linear_init, start_offset_mul, ctx.NextTemp());
builder.CreateCall(fill_rows_fn,
{cand.base_ptr, start_offset, rows,
ctx.GetConstInt(cand.linear_step), cand.bound,
ctx.GetConstInt(cand.fill_value)},
"");
preheader->Append<BranchInst>(Type::GetVoidType(), cand.exit);
cand.induction->RemoveIncomingBlock(preheader);
cand.linear->RemoveIncomingBlock(preheader);
preheader->RemoveSuccessor(cand.header);
cand.header->RemovePredecessor(preheader);
preheader->AddSuccessor(cand.exit);
cand.exit->AddPredecessor(preheader);
return true;
}
} // namespace
bool RunLoopIdiom(Function& func, Module& module, Context& ctx) {
if (func.IsExternal()) return false;
analysis::DominatorTree dom_tree(func);
analysis::LoopInfo loop_info(func, dom_tree);
for (const auto& loop_ptr : loop_info.GetLoops()) {
GuardedRowFillCandidate row_fill;
if (BuildGuardedRowFillCandidate(func, loop_ptr.get(), &row_fill)) {
if (RunGuardedRowFillLoop(func, row_fill, module, ctx)) {
return true;
}
}
FillLoopCandidate cand;
if (!BuildFillLoopCandidate(func, loop_ptr.get(), &cand)) continue;
if (RunFillLoop(func, cand, module, ctx)) {
return true;
}
}
return false;
}
} // namespace passes
} // namespace ir

@ -0,0 +1,845 @@
// 循环并行化:
// - 将一部分安全的规范循环抽取成 worker 函数
// - 通过运行时 __par_runN 启动固定线程数并行执行
//
// 当前限制:
// - 仅并行化不存在 SSA live-out 的循环
// - 循环访问对象必须是全局数组/全局变量
// - 不支持循环中的普通函数调用
#include "ir/IR.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "LoopPassUtils.h"
namespace ir {
namespace passes {
bool RunSimpleDCE(Function& func);
bool RunCFGSimplify(Function& func, Context& ctx);
namespace {
constexpr int kParallelLoopSlots = 8;
constexpr int kParallelThreads = 4;
enum class ParallelLoopKind {
Pointwise,
ReductionAddI32,
GuardedFillI32,
};
struct LoopContextValue {
Value* original = nullptr;
GlobalVariable* slot = nullptr;
};
bool ExprDependsOn(Value* value, Value* needle,
std::unordered_set<Value*>& visiting) {
if (value == needle) return true;
auto* inst = dynamic_cast<Instruction*>(value);
if (!inst) return false;
if (!visiting.insert(value).second) return false;
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
if (ExprDependsOn(inst->GetOperand(i), needle, visiting)) {
return true;
}
}
return false;
}
bool ExprDependsOn(Value* value, Value* needle) {
std::unordered_set<Value*> visiting;
return ExprDependsOn(value, needle, visiting);
}
Value* StripPointerBase(Value* value) {
while (auto* gep = dynamic_cast<GepInst*>(value)) {
value = gep->GetBase();
}
return value;
}
bool IsSupportedParallelInst(Instruction* inst) {
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Mod:
case Opcode::Cmp:
case Opcode::Cast:
case Opcode::Load:
case Opcode::Store:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Gep:
case Opcode::Phi:
case Opcode::Ret:
return true;
case Opcode::Call:
case Opcode::Alloca:
return false;
}
return false;
}
bool IsScalarContextCandidate(Value* value) {
auto* arg = dynamic_cast<Argument*>(value);
if (arg && (arg->GetType()->IsInt32() || arg->GetType()->IsFloat32())) {
return true;
}
auto* inst = dynamic_cast<Instruction*>(value);
if (inst && inst->GetType() &&
(inst->GetType()->IsInt32() || inst->GetType()->IsFloat32())) {
return true;
}
return false;
}
bool HasOutsideUse(Instruction* inst, analysis::Loop* loop) {
for (const auto& use : inst->GetUses()) {
auto* user = dynamic_cast<Instruction*>(use.GetUser());
if (!user) return true;
if (!user->GetParent() || !loop->Contains(user->GetParent())) {
return true;
}
}
return false;
}
void ReplaceUsesOutsideLoop(Value* value, Value* replacement,
analysis::Loop* loop) {
if (!value || !replacement || !loop) return;
auto uses = value->GetUses();
for (const auto& use : uses) {
auto* user = dynamic_cast<Instruction*>(use.GetUser());
if (!user) continue;
auto* parent = user->GetParent();
if (parent && loop->Contains(parent)) continue;
user->SetOperand(use.GetOperandIndex(), replacement);
}
}
Value* GetIncomingForBlock(PhiInst* phi, BasicBlock* block) {
if (!phi || !block) return nullptr;
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
if (phi->GetIncomingBlock(i) == block) {
return phi->GetIncomingValue(i);
}
}
return nullptr;
}
struct ParallelLoopCandidate {
Function* parent = nullptr;
analysis::Loop* loop = nullptr;
ParallelLoopKind kind = ParallelLoopKind::Pointwise;
BasicBlock* header = nullptr;
BasicBlock* body = nullptr;
BasicBlock* guard = nullptr;
BasicBlock* action = nullptr;
BasicBlock* preheader = nullptr;
BasicBlock* exit = nullptr;
BasicBlock* latch = nullptr;
CmpInst* header_cmp = nullptr;
PhiInst* induction = nullptr;
Value* induction_next = nullptr;
Value* bound = nullptr;
PhiInst* linear = nullptr;
Value* linear_init = nullptr;
Value* linear_next = nullptr;
int linear_step = 0;
PhiInst* reduction = nullptr;
Value* reduction_init = nullptr;
Value* reduction_next = nullptr;
bool has_loads = false;
std::vector<LoopContextValue> contexts;
};
bool BuildGuardedFillCandidate(Function& func, analysis::Loop* loop,
ParallelLoopCandidate* out) {
if (!loop) return false;
auto* header = loop->GetHeader();
auto* preheader = loop->GetPreheader();
if (!header || !preheader) return false;
if (loop->GetChildren().size() != 0) return false;
if (loop->GetBlocks().size() != 4) return false;
if (loop->GetLatches().size() != 1) return false;
auto* latch = loop->GetLatches().front();
if (!latch) return false;
auto* header_term = header->HasTerminator()
? dynamic_cast<CondBranchInst*>(
header->MutableInstructions().back().get())
: nullptr;
if (!header_term) return false;
auto* header_cmp = dynamic_cast<CmpInst*>(header_term->GetCond());
if (!header_cmp || header_cmp->GetCmpOp() != CmpOp::Lt) return false;
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
if (loop->Contains(header_term->GetTrueBlock()) &&
!loop->Contains(header_term->GetFalseBlock())) {
body = header_term->GetTrueBlock();
exit = header_term->GetFalseBlock();
} else if (loop->Contains(header_term->GetFalseBlock()) &&
!loop->Contains(header_term->GetTrueBlock())) {
body = header_term->GetFalseBlock();
exit = header_term->GetTrueBlock();
} else {
return false;
}
auto induction = MatchCanonicalInduction(header, preheader, latch);
if (!induction.has_value() || induction->step != 1) return false;
Value* bound = nullptr;
if (header_cmp->GetLhs() == induction->phi &&
IsLoopInvariantValue(header_cmp->GetRhs(), loop)) {
bound = header_cmp->GetRhs();
} else if (header_cmp->GetRhs() == induction->phi &&
IsLoopInvariantValue(header_cmp->GetLhs(), loop)) {
bound = header_cmp->GetLhs();
} else {
return false;
}
auto header_phis = CollectHeaderPhis(header);
if (header_phis.size() != 2) return false;
PhiInst* linear_phi = nullptr;
for (auto* phi : header_phis) {
if (phi != induction->phi) {
linear_phi = phi;
break;
}
}
if (!linear_phi || !linear_phi->GetType() || !linear_phi->GetType()->IsInt32()) {
return false;
}
auto* linear_init = GetIncomingForBlock(linear_phi, preheader);
auto* linear_next = GetIncomingForBlock(linear_phi, latch);
auto* linear_next_bin = dynamic_cast<BinaryInst*>(linear_next);
if (!linear_init || !linear_next_bin ||
linear_next_bin->GetOpcode() != Opcode::Add ||
linear_next_bin->GetLhs() != linear_phi) {
return false;
}
auto* linear_step_ci = dynamic_cast<ConstantInt*>(linear_next_bin->GetRhs());
if (!linear_step_ci || linear_step_ci->GetValue() <= 0) return false;
auto* guard = dynamic_cast<CondBranchInst*>(body->MutableInstructions().back().get());
if (!guard) return false;
auto* true_bb = guard->GetTrueBlock();
auto* false_bb = guard->GetFalseBlock();
if (!loop->Contains(true_bb) && !loop->Contains(false_bb)) return false;
if (loop->Contains(true_bb) && loop->Contains(false_bb)) {
if (true_bb == latch || false_bb == latch) {
// fine
} else {
return false;
}
}
BasicBlock* action = nullptr;
if (true_bb == latch && false_bb != latch && loop->Contains(false_bb)) {
action = false_bb;
} else if (false_bb == latch && true_bb != latch &&
loop->Contains(true_bb)) {
action = true_bb;
} else {
return false;
}
if (action == header || action == body) return false;
auto* action_term =
dynamic_cast<BranchInst*>(action->MutableInstructions().back().get());
if (!action_term || action_term->GetTarget() != latch) return false;
CallInst* fill_call = nullptr;
for (const auto& inst_ptr : action->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator()) continue;
if (auto* gep = dynamic_cast<GepInst*>(inst)) {
if (gep->GetBase() == nullptr || gep->GetIndex() == nullptr) return false;
continue;
}
fill_call = dynamic_cast<CallInst*>(inst);
if (!fill_call) return false;
}
if (!fill_call || fill_call->GetNumArgs() != 3) return false;
auto* callee = fill_call->GetCallee();
if (!callee || callee->GetName() != "__fill_i32") return false;
auto* fill_ptr = dynamic_cast<GepInst*>(fill_call->GetArg(0));
auto* fill_count = fill_call->GetArg(1);
auto* fill_value = dynamic_cast<ConstantInt*>(fill_call->GetArg(2));
if (!fill_ptr || !fill_value) return false;
if (fill_ptr->GetBase() == nullptr || !fill_ptr->GetBase()->GetType() ||
!fill_ptr->GetBase()->GetType()->IsPtrInt32()) {
return false;
}
if (fill_ptr->GetIndex() != linear_phi) return false;
if (fill_count != bound) return false;
std::vector<Value*> context_values;
std::unordered_set<Value*> seen_contexts;
auto collect_contexts = [&](BasicBlock* block) -> bool {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dynamic_cast<PhiInst*>(inst) != nullptr || inst->IsTerminator()) continue;
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
Value* operand = inst->GetOperand(i);
if (dynamic_cast<ConstantValue*>(operand) || dynamic_cast<Function*>(operand) ||
dynamic_cast<BasicBlock*>(operand) || dynamic_cast<GlobalVariable*>(operand)) {
continue;
}
auto* operand_inst = dynamic_cast<Instruction*>(operand);
if ((operand_inst && loop->Contains(operand_inst->GetParent())) ||
operand == induction->phi || operand == induction->next ||
operand == linear_phi || operand == linear_next) {
continue;
}
if (!IsScalarContextCandidate(operand)) return false;
if (seen_contexts.insert(operand).second) {
context_values.push_back(operand);
}
}
if (inst != linear_next && inst != induction->next &&
HasOutsideUse(inst, loop)) {
return false;
}
}
return true;
};
if (!collect_contexts(body) || !collect_contexts(action) ||
!collect_contexts(latch)) {
return false;
}
if (context_values.size() > 6) return false;
out->parent = &func;
out->loop = loop;
out->kind = ParallelLoopKind::GuardedFillI32;
out->header = header;
out->body = body;
out->guard = body;
out->action = action;
out->preheader = preheader;
out->exit = exit;
out->latch = latch;
out->header_cmp = header_cmp;
out->induction = induction->phi;
out->induction_next = induction->next;
out->bound = bound;
out->linear = linear_phi;
out->linear_init = linear_init;
out->linear_next = linear_next;
out->linear_step = linear_step_ci->GetValue();
out->has_loads = true;
for (Value* value : context_values) {
out->contexts.push_back({value, nullptr});
}
return true;
}
bool BuildParallelCandidate(Function& func, analysis::Loop* loop,
ParallelLoopCandidate* out) {
if (BuildGuardedFillCandidate(func, loop, out)) return true;
auto match = MatchCanonicalLoop(loop);
if (!match.has_value()) return false;
if (match->body != match->latch || loop->GetBlocks().size() != 2) return false;
if (match->exit->GetInstructions().size() > 0 &&
dynamic_cast<PhiInst*>(match->exit->GetInstructions().front().get()) != nullptr) {
return false;
}
PhiInst* reduction_phi = nullptr;
Value* reduction_init = nullptr;
Value* reduction_next = nullptr;
ParallelLoopKind kind = ParallelLoopKind::Pointwise;
if (match->header_phis.size() == 1 &&
match->header_phis.front() == match->induction.phi) {
kind = ParallelLoopKind::Pointwise;
} else if (false && match->header_phis.size() == 2) {
for (auto* phi : match->header_phis) {
if (phi != match->induction.phi) {
reduction_phi = phi;
break;
}
}
if (!reduction_phi || !reduction_phi->GetType() ||
!reduction_phi->GetType()->IsInt32()) {
return false;
}
reduction_init = GetIncomingForBlock(reduction_phi, match->preheader);
reduction_next = GetIncomingForBlock(reduction_phi, match->latch);
if (!reduction_init || !reduction_next) return false;
auto* reduction_next_inst = dynamic_cast<Instruction*>(reduction_next);
if (!reduction_next_inst || reduction_next_inst->GetParent() != match->body) {
return false;
}
auto* init_ci = dynamic_cast<ConstantInt*>(reduction_init);
if (!init_ci || init_ci->GetValue() != 0) return false;
auto* red_next_bin = dynamic_cast<BinaryInst*>(reduction_next);
if (!red_next_bin || red_next_bin->GetOpcode() != Opcode::Add ||
!red_next_bin->GetType() || !red_next_bin->GetType()->IsInt32()) {
return false;
}
Value* other = nullptr;
if (red_next_bin->GetLhs() == reduction_phi) {
other = red_next_bin->GetRhs();
} else if (red_next_bin->GetRhs() == reduction_phi) {
other = red_next_bin->GetLhs();
} else {
return false;
}
if (ExprDependsOn(other, reduction_phi)) return false;
kind = ParallelLoopKind::ReductionAddI32;
} else {
return false;
}
std::unordered_set<Value*> store_bases;
std::unordered_set<Value*> load_bases;
std::vector<Value*> context_values;
std::unordered_set<Value*> seen_contexts;
for (const auto& bb_ptr : func.GetBlocks()) {
auto* block = bb_ptr.get();
if (!loop->Contains(block)) continue;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!IsSupportedParallelInst(inst)) return false;
if (inst->GetOpcode() == Opcode::Call) return false;
if (kind == ParallelLoopKind::ReductionAddI32 &&
inst->GetOpcode() == Opcode::Store) {
return false;
}
if (inst->GetOpcode() != Opcode::Store && inst->GetOpcode() != Opcode::Br &&
inst->GetOpcode() != Opcode::CondBr && inst->GetOpcode() != Opcode::Phi &&
inst != reduction_phi &&
HasOutsideUse(inst, loop)) {
return false;
}
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
Value* operand = inst->GetOperand(i);
if (dynamic_cast<ConstantValue*>(operand) || dynamic_cast<Function*>(operand) ||
dynamic_cast<BasicBlock*>(operand) || dynamic_cast<GlobalVariable*>(operand)) {
continue;
}
auto* operand_inst = dynamic_cast<Instruction*>(operand);
if ((operand_inst && loop->Contains(operand_inst->GetParent())) ||
operand == match->induction.phi ||
operand == match->induction.next ||
operand == reduction_phi ||
operand == reduction_next) {
continue;
}
if (!IsScalarContextCandidate(operand)) return false;
if (seen_contexts.insert(operand).second) {
context_values.push_back(operand);
}
}
if (auto* load = dynamic_cast<LoadInst*>(inst)) {
Value* base = StripPointerBase(load->GetPtr());
if (dynamic_cast<GlobalVariable*>(base) == nullptr) return false;
load_bases.insert(base);
if (auto* gep = dynamic_cast<GepInst*>(load->GetPtr())) {
if (base == StripPointerBase(load->GetPtr()) &&
!ExprDependsOn(gep->GetIndex(), match->induction.phi)) {
// allowed for pure reads; dependence checked later with stores
}
}
} else if (auto* store = dynamic_cast<StoreInst*>(inst)) {
Value* base = StripPointerBase(store->GetPtr());
auto* gv = dynamic_cast<GlobalVariable*>(base);
if (!gv || gv->GetCount() <= 1) return false;
store_bases.insert(base);
auto* gep = dynamic_cast<GepInst*>(store->GetPtr());
if (!gep || !ExprDependsOn(gep->GetIndex(), match->induction.phi)) return false;
}
}
}
for (Value* base : store_bases) {
if (load_bases.count(base) == 0) continue;
for (const auto& bb_ptr : func.GetBlocks()) {
auto* block = bb_ptr.get();
if (!loop->Contains(block)) continue;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* load = dynamic_cast<LoadInst*>(inst_ptr.get());
if (!load) continue;
if (StripPointerBase(load->GetPtr()) != base) continue;
auto* gep = dynamic_cast<GepInst*>(load->GetPtr());
if (!gep || !ExprDependsOn(gep->GetIndex(), match->induction.phi)) return false;
}
}
}
if (context_values.size() > 6) return false;
out->parent = &func;
out->loop = loop;
out->kind = kind;
out->header = match->header;
out->body = match->body;
out->preheader = match->preheader;
out->exit = match->exit;
out->latch = match->latch;
out->header_cmp = match->header_cmp;
out->induction = match->induction.phi;
out->induction_next = match->induction.next;
out->bound = match->bound;
out->reduction = reduction_phi;
out->reduction_init = reduction_init;
out->reduction_next = reduction_next;
out->has_loads = !load_bases.empty();
for (Value* value : context_values) {
out->contexts.push_back({value, nullptr});
}
return true;
}
bool IsGeneratedParallelWorker(const Function& func) {
return func.GetName().rfind("__par_worker", 0) == 0;
}
Function* GetOrCreateRuntimeLauncher(Module& module, int slot) {
const std::string name = "__par_run" + std::to_string(slot);
if (auto* fn = module.FindFunction(name)) return fn;
auto* fn = module.CreateFunction(name, Type::GetVoidType(), {});
fn->SetExternal(true);
return fn;
}
std::string NextWorkerName(int slot) {
return "__par_worker" + std::to_string(slot);
}
void CloneWorkerBlocks(const ParallelLoopCandidate& cand, Function* worker,
GlobalVariable* bound_slot,
const std::vector<LoopContextValue>& ctx_slots,
GlobalVariable* reduction_slot, Context& ctx) {
if (cand.kind == ParallelLoopKind::GuardedFillI32) {
auto* entry = worker->GetEntry();
auto* tid = worker->GetArgument(0);
auto* header = worker->CreateBlock(cand.header->GetName());
auto* guard = worker->CreateBlock(cand.guard->GetName());
auto* action = worker->CreateBlock(cand.action->GetName());
auto* latch = worker->CreateBlock(cand.latch->GetName());
auto* worker_exit = worker->CreateBlock("par.exit");
IRBuilder builder(ctx, entry);
auto* bound_val = builder.CreateLoad(bound_slot, ctx.NextTemp());
Value* threads_val = ctx.GetConstInt(kParallelThreads);
auto* start_mul = builder.CreateMul(tid, bound_val, ctx.NextTemp());
auto* start = builder.CreateDiv(start_mul, threads_val, ctx.NextTemp());
auto* next_tid = builder.CreateAdd(tid, ctx.GetConstInt(1), ctx.NextTemp());
auto* end_mul = builder.CreateMul(next_tid, bound_val, ctx.NextTemp());
auto* end = builder.CreateDiv(end_mul, threads_val, ctx.NextTemp());
ValueMap remap;
remap[cand.bound] = bound_val;
for (const auto& ctx_value : ctx_slots) {
builder.SetInsertPoint(entry);
auto* loaded = builder.CreateLoad(ctx_value.slot, ctx.NextTemp());
remap[ctx_value.original] = loaded;
}
builder.SetInsertPoint(entry);
auto* start_linear_mul =
builder.CreateMul(start, ctx.GetConstInt(cand.linear_step), ctx.NextTemp());
Value* linear_init = cand.linear_init;
auto it = remap.find(cand.linear_init);
if (it != remap.end()) linear_init = it->second;
auto* start_linear = builder.CreateAdd(linear_init, start_linear_mul, ctx.NextTemp());
builder.CreateBr(header);
auto* new_iv = header->PrependPhi(cand.induction->GetType(), ctx.NextTemp());
auto* new_linear = header->PrependPhi(cand.linear->GetType(), ctx.NextTemp());
remap[cand.induction] = new_iv;
remap[cand.linear] = new_linear;
if (auto cloned_cmp = CloneInstruction(cand.header_cmp, remap, ".par")) {
auto* raw_cmp = static_cast<CmpInst*>(cloned_cmp.get());
header->MutableInstructions().push_back(std::move(cloned_cmp));
raw_cmp->SetParent(header);
remap[cand.header_cmp] = raw_cmp;
if (cand.header_cmp->GetLhs() == cand.induction &&
cand.header_cmp->GetRhs() == cand.bound) {
raw_cmp->SetOperand(0, new_iv);
raw_cmp->SetOperand(1, end);
} else if (cand.header_cmp->GetRhs() == cand.induction &&
cand.header_cmp->GetLhs() == cand.bound) {
raw_cmp->SetOperand(0, end);
raw_cmp->SetOperand(1, new_iv);
}
header->Append<CondBranchInst>(Type::GetVoidType(), raw_cmp, guard, worker_exit);
}
for (const auto& inst_ptr : cand.guard->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dynamic_cast<PhiInst*>(inst) != nullptr || inst->IsTerminator()) continue;
if (auto cloned = CloneInstruction(inst, remap, ".par")) {
auto* raw = cloned.get();
guard->MutableInstructions().push_back(std::move(cloned));
raw->SetParent(guard);
remap[inst] = raw;
}
}
auto* guard_term =
static_cast<CondBranchInst*>(cand.guard->MutableInstructions().back().get());
auto* guard_cond = RemapValue(guard_term->GetCond(), remap);
BasicBlock* true_target =
(guard_term->GetTrueBlock() == cand.action) ? action : latch;
BasicBlock* false_target =
(guard_term->GetFalseBlock() == cand.action) ? action : latch;
guard->Append<CondBranchInst>(Type::GetVoidType(), guard_cond, true_target,
false_target);
for (const auto& inst_ptr : cand.action->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dynamic_cast<PhiInst*>(inst) != nullptr || inst->IsTerminator()) continue;
if (auto cloned = CloneInstruction(inst, remap, ".par")) {
auto* raw = cloned.get();
action->MutableInstructions().push_back(std::move(cloned));
raw->SetParent(action);
remap[inst] = raw;
}
}
action->Append<BranchInst>(Type::GetVoidType(), latch);
for (const auto& inst_ptr : cand.latch->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dynamic_cast<PhiInst*>(inst) != nullptr || inst->IsTerminator()) continue;
if (auto cloned = CloneInstruction(inst, remap, ".par")) {
auto* raw = cloned.get();
latch->MutableInstructions().push_back(std::move(cloned));
raw->SetParent(latch);
remap[inst] = raw;
}
}
latch->Append<BranchInst>(Type::GetVoidType(), header);
new_iv->AddIncoming(start, entry);
new_iv->AddIncoming(RemapValue(cand.induction_next, remap), latch);
new_linear->AddIncoming(start_linear, entry);
new_linear->AddIncoming(RemapValue(cand.linear_next, remap), latch);
worker_exit->Append<ReturnInst>(Type::GetVoidType(), nullptr);
return;
}
auto* entry = worker->GetEntry();
auto* tid = worker->GetArgument(0);
auto* header = worker->CreateBlock(cand.header->GetName());
auto* body = worker->CreateBlock(cand.body->GetName());
auto* worker_exit = worker->CreateBlock("par.exit");
IRBuilder builder(ctx, entry);
auto* bound_val = builder.CreateLoad(bound_slot, ctx.NextTemp());
Value* threads_val = ctx.GetConstInt(kParallelThreads);
auto* start_mul = builder.CreateMul(tid, bound_val, ctx.NextTemp());
auto* start = builder.CreateDiv(start_mul, threads_val, ctx.NextTemp());
auto* next_tid = builder.CreateAdd(tid, ctx.GetConstInt(1), ctx.NextTemp());
auto* end_mul = builder.CreateMul(next_tid, bound_val, ctx.NextTemp());
auto* end = builder.CreateDiv(end_mul, threads_val, ctx.NextTemp());
ValueMap remap;
remap[cand.induction] = start;
remap[cand.bound] = bound_val;
for (const auto& ctx_value : ctx_slots) {
builder.SetInsertPoint(entry);
auto* loaded = builder.CreateLoad(ctx_value.slot, ctx.NextTemp());
remap[ctx_value.original] = loaded;
}
builder.SetInsertPoint(entry);
builder.CreateBr(header);
auto* new_phi = header->PrependPhi(cand.induction->GetType(), ctx.NextTemp());
remap[cand.induction] = new_phi;
PhiInst* new_reduction_phi = nullptr;
if (cand.kind == ParallelLoopKind::ReductionAddI32) {
new_reduction_phi = header->PrependPhi(Type::GetInt32Type(), ctx.NextTemp());
remap[cand.reduction] = new_reduction_phi;
}
if (auto cloned_cmp = CloneInstruction(cand.header_cmp, remap, ".par")) {
auto* raw_cmp = static_cast<CmpInst*>(cloned_cmp.get());
header->MutableInstructions().push_back(std::move(cloned_cmp));
raw_cmp->SetParent(header);
remap[cand.header_cmp] = raw_cmp;
if (cand.header_cmp->GetLhs() == cand.induction &&
cand.header_cmp->GetRhs() == cand.bound) {
raw_cmp->SetOperand(0, new_phi);
raw_cmp->SetOperand(1, end);
} else if (cand.header_cmp->GetRhs() == cand.induction &&
cand.header_cmp->GetLhs() == cand.bound) {
raw_cmp->SetOperand(0, end);
raw_cmp->SetOperand(1, new_phi);
}
header->Append<CondBranchInst>(Type::GetVoidType(), raw_cmp, body, worker_exit);
}
for (const auto& inst_ptr : cand.body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dynamic_cast<PhiInst*>(inst) != nullptr) continue;
if (inst->GetOpcode() == Opcode::Br) continue;
if (auto cloned = CloneInstruction(inst, remap, ".par")) {
auto* raw = cloned.get();
body->MutableInstructions().push_back(std::move(cloned));
raw->SetParent(body);
remap[inst] = raw;
}
}
new_phi->AddIncoming(start, entry);
new_phi->AddIncoming(RemapValue(cand.induction_next, remap), body);
if (new_reduction_phi) {
new_reduction_phi->AddIncoming(ctx.GetConstInt(0), entry);
new_reduction_phi->AddIncoming(RemapValue(cand.reduction_next, remap), body);
}
body->Append<BranchInst>(Type::GetVoidType(), header);
if (new_reduction_phi) {
IRBuilder exit_builder(ctx, worker_exit);
auto* partial_ptr = exit_builder.CreateGep(reduction_slot, tid, ctx.NextTemp());
exit_builder.CreateStore(new_reduction_phi, partial_ptr);
}
worker_exit->Append<ReturnInst>(Type::GetVoidType(), nullptr);
}
bool ParallelizeCandidate(Module& module, ParallelLoopCandidate& cand, int slot) {
auto& ctx = module.GetContext();
auto* bound_slot =
module.CreateGlobalVar("__par_bound" + std::to_string(slot), 0, 1,
Type::GetPtrInt32Type());
GlobalVariable* reduction_slot = nullptr;
if (cand.kind == ParallelLoopKind::ReductionAddI32) {
reduction_slot = module.CreateGlobalVar(
"__par_red" + std::to_string(slot), 0, kParallelThreads,
Type::GetPtrInt32Type());
}
for (size_t i = 0; i < cand.contexts.size(); ++i) {
auto& entry = cand.contexts[i];
bool is_float = entry.original->GetType() && entry.original->GetType()->IsFloat32();
entry.slot = module.CreateGlobalVar(
"__par_ctx" + std::to_string(slot) + "_" + std::to_string(i), 0, 1,
is_float ? Type::GetPtrFloat32Type() : Type::GetPtrInt32Type());
}
auto* worker =
module.CreateFunction(NextWorkerName(slot), Type::GetVoidType(),
{Type::GetInt32Type()});
CloneWorkerBlocks(cand, worker, bound_slot, cand.contexts, reduction_slot, ctx);
auto* launcher = GetOrCreateRuntimeLauncher(module, slot);
auto* preheader = cand.preheader;
if (preheader->HasTerminator()) {
preheader->RemoveInstruction(preheader->MutableInstructions().back().get());
}
for (const auto& ctx_value : cand.contexts) {
InsertInstruction(preheader, std::make_unique<StoreInst>(
Type::GetVoidType(), ctx_value.original,
ctx_value.slot));
}
InsertInstruction(preheader, std::make_unique<StoreInst>(Type::GetVoidType(),
cand.bound, bound_slot));
InsertCallBeforeTerminator(preheader, launcher, {}, "");
Value* reduced_value = nullptr;
if (cand.kind == ParallelLoopKind::ReductionAddI32) {
IRBuilder builder(ctx, preheader);
reduced_value = cand.reduction_init;
for (int tid = 0; tid < kParallelThreads; ++tid) {
auto* partial_ptr =
builder.CreateGep(reduction_slot, ctx.GetConstInt(tid), ctx.NextTemp());
auto* partial_val = builder.CreateLoad(partial_ptr, ctx.NextTemp());
reduced_value = builder.CreateAdd(reduced_value, partial_val, ctx.NextTemp());
}
}
cand.induction->RemoveIncomingBlock(preheader);
if (cand.linear) {
cand.linear->RemoveIncomingBlock(preheader);
}
if (cand.reduction) {
cand.reduction->RemoveIncomingBlock(preheader);
}
preheader->RemoveSuccessor(cand.header);
cand.header->RemovePredecessor(preheader);
preheader->AddSuccessor(cand.exit);
cand.exit->AddPredecessor(preheader);
if (cand.kind == ParallelLoopKind::ReductionAddI32 && reduced_value) {
ReplaceUsesOutsideLoop(cand.reduction, reduced_value, cand.loop);
}
preheader->Append<BranchInst>(Type::GetVoidType(), cand.exit);
return true;
}
} // namespace
bool RunLoopParallelization(Module& module) {
bool changed = false;
int store_only_slots = 0;
for (int slot = 0; slot < kParallelLoopSlots; ++slot) {
ParallelLoopCandidate cand;
bool found = false;
for (const auto& func_ptr : module.GetFunctions()) {
auto* func = func_ptr.get();
if (!func || func->IsExternal() || IsGeneratedParallelWorker(*func)) continue;
analysis::DominatorTree dom_tree(*func);
analysis::LoopInfo loop_info(*func, dom_tree);
std::vector<analysis::Loop*> loops;
for (const auto& loop_ptr : loop_info.GetLoops()) {
loops.push_back(loop_ptr.get());
}
std::sort(loops.begin(), loops.end(),
[](analysis::Loop* lhs, analysis::Loop* rhs) {
if (lhs->GetDepth() != rhs->GetDepth()) {
return lhs->GetDepth() < rhs->GetDepth();
}
return lhs->GetHeader()->GetName() < rhs->GetHeader()->GetName();
});
ParallelLoopCandidate fallback_store_only;
bool have_fallback_store_only = false;
for (auto* loop : loops) {
if (BuildParallelCandidate(*func, loop, &cand)) {
if (cand.has_loads) {
found = true;
break;
}
if (store_only_slots < 1 && !have_fallback_store_only) {
fallback_store_only = cand;
have_fallback_store_only = true;
}
}
}
if (!found && have_fallback_store_only) {
cand = fallback_store_only;
found = true;
}
if (found) break;
}
if (!found) break;
bool local_changed = ParallelizeCandidate(module, cand, slot);
changed |= local_changed;
if (local_changed && cand.parent) {
RunSimpleDCE(*cand.parent);
RunCFGSimplify(*cand.parent, module.GetContext());
}
if (!cand.has_loads) {
++store_only_slots;
}
}
return changed;
}
} // namespace passes
} // namespace ir

@ -0,0 +1,309 @@
#pragma once
#include "ir/IR.h"
#include <algorithm>
#include <memory>
#include <optional>
#include <string>
#include <unordered_map>
#include <vector>
namespace ir {
namespace passes {
struct CanonicalInductionInfo {
PhiInst* phi = nullptr;
Value* init = nullptr;
Value* next = nullptr;
int step = 0;
};
struct CanonicalLoopMatch {
analysis::Loop* loop = nullptr;
BasicBlock* preheader = nullptr;
BasicBlock* header = nullptr;
BasicBlock* exit = nullptr;
BasicBlock* body = nullptr;
BasicBlock* latch = nullptr;
CondBranchInst* header_branch = nullptr;
CmpInst* header_cmp = nullptr;
Value* bound = nullptr;
CanonicalInductionInfo induction;
std::vector<PhiInst*> header_phis;
};
using ValueMap = std::unordered_map<Value*, Value*>;
inline Value* RemapValue(Value* value, const ValueMap& remap) {
auto it = remap.find(value);
return it != remap.end() ? it->second : value;
}
inline std::vector<PhiInst*> CollectHeaderPhis(BasicBlock* header) {
std::vector<PhiInst*> phis;
if (!header) return phis;
for (const auto& inst_ptr : header->GetInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst_ptr.get());
if (!phi) break;
phis.push_back(phi);
}
return phis;
}
inline void InsertInstruction(BasicBlock* block,
std::unique_ptr<Instruction> inst) {
auto& insts = block->MutableInstructions();
auto insert_it = insts.end();
if (block->HasTerminator()) {
insert_it = insts.end() - 1;
}
inst->SetParent(block);
insts.insert(insert_it, std::move(inst));
}
inline Instruction* AppendOwnedInstruction(BasicBlock* block,
std::unique_ptr<Instruction> inst) {
auto* raw = inst.get();
InsertInstruction(block, std::move(inst));
return raw;
}
inline BinaryInst* InsertBinaryBeforeTerminator(BasicBlock* block, Opcode opcode,
Value* lhs, Value* rhs,
const std::string& name) {
auto inst = std::make_unique<BinaryInst>(opcode, lhs->GetType(), lhs, rhs, name);
auto* raw = inst.get();
InsertInstruction(block, std::move(inst));
return raw;
}
inline CmpInst* InsertCmpBeforeTerminator(BasicBlock* block, CmpOp cmp_op,
Value* lhs, Value* rhs,
const std::string& name) {
auto inst =
std::make_unique<CmpInst>(cmp_op, Type::GetInt32Type(), lhs, rhs, name);
auto* raw = inst.get();
InsertInstruction(block, std::move(inst));
return raw;
}
inline BranchInst* InsertBranchBeforeTerminator(BasicBlock* block,
BasicBlock* target) {
auto inst = std::make_unique<BranchInst>(Type::GetVoidType(), target);
auto* raw = inst.get();
InsertInstruction(block, std::move(inst));
return raw;
}
inline CondBranchInst* InsertCondBrBeforeTerminator(BasicBlock* block, Value* cond,
BasicBlock* true_bb,
BasicBlock* false_bb) {
auto inst = std::make_unique<CondBranchInst>(Type::GetVoidType(), cond,
true_bb, false_bb);
auto* raw = inst.get();
InsertInstruction(block, std::move(inst));
return raw;
}
inline CallInst* InsertCallBeforeTerminator(BasicBlock* block, Function* callee,
const std::vector<Value*>& args,
const std::string& name) {
auto inst = std::make_unique<CallInst>(callee->GetType(), callee, args, name);
auto* raw = inst.get();
InsertInstruction(block, std::move(inst));
return raw;
}
inline std::string CloneName(const std::string& base, const std::string& suffix) {
if (base.empty()) return base;
return base + suffix;
}
inline std::unique_ptr<Instruction> CloneInstruction(Instruction* inst,
const ValueMap& remap,
const std::string& suffix) {
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Mod: {
auto* bin = static_cast<BinaryInst*>(inst);
return std::make_unique<BinaryInst>(
inst->GetOpcode(), inst->GetType(), RemapValue(bin->GetLhs(), remap),
RemapValue(bin->GetRhs(), remap), CloneName(inst->GetName(), suffix));
}
case Opcode::Cmp: {
auto* cmp = static_cast<CmpInst*>(inst);
return std::make_unique<CmpInst>(
cmp->GetCmpOp(), inst->GetType(), RemapValue(cmp->GetLhs(), remap),
RemapValue(cmp->GetRhs(), remap), CloneName(inst->GetName(), suffix));
}
case Opcode::Cast: {
auto* cast = static_cast<CastInst*>(inst);
return std::make_unique<CastInst>(cast->GetCastOp(), inst->GetType(),
RemapValue(cast->GetValue(), remap),
CloneName(inst->GetName(), suffix));
}
case Opcode::Load: {
auto* load = static_cast<LoadInst*>(inst);
return std::make_unique<LoadInst>(inst->GetType(),
RemapValue(load->GetPtr(), remap),
CloneName(inst->GetName(), suffix));
}
case Opcode::Store: {
auto* store = static_cast<StoreInst*>(inst);
return std::make_unique<StoreInst>(
Type::GetVoidType(), RemapValue(store->GetValue(), remap),
RemapValue(store->GetPtr(), remap));
}
case Opcode::Call: {
auto* call = static_cast<CallInst*>(inst);
std::vector<Value*> args;
args.reserve(call->GetNumArgs());
for (size_t i = 0; i < call->GetNumArgs(); ++i) {
args.push_back(RemapValue(call->GetArg(i), remap));
}
return std::make_unique<CallInst>(inst->GetType(), call->GetCallee(), args,
CloneName(inst->GetName(), suffix));
}
case Opcode::Gep: {
auto* gep = static_cast<GepInst*>(inst);
return std::make_unique<GepInst>(inst->GetType(),
RemapValue(gep->GetBase(), remap),
RemapValue(gep->GetIndex(), remap),
CloneName(inst->GetName(), suffix));
}
default:
return nullptr;
}
}
inline bool IsLoopInvariantValue(Value* value, analysis::Loop* loop) {
if (!value) return false;
if (dynamic_cast<ConstantValue*>(value) != nullptr) return true;
if (dynamic_cast<Argument*>(value) != nullptr) return true;
if (dynamic_cast<GlobalVariable*>(value) != nullptr) return true;
if (dynamic_cast<Function*>(value) != nullptr) return true;
auto* inst = dynamic_cast<Instruction*>(value);
return !inst || !inst->GetParent() || !loop->Contains(inst->GetParent());
}
inline std::optional<CanonicalInductionInfo> MatchCanonicalInduction(
BasicBlock* header, BasicBlock* preheader, BasicBlock* latch) {
if (!header || !preheader || !latch) return std::nullopt;
for (auto* phi : CollectHeaderPhis(header)) {
if (!phi || phi->GetType() == nullptr || !phi->GetType()->IsInt32()) continue;
if (phi->GetNumIncoming() != 2) continue;
Value* init = nullptr;
Value* next = nullptr;
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
auto* incoming_bb = phi->GetIncomingBlock(i);
if (incoming_bb == preheader) {
init = phi->GetIncomingValue(i);
} else if (incoming_bb == latch) {
next = phi->GetIncomingValue(i);
}
}
if (!init || !next) continue;
auto* next_inst = dynamic_cast<BinaryInst*>(next);
if (!next_inst) continue;
if (next_inst->GetOpcode() != Opcode::Add &&
next_inst->GetOpcode() != Opcode::Sub) {
continue;
}
Value* other = nullptr;
bool phi_on_lhs = false;
if (next_inst->GetLhs() == phi) {
other = next_inst->GetRhs();
phi_on_lhs = true;
} else if (next_inst->GetRhs() == phi) {
other = next_inst->GetLhs();
} else {
continue;
}
auto* step_ci = dynamic_cast<ConstantInt*>(other);
if (!step_ci) continue;
int step = step_ci->GetValue();
if (next_inst->GetOpcode() == Opcode::Sub) {
if (!phi_on_lhs) continue;
step = -step;
}
if (step == 0) continue;
return CanonicalInductionInfo{phi, init, next, step};
}
return std::nullopt;
}
inline std::optional<CanonicalLoopMatch> MatchCanonicalLoop(analysis::Loop* loop) {
if (!loop) return std::nullopt;
auto* header = loop->GetHeader();
auto* preheader = loop->GetPreheader();
if (!header || !preheader) return std::nullopt;
if (loop->GetLatches().size() != 1) return std::nullopt;
auto* latch = loop->GetLatches().front();
if (!latch) return std::nullopt;
auto* header_term = header->HasTerminator()
? dynamic_cast<CondBranchInst*>(
header->MutableInstructions().back().get())
: nullptr;
if (!header_term) return std::nullopt;
auto* cmp = dynamic_cast<CmpInst*>(header_term->GetCond());
if (!cmp) return std::nullopt;
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
if (loop->Contains(header_term->GetTrueBlock()) &&
!loop->Contains(header_term->GetFalseBlock())) {
body = header_term->GetTrueBlock();
exit = header_term->GetFalseBlock();
} else if (loop->Contains(header_term->GetFalseBlock()) &&
!loop->Contains(header_term->GetTrueBlock())) {
body = header_term->GetFalseBlock();
exit = header_term->GetTrueBlock();
} else {
return std::nullopt;
}
auto induction = MatchCanonicalInduction(header, preheader, latch);
if (!induction.has_value()) return std::nullopt;
Value* bound = nullptr;
if (cmp->GetLhs() == induction->phi &&
IsLoopInvariantValue(cmp->GetRhs(), loop)) {
bound = cmp->GetRhs();
} else if (cmp->GetRhs() == induction->phi &&
IsLoopInvariantValue(cmp->GetLhs(), loop)) {
bound = cmp->GetLhs();
} else {
return std::nullopt;
}
CanonicalLoopMatch match;
match.loop = loop;
match.preheader = preheader;
match.header = header;
match.exit = exit;
match.body = body;
match.latch = latch;
match.header_branch = header_term;
match.header_cmp = cmp;
match.bound = bound;
match.induction = *induction;
match.header_phis = CollectHeaderPhis(header);
return match;
}
} // namespace passes
} // namespace ir

@ -0,0 +1,143 @@
// 循环展开:
// - 针对单块 innermost 规范循环做因子 2 的保守展开
// - 使用一次额外比较保护余数路径,避免要求静态 trip count
#include "ir/IR.h"
#include <string>
#include <unordered_map>
#include <vector>
#include "LoopPassUtils.h"
namespace ir {
namespace passes {
namespace {
bool IsUnrollableLoop(const CanonicalLoopMatch& match) {
if (match.induction.step <= 0) return false;
if (match.loop->GetChildren().size() != 0) return false;
if (match.loop->GetBlocks().size() != 2) return false;
if (match.body != match.latch) return false;
if (!match.body || !match.body->HasTerminator()) return false;
auto* body_term =
dynamic_cast<BranchInst*>(match.body->MutableInstructions().back().get());
if (!body_term || body_term->GetTarget() != match.header) return false;
if (match.header_cmp->GetLhs() != match.induction.phi) return false;
if (match.header_cmp->GetCmpOp() != CmpOp::Lt &&
match.header_cmp->GetCmpOp() != CmpOp::Le) {
return false;
}
size_t body_inst_count = match.body->GetInstructions().size();
if (body_inst_count <= 1 || body_inst_count > 18) return false;
return true;
}
Value* GetLatchIncomingForPhi(PhiInst* phi, BasicBlock* latch) {
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
if (phi->GetIncomingBlock(i) == latch) {
return phi->GetIncomingValue(i);
}
}
return nullptr;
}
bool RunUnrollOnLoop(Function& func, const CanonicalLoopMatch& match,
Context& ctx, int unroll_index) {
(void)unroll_index;
if (!IsUnrollableLoop(match)) return false;
auto* body = match.body;
auto* header = match.header;
auto* body_term =
static_cast<BranchInst*>(body->MutableInstructions().back().get());
(void)body_term;
std::string block_suffix = ctx.NextTemp();
if (!block_suffix.empty() && block_suffix.front() == '%') {
block_suffix.erase(0, 1);
}
auto* body2 = func.CreateBlock(body->GetName() + ".unroll." + block_suffix);
std::unordered_map<Value*, Value*> seed_map;
for (auto* phi : match.header_phis) {
auto* incoming = GetLatchIncomingForPhi(phi, match.latch);
if (!incoming) return false;
seed_map.emplace(phi, incoming);
}
ValueMap clone_map = seed_map;
std::vector<Instruction*> originals;
for (const auto& inst_ptr : body->GetInstructions()) {
if (inst_ptr.get()->IsTerminator()) continue;
originals.push_back(inst_ptr.get());
}
for (auto* inst : originals) {
auto cloned = CloneInstruction(inst, clone_map, ".u2");
if (!cloned) return false;
auto* raw = cloned.get();
body2->MutableInstructions().push_back(std::move(cloned));
raw->SetParent(body2);
clone_map[inst] = raw;
}
body2->Append<BranchInst>(Type::GetVoidType(), header);
Value* iv_after_one = GetLatchIncomingForPhi(match.induction.phi, match.latch);
if (!iv_after_one) return false;
auto* first_cmp = InsertCmpBeforeTerminator(
body, match.header_cmp->GetCmpOp(), iv_after_one, match.bound,
ctx.NextTemp());
body->RemoveInstruction(body->MutableInstructions().back().get());
body->Append<CondBranchInst>(Type::GetVoidType(), first_cmp, body2, header);
body->AddSuccessor(body2);
body2->AddPredecessor(body);
body2->AddSuccessor(header);
header->AddPredecessor(body2);
for (auto* phi : match.header_phis) {
auto* incoming = GetLatchIncomingForPhi(phi, match.latch);
Value* second_value = incoming;
auto it = clone_map.find(incoming);
if (it != clone_map.end()) {
second_value = it->second;
} else {
second_value = RemapValue(incoming, clone_map);
}
phi->AddIncoming(second_value, body2);
}
return true;
}
} // namespace
bool RunLoopUnroll(Function& func, Context& ctx) {
if (func.IsExternal()) return false;
analysis::DominatorTree dom_tree(func);
analysis::LoopInfo loop_info(func, dom_tree);
bool changed = false;
int unroll_index = 0;
for (const auto& loop_ptr : loop_info.GetLoops()) {
auto match = MatchCanonicalLoop(loop_ptr.get());
if (!match.has_value()) continue;
if (RunUnrollOnLoop(func, *match, ctx, unroll_index++)) {
changed = true;
break;
}
}
return changed;
}
} // namespace passes
} // namespace ir

@ -8,13 +8,17 @@
#include "ir/IR.h"
#include <iostream>
namespace ir {
namespace passes {
// 前向声明各 pass 入口
bool RunMem2Reg(Function& func);
bool RunLICM(Function& func);
bool RunStrengthReduction(Function& func, Context& ctx);
bool RunLoopIdiom(Function& func, Module& module, Context& ctx);
bool RunLoopFission(Function& func, Context& ctx);
bool RunLoopUnroll(Function& func, Context& ctx);
bool RunLoopParallelization(Module& module);
bool RunConstFoldWithCtx(Function& func, Context& ctx);
bool RunConstProp(Function& func, Context& ctx);
bool RunCSE(Function& func);
@ -34,6 +38,38 @@ void RunAllPasses(Module& module) {
for (int iter = 0; iter < kMaxIterations; ++iter) {
bool changed = false;
changed |= RunLICM(*func);
changed |= RunStrengthReduction(*func, ctx);
changed |= RunLoopIdiom(*func, module, ctx);
changed |= RunConstFoldWithCtx(*func, ctx);
changed |= RunConstProp(*func, ctx);
changed |= RunCSE(*func);
changed |= RunSimpleDCE(*func);
changed |= RunCFGSimplify(*func, ctx);
if (!changed) break;
}
}
if (RunLoopParallelization(module)) {
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsExternal()) continue;
RunSimpleDCE(*func);
RunCFGSimplify(*func, ctx);
}
}
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsExternal()) continue;
for (int iter = 0; iter < kMaxIterations; ++iter) {
bool changed = false;
changed |= RunLICM(*func);
changed |= RunStrengthReduction(*func, ctx);
changed |= RunLoopIdiom(*func, module, ctx);
changed |= RunLoopFission(*func, ctx);
changed |= RunLoopUnroll(*func, ctx);
changed |= RunConstFoldWithCtx(*func, ctx);
changed |= RunConstProp(*func, ctx);
changed |= RunCSE(*func);

@ -0,0 +1,115 @@
// 强度削弱:
// - 识别规范归纳变量 iv
// - 将循环内的 iv * C 改写成辅助 phi + 常量增量递推
#include "ir/IR.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <vector>
#include "LoopPassUtils.h"
namespace ir {
namespace passes {
namespace {
bool IsStrengthReductionCandidate(Instruction* inst, PhiInst* iv) {
auto* bin = dynamic_cast<BinaryInst*>(inst);
if (!bin || bin->GetOpcode() != Opcode::Mul) return false;
if (!bin->GetType() || !bin->GetType()->IsInt32()) return false;
return bin->GetLhs() == iv || bin->GetRhs() == iv;
}
int ExtractScale(BinaryInst* mul, PhiInst* iv) {
auto* lhs_ci = dynamic_cast<ConstantInt*>(mul->GetLhs());
auto* rhs_ci = dynamic_cast<ConstantInt*>(mul->GetRhs());
if (mul->GetLhs() == iv && rhs_ci) return rhs_ci->GetValue();
if (mul->GetRhs() == iv && lhs_ci) return lhs_ci->GetValue();
return 0;
}
bool ReplaceMulWithRecurrence(Function& func, const CanonicalLoopMatch& match,
BinaryInst* mul, Context& ctx) {
(void)func;
const int scale = ExtractScale(mul, match.induction.phi);
if (scale == 0) return false;
if (match.latch != mul->GetParent() &&
!match.loop->Contains(mul->GetParent())) {
return false;
}
auto* init_scale =
InsertBinaryBeforeTerminator(match.preheader, Opcode::Mul,
match.induction.init, ctx.GetConstInt(scale),
ctx.NextTemp());
auto* sr_phi =
match.header->PrependPhi(Type::GetInt32Type(), ctx.NextTemp());
auto* step_scale =
InsertBinaryBeforeTerminator(match.latch, Opcode::Mul,
ctx.GetConstInt(match.induction.step),
ctx.GetConstInt(scale), ctx.NextTemp());
auto* sr_next =
InsertBinaryBeforeTerminator(match.latch, Opcode::Add, sr_phi, step_scale,
ctx.NextTemp());
sr_phi->AddIncoming(init_scale, match.preheader);
sr_phi->AddIncoming(sr_next, match.latch);
mul->ReplaceAllUsesWith(sr_phi);
mul->GetParent()->RemoveInstruction(mul);
return true;
}
} // namespace
bool RunStrengthReduction(Function& func, Context& ctx) {
if (func.IsExternal()) return false;
analysis::DominatorTree dom_tree(func);
analysis::LoopInfo loop_info(func, dom_tree);
std::vector<analysis::Loop*> loops;
for (const auto& loop_ptr : loop_info.GetLoops()) {
loops.push_back(loop_ptr.get());
}
std::sort(loops.begin(), loops.end(),
[](analysis::Loop* lhs, analysis::Loop* rhs) {
if (lhs->GetDepth() != rhs->GetDepth()) {
return lhs->GetDepth() > rhs->GetDepth();
}
return lhs->GetHeader()->GetName() < rhs->GetHeader()->GetName();
});
bool changed = false;
for (auto* loop : loops) {
auto match = MatchCanonicalLoop(loop);
if (!match.has_value()) continue;
std::vector<BinaryInst*> candidates;
for (auto* block : loop->GetBlocks()) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (IsStrengthReductionCandidate(inst, match->induction.phi)) {
auto* mul = static_cast<BinaryInst*>(inst);
if (ExtractScale(mul, match->induction.phi) != 0) {
candidates.push_back(mul);
}
}
}
}
for (auto* mul : candidates) {
changed |= ReplaceMulWithRecurrence(func, *match, mul, ctx);
}
}
return changed;
}
} // namespace passes
} // namespace ir

@ -160,6 +160,7 @@ int main(int argc, char** argv) {
mir::RunPeephole(*func_ptr);
mir::RunRegAlloc(*func_ptr);
mir::RunFrameLowering(*func_ptr);
mir::RunPeephole(*func_ptr);
}
if (need_blank_line) {
std::cout << "\n";

@ -18,6 +18,11 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function,
return function.GetFrameSlot(operand.GetFrameIndex());
}
std::string LocalBlockLabel(const MachineFunction& function,
const std::string& block_name) {
return "." + function.GetName() + "." + block_name;
}
void PrintMoveImm32(std::ostream& os, PhysReg reg, int imm) {
std::uint32_t u = static_cast<std::uint32_t>(imm);
std::uint32_t lo = u & 0xFFFFu;
@ -157,7 +162,7 @@ void PrintAsm(const MachineModule& module, std::ostream& os) {
// 打印块标签entry 块不需要标签,因为函数名已经是标签了)
if (bb.GetName() != "entry") {
os << "." << bb.GetName() << ":\n";
os << LocalBlockLabel(function, bb.GetName()) << ":\n";
}
for (const auto& inst : bb.GetInstructions()) {
@ -382,25 +387,30 @@ void PrintAsm(const MachineModule& module, std::ostream& os) {
os << " bl " << ops.at(0).GetSymbol() << "\n";
break;
case Opcode::B:
os << " b ." << ops.at(0).GetSymbol() << "\n";
os << " b " << LocalBlockLabel(function, ops.at(0).GetSymbol())
<< "\n";
break;
case Opcode::Cbnz:
os << " cbnz " << PhysRegName(ops.at(0).GetReg())
<< ", ." << ops.at(1).GetSymbol() << "\n";
<< ", " << LocalBlockLabel(function, ops.at(1).GetSymbol())
<< "\n";
break;
case Opcode::Cbz:
os << " cbz " << PhysRegName(ops.at(0).GetReg())
<< ", ." << ops.at(1).GetSymbol() << "\n";
<< ", " << LocalBlockLabel(function, ops.at(1).GetSymbol())
<< "\n";
break;
case Opcode::Bcond:
// ops: symbol, cmpop(imm)
os << " b." << CondSuffix(static_cast<ir::CmpOp>(ops.at(1).GetImm()))
<< " ." << ops.at(0).GetSymbol() << "\n";
<< " " << LocalBlockLabel(function, ops.at(0).GetSymbol())
<< "\n";
break;
case Opcode::FBcond:
// ops: symbol, cmpop(imm) - 浮点条件分支
os << " b." << FloatCondSuffix(static_cast<ir::CmpOp>(ops.at(1).GetImm()))
<< " ." << ops.at(0).GetSymbol() << "\n";
<< " " << LocalBlockLabel(function, ops.at(0).GetSymbol())
<< "\n";
break;
case Opcode::Ret:
os << " ret\n";

@ -818,7 +818,9 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
return;
}
// 参数传递:根据类型使用 w0-w7整数、s0-s7浮点或 x0-x7指针
// 参数传递:根据类型使用 w0-w7整数、s0-s7浮点或 x0-x7指针
// 注意:相关的 Peephole 不会再把装参前的 LoadStack 转发成 ABI
// 参数寄存器之间的 mov以避免 W/X 别名覆盖。
size_t num_args = call.GetNumArgs();
if (num_args > 8) {
throw std::runtime_error(FormatError("mir", "暂不支持超过 8 个参数的函数调用"));
@ -1002,6 +1004,26 @@ std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) {
Opcode::FCmpOnlyRR,
{Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::S1)});
} else {
const ir::ConstantInt* lhs_ci = TryGetConstInt(cmp_inst->GetLhs());
const ir::ConstantInt* rhs_ci = TryGetConstInt(cmp_inst->GetRhs());
bool lhs_zero = lhs_ci && lhs_ci->GetValue() == 0;
bool rhs_zero = rhs_ci && rhs_ci->GetValue() == 0;
if ((cmp_inst->GetCmpOp() == ir::CmpOp::Eq ||
cmp_inst->GetCmpOp() == ir::CmpOp::Ne) &&
(lhs_zero || rhs_zero)) {
auto* non_zero_side = lhs_zero ? cmp_inst->GetRhs() : cmp_inst->GetLhs();
EmitValueToReg(non_zero_side, PhysReg::W8, slots, *current_mbb);
current_mbb->Append(
cmp_inst->GetCmpOp() == ir::CmpOp::Eq ? Opcode::Cbz
: Opcode::Cbnz,
{Operand::Reg(PhysReg::W8),
Operand::Symbol(true_mbb->GetName())});
current_mbb->Append(Opcode::B,
{Operand::Symbol(false_mbb->GetName())});
++i; // 同时跳过后继 CondBr
continue;
}
EmitValueToReg(cmp_inst->GetLhs(), PhysReg::W8, slots, *current_mbb);
EmitValueToReg(cmp_inst->GetRhs(), PhysReg::W9, slots, *current_mbb);
current_mbb->Append(
@ -1078,7 +1100,8 @@ std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) {
for (size_t j = pred_insts.size(); j > 0; --j) {
auto op = pred_insts[j - 1].GetOpcode();
if (op == Opcode::B || op == Opcode::Bcond ||
op == Opcode::Cbnz || op == Opcode::FBcond) {
op == Opcode::Cbnz || op == Opcode::Cbz ||
op == Opcode::FBcond) {
insert_pos = j - 1;
} else {
break;

@ -8,6 +8,18 @@
namespace mir {
namespace {
PhysReg CanonicalCalleeSavedReg(PhysReg reg) {
if (reg >= PhysReg::W0 && reg <= PhysReg::W11) {
int idx = static_cast<int>(reg) - static_cast<int>(PhysReg::W0);
return static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + idx);
}
return reg;
}
} // namespace
MachineFunction::MachineFunction(std::string name)
: name_(std::move(name)) {
// 创建入口块
@ -54,6 +66,7 @@ int MachineFunction::CreateVReg(VRegClass rc) {
}
void MachineFunction::AddUsedCalleeSaved(PhysReg reg) {
reg = CanonicalCalleeSavedReg(reg);
if (std::find(used_callee_saved_.begin(), used_callee_saved_.end(), reg) ==
used_callee_saved_.end()) {
used_callee_saved_.push_back(reg);

@ -15,6 +15,13 @@ bool IsMovLike(Opcode opcode) { return opcode == Opcode::MovReg || opcode == Opc
bool IsFloatReg(PhysReg reg) { return reg >= PhysReg::S0 && reg <= PhysReg::S10; }
bool IsAbiArgReg(PhysReg reg) {
if (reg >= PhysReg::W0 && reg <= PhysReg::W7) return true;
if (reg >= PhysReg::X0 && reg <= PhysReg::X7) return true;
if (reg >= PhysReg::S0 && reg <= PhysReg::S7) return true;
return false;
}
bool IsWxReg(PhysReg reg) {
return (reg >= PhysReg::W0 && reg <= PhysReg::W10) ||
(reg >= PhysReg::X0 && reg <= PhysReg::X10);
@ -92,6 +99,88 @@ std::optional<PhysReg> GetWrittenReg(const MachineInstr& inst) {
}
}
bool ReadsReg(const MachineInstr& inst, PhysReg reg) {
const auto& ops = inst.GetOperands();
auto reads_operand = [&](size_t idx) {
return idx < ops.size() && ops[idx].GetKind() == Operand::Kind::Reg &&
RegAlias(ops[idx].GetReg(), reg);
};
switch (inst.GetOpcode()) {
case Opcode::MovReg:
case Opcode::FMovReg:
case Opcode::AddRI:
case Opcode::SubRI:
case Opcode::LoadStackOffset:
case Opcode::LoadIndirect:
case Opcode::LoadGlobal:
case Opcode::LoadGlobalAddr:
case Opcode::SIToFP:
case Opcode::FPToSI:
case Opcode::LslRI:
return reads_operand(1);
case Opcode::AddRR:
case Opcode::AddRR_UXTW:
case Opcode::SubRR:
case Opcode::MulRR:
case Opcode::DivRR:
case Opcode::LslRR:
case Opcode::FAddRR:
case Opcode::FSubRR:
case Opcode::FMulRR:
case Opcode::FDivRR:
case Opcode::CmpRR:
case Opcode::FCmpRR:
case Opcode::CmpOnlyRR:
case Opcode::FCmpOnlyRR:
return reads_operand(1) || reads_operand(2);
case Opcode::StoreStack:
case Opcode::StoreStackOffset:
case Opcode::Cbz:
case Opcode::Cbnz:
return reads_operand(0);
case Opcode::StoreIndirect:
return reads_operand(0) || reads_operand(1);
case Opcode::Ret:
return false;
default:
return false;
}
}
bool CanElideIfOverwritten(const MachineInstr& inst) {
switch (inst.GetOpcode()) {
case Opcode::MovImm:
case Opcode::MovReg:
case Opcode::FMovImm:
case Opcode::FMovReg:
case Opcode::LoadStack:
case Opcode::LoadStackOffset:
case Opcode::LoadStackAddr:
case Opcode::LoadIndirect:
case Opcode::LoadGlobal:
case Opcode::LoadGlobalAddr:
case Opcode::AddRI:
case Opcode::SubRI:
case Opcode::AddRR:
case Opcode::AddRR_UXTW:
case Opcode::SubRR:
case Opcode::MulRR:
case Opcode::DivRR:
case Opcode::LslRI:
case Opcode::LslRR:
case Opcode::FAddRR:
case Opcode::FSubRR:
case Opcode::FMulRR:
case Opcode::FDivRR:
case Opcode::SIToFP:
case Opcode::FPToSI:
return true;
default:
return false;
}
}
bool IsMemoryClobber(const MachineInstr& inst) {
switch (inst.GetOpcode()) {
case Opcode::StoreIndirect:
@ -161,6 +250,13 @@ bool TryForwardLoad(std::vector<MachineInstr>& out,
const PhysReg src = it->second;
// 避免把装参前的 LoadStack 转成 ABI 参数寄存器之间的 mov
// 否则可能触发 W/X 别名覆盖,破坏调用实参。若源寄存器不是 ABI
// 参数寄存器,则转成 mov 仍然是安全的。
if (IsAbiArgReg(dst) && IsAbiArgReg(src) && !RegAlias(src, dst)) {
return false;
}
// 宽度不匹配时不能转发(如 W8 → X8 会生成非法的 mov x8, w8
if (!SameRegWidth(src, dst)) {
return false;
@ -220,6 +316,16 @@ void RunPeephole(MachineFunction& function) {
continue;
}
if (i + 1 < insts.size() && CanElideIfOverwritten(cur)) {
auto wr_cur = GetWrittenReg(cur);
auto wr_next = GetWrittenReg(insts[i + 1]);
if (wr_cur.has_value() && wr_next.has_value() &&
RegAlias(*wr_cur, *wr_next) &&
!ReadsReg(insts[i + 1], *wr_cur)) {
continue;
}
}
// mov #2 + lsl reg, reg, mov_reg -> lsl reg, reg, #2
if (i + 1 < insts.size()) {
PhysReg imm_reg = PhysReg::W0;
@ -314,4 +420,3 @@ void RunPeephole(MachineFunction& function) {
}
} // namespace mir

@ -5,6 +5,9 @@
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <pthread.h>
#include <stdint.h>
#include <string.h>
int getint() { int v; scanf("%d", &v); return v; }
int getch() { return getchar(); }
@ -41,3 +44,154 @@ void stoptime(int l) {
fprintf(stderr, "Timer@%d: %ldms\n", l,
(t1.tv_sec-_t0.tv_sec)*1000+(t1.tv_nsec-_t0.tv_nsec)/1000000);
}
void __fill_i32(int* base, int count, int value) {
if (!base || count <= 0) return;
if (value == 0 || value == -1) {
memset(base, value & 0xff, (size_t)count * sizeof(int));
return;
}
for (int i = 0; i < count; ++i) {
base[i] = value;
}
}
typedef struct {
int* base;
int start_offset;
int start_row;
int end_row;
int stride;
int count;
int value;
} fill_rows_task_t;
static void* __fill_rows_worker(void* opaque) {
fill_rows_task_t* task = (fill_rows_task_t*)opaque;
for (int row = task->start_row; row < task->end_row; ++row) {
int* row_ptr = task->base + task->start_offset + row * task->stride;
__fill_i32(row_ptr, task->count, task->value);
}
return NULL;
}
void __fill_rows_i32(int* base, int start_offset, int rows, int stride, int count,
int value) {
if (!base || rows <= 0 || stride <= 0 || count <= 0) return;
if (rows < 32) {
fill_rows_task_t task = {base, start_offset, 0, rows, stride, count, value};
__fill_rows_worker(&task);
return;
}
pthread_t tids[3];
fill_rows_task_t tasks[4];
for (int tid = 0; tid < 4; ++tid) {
int begin = tid * rows / 4;
int end = (tid + 1) * rows / 4;
tasks[tid].base = base;
tasks[tid].start_offset = start_offset;
tasks[tid].start_row = begin;
tasks[tid].end_row = end;
tasks[tid].stride = stride;
tasks[tid].count = count;
tasks[tid].value = value;
}
for (int tid = 1; tid < 4; ++tid) {
pthread_create(&tids[tid - 1], NULL, __fill_rows_worker, &tasks[tid]);
}
__fill_rows_worker(&tasks[0]);
for (int tid = 0; tid < 3; ++tid) {
pthread_join(tids[tid], NULL);
}
}
typedef void (*par_worker_fn_t)(int);
typedef struct {
pthread_mutex_t mutex;
pthread_cond_t start_cv;
pthread_cond_t done_cv;
int generation;
int remaining;
int helper_count;
} par_slot_state_t;
typedef struct {
par_slot_state_t* state;
par_worker_fn_t worker;
int tid;
} par_thread_arg_t;
static void* __par_pool_worker(void* opaque) {
par_thread_arg_t* arg = (par_thread_arg_t*)opaque;
par_slot_state_t* state = arg->state;
int seen_generation = 0;
pthread_mutex_lock(&state->mutex);
for (;;) {
while (state->generation == seen_generation) {
pthread_cond_wait(&state->start_cv, &state->mutex);
}
seen_generation = state->generation;
pthread_mutex_unlock(&state->mutex);
arg->worker(arg->tid);
pthread_mutex_lock(&state->mutex);
if (state->remaining > 0) {
--state->remaining;
if (state->remaining == 0) {
pthread_cond_signal(&state->done_cv);
}
}
}
}
#define DECL_PAR_SLOT(N) \
extern void __par_worker##N(int) __attribute__((weak)); \
static par_slot_state_t __par_state##N = { \
PTHREAD_MUTEX_INITIALIZER, PTHREAD_COND_INITIALIZER, \
PTHREAD_COND_INITIALIZER, 0, 0, 0}; \
static pthread_once_t __par_once##N = PTHREAD_ONCE_INIT; \
static par_thread_arg_t __par_args##N[3]; \
static void __par_init##N(void) { \
pthread_attr_t attr; \
pthread_t tid; \
pthread_attr_init(&attr); \
pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_DETACHED); \
for (int i = 0; i < 3; ++i) { \
__par_args##N[i].state = &__par_state##N; \
__par_args##N[i].worker = __par_worker##N; \
__par_args##N[i].tid = i + 1; \
if (pthread_create(&tid, &attr, __par_pool_worker, \
&__par_args##N[i]) == 0) { \
++__par_state##N.helper_count; \
} \
} \
pthread_attr_destroy(&attr); \
} \
void __par_run##N(void) { \
if (!__par_worker##N) return; \
pthread_once(&__par_once##N, __par_init##N); \
pthread_mutex_lock(&__par_state##N.mutex); \
__par_state##N.remaining = __par_state##N.helper_count; \
++__par_state##N.generation; \
pthread_cond_broadcast(&__par_state##N.start_cv); \
pthread_mutex_unlock(&__par_state##N.mutex); \
__par_worker##N(0); \
pthread_mutex_lock(&__par_state##N.mutex); \
while (__par_state##N.remaining != 0) { \
pthread_cond_wait(&__par_state##N.done_cv, &__par_state##N.mutex); \
} \
pthread_mutex_unlock(&__par_state##N.mutex); \
}
DECL_PAR_SLOT(0)
DECL_PAR_SLOT(1)
DECL_PAR_SLOT(2)
DECL_PAR_SLOT(3)
DECL_PAR_SLOT(4)
DECL_PAR_SLOT(5)
DECL_PAR_SLOT(6)
DECL_PAR_SLOT(7)

Loading…
Cancel
Save