You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

263 lines
8.1 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// 循环不变代码外提LICM
// - 基于 DominatorTree + LoopInfo 识别自然循环
// - 将循环内不变且可安全提前执行的指令移动到 preheader
// - 顺带消除同一循环中重复的不变表达式
#include "ir/IR.h"
#include <algorithm>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace passes {
namespace {
struct ExprKey {
Opcode opcode;
CmpOp cmp_op = CmpOp::Eq;
CastOp cast_op = CastOp::IntToFloat;
std::vector<Value*> operands;
bool operator==(const ExprKey& other) const {
return opcode == other.opcode && cmp_op == other.cmp_op &&
cast_op == other.cast_op && operands == other.operands;
}
};
struct ExprKeyHash {
size_t operator()(const ExprKey& key) const {
size_t h = std::hash<int>()(static_cast<int>(key.opcode));
h ^= std::hash<int>()(static_cast<int>(key.cmp_op)) + 0x9e3779b9 +
(h << 6) + (h >> 2);
h ^= std::hash<int>()(static_cast<int>(key.cast_op)) + 0x9e3779b9 +
(h << 6) + (h >> 2);
for (auto* operand : key.operands) {
h ^= std::hash<void*>()(operand) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
bool IsSupportedInvariantOpcode(Opcode op) {
switch (op) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Cmp:
case Opcode::Cast:
case Opcode::Gep:
case Opcode::Load:
return true;
default:
return false;
}
}
bool IsLoopInvariantValue(Value* value, analysis::Loop* loop,
const std::unordered_set<Instruction*>& invariant) {
if (!value) return false;
if (dynamic_cast<ConstantValue*>(value) != nullptr) return true;
if (dynamic_cast<Argument*>(value) != nullptr) return true;
if (dynamic_cast<GlobalVariable*>(value) != nullptr) return true;
if (dynamic_cast<Function*>(value) != nullptr) return true;
if (dynamic_cast<BasicBlock*>(value) != nullptr) return true;
auto* inst = dynamic_cast<Instruction*>(value);
if (!inst) return true;
auto* parent = inst->GetParent();
if (!parent || !loop->Contains(parent)) return true;
return invariant.count(inst) != 0;
}
Value* GetPointerBase(Value* ptr) {
while (auto* gep = dynamic_cast<GepInst*>(ptr)) {
ptr = gep->GetBase();
}
return ptr;
}
bool MayAlias(Value* lhs, Value* rhs) {
if (lhs == rhs) return true;
return GetPointerBase(lhs) == GetPointerBase(rhs);
}
bool IsStoredInLoop(Value* ptr, analysis::Loop* loop) {
for (auto* block : loop->GetBlocks()) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* store = dynamic_cast<StoreInst*>(inst_ptr.get());
if (store && MayAlias(store->GetPtr(), ptr)) {
return true;
}
}
}
return false;
}
bool IsSafeInvariantInstruction(Instruction* inst, analysis::Loop* loop,
const std::unordered_set<Instruction*>& invariant) {
if (!inst || !IsSupportedInvariantOpcode(inst->GetOpcode())) return false;
if (inst->GetOpcode() == Opcode::Load) {
auto* load = static_cast<LoadInst*>(inst);
if (!IsLoopInvariantValue(load->GetPtr(), loop, invariant)) return false;
return !IsStoredInLoop(load->GetPtr(), loop);
}
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
if (!IsLoopInvariantValue(inst->GetOperand(i), loop, invariant)) {
return false;
}
}
return true;
}
ExprKey MakeExprKey(Instruction* inst) {
ExprKey key;
key.opcode = inst->GetOpcode();
if (auto* cmp = dynamic_cast<CmpInst*>(inst)) {
key.cmp_op = cmp->GetCmpOp();
}
if (auto* cast = dynamic_cast<CastInst*>(inst)) {
key.cast_op = cast->GetCastOp();
}
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
key.operands.push_back(inst->GetOperand(i));
}
return key;
}
std::vector<Instruction*> CollectLoopInstructions(analysis::Loop* loop,
Function& func) {
std::vector<Instruction*> ordered;
for (const auto& bb_ptr : func.GetBlocks()) {
auto* block = bb_ptr.get();
if (!block || !loop->Contains(block)) continue;
for (const auto& inst_ptr : block->GetInstructions()) {
ordered.push_back(inst_ptr.get());
}
}
return ordered;
}
std::unique_ptr<Instruction> DetachInstruction(BasicBlock* block,
Instruction* inst) {
auto& insts = block->MutableInstructions();
auto it = std::find_if(insts.begin(), insts.end(),
[inst](const std::unique_ptr<Instruction>& ptr) {
return ptr.get() == inst;
});
if (it == insts.end()) return nullptr;
std::unique_ptr<Instruction> owned = std::move(*it);
insts.erase(it);
owned->SetParent(nullptr);
return owned;
}
void InsertBeforeTerminator(BasicBlock* block, std::unique_ptr<Instruction> inst) {
auto& insts = block->MutableInstructions();
auto insert_it = insts.end();
if (block->HasTerminator()) {
insert_it = insts.end() - 1;
}
inst->SetParent(block);
insts.insert(insert_it, std::move(inst));
}
void SeedAvailableInvariants(
BasicBlock* preheader, analysis::Loop* loop,
std::unordered_map<ExprKey, Instruction*, ExprKeyHash>& available,
const std::unordered_set<Instruction*>& invariant) {
if (!preheader) return;
for (const auto& inst_ptr : preheader->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!inst || !IsSupportedInvariantOpcode(inst->GetOpcode())) continue;
if (!IsSafeInvariantInstruction(inst, loop, invariant)) continue;
available.emplace(MakeExprKey(inst), inst);
}
}
bool RunLICMOnLoop(analysis::Loop* loop, Function& func) {
auto* preheader = loop->GetPreheader();
if (!preheader) return false;
bool changed = false;
std::unordered_set<Instruction*> invariant;
bool progress = true;
while (progress) {
progress = false;
std::unordered_map<ExprKey, Instruction*, ExprKeyHash> available;
SeedAvailableInvariants(preheader, loop, available, invariant);
for (auto* inst : CollectLoopInstructions(loop, func)) {
if (!inst || invariant.count(inst) != 0) continue;
auto* block = inst->GetParent();
if (!block || block == preheader) continue;
if (inst->GetOpcode() == Opcode::Phi || inst->IsTerminator() ||
inst->GetOpcode() == Opcode::Alloca || inst->GetOpcode() == Opcode::Ret ||
inst->GetOpcode() == Opcode::Store || inst->GetOpcode() == Opcode::Call ||
inst->GetOpcode() == Opcode::Div || inst->GetOpcode() == Opcode::Mod) {
continue;
}
if (!IsSafeInvariantInstruction(inst, loop, invariant)) continue;
ExprKey key = MakeExprKey(inst);
auto avail_it = available.find(key);
if (avail_it != available.end()) {
inst->ReplaceAllUsesWith(avail_it->second);
block->RemoveInstruction(inst);
} else {
auto owned = DetachInstruction(block, inst);
if (!owned) continue;
auto* moved = owned.get();
InsertBeforeTerminator(preheader, std::move(owned));
available.emplace(std::move(key), moved);
invariant.insert(moved);
}
changed = true;
progress = true;
break;
}
}
return changed;
}
} // namespace
bool RunLICM(Function& func) {
if (func.IsExternal()) return false;
analysis::DominatorTree dom_tree(func);
analysis::LoopInfo loop_info(func, dom_tree);
std::vector<analysis::Loop*> ordered_loops;
for (const auto& loop_ptr : loop_info.GetLoops()) {
ordered_loops.push_back(loop_ptr.get());
}
std::sort(ordered_loops.begin(), ordered_loops.end(),
[](analysis::Loop* lhs, analysis::Loop* rhs) {
if (lhs->GetDepth() != rhs->GetDepth()) {
return lhs->GetDepth() > rhs->GetDepth();
}
if (lhs->GetBlocks().size() != rhs->GetBlocks().size()) {
return lhs->GetBlocks().size() < rhs->GetBlocks().size();
}
return lhs->GetHeader()->GetName() < rhs->GetHeader()->GetName();
});
bool changed = false;
for (auto* loop : ordered_loops) {
changed |= RunLICMOnLoop(loop, func);
}
return changed;
}
} // namespace passes
} // namespace ir