forked from NUDT-compiler/nudt-compiler-cpp
parent
9e8984d740
commit
4df492feb9
@ -0,0 +1,211 @@
|
||||
# Session Handoff
|
||||
|
||||
Date: 2026-05-28
|
||||
|
||||
## Repo State
|
||||
|
||||
- Current branch: `Shrink`
|
||||
- Worktree is dirty; do not reset blindly.
|
||||
- Modified tracked files:
|
||||
- `include/ir/IR.h`
|
||||
- `scripts/run_all_tests.sh`
|
||||
- `scripts/verify_asm.sh`
|
||||
- `scripts/verify_ir.sh`
|
||||
- `src/ir/analysis/DominatorTree.cpp`
|
||||
- `src/ir/analysis/LoopInfo.cpp`
|
||||
- `src/ir/passes/CMakeLists.txt`
|
||||
- `src/ir/passes/PassManager.cpp`
|
||||
- `src/main.cpp`
|
||||
- `src/mir/AsmPrinter.cpp`
|
||||
- `src/mir/Lowering.cpp`
|
||||
- `src/mir/MIRFunction.cpp`
|
||||
- `src/mir/passes/Peephole.cpp`
|
||||
- `sylib/sylib.c`
|
||||
- New untracked files:
|
||||
- `src/ir/passes/LICM.cpp`
|
||||
- `src/ir/passes/LoopFission.cpp`
|
||||
- `src/ir/passes/LoopIdiom.cpp`
|
||||
- `src/ir/passes/LoopParallelize.cpp`
|
||||
- `src/ir/passes/LoopPassUtils.h`
|
||||
- `src/ir/passes/LoopUnroll.cpp`
|
||||
- `src/ir/passes/StrengthReduction.cpp`
|
||||
|
||||
## Toolchain On Current Machine
|
||||
|
||||
- `cmake 3.22.1`
|
||||
- `g++ 11.4.0`
|
||||
- `clang 14.0.0`
|
||||
- `llc 14.0.0`
|
||||
- `aarch64-linux-gnu-gcc 11.4.0`
|
||||
- `qemu-aarch64 6.2.0`
|
||||
|
||||
Required packages on a fresh Ubuntu:
|
||||
|
||||
```bash
|
||||
sudo apt update
|
||||
sudo apt install -y \
|
||||
build-essential \
|
||||
cmake \
|
||||
clang \
|
||||
llvm \
|
||||
gcc-aarch64-linux-gnu \
|
||||
qemu-user \
|
||||
libc6-arm64-cross
|
||||
```
|
||||
|
||||
## Important Build Detail
|
||||
|
||||
- The repo vendors `antlr4-runtime-4.13.2` in `third_party`, so no system ANTLR runtime install is needed.
|
||||
- Current frontend build consumes generated parser sources from `build/generated/antlr4` if present.
|
||||
- There is also parser source in `src/antlr4/`, but current CMake does not wire that directory directly into the build.
|
||||
- Safest migration path: copy the repo together with the current `build/generated/antlr4` directory, or later patch CMake to use `src/antlr4/*.cpp`.
|
||||
|
||||
## Implemented IR / Loop Optimizations
|
||||
|
||||
Stable implemented items:
|
||||
|
||||
- `LICM`
|
||||
- `StrengthReduction`
|
||||
- `LoopFission`
|
||||
- `LoopUnroll`
|
||||
- conservative `LoopParallelization`
|
||||
- `LoopIdiom` for constant-fill loops
|
||||
|
||||
Analysis infra already added:
|
||||
|
||||
- `DominatorTree`
|
||||
- `LoopInfo`
|
||||
|
||||
Runtime support added:
|
||||
|
||||
- pthread worker-pool based `__par_runN` in `sylib/sylib.c`
|
||||
- `__fill_i32` helper in `sylib/sylib.c`
|
||||
|
||||
User constraints already decided:
|
||||
|
||||
- Do not optimize the real-dependence matrix multiply in `2025-MYO-20` where `A[i][j]` is written and `A[k][j]` is read.
|
||||
- Reduction parallelization is still disabled.
|
||||
|
||||
## Timing Scripts
|
||||
|
||||
Timing output was added to:
|
||||
|
||||
- `scripts/verify_ir.sh`
|
||||
- `scripts/verify_asm.sh`
|
||||
- `scripts/run_all_tests.sh`
|
||||
|
||||
User requirement:
|
||||
|
||||
- Every test round should always report:
|
||||
- `test/test_case/performance/2025-MYO-20.sy`
|
||||
- `./scripts/run_all_tests.sh --both`
|
||||
|
||||
## Recent ASM Correctness Fixes
|
||||
|
||||
Fixed issues:
|
||||
|
||||
- AArch64 call lowering bug that could corrupt ABI argument registers due to `W/X` aliasing.
|
||||
- Duplicate local labels like `.par.exit` across worker functions by prefixing block labels with the function name.
|
||||
- Duplicate callee-saved save/restore of alias registers like `w8/x8`.
|
||||
|
||||
Relevant files:
|
||||
|
||||
- `src/mir/Lowering.cpp`
|
||||
- `src/mir/AsmPrinter.cpp`
|
||||
- `src/mir/MIRFunction.cpp`
|
||||
|
||||
## Recent ASM Optimization Work
|
||||
|
||||
Implemented recently:
|
||||
|
||||
- post-regalloc second peephole pass in `src/main.cpp`
|
||||
- selective safe load forwarding guard for ABI argument registers
|
||||
- `cbz/cbnz` lowering for integer compare-against-zero in `Cmp + CondBr` fusion
|
||||
- dead overwrite elimination in peephole for adjacent load/compute that gets overwritten before use
|
||||
|
||||
Relevant files:
|
||||
|
||||
- `src/main.cpp`
|
||||
- `src/mir/Lowering.cpp`
|
||||
- `src/mir/passes/Peephole.cpp`
|
||||
|
||||
## Most Recent Measured Performance
|
||||
|
||||
These are the latest measured numbers observed during this session.
|
||||
|
||||
IR:
|
||||
|
||||
- `2025-MYO-20` stable reference before latest ASM-only work:
|
||||
- around `31.109s`
|
||||
- earlier stable reference before that: around `30.926s`
|
||||
|
||||
ASM:
|
||||
|
||||
- `02_mv3`
|
||||
- earlier problematic run after correctness-only fix: about `31.662s`
|
||||
- after later backend cleanup, best observed run in this session: about `31.505s`
|
||||
- another later run: about `31.529s`
|
||||
- `01_mm2`
|
||||
- earlier reference in this session: about `38.010s`
|
||||
- later improved run: about `37.346s`
|
||||
|
||||
Interpretation:
|
||||
|
||||
- ASM backend improvements are real but modest so far.
|
||||
- Main remaining bottleneck is still heavy stack traffic in hot loops.
|
||||
|
||||
## Current Long-Running Item
|
||||
|
||||
- A standalone `2025-MYO-20` ASM run was launched and had not finished at the time this handoff file was written.
|
||||
- A full `./scripts/run_all_tests.sh --both` run had progressed to the final `2025-MYO-20` ASM item instead of failing early, but final completion time was still pending.
|
||||
|
||||
## Good Commands To Resume Work
|
||||
|
||||
Build:
|
||||
|
||||
```bash
|
||||
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build -j"$(nproc)" --target compiler
|
||||
```
|
||||
|
||||
Quick correctness:
|
||||
|
||||
```bash
|
||||
./scripts/verify_ir.sh test/test_case/functional/simple_add.sy /tmp/ir_check --run
|
||||
./scripts/verify_asm.sh test/test_case/functional/simple_add.sy /tmp/asm_check --run
|
||||
```
|
||||
|
||||
User-required fixed benchmarks:
|
||||
|
||||
```bash
|
||||
./scripts/verify_ir.sh test/test_case/performance/2025-MYO-20.sy /tmp/timed_2025 --run
|
||||
./scripts/run_all_tests.sh --both
|
||||
```
|
||||
|
||||
Useful ASM profiling targets:
|
||||
|
||||
```bash
|
||||
./scripts/verify_asm.sh test/test_case/performance/01_mm2.sy /tmp/asm_mm2 --run
|
||||
./scripts/verify_asm.sh test/test_case/performance/02_mv3.sy /tmp/asm_mv3 --run
|
||||
./scripts/verify_asm.sh test/test_case/performance/2025-MYO-20.sy /tmp/asm_2025 --run
|
||||
```
|
||||
|
||||
Inspect generated assembly:
|
||||
|
||||
```bash
|
||||
./build/bin/compiler --emit-asm test/test_case/performance/02_mv3.sy > /tmp/02_mv3.s
|
||||
./build/bin/compiler --emit-asm test/test_case/performance/01_mm2.sy > /tmp/01_mm2.s
|
||||
./build/bin/compiler --emit-asm test/test_case/performance/2025-MYO-20.sy > /tmp/2025.s
|
||||
```
|
||||
|
||||
## Suggested Next Steps
|
||||
|
||||
Priority order:
|
||||
|
||||
1. Finish measuring `2025-MYO-20` ASM and a complete `--both` run on the faster Ubuntu machine.
|
||||
2. Keep working on MIR/ASM backend, not IR parallelization.
|
||||
3. Target hot-loop stack traffic:
|
||||
- reduce phi-related spill/reload churn
|
||||
- widen zero-compare branch simplification beyond the current fused path
|
||||
- add more dead store / dead load cleanup after frame lowering
|
||||
4. Only claim speedups when confirmed with the fixed benchmark pair above.
|
||||
@ -0,0 +1,171 @@
|
||||
// 循环分裂:
|
||||
// - 针对单块循环中两段彼此独立的 store 语句组做保守分裂
|
||||
// - 仅处理单归纳变量、无其他 loop-carried phi 的情形
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
namespace ir {
|
||||
namespace passes {
|
||||
|
||||
namespace {
|
||||
|
||||
Value* StripPointerBase(Value* value) {
|
||||
while (auto* gep = dynamic_cast<GepInst*>(value)) {
|
||||
value = gep->GetBase();
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
bool IsFissionCandidate(const CanonicalLoopMatch& match) {
|
||||
if (match.loop->GetChildren().size() != 0) return false;
|
||||
if (match.loop->GetBlocks().size() != 2) return false;
|
||||
if (match.body != match.latch) return false;
|
||||
if (match.header_phis.size() != 1) return false;
|
||||
if (match.header_phis.front() != match.induction.phi) return false;
|
||||
if (match.induction.step <= 0) return false;
|
||||
auto* body_term =
|
||||
dynamic_cast<BranchInst*>(match.body->MutableInstructions().back().get());
|
||||
return body_term && body_term->GetTarget() == match.header;
|
||||
}
|
||||
|
||||
bool DependsOnAny(Instruction* inst, const std::unordered_set<Instruction*>& defs) {
|
||||
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
|
||||
auto* def = dynamic_cast<Instruction*>(inst->GetOperand(i));
|
||||
if (def && defs.count(def) != 0) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool RunFissionOnLoop(Function& func, const CanonicalLoopMatch& match,
|
||||
Context& ctx) {
|
||||
if (!IsFissionCandidate(match)) return false;
|
||||
|
||||
std::vector<Instruction*> body_insts;
|
||||
for (const auto& inst_ptr : match.body->GetInstructions()) {
|
||||
if (!inst_ptr.get()->IsTerminator()) {
|
||||
body_insts.push_back(inst_ptr.get());
|
||||
}
|
||||
}
|
||||
if (body_insts.size() < 3) return false;
|
||||
|
||||
auto* iv_next = dynamic_cast<Instruction*>(match.induction.next);
|
||||
if (!iv_next || iv_next->GetParent() != match.body) return false;
|
||||
|
||||
std::vector<size_t> store_positions;
|
||||
for (size_t i = 0; i < body_insts.size(); ++i) {
|
||||
if (dynamic_cast<StoreInst*>(body_insts[i]) != nullptr) {
|
||||
store_positions.push_back(i);
|
||||
}
|
||||
}
|
||||
if (store_positions.size() != 2) return false;
|
||||
|
||||
const size_t first_store_idx = store_positions[0];
|
||||
const size_t second_store_idx = store_positions[1];
|
||||
if (body_insts.back() != iv_next) return false;
|
||||
if (second_store_idx + 1 != body_insts.size() - 1) return false;
|
||||
|
||||
auto* first_store = static_cast<StoreInst*>(body_insts[first_store_idx]);
|
||||
auto* second_store = static_cast<StoreInst*>(body_insts[second_store_idx]);
|
||||
if (StripPointerBase(first_store->GetPtr()) == StripPointerBase(second_store->GetPtr())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<Instruction*> group1(body_insts.begin(),
|
||||
body_insts.begin() + first_store_idx + 1);
|
||||
std::vector<Instruction*> group2(body_insts.begin() + first_store_idx + 1,
|
||||
body_insts.begin() + second_store_idx + 1);
|
||||
|
||||
std::unordered_set<Instruction*> group1_defs(group1.begin(), group1.end());
|
||||
std::unordered_set<Instruction*> group2_defs(group2.begin(), group2.end());
|
||||
group1_defs.erase(iv_next);
|
||||
group2_defs.erase(iv_next);
|
||||
|
||||
for (auto* inst : group2) {
|
||||
if (DependsOnAny(inst, group1_defs)) return false;
|
||||
}
|
||||
for (auto* inst : group1) {
|
||||
if (DependsOnAny(inst, group2_defs)) return false;
|
||||
}
|
||||
|
||||
auto* original_exit = match.exit;
|
||||
std::string block_suffix = ctx.NextTemp();
|
||||
if (!block_suffix.empty() && block_suffix.front() == '%') {
|
||||
block_suffix.erase(0, 1);
|
||||
}
|
||||
auto* preheader2 =
|
||||
func.CreateBlock(match.header->GetName() + ".fission.pre." + block_suffix);
|
||||
auto* header2 =
|
||||
func.CreateBlock(match.header->GetName() + ".fission.hdr." + block_suffix);
|
||||
auto* body2 =
|
||||
func.CreateBlock(match.body->GetName() + ".fission.body." + block_suffix);
|
||||
|
||||
preheader2->Append<BranchInst>(Type::GetVoidType(), header2);
|
||||
|
||||
auto* iv2 = header2->PrependPhi(Type::GetInt32Type(), ctx.NextTemp());
|
||||
iv2->AddIncoming(match.induction.init, preheader2);
|
||||
|
||||
auto* cmp2 = header2->Append<CmpInst>(
|
||||
match.header_cmp->GetCmpOp(), Type::GetInt32Type(), iv2, match.bound,
|
||||
ctx.NextTemp());
|
||||
header2->Append<CondBranchInst>(Type::GetVoidType(), cmp2, body2, original_exit);
|
||||
|
||||
ValueMap remap;
|
||||
remap.emplace(match.induction.phi, iv2);
|
||||
for (auto* inst : group2) {
|
||||
auto cloned = CloneInstruction(inst, remap, ".f2");
|
||||
if (!cloned) return false;
|
||||
auto* raw = cloned.get();
|
||||
body2->MutableInstructions().push_back(std::move(cloned));
|
||||
raw->SetParent(body2);
|
||||
remap[inst] = raw;
|
||||
}
|
||||
|
||||
auto next2_cloned = CloneInstruction(iv_next, remap, ".f2");
|
||||
if (!next2_cloned) return false;
|
||||
auto* next2 = next2_cloned.get();
|
||||
body2->MutableInstructions().push_back(std::move(next2_cloned));
|
||||
next2->SetParent(body2);
|
||||
body2->Append<BranchInst>(Type::GetVoidType(), header2);
|
||||
iv2->AddIncoming(next2, body2);
|
||||
|
||||
const bool exit_is_true = (match.header_branch->GetTrueBlock() == original_exit);
|
||||
match.header_branch->SetOperand(exit_is_true ? 1 : 2, preheader2);
|
||||
match.header->RemoveSuccessor(original_exit);
|
||||
match.header->AddSuccessor(preheader2);
|
||||
preheader2->AddPredecessor(match.header);
|
||||
original_exit->RemovePredecessor(match.header);
|
||||
|
||||
for (auto* inst : group2) {
|
||||
match.body->RemoveInstruction(inst);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoopFission(Function& func, Context& ctx) {
|
||||
if (func.IsExternal()) return false;
|
||||
|
||||
analysis::DominatorTree dom_tree(func);
|
||||
analysis::LoopInfo loop_info(func, dom_tree);
|
||||
|
||||
for (const auto& loop_ptr : loop_info.GetLoops()) {
|
||||
auto match = MatchCanonicalLoop(loop_ptr.get());
|
||||
if (!match.has_value()) continue;
|
||||
if (RunFissionOnLoop(func, *match, ctx)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace passes
|
||||
} // namespace ir
|
||||
@ -0,0 +1,465 @@
|
||||
// 循环习语优化:
|
||||
// - 将连续常量填充的规范循环替换为运行时批量填充调用
|
||||
// - 当前仅处理 step=1、init=0、单 store 的 innermost 循环
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
namespace ir {
|
||||
namespace passes {
|
||||
|
||||
namespace {
|
||||
|
||||
struct FillLoopCandidate {
|
||||
analysis::Loop* loop = nullptr;
|
||||
BasicBlock* preheader = nullptr;
|
||||
BasicBlock* header = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
PhiInst* induction = nullptr;
|
||||
Value* bound = nullptr;
|
||||
Value* base_ptr = nullptr;
|
||||
Value* offset = nullptr;
|
||||
int fill_value = 0;
|
||||
};
|
||||
|
||||
struct GuardedRowFillCandidate {
|
||||
analysis::Loop* loop = nullptr;
|
||||
BasicBlock* preheader = nullptr;
|
||||
BasicBlock* header = nullptr;
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* action = nullptr;
|
||||
BasicBlock* latch = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
PhiInst* induction = nullptr;
|
||||
Value* bound = nullptr;
|
||||
PhiInst* linear = nullptr;
|
||||
Value* linear_init = nullptr;
|
||||
int linear_step = 0;
|
||||
Value* base_ptr = nullptr;
|
||||
Value* threshold = nullptr;
|
||||
bool prefix = false;
|
||||
int fill_value = 0;
|
||||
};
|
||||
|
||||
bool ExprDependsOn(Value* value, Value* needle,
|
||||
std::unordered_set<Value*>& visiting) {
|
||||
if (value == needle) return true;
|
||||
auto* inst = dynamic_cast<Instruction*>(value);
|
||||
if (!inst) return false;
|
||||
if (!visiting.insert(value).second) return false;
|
||||
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
|
||||
if (ExprDependsOn(inst->GetOperand(i), needle, visiting)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ExprDependsOn(Value* value, Value* needle) {
|
||||
std::unordered_set<Value*> visiting;
|
||||
return ExprDependsOn(value, needle, visiting);
|
||||
}
|
||||
|
||||
Value* GetIncomingForBlock(PhiInst* phi, BasicBlock* block) {
|
||||
if (!phi || !block) return nullptr;
|
||||
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
|
||||
if (phi->GetIncomingBlock(i) == block) {
|
||||
return phi->GetIncomingValue(i);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Value* MaterializeInvariantExpr(Value* value, analysis::Loop* loop, IRBuilder& builder,
|
||||
ValueMap& remap) {
|
||||
auto it = remap.find(value);
|
||||
if (it != remap.end()) return it->second;
|
||||
if (dynamic_cast<ConstantValue*>(value) || dynamic_cast<Argument*>(value) ||
|
||||
dynamic_cast<GlobalVariable*>(value) || dynamic_cast<Function*>(value)) {
|
||||
return value;
|
||||
}
|
||||
auto* inst = dynamic_cast<Instruction*>(value);
|
||||
if (!inst || !loop->Contains(inst->GetParent())) return value;
|
||||
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
|
||||
auto* operand = inst->GetOperand(i);
|
||||
remap[operand] = MaterializeInvariantExpr(operand, loop, builder, remap);
|
||||
}
|
||||
auto cloned = CloneInstruction(inst, remap, ".idiom");
|
||||
if (!cloned) return nullptr;
|
||||
auto* raw = cloned.get();
|
||||
InsertInstruction(builder.GetInsertBlock(), std::move(cloned));
|
||||
remap[inst] = raw;
|
||||
return raw;
|
||||
}
|
||||
|
||||
bool HasOutsideUse(Instruction* inst, analysis::Loop* loop) {
|
||||
for (const auto& use : inst->GetUses()) {
|
||||
auto* user = dynamic_cast<Instruction*>(use.GetUser());
|
||||
if (!user) return true;
|
||||
if (!user->GetParent() || !loop->Contains(user->GetParent())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Value* MatchContiguousOffset(Value* index, PhiInst* iv, analysis::Loop* loop) {
|
||||
if (index == iv) return nullptr;
|
||||
auto* bin = dynamic_cast<BinaryInst*>(index);
|
||||
if (!bin || bin->GetOpcode() != Opcode::Add ||
|
||||
!bin->GetType() || !bin->GetType()->IsInt32()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (bin->GetLhs() == iv && IsLoopInvariantValue(bin->GetRhs(), loop)) {
|
||||
return bin->GetRhs();
|
||||
}
|
||||
if (bin->GetRhs() == iv && IsLoopInvariantValue(bin->GetLhs(), loop)) {
|
||||
return bin->GetLhs();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool BuildFillLoopCandidate(Function& func, analysis::Loop* loop,
|
||||
FillLoopCandidate* out) {
|
||||
(void)func;
|
||||
auto match = MatchCanonicalLoop(loop);
|
||||
if (!match.has_value()) return false;
|
||||
if (match->loop->GetChildren().size() != 0) return false;
|
||||
if (match->body != match->latch || loop->GetBlocks().size() != 2) return false;
|
||||
if (match->header_phis.size() != 1 ||
|
||||
match->header_phis.front() != match->induction.phi) {
|
||||
return false;
|
||||
}
|
||||
if (match->induction.step != 1) return false;
|
||||
if (match->header_cmp->GetCmpOp() != CmpOp::Lt) return false;
|
||||
auto* init_ci = dynamic_cast<ConstantInt*>(match->induction.init);
|
||||
if (!init_ci || init_ci->GetValue() != 0) return false;
|
||||
if (!match->exit->GetInstructions().empty() &&
|
||||
dynamic_cast<PhiInst*>(match->exit->GetInstructions().front().get()) != nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
StoreInst* store = nullptr;
|
||||
std::vector<Instruction*> body_insts;
|
||||
for (const auto& inst_ptr : match->body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (dynamic_cast<PhiInst*>(inst) != nullptr || inst->IsTerminator()) continue;
|
||||
body_insts.push_back(inst);
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Gep:
|
||||
case Opcode::Store:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
if (inst != match->induction.next && HasOutsideUse(inst, loop)) {
|
||||
return false;
|
||||
}
|
||||
if (auto* maybe_store = dynamic_cast<StoreInst*>(inst)) {
|
||||
if (store) return false;
|
||||
store = maybe_store;
|
||||
}
|
||||
}
|
||||
if (!store) return false;
|
||||
|
||||
auto* fill_ci = dynamic_cast<ConstantInt*>(store->GetValue());
|
||||
if (!fill_ci) return false;
|
||||
auto* gep = dynamic_cast<GepInst*>(store->GetPtr());
|
||||
if (!gep || !gep->GetBase() || !gep->GetBase()->GetType() ||
|
||||
!gep->GetBase()->GetType()->IsPtrInt32()) {
|
||||
return false;
|
||||
}
|
||||
Value* offset = MatchContiguousOffset(gep->GetIndex(), match->induction.phi, loop);
|
||||
if (gep->GetIndex() != match->induction.phi && offset == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
out->loop = loop;
|
||||
out->preheader = match->preheader;
|
||||
out->header = match->header;
|
||||
out->exit = match->exit;
|
||||
out->induction = match->induction.phi;
|
||||
out->bound = match->bound;
|
||||
out->base_ptr = gep->GetBase();
|
||||
out->offset = offset;
|
||||
out->fill_value = fill_ci->GetValue();
|
||||
return true;
|
||||
}
|
||||
|
||||
Function* GetOrCreateFillI32(Module& module) {
|
||||
if (auto* fn = module.FindFunction("__fill_i32")) return fn;
|
||||
auto* fn = module.CreateFunction("__fill_i32", Type::GetVoidType(),
|
||||
{Type::GetPtrInt32Type(), Type::GetInt32Type(),
|
||||
Type::GetInt32Type()});
|
||||
fn->SetExternal(true);
|
||||
return fn;
|
||||
}
|
||||
|
||||
Function* GetOrCreateFillRowsI32(Module& module) {
|
||||
if (auto* fn = module.FindFunction("__fill_rows_i32")) return fn;
|
||||
auto* fn = module.CreateFunction(
|
||||
"__fill_rows_i32", Type::GetVoidType(),
|
||||
{Type::GetPtrInt32Type(), Type::GetInt32Type(), Type::GetInt32Type(),
|
||||
Type::GetInt32Type(), Type::GetInt32Type(), Type::GetInt32Type()});
|
||||
fn->SetExternal(true);
|
||||
return fn;
|
||||
}
|
||||
|
||||
bool BuildGuardedRowFillCandidate(Function& func, analysis::Loop* loop,
|
||||
GuardedRowFillCandidate* out) {
|
||||
(void)func;
|
||||
if (!loop) return false;
|
||||
if (loop->GetChildren().size() != 0) return false;
|
||||
if (loop->GetBlocks().size() != 4) return false;
|
||||
|
||||
auto* header = loop->GetHeader();
|
||||
auto* preheader = loop->GetPreheader();
|
||||
if (!header || !preheader) return false;
|
||||
if (loop->GetLatches().size() != 1) return false;
|
||||
auto* latch = loop->GetLatches().front();
|
||||
if (!latch) return false;
|
||||
|
||||
auto* header_term = header->HasTerminator()
|
||||
? dynamic_cast<CondBranchInst*>(
|
||||
header->MutableInstructions().back().get())
|
||||
: nullptr;
|
||||
if (!header_term) return false;
|
||||
auto* header_cmp = dynamic_cast<CmpInst*>(header_term->GetCond());
|
||||
if (!header_cmp || header_cmp->GetCmpOp() != CmpOp::Lt) return false;
|
||||
|
||||
auto induction = MatchCanonicalInduction(header, preheader, latch);
|
||||
if (!induction.has_value() || induction->step != 1) return false;
|
||||
Value* bound = nullptr;
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
if (loop->Contains(header_term->GetTrueBlock()) &&
|
||||
!loop->Contains(header_term->GetFalseBlock())) {
|
||||
body = header_term->GetTrueBlock();
|
||||
exit = header_term->GetFalseBlock();
|
||||
} else if (loop->Contains(header_term->GetFalseBlock()) &&
|
||||
!loop->Contains(header_term->GetTrueBlock())) {
|
||||
body = header_term->GetFalseBlock();
|
||||
exit = header_term->GetTrueBlock();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
if (header_cmp->GetLhs() == induction->phi &&
|
||||
IsLoopInvariantValue(header_cmp->GetRhs(), loop)) {
|
||||
bound = header_cmp->GetRhs();
|
||||
} else if (header_cmp->GetRhs() == induction->phi &&
|
||||
IsLoopInvariantValue(header_cmp->GetLhs(), loop)) {
|
||||
bound = header_cmp->GetLhs();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto header_phis = CollectHeaderPhis(header);
|
||||
if (header_phis.size() != 2) return false;
|
||||
PhiInst* linear_phi = nullptr;
|
||||
for (auto* phi : header_phis) {
|
||||
if (phi != induction->phi) {
|
||||
linear_phi = phi;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!linear_phi || !linear_phi->GetType() || !linear_phi->GetType()->IsInt32()) {
|
||||
return false;
|
||||
}
|
||||
auto* linear_init = GetIncomingForBlock(linear_phi, preheader);
|
||||
auto* linear_next = GetIncomingForBlock(linear_phi, latch);
|
||||
auto* linear_next_bin = dynamic_cast<BinaryInst*>(linear_next);
|
||||
if (!linear_init || !linear_next_bin ||
|
||||
linear_next_bin->GetOpcode() != Opcode::Add ||
|
||||
linear_next_bin->GetLhs() != linear_phi) {
|
||||
return false;
|
||||
}
|
||||
auto* linear_step_ci = dynamic_cast<ConstantInt*>(linear_next_bin->GetRhs());
|
||||
if (!linear_step_ci || linear_step_ci->GetValue() <= 0) return false;
|
||||
|
||||
auto* guard = body->HasTerminator()
|
||||
? dynamic_cast<CondBranchInst*>(body->MutableInstructions().back().get())
|
||||
: nullptr;
|
||||
if (!guard) return false;
|
||||
BasicBlock* action = nullptr;
|
||||
if (guard->GetTrueBlock() == latch && loop->Contains(guard->GetFalseBlock())) {
|
||||
action = guard->GetFalseBlock();
|
||||
} else if (guard->GetFalseBlock() == latch &&
|
||||
loop->Contains(guard->GetTrueBlock())) {
|
||||
action = guard->GetTrueBlock();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
auto* action_term =
|
||||
dynamic_cast<BranchInst*>(action->MutableInstructions().back().get());
|
||||
if (!action_term || action_term->GetTarget() != latch) return false;
|
||||
|
||||
CallInst* fill_call = nullptr;
|
||||
GepInst* fill_gep = nullptr;
|
||||
for (const auto& inst_ptr : action->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator()) continue;
|
||||
if (auto* gep = dynamic_cast<GepInst*>(inst)) {
|
||||
fill_gep = gep;
|
||||
continue;
|
||||
}
|
||||
fill_call = dynamic_cast<CallInst*>(inst);
|
||||
}
|
||||
if (!fill_call || !fill_gep || fill_call->GetNumArgs() != 3) return false;
|
||||
auto* callee = fill_call->GetCallee();
|
||||
if (!callee || callee->GetName() != "__fill_i32") return false;
|
||||
auto* fill_value = dynamic_cast<ConstantInt*>(fill_call->GetArg(2));
|
||||
if (!fill_value) return false;
|
||||
if (fill_call->GetArg(0) != fill_gep || fill_call->GetArg(1) != bound) {
|
||||
return false;
|
||||
}
|
||||
if (fill_gep->GetIndex() != linear_phi) return false;
|
||||
if (!fill_gep->GetBase() || !fill_gep->GetBase()->GetType() ||
|
||||
!fill_gep->GetBase()->GetType()->IsPtrInt32()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* guard_cmp = dynamic_cast<CmpInst*>(guard->GetCond());
|
||||
if (!guard_cmp || !guard_cmp->GetType() || !guard_cmp->GetType()->IsInt32()) {
|
||||
return false;
|
||||
}
|
||||
Value* threshold = nullptr;
|
||||
bool prefix = false;
|
||||
bool suffix = false;
|
||||
if (guard_cmp->GetLhs() == induction->phi &&
|
||||
!ExprDependsOn(guard_cmp->GetRhs(), induction->phi) &&
|
||||
!ExprDependsOn(guard_cmp->GetRhs(), linear_phi)) {
|
||||
threshold = guard_cmp->GetRhs();
|
||||
if (guard_cmp->GetCmpOp() == CmpOp::Lt && action == guard->GetTrueBlock()) {
|
||||
prefix = true;
|
||||
} else if (guard_cmp->GetCmpOp() == CmpOp::Ge &&
|
||||
action == guard->GetTrueBlock()) {
|
||||
suffix = true;
|
||||
} else if (guard_cmp->GetCmpOp() == CmpOp::Lt &&
|
||||
action == guard->GetFalseBlock()) {
|
||||
suffix = true;
|
||||
} else if (guard_cmp->GetCmpOp() == CmpOp::Ge &&
|
||||
action == guard->GetFalseBlock()) {
|
||||
prefix = true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
out->loop = loop;
|
||||
out->preheader = preheader;
|
||||
out->header = header;
|
||||
out->body = body;
|
||||
out->action = action;
|
||||
out->latch = latch;
|
||||
out->exit = exit;
|
||||
out->induction = induction->phi;
|
||||
out->bound = bound;
|
||||
out->linear = linear_phi;
|
||||
out->linear_init = linear_init;
|
||||
out->linear_step = linear_step_ci->GetValue();
|
||||
out->base_ptr = fill_gep->GetBase();
|
||||
out->fill_value = fill_value->GetValue();
|
||||
out->threshold = threshold;
|
||||
out->prefix = prefix;
|
||||
return prefix || suffix;
|
||||
}
|
||||
|
||||
bool RunFillLoop(Function& func, const FillLoopCandidate& cand,
|
||||
Module& module, Context& ctx) {
|
||||
(void)func;
|
||||
auto* fill_fn = GetOrCreateFillI32(module);
|
||||
auto* preheader = cand.preheader;
|
||||
if (preheader->HasTerminator()) {
|
||||
preheader->RemoveInstruction(preheader->MutableInstructions().back().get());
|
||||
}
|
||||
|
||||
IRBuilder builder(ctx, preheader);
|
||||
Value* start_ptr = cand.base_ptr;
|
||||
if (cand.offset) {
|
||||
start_ptr = builder.CreateGep(cand.base_ptr, cand.offset, ctx.NextTemp());
|
||||
}
|
||||
builder.CreateCall(fill_fn, {start_ptr, cand.bound, ctx.GetConstInt(cand.fill_value)},
|
||||
"");
|
||||
preheader->Append<BranchInst>(Type::GetVoidType(), cand.exit);
|
||||
|
||||
cand.induction->RemoveIncomingBlock(preheader);
|
||||
preheader->RemoveSuccessor(cand.header);
|
||||
cand.header->RemovePredecessor(preheader);
|
||||
preheader->AddSuccessor(cand.exit);
|
||||
cand.exit->AddPredecessor(preheader);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RunGuardedRowFillLoop(Function& func, const GuardedRowFillCandidate& cand,
|
||||
Module& module, Context& ctx) {
|
||||
(void)func;
|
||||
auto* fill_rows_fn = GetOrCreateFillRowsI32(module);
|
||||
auto* preheader = cand.preheader;
|
||||
if (preheader->HasTerminator()) {
|
||||
preheader->RemoveInstruction(preheader->MutableInstructions().back().get());
|
||||
}
|
||||
|
||||
IRBuilder builder(ctx, preheader);
|
||||
ValueMap remap;
|
||||
auto* threshold =
|
||||
MaterializeInvariantExpr(cand.threshold, cand.loop, builder, remap);
|
||||
if (!threshold) return false;
|
||||
Value* start_index = cand.prefix ? ctx.GetConstInt(0) : threshold;
|
||||
Value* rows = cand.prefix ? threshold : nullptr;
|
||||
if (!cand.prefix) {
|
||||
rows = builder.CreateSub(cand.bound, start_index, ctx.NextTemp());
|
||||
}
|
||||
auto* start_offset_mul =
|
||||
builder.CreateMul(start_index, ctx.GetConstInt(cand.linear_step), ctx.NextTemp());
|
||||
auto* start_offset =
|
||||
builder.CreateAdd(cand.linear_init, start_offset_mul, ctx.NextTemp());
|
||||
builder.CreateCall(fill_rows_fn,
|
||||
{cand.base_ptr, start_offset, rows,
|
||||
ctx.GetConstInt(cand.linear_step), cand.bound,
|
||||
ctx.GetConstInt(cand.fill_value)},
|
||||
"");
|
||||
preheader->Append<BranchInst>(Type::GetVoidType(), cand.exit);
|
||||
|
||||
cand.induction->RemoveIncomingBlock(preheader);
|
||||
cand.linear->RemoveIncomingBlock(preheader);
|
||||
preheader->RemoveSuccessor(cand.header);
|
||||
cand.header->RemovePredecessor(preheader);
|
||||
preheader->AddSuccessor(cand.exit);
|
||||
cand.exit->AddPredecessor(preheader);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoopIdiom(Function& func, Module& module, Context& ctx) {
|
||||
if (func.IsExternal()) return false;
|
||||
|
||||
analysis::DominatorTree dom_tree(func);
|
||||
analysis::LoopInfo loop_info(func, dom_tree);
|
||||
|
||||
for (const auto& loop_ptr : loop_info.GetLoops()) {
|
||||
GuardedRowFillCandidate row_fill;
|
||||
if (BuildGuardedRowFillCandidate(func, loop_ptr.get(), &row_fill)) {
|
||||
if (RunGuardedRowFillLoop(func, row_fill, module, ctx)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
FillLoopCandidate cand;
|
||||
if (!BuildFillLoopCandidate(func, loop_ptr.get(), &cand)) continue;
|
||||
if (RunFillLoop(func, cand, module, ctx)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace passes
|
||||
} // namespace ir
|
||||
@ -0,0 +1,845 @@
|
||||
// 循环并行化:
|
||||
// - 将一部分安全的规范循环抽取成 worker 函数
|
||||
// - 通过运行时 __par_runN 启动固定线程数并行执行
|
||||
//
|
||||
// 当前限制:
|
||||
// - 仅并行化不存在 SSA live-out 的循环
|
||||
// - 循环访问对象必须是全局数组/全局变量
|
||||
// - 不支持循环中的普通函数调用
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
namespace ir {
|
||||
namespace passes {
|
||||
|
||||
bool RunSimpleDCE(Function& func);
|
||||
bool RunCFGSimplify(Function& func, Context& ctx);
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr int kParallelLoopSlots = 8;
|
||||
constexpr int kParallelThreads = 4;
|
||||
|
||||
enum class ParallelLoopKind {
|
||||
Pointwise,
|
||||
ReductionAddI32,
|
||||
GuardedFillI32,
|
||||
};
|
||||
|
||||
struct LoopContextValue {
|
||||
Value* original = nullptr;
|
||||
GlobalVariable* slot = nullptr;
|
||||
};
|
||||
|
||||
bool ExprDependsOn(Value* value, Value* needle,
|
||||
std::unordered_set<Value*>& visiting) {
|
||||
if (value == needle) return true;
|
||||
auto* inst = dynamic_cast<Instruction*>(value);
|
||||
if (!inst) return false;
|
||||
if (!visiting.insert(value).second) return false;
|
||||
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
|
||||
if (ExprDependsOn(inst->GetOperand(i), needle, visiting)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ExprDependsOn(Value* value, Value* needle) {
|
||||
std::unordered_set<Value*> visiting;
|
||||
return ExprDependsOn(value, needle, visiting);
|
||||
}
|
||||
|
||||
Value* StripPointerBase(Value* value) {
|
||||
while (auto* gep = dynamic_cast<GepInst*>(value)) {
|
||||
value = gep->GetBase();
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
bool IsSupportedParallelInst(Instruction* inst) {
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::Div:
|
||||
case Opcode::Mod:
|
||||
case Opcode::Cmp:
|
||||
case Opcode::Cast:
|
||||
case Opcode::Load:
|
||||
case Opcode::Store:
|
||||
case Opcode::Br:
|
||||
case Opcode::CondBr:
|
||||
case Opcode::Gep:
|
||||
case Opcode::Phi:
|
||||
case Opcode::Ret:
|
||||
return true;
|
||||
case Opcode::Call:
|
||||
case Opcode::Alloca:
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsScalarContextCandidate(Value* value) {
|
||||
auto* arg = dynamic_cast<Argument*>(value);
|
||||
if (arg && (arg->GetType()->IsInt32() || arg->GetType()->IsFloat32())) {
|
||||
return true;
|
||||
}
|
||||
auto* inst = dynamic_cast<Instruction*>(value);
|
||||
if (inst && inst->GetType() &&
|
||||
(inst->GetType()->IsInt32() || inst->GetType()->IsFloat32())) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool HasOutsideUse(Instruction* inst, analysis::Loop* loop) {
|
||||
for (const auto& use : inst->GetUses()) {
|
||||
auto* user = dynamic_cast<Instruction*>(use.GetUser());
|
||||
if (!user) return true;
|
||||
if (!user->GetParent() || !loop->Contains(user->GetParent())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ReplaceUsesOutsideLoop(Value* value, Value* replacement,
|
||||
analysis::Loop* loop) {
|
||||
if (!value || !replacement || !loop) return;
|
||||
auto uses = value->GetUses();
|
||||
for (const auto& use : uses) {
|
||||
auto* user = dynamic_cast<Instruction*>(use.GetUser());
|
||||
if (!user) continue;
|
||||
auto* parent = user->GetParent();
|
||||
if (parent && loop->Contains(parent)) continue;
|
||||
user->SetOperand(use.GetOperandIndex(), replacement);
|
||||
}
|
||||
}
|
||||
|
||||
Value* GetIncomingForBlock(PhiInst* phi, BasicBlock* block) {
|
||||
if (!phi || !block) return nullptr;
|
||||
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
|
||||
if (phi->GetIncomingBlock(i) == block) {
|
||||
return phi->GetIncomingValue(i);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
struct ParallelLoopCandidate {
|
||||
Function* parent = nullptr;
|
||||
analysis::Loop* loop = nullptr;
|
||||
ParallelLoopKind kind = ParallelLoopKind::Pointwise;
|
||||
BasicBlock* header = nullptr;
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* guard = nullptr;
|
||||
BasicBlock* action = nullptr;
|
||||
BasicBlock* preheader = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
BasicBlock* latch = nullptr;
|
||||
CmpInst* header_cmp = nullptr;
|
||||
PhiInst* induction = nullptr;
|
||||
Value* induction_next = nullptr;
|
||||
Value* bound = nullptr;
|
||||
PhiInst* linear = nullptr;
|
||||
Value* linear_init = nullptr;
|
||||
Value* linear_next = nullptr;
|
||||
int linear_step = 0;
|
||||
PhiInst* reduction = nullptr;
|
||||
Value* reduction_init = nullptr;
|
||||
Value* reduction_next = nullptr;
|
||||
bool has_loads = false;
|
||||
std::vector<LoopContextValue> contexts;
|
||||
};
|
||||
|
||||
bool BuildGuardedFillCandidate(Function& func, analysis::Loop* loop,
|
||||
ParallelLoopCandidate* out) {
|
||||
if (!loop) return false;
|
||||
auto* header = loop->GetHeader();
|
||||
auto* preheader = loop->GetPreheader();
|
||||
if (!header || !preheader) return false;
|
||||
if (loop->GetChildren().size() != 0) return false;
|
||||
if (loop->GetBlocks().size() != 4) return false;
|
||||
if (loop->GetLatches().size() != 1) return false;
|
||||
auto* latch = loop->GetLatches().front();
|
||||
if (!latch) return false;
|
||||
|
||||
auto* header_term = header->HasTerminator()
|
||||
? dynamic_cast<CondBranchInst*>(
|
||||
header->MutableInstructions().back().get())
|
||||
: nullptr;
|
||||
if (!header_term) return false;
|
||||
auto* header_cmp = dynamic_cast<CmpInst*>(header_term->GetCond());
|
||||
if (!header_cmp || header_cmp->GetCmpOp() != CmpOp::Lt) return false;
|
||||
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
if (loop->Contains(header_term->GetTrueBlock()) &&
|
||||
!loop->Contains(header_term->GetFalseBlock())) {
|
||||
body = header_term->GetTrueBlock();
|
||||
exit = header_term->GetFalseBlock();
|
||||
} else if (loop->Contains(header_term->GetFalseBlock()) &&
|
||||
!loop->Contains(header_term->GetTrueBlock())) {
|
||||
body = header_term->GetFalseBlock();
|
||||
exit = header_term->GetTrueBlock();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto induction = MatchCanonicalInduction(header, preheader, latch);
|
||||
if (!induction.has_value() || induction->step != 1) return false;
|
||||
Value* bound = nullptr;
|
||||
if (header_cmp->GetLhs() == induction->phi &&
|
||||
IsLoopInvariantValue(header_cmp->GetRhs(), loop)) {
|
||||
bound = header_cmp->GetRhs();
|
||||
} else if (header_cmp->GetRhs() == induction->phi &&
|
||||
IsLoopInvariantValue(header_cmp->GetLhs(), loop)) {
|
||||
bound = header_cmp->GetLhs();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto header_phis = CollectHeaderPhis(header);
|
||||
if (header_phis.size() != 2) return false;
|
||||
|
||||
PhiInst* linear_phi = nullptr;
|
||||
for (auto* phi : header_phis) {
|
||||
if (phi != induction->phi) {
|
||||
linear_phi = phi;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!linear_phi || !linear_phi->GetType() || !linear_phi->GetType()->IsInt32()) {
|
||||
return false;
|
||||
}
|
||||
auto* linear_init = GetIncomingForBlock(linear_phi, preheader);
|
||||
auto* linear_next = GetIncomingForBlock(linear_phi, latch);
|
||||
auto* linear_next_bin = dynamic_cast<BinaryInst*>(linear_next);
|
||||
if (!linear_init || !linear_next_bin ||
|
||||
linear_next_bin->GetOpcode() != Opcode::Add ||
|
||||
linear_next_bin->GetLhs() != linear_phi) {
|
||||
return false;
|
||||
}
|
||||
auto* linear_step_ci = dynamic_cast<ConstantInt*>(linear_next_bin->GetRhs());
|
||||
if (!linear_step_ci || linear_step_ci->GetValue() <= 0) return false;
|
||||
|
||||
auto* guard = dynamic_cast<CondBranchInst*>(body->MutableInstructions().back().get());
|
||||
if (!guard) return false;
|
||||
auto* true_bb = guard->GetTrueBlock();
|
||||
auto* false_bb = guard->GetFalseBlock();
|
||||
if (!loop->Contains(true_bb) && !loop->Contains(false_bb)) return false;
|
||||
if (loop->Contains(true_bb) && loop->Contains(false_bb)) {
|
||||
if (true_bb == latch || false_bb == latch) {
|
||||
// fine
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
BasicBlock* action = nullptr;
|
||||
if (true_bb == latch && false_bb != latch && loop->Contains(false_bb)) {
|
||||
action = false_bb;
|
||||
} else if (false_bb == latch && true_bb != latch &&
|
||||
loop->Contains(true_bb)) {
|
||||
action = true_bb;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
if (action == header || action == body) return false;
|
||||
auto* action_term =
|
||||
dynamic_cast<BranchInst*>(action->MutableInstructions().back().get());
|
||||
if (!action_term || action_term->GetTarget() != latch) return false;
|
||||
|
||||
CallInst* fill_call = nullptr;
|
||||
for (const auto& inst_ptr : action->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator()) continue;
|
||||
if (auto* gep = dynamic_cast<GepInst*>(inst)) {
|
||||
if (gep->GetBase() == nullptr || gep->GetIndex() == nullptr) return false;
|
||||
continue;
|
||||
}
|
||||
fill_call = dynamic_cast<CallInst*>(inst);
|
||||
if (!fill_call) return false;
|
||||
}
|
||||
if (!fill_call || fill_call->GetNumArgs() != 3) return false;
|
||||
auto* callee = fill_call->GetCallee();
|
||||
if (!callee || callee->GetName() != "__fill_i32") return false;
|
||||
auto* fill_ptr = dynamic_cast<GepInst*>(fill_call->GetArg(0));
|
||||
auto* fill_count = fill_call->GetArg(1);
|
||||
auto* fill_value = dynamic_cast<ConstantInt*>(fill_call->GetArg(2));
|
||||
if (!fill_ptr || !fill_value) return false;
|
||||
if (fill_ptr->GetBase() == nullptr || !fill_ptr->GetBase()->GetType() ||
|
||||
!fill_ptr->GetBase()->GetType()->IsPtrInt32()) {
|
||||
return false;
|
||||
}
|
||||
if (fill_ptr->GetIndex() != linear_phi) return false;
|
||||
if (fill_count != bound) return false;
|
||||
|
||||
std::vector<Value*> context_values;
|
||||
std::unordered_set<Value*> seen_contexts;
|
||||
auto collect_contexts = [&](BasicBlock* block) -> bool {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (dynamic_cast<PhiInst*>(inst) != nullptr || inst->IsTerminator()) continue;
|
||||
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
|
||||
Value* operand = inst->GetOperand(i);
|
||||
if (dynamic_cast<ConstantValue*>(operand) || dynamic_cast<Function*>(operand) ||
|
||||
dynamic_cast<BasicBlock*>(operand) || dynamic_cast<GlobalVariable*>(operand)) {
|
||||
continue;
|
||||
}
|
||||
auto* operand_inst = dynamic_cast<Instruction*>(operand);
|
||||
if ((operand_inst && loop->Contains(operand_inst->GetParent())) ||
|
||||
operand == induction->phi || operand == induction->next ||
|
||||
operand == linear_phi || operand == linear_next) {
|
||||
continue;
|
||||
}
|
||||
if (!IsScalarContextCandidate(operand)) return false;
|
||||
if (seen_contexts.insert(operand).second) {
|
||||
context_values.push_back(operand);
|
||||
}
|
||||
}
|
||||
if (inst != linear_next && inst != induction->next &&
|
||||
HasOutsideUse(inst, loop)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
if (!collect_contexts(body) || !collect_contexts(action) ||
|
||||
!collect_contexts(latch)) {
|
||||
return false;
|
||||
}
|
||||
if (context_values.size() > 6) return false;
|
||||
|
||||
out->parent = &func;
|
||||
out->loop = loop;
|
||||
out->kind = ParallelLoopKind::GuardedFillI32;
|
||||
out->header = header;
|
||||
out->body = body;
|
||||
out->guard = body;
|
||||
out->action = action;
|
||||
out->preheader = preheader;
|
||||
out->exit = exit;
|
||||
out->latch = latch;
|
||||
out->header_cmp = header_cmp;
|
||||
out->induction = induction->phi;
|
||||
out->induction_next = induction->next;
|
||||
out->bound = bound;
|
||||
out->linear = linear_phi;
|
||||
out->linear_init = linear_init;
|
||||
out->linear_next = linear_next;
|
||||
out->linear_step = linear_step_ci->GetValue();
|
||||
out->has_loads = true;
|
||||
for (Value* value : context_values) {
|
||||
out->contexts.push_back({value, nullptr});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool BuildParallelCandidate(Function& func, analysis::Loop* loop,
|
||||
ParallelLoopCandidate* out) {
|
||||
if (BuildGuardedFillCandidate(func, loop, out)) return true;
|
||||
|
||||
auto match = MatchCanonicalLoop(loop);
|
||||
if (!match.has_value()) return false;
|
||||
|
||||
if (match->body != match->latch || loop->GetBlocks().size() != 2) return false;
|
||||
if (match->exit->GetInstructions().size() > 0 &&
|
||||
dynamic_cast<PhiInst*>(match->exit->GetInstructions().front().get()) != nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
PhiInst* reduction_phi = nullptr;
|
||||
Value* reduction_init = nullptr;
|
||||
Value* reduction_next = nullptr;
|
||||
ParallelLoopKind kind = ParallelLoopKind::Pointwise;
|
||||
if (match->header_phis.size() == 1 &&
|
||||
match->header_phis.front() == match->induction.phi) {
|
||||
kind = ParallelLoopKind::Pointwise;
|
||||
} else if (false && match->header_phis.size() == 2) {
|
||||
for (auto* phi : match->header_phis) {
|
||||
if (phi != match->induction.phi) {
|
||||
reduction_phi = phi;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!reduction_phi || !reduction_phi->GetType() ||
|
||||
!reduction_phi->GetType()->IsInt32()) {
|
||||
return false;
|
||||
}
|
||||
reduction_init = GetIncomingForBlock(reduction_phi, match->preheader);
|
||||
reduction_next = GetIncomingForBlock(reduction_phi, match->latch);
|
||||
if (!reduction_init || !reduction_next) return false;
|
||||
auto* reduction_next_inst = dynamic_cast<Instruction*>(reduction_next);
|
||||
if (!reduction_next_inst || reduction_next_inst->GetParent() != match->body) {
|
||||
return false;
|
||||
}
|
||||
auto* init_ci = dynamic_cast<ConstantInt*>(reduction_init);
|
||||
if (!init_ci || init_ci->GetValue() != 0) return false;
|
||||
auto* red_next_bin = dynamic_cast<BinaryInst*>(reduction_next);
|
||||
if (!red_next_bin || red_next_bin->GetOpcode() != Opcode::Add ||
|
||||
!red_next_bin->GetType() || !red_next_bin->GetType()->IsInt32()) {
|
||||
return false;
|
||||
}
|
||||
Value* other = nullptr;
|
||||
if (red_next_bin->GetLhs() == reduction_phi) {
|
||||
other = red_next_bin->GetRhs();
|
||||
} else if (red_next_bin->GetRhs() == reduction_phi) {
|
||||
other = red_next_bin->GetLhs();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
if (ExprDependsOn(other, reduction_phi)) return false;
|
||||
kind = ParallelLoopKind::ReductionAddI32;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_set<Value*> store_bases;
|
||||
std::unordered_set<Value*> load_bases;
|
||||
std::vector<Value*> context_values;
|
||||
std::unordered_set<Value*> seen_contexts;
|
||||
|
||||
for (const auto& bb_ptr : func.GetBlocks()) {
|
||||
auto* block = bb_ptr.get();
|
||||
if (!loop->Contains(block)) continue;
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (!IsSupportedParallelInst(inst)) return false;
|
||||
if (inst->GetOpcode() == Opcode::Call) return false;
|
||||
if (kind == ParallelLoopKind::ReductionAddI32 &&
|
||||
inst->GetOpcode() == Opcode::Store) {
|
||||
return false;
|
||||
}
|
||||
if (inst->GetOpcode() != Opcode::Store && inst->GetOpcode() != Opcode::Br &&
|
||||
inst->GetOpcode() != Opcode::CondBr && inst->GetOpcode() != Opcode::Phi &&
|
||||
inst != reduction_phi &&
|
||||
HasOutsideUse(inst, loop)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
|
||||
Value* operand = inst->GetOperand(i);
|
||||
if (dynamic_cast<ConstantValue*>(operand) || dynamic_cast<Function*>(operand) ||
|
||||
dynamic_cast<BasicBlock*>(operand) || dynamic_cast<GlobalVariable*>(operand)) {
|
||||
continue;
|
||||
}
|
||||
auto* operand_inst = dynamic_cast<Instruction*>(operand);
|
||||
if ((operand_inst && loop->Contains(operand_inst->GetParent())) ||
|
||||
operand == match->induction.phi ||
|
||||
operand == match->induction.next ||
|
||||
operand == reduction_phi ||
|
||||
operand == reduction_next) {
|
||||
continue;
|
||||
}
|
||||
if (!IsScalarContextCandidate(operand)) return false;
|
||||
if (seen_contexts.insert(operand).second) {
|
||||
context_values.push_back(operand);
|
||||
}
|
||||
}
|
||||
|
||||
if (auto* load = dynamic_cast<LoadInst*>(inst)) {
|
||||
Value* base = StripPointerBase(load->GetPtr());
|
||||
if (dynamic_cast<GlobalVariable*>(base) == nullptr) return false;
|
||||
load_bases.insert(base);
|
||||
if (auto* gep = dynamic_cast<GepInst*>(load->GetPtr())) {
|
||||
if (base == StripPointerBase(load->GetPtr()) &&
|
||||
!ExprDependsOn(gep->GetIndex(), match->induction.phi)) {
|
||||
// allowed for pure reads; dependence checked later with stores
|
||||
}
|
||||
}
|
||||
} else if (auto* store = dynamic_cast<StoreInst*>(inst)) {
|
||||
Value* base = StripPointerBase(store->GetPtr());
|
||||
auto* gv = dynamic_cast<GlobalVariable*>(base);
|
||||
if (!gv || gv->GetCount() <= 1) return false;
|
||||
store_bases.insert(base);
|
||||
auto* gep = dynamic_cast<GepInst*>(store->GetPtr());
|
||||
if (!gep || !ExprDependsOn(gep->GetIndex(), match->induction.phi)) return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (Value* base : store_bases) {
|
||||
if (load_bases.count(base) == 0) continue;
|
||||
for (const auto& bb_ptr : func.GetBlocks()) {
|
||||
auto* block = bb_ptr.get();
|
||||
if (!loop->Contains(block)) continue;
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* load = dynamic_cast<LoadInst*>(inst_ptr.get());
|
||||
if (!load) continue;
|
||||
if (StripPointerBase(load->GetPtr()) != base) continue;
|
||||
auto* gep = dynamic_cast<GepInst*>(load->GetPtr());
|
||||
if (!gep || !ExprDependsOn(gep->GetIndex(), match->induction.phi)) return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (context_values.size() > 6) return false;
|
||||
|
||||
out->parent = &func;
|
||||
out->loop = loop;
|
||||
out->kind = kind;
|
||||
out->header = match->header;
|
||||
out->body = match->body;
|
||||
out->preheader = match->preheader;
|
||||
out->exit = match->exit;
|
||||
out->latch = match->latch;
|
||||
out->header_cmp = match->header_cmp;
|
||||
out->induction = match->induction.phi;
|
||||
out->induction_next = match->induction.next;
|
||||
out->bound = match->bound;
|
||||
out->reduction = reduction_phi;
|
||||
out->reduction_init = reduction_init;
|
||||
out->reduction_next = reduction_next;
|
||||
out->has_loads = !load_bases.empty();
|
||||
for (Value* value : context_values) {
|
||||
out->contexts.push_back({value, nullptr});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsGeneratedParallelWorker(const Function& func) {
|
||||
return func.GetName().rfind("__par_worker", 0) == 0;
|
||||
}
|
||||
|
||||
Function* GetOrCreateRuntimeLauncher(Module& module, int slot) {
|
||||
const std::string name = "__par_run" + std::to_string(slot);
|
||||
if (auto* fn = module.FindFunction(name)) return fn;
|
||||
auto* fn = module.CreateFunction(name, Type::GetVoidType(), {});
|
||||
fn->SetExternal(true);
|
||||
return fn;
|
||||
}
|
||||
|
||||
std::string NextWorkerName(int slot) {
|
||||
return "__par_worker" + std::to_string(slot);
|
||||
}
|
||||
|
||||
void CloneWorkerBlocks(const ParallelLoopCandidate& cand, Function* worker,
|
||||
GlobalVariable* bound_slot,
|
||||
const std::vector<LoopContextValue>& ctx_slots,
|
||||
GlobalVariable* reduction_slot, Context& ctx) {
|
||||
if (cand.kind == ParallelLoopKind::GuardedFillI32) {
|
||||
auto* entry = worker->GetEntry();
|
||||
auto* tid = worker->GetArgument(0);
|
||||
auto* header = worker->CreateBlock(cand.header->GetName());
|
||||
auto* guard = worker->CreateBlock(cand.guard->GetName());
|
||||
auto* action = worker->CreateBlock(cand.action->GetName());
|
||||
auto* latch = worker->CreateBlock(cand.latch->GetName());
|
||||
auto* worker_exit = worker->CreateBlock("par.exit");
|
||||
|
||||
IRBuilder builder(ctx, entry);
|
||||
auto* bound_val = builder.CreateLoad(bound_slot, ctx.NextTemp());
|
||||
Value* threads_val = ctx.GetConstInt(kParallelThreads);
|
||||
auto* start_mul = builder.CreateMul(tid, bound_val, ctx.NextTemp());
|
||||
auto* start = builder.CreateDiv(start_mul, threads_val, ctx.NextTemp());
|
||||
auto* next_tid = builder.CreateAdd(tid, ctx.GetConstInt(1), ctx.NextTemp());
|
||||
auto* end_mul = builder.CreateMul(next_tid, bound_val, ctx.NextTemp());
|
||||
auto* end = builder.CreateDiv(end_mul, threads_val, ctx.NextTemp());
|
||||
|
||||
ValueMap remap;
|
||||
remap[cand.bound] = bound_val;
|
||||
for (const auto& ctx_value : ctx_slots) {
|
||||
builder.SetInsertPoint(entry);
|
||||
auto* loaded = builder.CreateLoad(ctx_value.slot, ctx.NextTemp());
|
||||
remap[ctx_value.original] = loaded;
|
||||
}
|
||||
builder.SetInsertPoint(entry);
|
||||
auto* start_linear_mul =
|
||||
builder.CreateMul(start, ctx.GetConstInt(cand.linear_step), ctx.NextTemp());
|
||||
Value* linear_init = cand.linear_init;
|
||||
auto it = remap.find(cand.linear_init);
|
||||
if (it != remap.end()) linear_init = it->second;
|
||||
auto* start_linear = builder.CreateAdd(linear_init, start_linear_mul, ctx.NextTemp());
|
||||
builder.CreateBr(header);
|
||||
|
||||
auto* new_iv = header->PrependPhi(cand.induction->GetType(), ctx.NextTemp());
|
||||
auto* new_linear = header->PrependPhi(cand.linear->GetType(), ctx.NextTemp());
|
||||
remap[cand.induction] = new_iv;
|
||||
remap[cand.linear] = new_linear;
|
||||
|
||||
if (auto cloned_cmp = CloneInstruction(cand.header_cmp, remap, ".par")) {
|
||||
auto* raw_cmp = static_cast<CmpInst*>(cloned_cmp.get());
|
||||
header->MutableInstructions().push_back(std::move(cloned_cmp));
|
||||
raw_cmp->SetParent(header);
|
||||
remap[cand.header_cmp] = raw_cmp;
|
||||
if (cand.header_cmp->GetLhs() == cand.induction &&
|
||||
cand.header_cmp->GetRhs() == cand.bound) {
|
||||
raw_cmp->SetOperand(0, new_iv);
|
||||
raw_cmp->SetOperand(1, end);
|
||||
} else if (cand.header_cmp->GetRhs() == cand.induction &&
|
||||
cand.header_cmp->GetLhs() == cand.bound) {
|
||||
raw_cmp->SetOperand(0, end);
|
||||
raw_cmp->SetOperand(1, new_iv);
|
||||
}
|
||||
header->Append<CondBranchInst>(Type::GetVoidType(), raw_cmp, guard, worker_exit);
|
||||
}
|
||||
|
||||
for (const auto& inst_ptr : cand.guard->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (dynamic_cast<PhiInst*>(inst) != nullptr || inst->IsTerminator()) continue;
|
||||
if (auto cloned = CloneInstruction(inst, remap, ".par")) {
|
||||
auto* raw = cloned.get();
|
||||
guard->MutableInstructions().push_back(std::move(cloned));
|
||||
raw->SetParent(guard);
|
||||
remap[inst] = raw;
|
||||
}
|
||||
}
|
||||
auto* guard_term =
|
||||
static_cast<CondBranchInst*>(cand.guard->MutableInstructions().back().get());
|
||||
auto* guard_cond = RemapValue(guard_term->GetCond(), remap);
|
||||
BasicBlock* true_target =
|
||||
(guard_term->GetTrueBlock() == cand.action) ? action : latch;
|
||||
BasicBlock* false_target =
|
||||
(guard_term->GetFalseBlock() == cand.action) ? action : latch;
|
||||
guard->Append<CondBranchInst>(Type::GetVoidType(), guard_cond, true_target,
|
||||
false_target);
|
||||
|
||||
for (const auto& inst_ptr : cand.action->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (dynamic_cast<PhiInst*>(inst) != nullptr || inst->IsTerminator()) continue;
|
||||
if (auto cloned = CloneInstruction(inst, remap, ".par")) {
|
||||
auto* raw = cloned.get();
|
||||
action->MutableInstructions().push_back(std::move(cloned));
|
||||
raw->SetParent(action);
|
||||
remap[inst] = raw;
|
||||
}
|
||||
}
|
||||
action->Append<BranchInst>(Type::GetVoidType(), latch);
|
||||
|
||||
for (const auto& inst_ptr : cand.latch->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (dynamic_cast<PhiInst*>(inst) != nullptr || inst->IsTerminator()) continue;
|
||||
if (auto cloned = CloneInstruction(inst, remap, ".par")) {
|
||||
auto* raw = cloned.get();
|
||||
latch->MutableInstructions().push_back(std::move(cloned));
|
||||
raw->SetParent(latch);
|
||||
remap[inst] = raw;
|
||||
}
|
||||
}
|
||||
latch->Append<BranchInst>(Type::GetVoidType(), header);
|
||||
|
||||
new_iv->AddIncoming(start, entry);
|
||||
new_iv->AddIncoming(RemapValue(cand.induction_next, remap), latch);
|
||||
new_linear->AddIncoming(start_linear, entry);
|
||||
new_linear->AddIncoming(RemapValue(cand.linear_next, remap), latch);
|
||||
worker_exit->Append<ReturnInst>(Type::GetVoidType(), nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
auto* entry = worker->GetEntry();
|
||||
auto* tid = worker->GetArgument(0);
|
||||
auto* header = worker->CreateBlock(cand.header->GetName());
|
||||
auto* body = worker->CreateBlock(cand.body->GetName());
|
||||
auto* worker_exit = worker->CreateBlock("par.exit");
|
||||
|
||||
IRBuilder builder(ctx, entry);
|
||||
auto* bound_val = builder.CreateLoad(bound_slot, ctx.NextTemp());
|
||||
Value* threads_val = ctx.GetConstInt(kParallelThreads);
|
||||
auto* start_mul = builder.CreateMul(tid, bound_val, ctx.NextTemp());
|
||||
auto* start = builder.CreateDiv(start_mul, threads_val, ctx.NextTemp());
|
||||
auto* next_tid = builder.CreateAdd(tid, ctx.GetConstInt(1), ctx.NextTemp());
|
||||
auto* end_mul = builder.CreateMul(next_tid, bound_val, ctx.NextTemp());
|
||||
auto* end = builder.CreateDiv(end_mul, threads_val, ctx.NextTemp());
|
||||
|
||||
ValueMap remap;
|
||||
remap[cand.induction] = start;
|
||||
remap[cand.bound] = bound_val;
|
||||
for (const auto& ctx_value : ctx_slots) {
|
||||
builder.SetInsertPoint(entry);
|
||||
auto* loaded = builder.CreateLoad(ctx_value.slot, ctx.NextTemp());
|
||||
remap[ctx_value.original] = loaded;
|
||||
}
|
||||
builder.SetInsertPoint(entry);
|
||||
builder.CreateBr(header);
|
||||
auto* new_phi = header->PrependPhi(cand.induction->GetType(), ctx.NextTemp());
|
||||
remap[cand.induction] = new_phi;
|
||||
PhiInst* new_reduction_phi = nullptr;
|
||||
if (cand.kind == ParallelLoopKind::ReductionAddI32) {
|
||||
new_reduction_phi = header->PrependPhi(Type::GetInt32Type(), ctx.NextTemp());
|
||||
remap[cand.reduction] = new_reduction_phi;
|
||||
}
|
||||
|
||||
if (auto cloned_cmp = CloneInstruction(cand.header_cmp, remap, ".par")) {
|
||||
auto* raw_cmp = static_cast<CmpInst*>(cloned_cmp.get());
|
||||
header->MutableInstructions().push_back(std::move(cloned_cmp));
|
||||
raw_cmp->SetParent(header);
|
||||
remap[cand.header_cmp] = raw_cmp;
|
||||
if (cand.header_cmp->GetLhs() == cand.induction &&
|
||||
cand.header_cmp->GetRhs() == cand.bound) {
|
||||
raw_cmp->SetOperand(0, new_phi);
|
||||
raw_cmp->SetOperand(1, end);
|
||||
} else if (cand.header_cmp->GetRhs() == cand.induction &&
|
||||
cand.header_cmp->GetLhs() == cand.bound) {
|
||||
raw_cmp->SetOperand(0, end);
|
||||
raw_cmp->SetOperand(1, new_phi);
|
||||
}
|
||||
header->Append<CondBranchInst>(Type::GetVoidType(), raw_cmp, body, worker_exit);
|
||||
}
|
||||
|
||||
for (const auto& inst_ptr : cand.body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (dynamic_cast<PhiInst*>(inst) != nullptr) continue;
|
||||
if (inst->GetOpcode() == Opcode::Br) continue;
|
||||
if (auto cloned = CloneInstruction(inst, remap, ".par")) {
|
||||
auto* raw = cloned.get();
|
||||
body->MutableInstructions().push_back(std::move(cloned));
|
||||
raw->SetParent(body);
|
||||
remap[inst] = raw;
|
||||
}
|
||||
}
|
||||
new_phi->AddIncoming(start, entry);
|
||||
new_phi->AddIncoming(RemapValue(cand.induction_next, remap), body);
|
||||
if (new_reduction_phi) {
|
||||
new_reduction_phi->AddIncoming(ctx.GetConstInt(0), entry);
|
||||
new_reduction_phi->AddIncoming(RemapValue(cand.reduction_next, remap), body);
|
||||
}
|
||||
body->Append<BranchInst>(Type::GetVoidType(), header);
|
||||
|
||||
if (new_reduction_phi) {
|
||||
IRBuilder exit_builder(ctx, worker_exit);
|
||||
auto* partial_ptr = exit_builder.CreateGep(reduction_slot, tid, ctx.NextTemp());
|
||||
exit_builder.CreateStore(new_reduction_phi, partial_ptr);
|
||||
}
|
||||
worker_exit->Append<ReturnInst>(Type::GetVoidType(), nullptr);
|
||||
}
|
||||
|
||||
bool ParallelizeCandidate(Module& module, ParallelLoopCandidate& cand, int slot) {
|
||||
auto& ctx = module.GetContext();
|
||||
auto* bound_slot =
|
||||
module.CreateGlobalVar("__par_bound" + std::to_string(slot), 0, 1,
|
||||
Type::GetPtrInt32Type());
|
||||
GlobalVariable* reduction_slot = nullptr;
|
||||
if (cand.kind == ParallelLoopKind::ReductionAddI32) {
|
||||
reduction_slot = module.CreateGlobalVar(
|
||||
"__par_red" + std::to_string(slot), 0, kParallelThreads,
|
||||
Type::GetPtrInt32Type());
|
||||
}
|
||||
for (size_t i = 0; i < cand.contexts.size(); ++i) {
|
||||
auto& entry = cand.contexts[i];
|
||||
bool is_float = entry.original->GetType() && entry.original->GetType()->IsFloat32();
|
||||
entry.slot = module.CreateGlobalVar(
|
||||
"__par_ctx" + std::to_string(slot) + "_" + std::to_string(i), 0, 1,
|
||||
is_float ? Type::GetPtrFloat32Type() : Type::GetPtrInt32Type());
|
||||
}
|
||||
|
||||
auto* worker =
|
||||
module.CreateFunction(NextWorkerName(slot), Type::GetVoidType(),
|
||||
{Type::GetInt32Type()});
|
||||
CloneWorkerBlocks(cand, worker, bound_slot, cand.contexts, reduction_slot, ctx);
|
||||
|
||||
auto* launcher = GetOrCreateRuntimeLauncher(module, slot);
|
||||
auto* preheader = cand.preheader;
|
||||
if (preheader->HasTerminator()) {
|
||||
preheader->RemoveInstruction(preheader->MutableInstructions().back().get());
|
||||
}
|
||||
for (const auto& ctx_value : cand.contexts) {
|
||||
InsertInstruction(preheader, std::make_unique<StoreInst>(
|
||||
Type::GetVoidType(), ctx_value.original,
|
||||
ctx_value.slot));
|
||||
}
|
||||
InsertInstruction(preheader, std::make_unique<StoreInst>(Type::GetVoidType(),
|
||||
cand.bound, bound_slot));
|
||||
InsertCallBeforeTerminator(preheader, launcher, {}, "");
|
||||
Value* reduced_value = nullptr;
|
||||
if (cand.kind == ParallelLoopKind::ReductionAddI32) {
|
||||
IRBuilder builder(ctx, preheader);
|
||||
reduced_value = cand.reduction_init;
|
||||
for (int tid = 0; tid < kParallelThreads; ++tid) {
|
||||
auto* partial_ptr =
|
||||
builder.CreateGep(reduction_slot, ctx.GetConstInt(tid), ctx.NextTemp());
|
||||
auto* partial_val = builder.CreateLoad(partial_ptr, ctx.NextTemp());
|
||||
reduced_value = builder.CreateAdd(reduced_value, partial_val, ctx.NextTemp());
|
||||
}
|
||||
}
|
||||
cand.induction->RemoveIncomingBlock(preheader);
|
||||
if (cand.linear) {
|
||||
cand.linear->RemoveIncomingBlock(preheader);
|
||||
}
|
||||
if (cand.reduction) {
|
||||
cand.reduction->RemoveIncomingBlock(preheader);
|
||||
}
|
||||
preheader->RemoveSuccessor(cand.header);
|
||||
cand.header->RemovePredecessor(preheader);
|
||||
preheader->AddSuccessor(cand.exit);
|
||||
cand.exit->AddPredecessor(preheader);
|
||||
if (cand.kind == ParallelLoopKind::ReductionAddI32 && reduced_value) {
|
||||
ReplaceUsesOutsideLoop(cand.reduction, reduced_value, cand.loop);
|
||||
}
|
||||
preheader->Append<BranchInst>(Type::GetVoidType(), cand.exit);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoopParallelization(Module& module) {
|
||||
bool changed = false;
|
||||
int store_only_slots = 0;
|
||||
for (int slot = 0; slot < kParallelLoopSlots; ++slot) {
|
||||
ParallelLoopCandidate cand;
|
||||
bool found = false;
|
||||
|
||||
for (const auto& func_ptr : module.GetFunctions()) {
|
||||
auto* func = func_ptr.get();
|
||||
if (!func || func->IsExternal() || IsGeneratedParallelWorker(*func)) continue;
|
||||
|
||||
analysis::DominatorTree dom_tree(*func);
|
||||
analysis::LoopInfo loop_info(*func, dom_tree);
|
||||
std::vector<analysis::Loop*> loops;
|
||||
for (const auto& loop_ptr : loop_info.GetLoops()) {
|
||||
loops.push_back(loop_ptr.get());
|
||||
}
|
||||
std::sort(loops.begin(), loops.end(),
|
||||
[](analysis::Loop* lhs, analysis::Loop* rhs) {
|
||||
if (lhs->GetDepth() != rhs->GetDepth()) {
|
||||
return lhs->GetDepth() < rhs->GetDepth();
|
||||
}
|
||||
return lhs->GetHeader()->GetName() < rhs->GetHeader()->GetName();
|
||||
});
|
||||
|
||||
ParallelLoopCandidate fallback_store_only;
|
||||
bool have_fallback_store_only = false;
|
||||
for (auto* loop : loops) {
|
||||
if (BuildParallelCandidate(*func, loop, &cand)) {
|
||||
if (cand.has_loads) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
if (store_only_slots < 1 && !have_fallback_store_only) {
|
||||
fallback_store_only = cand;
|
||||
have_fallback_store_only = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!found && have_fallback_store_only) {
|
||||
cand = fallback_store_only;
|
||||
found = true;
|
||||
}
|
||||
if (found) break;
|
||||
}
|
||||
|
||||
if (!found) break;
|
||||
bool local_changed = ParallelizeCandidate(module, cand, slot);
|
||||
changed |= local_changed;
|
||||
if (local_changed && cand.parent) {
|
||||
RunSimpleDCE(*cand.parent);
|
||||
RunCFGSimplify(*cand.parent, module.GetContext());
|
||||
}
|
||||
if (!cand.has_loads) {
|
||||
++store_only_slots;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace passes
|
||||
} // namespace ir
|
||||
@ -0,0 +1,309 @@
|
||||
#pragma once
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace passes {
|
||||
|
||||
struct CanonicalInductionInfo {
|
||||
PhiInst* phi = nullptr;
|
||||
Value* init = nullptr;
|
||||
Value* next = nullptr;
|
||||
int step = 0;
|
||||
};
|
||||
|
||||
struct CanonicalLoopMatch {
|
||||
analysis::Loop* loop = nullptr;
|
||||
BasicBlock* preheader = nullptr;
|
||||
BasicBlock* header = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* latch = nullptr;
|
||||
CondBranchInst* header_branch = nullptr;
|
||||
CmpInst* header_cmp = nullptr;
|
||||
Value* bound = nullptr;
|
||||
CanonicalInductionInfo induction;
|
||||
std::vector<PhiInst*> header_phis;
|
||||
};
|
||||
|
||||
using ValueMap = std::unordered_map<Value*, Value*>;
|
||||
|
||||
inline Value* RemapValue(Value* value, const ValueMap& remap) {
|
||||
auto it = remap.find(value);
|
||||
return it != remap.end() ? it->second : value;
|
||||
}
|
||||
|
||||
inline std::vector<PhiInst*> CollectHeaderPhis(BasicBlock* header) {
|
||||
std::vector<PhiInst*> phis;
|
||||
if (!header) return phis;
|
||||
for (const auto& inst_ptr : header->GetInstructions()) {
|
||||
auto* phi = dynamic_cast<PhiInst*>(inst_ptr.get());
|
||||
if (!phi) break;
|
||||
phis.push_back(phi);
|
||||
}
|
||||
return phis;
|
||||
}
|
||||
|
||||
inline void InsertInstruction(BasicBlock* block,
|
||||
std::unique_ptr<Instruction> inst) {
|
||||
auto& insts = block->MutableInstructions();
|
||||
auto insert_it = insts.end();
|
||||
if (block->HasTerminator()) {
|
||||
insert_it = insts.end() - 1;
|
||||
}
|
||||
inst->SetParent(block);
|
||||
insts.insert(insert_it, std::move(inst));
|
||||
}
|
||||
|
||||
inline Instruction* AppendOwnedInstruction(BasicBlock* block,
|
||||
std::unique_ptr<Instruction> inst) {
|
||||
auto* raw = inst.get();
|
||||
InsertInstruction(block, std::move(inst));
|
||||
return raw;
|
||||
}
|
||||
|
||||
inline BinaryInst* InsertBinaryBeforeTerminator(BasicBlock* block, Opcode opcode,
|
||||
Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
auto inst = std::make_unique<BinaryInst>(opcode, lhs->GetType(), lhs, rhs, name);
|
||||
auto* raw = inst.get();
|
||||
InsertInstruction(block, std::move(inst));
|
||||
return raw;
|
||||
}
|
||||
|
||||
inline CmpInst* InsertCmpBeforeTerminator(BasicBlock* block, CmpOp cmp_op,
|
||||
Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
auto inst =
|
||||
std::make_unique<CmpInst>(cmp_op, Type::GetInt32Type(), lhs, rhs, name);
|
||||
auto* raw = inst.get();
|
||||
InsertInstruction(block, std::move(inst));
|
||||
return raw;
|
||||
}
|
||||
|
||||
inline BranchInst* InsertBranchBeforeTerminator(BasicBlock* block,
|
||||
BasicBlock* target) {
|
||||
auto inst = std::make_unique<BranchInst>(Type::GetVoidType(), target);
|
||||
auto* raw = inst.get();
|
||||
InsertInstruction(block, std::move(inst));
|
||||
return raw;
|
||||
}
|
||||
|
||||
inline CondBranchInst* InsertCondBrBeforeTerminator(BasicBlock* block, Value* cond,
|
||||
BasicBlock* true_bb,
|
||||
BasicBlock* false_bb) {
|
||||
auto inst = std::make_unique<CondBranchInst>(Type::GetVoidType(), cond,
|
||||
true_bb, false_bb);
|
||||
auto* raw = inst.get();
|
||||
InsertInstruction(block, std::move(inst));
|
||||
return raw;
|
||||
}
|
||||
|
||||
inline CallInst* InsertCallBeforeTerminator(BasicBlock* block, Function* callee,
|
||||
const std::vector<Value*>& args,
|
||||
const std::string& name) {
|
||||
auto inst = std::make_unique<CallInst>(callee->GetType(), callee, args, name);
|
||||
auto* raw = inst.get();
|
||||
InsertInstruction(block, std::move(inst));
|
||||
return raw;
|
||||
}
|
||||
|
||||
inline std::string CloneName(const std::string& base, const std::string& suffix) {
|
||||
if (base.empty()) return base;
|
||||
return base + suffix;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<Instruction> CloneInstruction(Instruction* inst,
|
||||
const ValueMap& remap,
|
||||
const std::string& suffix) {
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::Div:
|
||||
case Opcode::Mod: {
|
||||
auto* bin = static_cast<BinaryInst*>(inst);
|
||||
return std::make_unique<BinaryInst>(
|
||||
inst->GetOpcode(), inst->GetType(), RemapValue(bin->GetLhs(), remap),
|
||||
RemapValue(bin->GetRhs(), remap), CloneName(inst->GetName(), suffix));
|
||||
}
|
||||
case Opcode::Cmp: {
|
||||
auto* cmp = static_cast<CmpInst*>(inst);
|
||||
return std::make_unique<CmpInst>(
|
||||
cmp->GetCmpOp(), inst->GetType(), RemapValue(cmp->GetLhs(), remap),
|
||||
RemapValue(cmp->GetRhs(), remap), CloneName(inst->GetName(), suffix));
|
||||
}
|
||||
case Opcode::Cast: {
|
||||
auto* cast = static_cast<CastInst*>(inst);
|
||||
return std::make_unique<CastInst>(cast->GetCastOp(), inst->GetType(),
|
||||
RemapValue(cast->GetValue(), remap),
|
||||
CloneName(inst->GetName(), suffix));
|
||||
}
|
||||
case Opcode::Load: {
|
||||
auto* load = static_cast<LoadInst*>(inst);
|
||||
return std::make_unique<LoadInst>(inst->GetType(),
|
||||
RemapValue(load->GetPtr(), remap),
|
||||
CloneName(inst->GetName(), suffix));
|
||||
}
|
||||
case Opcode::Store: {
|
||||
auto* store = static_cast<StoreInst*>(inst);
|
||||
return std::make_unique<StoreInst>(
|
||||
Type::GetVoidType(), RemapValue(store->GetValue(), remap),
|
||||
RemapValue(store->GetPtr(), remap));
|
||||
}
|
||||
case Opcode::Call: {
|
||||
auto* call = static_cast<CallInst*>(inst);
|
||||
std::vector<Value*> args;
|
||||
args.reserve(call->GetNumArgs());
|
||||
for (size_t i = 0; i < call->GetNumArgs(); ++i) {
|
||||
args.push_back(RemapValue(call->GetArg(i), remap));
|
||||
}
|
||||
return std::make_unique<CallInst>(inst->GetType(), call->GetCallee(), args,
|
||||
CloneName(inst->GetName(), suffix));
|
||||
}
|
||||
case Opcode::Gep: {
|
||||
auto* gep = static_cast<GepInst*>(inst);
|
||||
return std::make_unique<GepInst>(inst->GetType(),
|
||||
RemapValue(gep->GetBase(), remap),
|
||||
RemapValue(gep->GetIndex(), remap),
|
||||
CloneName(inst->GetName(), suffix));
|
||||
}
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool IsLoopInvariantValue(Value* value, analysis::Loop* loop) {
|
||||
if (!value) return false;
|
||||
if (dynamic_cast<ConstantValue*>(value) != nullptr) return true;
|
||||
if (dynamic_cast<Argument*>(value) != nullptr) return true;
|
||||
if (dynamic_cast<GlobalVariable*>(value) != nullptr) return true;
|
||||
if (dynamic_cast<Function*>(value) != nullptr) return true;
|
||||
auto* inst = dynamic_cast<Instruction*>(value);
|
||||
return !inst || !inst->GetParent() || !loop->Contains(inst->GetParent());
|
||||
}
|
||||
|
||||
inline std::optional<CanonicalInductionInfo> MatchCanonicalInduction(
|
||||
BasicBlock* header, BasicBlock* preheader, BasicBlock* latch) {
|
||||
if (!header || !preheader || !latch) return std::nullopt;
|
||||
|
||||
for (auto* phi : CollectHeaderPhis(header)) {
|
||||
if (!phi || phi->GetType() == nullptr || !phi->GetType()->IsInt32()) continue;
|
||||
if (phi->GetNumIncoming() != 2) continue;
|
||||
|
||||
Value* init = nullptr;
|
||||
Value* next = nullptr;
|
||||
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
|
||||
auto* incoming_bb = phi->GetIncomingBlock(i);
|
||||
if (incoming_bb == preheader) {
|
||||
init = phi->GetIncomingValue(i);
|
||||
} else if (incoming_bb == latch) {
|
||||
next = phi->GetIncomingValue(i);
|
||||
}
|
||||
}
|
||||
if (!init || !next) continue;
|
||||
|
||||
auto* next_inst = dynamic_cast<BinaryInst*>(next);
|
||||
if (!next_inst) continue;
|
||||
if (next_inst->GetOpcode() != Opcode::Add &&
|
||||
next_inst->GetOpcode() != Opcode::Sub) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Value* other = nullptr;
|
||||
bool phi_on_lhs = false;
|
||||
if (next_inst->GetLhs() == phi) {
|
||||
other = next_inst->GetRhs();
|
||||
phi_on_lhs = true;
|
||||
} else if (next_inst->GetRhs() == phi) {
|
||||
other = next_inst->GetLhs();
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* step_ci = dynamic_cast<ConstantInt*>(other);
|
||||
if (!step_ci) continue;
|
||||
|
||||
int step = step_ci->GetValue();
|
||||
if (next_inst->GetOpcode() == Opcode::Sub) {
|
||||
if (!phi_on_lhs) continue;
|
||||
step = -step;
|
||||
}
|
||||
if (step == 0) continue;
|
||||
|
||||
return CanonicalInductionInfo{phi, init, next, step};
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
inline std::optional<CanonicalLoopMatch> MatchCanonicalLoop(analysis::Loop* loop) {
|
||||
if (!loop) return std::nullopt;
|
||||
auto* header = loop->GetHeader();
|
||||
auto* preheader = loop->GetPreheader();
|
||||
if (!header || !preheader) return std::nullopt;
|
||||
if (loop->GetLatches().size() != 1) return std::nullopt;
|
||||
auto* latch = loop->GetLatches().front();
|
||||
if (!latch) return std::nullopt;
|
||||
|
||||
auto* header_term = header->HasTerminator()
|
||||
? dynamic_cast<CondBranchInst*>(
|
||||
header->MutableInstructions().back().get())
|
||||
: nullptr;
|
||||
if (!header_term) return std::nullopt;
|
||||
|
||||
auto* cmp = dynamic_cast<CmpInst*>(header_term->GetCond());
|
||||
if (!cmp) return std::nullopt;
|
||||
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
if (loop->Contains(header_term->GetTrueBlock()) &&
|
||||
!loop->Contains(header_term->GetFalseBlock())) {
|
||||
body = header_term->GetTrueBlock();
|
||||
exit = header_term->GetFalseBlock();
|
||||
} else if (loop->Contains(header_term->GetFalseBlock()) &&
|
||||
!loop->Contains(header_term->GetTrueBlock())) {
|
||||
body = header_term->GetFalseBlock();
|
||||
exit = header_term->GetTrueBlock();
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto induction = MatchCanonicalInduction(header, preheader, latch);
|
||||
if (!induction.has_value()) return std::nullopt;
|
||||
|
||||
Value* bound = nullptr;
|
||||
if (cmp->GetLhs() == induction->phi &&
|
||||
IsLoopInvariantValue(cmp->GetRhs(), loop)) {
|
||||
bound = cmp->GetRhs();
|
||||
} else if (cmp->GetRhs() == induction->phi &&
|
||||
IsLoopInvariantValue(cmp->GetLhs(), loop)) {
|
||||
bound = cmp->GetLhs();
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
CanonicalLoopMatch match;
|
||||
match.loop = loop;
|
||||
match.preheader = preheader;
|
||||
match.header = header;
|
||||
match.exit = exit;
|
||||
match.body = body;
|
||||
match.latch = latch;
|
||||
match.header_branch = header_term;
|
||||
match.header_cmp = cmp;
|
||||
match.bound = bound;
|
||||
match.induction = *induction;
|
||||
match.header_phis = CollectHeaderPhis(header);
|
||||
return match;
|
||||
}
|
||||
|
||||
} // namespace passes
|
||||
} // namespace ir
|
||||
@ -0,0 +1,143 @@
|
||||
// 循环展开:
|
||||
// - 针对单块 innermost 规范循环做因子 2 的保守展开
|
||||
// - 使用一次额外比较保护余数路径,避免要求静态 trip count
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
namespace ir {
|
||||
namespace passes {
|
||||
|
||||
namespace {
|
||||
|
||||
bool IsUnrollableLoop(const CanonicalLoopMatch& match) {
|
||||
if (match.induction.step <= 0) return false;
|
||||
if (match.loop->GetChildren().size() != 0) return false;
|
||||
if (match.loop->GetBlocks().size() != 2) return false;
|
||||
if (match.body != match.latch) return false;
|
||||
if (!match.body || !match.body->HasTerminator()) return false;
|
||||
|
||||
auto* body_term =
|
||||
dynamic_cast<BranchInst*>(match.body->MutableInstructions().back().get());
|
||||
if (!body_term || body_term->GetTarget() != match.header) return false;
|
||||
|
||||
if (match.header_cmp->GetLhs() != match.induction.phi) return false;
|
||||
if (match.header_cmp->GetCmpOp() != CmpOp::Lt &&
|
||||
match.header_cmp->GetCmpOp() != CmpOp::Le) {
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t body_inst_count = match.body->GetInstructions().size();
|
||||
if (body_inst_count <= 1 || body_inst_count > 18) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
Value* GetLatchIncomingForPhi(PhiInst* phi, BasicBlock* latch) {
|
||||
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
|
||||
if (phi->GetIncomingBlock(i) == latch) {
|
||||
return phi->GetIncomingValue(i);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool RunUnrollOnLoop(Function& func, const CanonicalLoopMatch& match,
|
||||
Context& ctx, int unroll_index) {
|
||||
(void)unroll_index;
|
||||
if (!IsUnrollableLoop(match)) return false;
|
||||
|
||||
auto* body = match.body;
|
||||
auto* header = match.header;
|
||||
auto* body_term =
|
||||
static_cast<BranchInst*>(body->MutableInstructions().back().get());
|
||||
(void)body_term;
|
||||
|
||||
std::string block_suffix = ctx.NextTemp();
|
||||
if (!block_suffix.empty() && block_suffix.front() == '%') {
|
||||
block_suffix.erase(0, 1);
|
||||
}
|
||||
auto* body2 = func.CreateBlock(body->GetName() + ".unroll." + block_suffix);
|
||||
|
||||
std::unordered_map<Value*, Value*> seed_map;
|
||||
for (auto* phi : match.header_phis) {
|
||||
auto* incoming = GetLatchIncomingForPhi(phi, match.latch);
|
||||
if (!incoming) return false;
|
||||
seed_map.emplace(phi, incoming);
|
||||
}
|
||||
|
||||
ValueMap clone_map = seed_map;
|
||||
std::vector<Instruction*> originals;
|
||||
for (const auto& inst_ptr : body->GetInstructions()) {
|
||||
if (inst_ptr.get()->IsTerminator()) continue;
|
||||
originals.push_back(inst_ptr.get());
|
||||
}
|
||||
|
||||
for (auto* inst : originals) {
|
||||
auto cloned = CloneInstruction(inst, clone_map, ".u2");
|
||||
if (!cloned) return false;
|
||||
auto* raw = cloned.get();
|
||||
body2->MutableInstructions().push_back(std::move(cloned));
|
||||
raw->SetParent(body2);
|
||||
clone_map[inst] = raw;
|
||||
}
|
||||
|
||||
body2->Append<BranchInst>(Type::GetVoidType(), header);
|
||||
|
||||
Value* iv_after_one = GetLatchIncomingForPhi(match.induction.phi, match.latch);
|
||||
if (!iv_after_one) return false;
|
||||
|
||||
auto* first_cmp = InsertCmpBeforeTerminator(
|
||||
body, match.header_cmp->GetCmpOp(), iv_after_one, match.bound,
|
||||
ctx.NextTemp());
|
||||
|
||||
body->RemoveInstruction(body->MutableInstructions().back().get());
|
||||
body->Append<CondBranchInst>(Type::GetVoidType(), first_cmp, body2, header);
|
||||
|
||||
body->AddSuccessor(body2);
|
||||
body2->AddPredecessor(body);
|
||||
body2->AddSuccessor(header);
|
||||
header->AddPredecessor(body2);
|
||||
|
||||
for (auto* phi : match.header_phis) {
|
||||
auto* incoming = GetLatchIncomingForPhi(phi, match.latch);
|
||||
Value* second_value = incoming;
|
||||
auto it = clone_map.find(incoming);
|
||||
if (it != clone_map.end()) {
|
||||
second_value = it->second;
|
||||
} else {
|
||||
second_value = RemapValue(incoming, clone_map);
|
||||
}
|
||||
phi->AddIncoming(second_value, body2);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoopUnroll(Function& func, Context& ctx) {
|
||||
if (func.IsExternal()) return false;
|
||||
|
||||
analysis::DominatorTree dom_tree(func);
|
||||
analysis::LoopInfo loop_info(func, dom_tree);
|
||||
|
||||
bool changed = false;
|
||||
int unroll_index = 0;
|
||||
for (const auto& loop_ptr : loop_info.GetLoops()) {
|
||||
auto match = MatchCanonicalLoop(loop_ptr.get());
|
||||
if (!match.has_value()) continue;
|
||||
if (RunUnrollOnLoop(func, *match, ctx, unroll_index++)) {
|
||||
changed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace passes
|
||||
} // namespace ir
|
||||
@ -0,0 +1,115 @@
|
||||
// 强度削弱:
|
||||
// - 识别规范归纳变量 iv
|
||||
// - 将循环内的 iv * C 改写成辅助 phi + 常量增量递推
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
namespace ir {
|
||||
namespace passes {
|
||||
|
||||
namespace {
|
||||
|
||||
bool IsStrengthReductionCandidate(Instruction* inst, PhiInst* iv) {
|
||||
auto* bin = dynamic_cast<BinaryInst*>(inst);
|
||||
if (!bin || bin->GetOpcode() != Opcode::Mul) return false;
|
||||
if (!bin->GetType() || !bin->GetType()->IsInt32()) return false;
|
||||
return bin->GetLhs() == iv || bin->GetRhs() == iv;
|
||||
}
|
||||
|
||||
int ExtractScale(BinaryInst* mul, PhiInst* iv) {
|
||||
auto* lhs_ci = dynamic_cast<ConstantInt*>(mul->GetLhs());
|
||||
auto* rhs_ci = dynamic_cast<ConstantInt*>(mul->GetRhs());
|
||||
if (mul->GetLhs() == iv && rhs_ci) return rhs_ci->GetValue();
|
||||
if (mul->GetRhs() == iv && lhs_ci) return lhs_ci->GetValue();
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool ReplaceMulWithRecurrence(Function& func, const CanonicalLoopMatch& match,
|
||||
BinaryInst* mul, Context& ctx) {
|
||||
(void)func;
|
||||
const int scale = ExtractScale(mul, match.induction.phi);
|
||||
if (scale == 0) return false;
|
||||
if (match.latch != mul->GetParent() &&
|
||||
!match.loop->Contains(mul->GetParent())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* init_scale =
|
||||
InsertBinaryBeforeTerminator(match.preheader, Opcode::Mul,
|
||||
match.induction.init, ctx.GetConstInt(scale),
|
||||
ctx.NextTemp());
|
||||
|
||||
auto* sr_phi =
|
||||
match.header->PrependPhi(Type::GetInt32Type(), ctx.NextTemp());
|
||||
|
||||
auto* step_scale =
|
||||
InsertBinaryBeforeTerminator(match.latch, Opcode::Mul,
|
||||
ctx.GetConstInt(match.induction.step),
|
||||
ctx.GetConstInt(scale), ctx.NextTemp());
|
||||
auto* sr_next =
|
||||
InsertBinaryBeforeTerminator(match.latch, Opcode::Add, sr_phi, step_scale,
|
||||
ctx.NextTemp());
|
||||
|
||||
sr_phi->AddIncoming(init_scale, match.preheader);
|
||||
sr_phi->AddIncoming(sr_next, match.latch);
|
||||
|
||||
mul->ReplaceAllUsesWith(sr_phi);
|
||||
mul->GetParent()->RemoveInstruction(mul);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunStrengthReduction(Function& func, Context& ctx) {
|
||||
if (func.IsExternal()) return false;
|
||||
|
||||
analysis::DominatorTree dom_tree(func);
|
||||
analysis::LoopInfo loop_info(func, dom_tree);
|
||||
|
||||
std::vector<analysis::Loop*> loops;
|
||||
for (const auto& loop_ptr : loop_info.GetLoops()) {
|
||||
loops.push_back(loop_ptr.get());
|
||||
}
|
||||
std::sort(loops.begin(), loops.end(),
|
||||
[](analysis::Loop* lhs, analysis::Loop* rhs) {
|
||||
if (lhs->GetDepth() != rhs->GetDepth()) {
|
||||
return lhs->GetDepth() > rhs->GetDepth();
|
||||
}
|
||||
return lhs->GetHeader()->GetName() < rhs->GetHeader()->GetName();
|
||||
});
|
||||
|
||||
bool changed = false;
|
||||
for (auto* loop : loops) {
|
||||
auto match = MatchCanonicalLoop(loop);
|
||||
if (!match.has_value()) continue;
|
||||
|
||||
std::vector<BinaryInst*> candidates;
|
||||
for (auto* block : loop->GetBlocks()) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (IsStrengthReductionCandidate(inst, match->induction.phi)) {
|
||||
auto* mul = static_cast<BinaryInst*>(inst);
|
||||
if (ExtractScale(mul, match->induction.phi) != 0) {
|
||||
candidates.push_back(mul);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* mul : candidates) {
|
||||
changed |= ReplaceMulWithRecurrence(func, *match, mul, ctx);
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace passes
|
||||
} // namespace ir
|
||||
Loading…
Reference in new issue