forked from p4jyxwm3q/nudt-compiler-cpp
master
parent
252073efe8
commit
b33ede5457
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,73 @@
|
||||
#pragma once
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
|
||||
class DominatorTree {
|
||||
public:
|
||||
explicit DominatorTree(Function& function);
|
||||
|
||||
void Recalculate();
|
||||
|
||||
Function& GetFunction() const { return *function_; }
|
||||
bool IsReachable(BasicBlock* block) const;
|
||||
bool Dominates(BasicBlock* dom, BasicBlock* node) const;
|
||||
bool Dominates(Instruction* dom, Instruction* user) const;
|
||||
BasicBlock* GetIDom(BasicBlock* block) const;
|
||||
const std::vector<BasicBlock*>& GetChildren(BasicBlock* block) const;
|
||||
const std::vector<BasicBlock*>& GetReversePostOrder() const {
|
||||
return reverse_post_order_;
|
||||
}
|
||||
|
||||
private:
|
||||
Function* function_ = nullptr;
|
||||
std::vector<BasicBlock*> reverse_post_order_;
|
||||
std::unordered_map<BasicBlock*, std::size_t> block_index_;
|
||||
std::vector<std::vector<std::uint8_t>> dominates_;
|
||||
std::unordered_map<BasicBlock*, BasicBlock*> immediate_dominator_;
|
||||
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dom_children_;
|
||||
};
|
||||
|
||||
struct Loop {
|
||||
BasicBlock* header = nullptr;
|
||||
std::unordered_set<BasicBlock*> blocks;
|
||||
std::vector<BasicBlock*> block_list;
|
||||
std::vector<BasicBlock*> latches;
|
||||
std::vector<BasicBlock*> exiting_blocks;
|
||||
std::vector<BasicBlock*> exit_blocks;
|
||||
BasicBlock* preheader = nullptr;
|
||||
Loop* parent = nullptr;
|
||||
std::vector<Loop*> subloops;
|
||||
|
||||
bool Contains(BasicBlock* block) const;
|
||||
bool Contains(const Loop* other) const;
|
||||
bool IsInnermost() const;
|
||||
};
|
||||
|
||||
class LoopInfo {
|
||||
public:
|
||||
LoopInfo(Function& function, const DominatorTree& dom_tree);
|
||||
|
||||
void Recalculate();
|
||||
|
||||
const std::vector<std::unique_ptr<Loop>>& GetLoops() const { return loops_; }
|
||||
std::vector<Loop*> GetTopLevelLoops() const;
|
||||
std::vector<Loop*> GetLoopsInPostOrder() const;
|
||||
Loop* GetLoopFor(BasicBlock* block) const;
|
||||
|
||||
private:
|
||||
Function* function_ = nullptr;
|
||||
const DominatorTree* dom_tree_ = nullptr;
|
||||
std::vector<std::unique_ptr<Loop>> loops_;
|
||||
std::vector<Loop*> top_level_loops_;
|
||||
std::unordered_map<BasicBlock*, Loop*> block_to_loop_;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
@ -1,4 +1,167 @@
|
||||
// 支配树分析:
|
||||
// - 构建/查询 Dominator Tree 及相关关系
|
||||
// - 为 mem2reg、CFG 优化与循环分析提供基础能力
|
||||
#include "ir/Analysis.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
std::vector<BasicBlock*> BuildReversePostOrder(Function& function) {
|
||||
std::vector<BasicBlock*> post_order;
|
||||
auto* entry = function.GetEntryBlock();
|
||||
if (!entry) {
|
||||
return post_order;
|
||||
}
|
||||
|
||||
std::unordered_set<BasicBlock*> visited;
|
||||
std::function<void(BasicBlock*)> dfs = [&](BasicBlock* block) {
|
||||
if (!block || !visited.insert(block).second) {
|
||||
return;
|
||||
}
|
||||
for (auto* succ : block->GetSuccessors()) {
|
||||
dfs(succ);
|
||||
}
|
||||
post_order.push_back(block);
|
||||
};
|
||||
dfs(entry);
|
||||
std::reverse(post_order.begin(), post_order.end());
|
||||
return post_order;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
DominatorTree::DominatorTree(Function& function) : function_(&function) {
|
||||
Recalculate();
|
||||
}
|
||||
|
||||
void DominatorTree::Recalculate() {
|
||||
reverse_post_order_ = BuildReversePostOrder(*function_);
|
||||
block_index_.clear();
|
||||
dominates_.clear();
|
||||
immediate_dominator_.clear();
|
||||
dom_children_.clear();
|
||||
|
||||
const auto num_blocks = reverse_post_order_.size();
|
||||
for (std::size_t i = 0; i < num_blocks; ++i) {
|
||||
block_index_.emplace(reverse_post_order_[i], i);
|
||||
}
|
||||
if (num_blocks == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dominates_.assign(num_blocks, std::vector<std::uint8_t>(num_blocks, 1));
|
||||
dominates_[0].assign(num_blocks, 0);
|
||||
dominates_[0][0] = 1;
|
||||
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
for (std::size_t i = 1; i < num_blocks; ++i) {
|
||||
auto* block = reverse_post_order_[i];
|
||||
std::vector<std::uint8_t> next(num_blocks, 1);
|
||||
bool has_reachable_pred = false;
|
||||
for (auto* pred : block->GetPredecessors()) {
|
||||
auto pred_it = block_index_.find(pred);
|
||||
if (pred_it == block_index_.end()) {
|
||||
continue;
|
||||
}
|
||||
has_reachable_pred = true;
|
||||
const auto& pred_dom = dominates_[pred_it->second];
|
||||
for (std::size_t bit = 0; bit < num_blocks; ++bit) {
|
||||
next[bit] &= pred_dom[bit];
|
||||
}
|
||||
}
|
||||
if (!has_reachable_pred) {
|
||||
next.assign(num_blocks, 0);
|
||||
}
|
||||
next[i] = 1;
|
||||
if (next != dominates_[i]) {
|
||||
dominates_[i] = std::move(next);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (std::size_t i = 1; i < num_blocks; ++i) {
|
||||
auto* block = reverse_post_order_[i];
|
||||
BasicBlock* idom = nullptr;
|
||||
for (std::size_t candidate = 0; candidate < num_blocks; ++candidate) {
|
||||
if (candidate == i || !dominates_[i][candidate]) {
|
||||
continue;
|
||||
}
|
||||
auto* candidate_block = reverse_post_order_[candidate];
|
||||
bool immediate = true;
|
||||
for (std::size_t other = 0; other < num_blocks; ++other) {
|
||||
if (other == i || other == candidate || !dominates_[i][other]) {
|
||||
continue;
|
||||
}
|
||||
if (Dominates(reverse_post_order_[other], candidate_block)) {
|
||||
immediate = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (immediate) {
|
||||
idom = candidate_block;
|
||||
break;
|
||||
}
|
||||
}
|
||||
immediate_dominator_.emplace(block, idom);
|
||||
if (idom) {
|
||||
dom_children_[idom].push_back(block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool DominatorTree::IsReachable(BasicBlock* block) const {
|
||||
return block != nullptr && block_index_.find(block) != block_index_.end();
|
||||
}
|
||||
|
||||
bool DominatorTree::Dominates(BasicBlock* dom, BasicBlock* node) const {
|
||||
if (!dom || !node) {
|
||||
return false;
|
||||
}
|
||||
const auto dom_it = block_index_.find(dom);
|
||||
const auto node_it = block_index_.find(node);
|
||||
if (dom_it == block_index_.end() || node_it == block_index_.end()) {
|
||||
return false;
|
||||
}
|
||||
return dominates_[node_it->second][dom_it->second] != 0;
|
||||
}
|
||||
|
||||
bool DominatorTree::Dominates(Instruction* dom, Instruction* user) const {
|
||||
if (!dom || !user) {
|
||||
return false;
|
||||
}
|
||||
if (dom == user) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto* dom_block = dom->GetParent();
|
||||
auto* user_block = user->GetParent();
|
||||
if (dom_block != user_block) {
|
||||
return Dominates(dom_block, user_block);
|
||||
}
|
||||
|
||||
for (const auto& inst_ptr : dom_block->GetInstructions()) {
|
||||
if (inst_ptr.get() == dom) {
|
||||
return true;
|
||||
}
|
||||
if (inst_ptr.get() == user) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
BasicBlock* DominatorTree::GetIDom(BasicBlock* block) const {
|
||||
auto it = immediate_dominator_.find(block);
|
||||
return it == immediate_dominator_.end() ? nullptr : it->second;
|
||||
}
|
||||
|
||||
const std::vector<BasicBlock*>& DominatorTree::GetChildren(BasicBlock* block) const {
|
||||
static const std::vector<BasicBlock*> kEmpty;
|
||||
auto it = dom_children_.find(block);
|
||||
return it == dom_children_.end() ? kEmpty : it->second;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
@ -1,4 +1,214 @@
|
||||
// 循环分析:
|
||||
// - 识别循环结构与层级关系
|
||||
// - 为后续优化(可选)提供循环信息
|
||||
#include "ir/Analysis.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
std::vector<BasicBlock*> CollectNaturalLoopBlocks(BasicBlock* header,
|
||||
BasicBlock* latch) {
|
||||
std::vector<BasicBlock*> stack{latch};
|
||||
std::unordered_set<BasicBlock*> loop_blocks{header, latch};
|
||||
while (!stack.empty()) {
|
||||
auto* block = stack.back();
|
||||
stack.pop_back();
|
||||
for (auto* pred : block->GetPredecessors()) {
|
||||
if (!pred || !loop_blocks.insert(pred).second) {
|
||||
continue;
|
||||
}
|
||||
stack.push_back(pred);
|
||||
}
|
||||
}
|
||||
return {loop_blocks.begin(), loop_blocks.end()};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool Loop::Contains(BasicBlock* block) const {
|
||||
return block != nullptr && blocks.find(block) != blocks.end();
|
||||
}
|
||||
|
||||
bool Loop::Contains(const Loop* other) const {
|
||||
if (!other) {
|
||||
return false;
|
||||
}
|
||||
for (auto* block : other->blocks) {
|
||||
if (!Contains(block)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Loop::IsInnermost() const { return subloops.empty(); }
|
||||
|
||||
LoopInfo::LoopInfo(Function& function, const DominatorTree& dom_tree)
|
||||
: function_(&function), dom_tree_(&dom_tree) {
|
||||
Recalculate();
|
||||
}
|
||||
|
||||
void LoopInfo::Recalculate() {
|
||||
loops_.clear();
|
||||
top_level_loops_.clear();
|
||||
block_to_loop_.clear();
|
||||
|
||||
std::unordered_map<BasicBlock*, Loop*> loops_by_header;
|
||||
for (auto* block : dom_tree_->GetReversePostOrder()) {
|
||||
for (auto* succ : block->GetSuccessors()) {
|
||||
if (!dom_tree_->Dominates(succ, block)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Loop* loop = nullptr;
|
||||
auto it = loops_by_header.find(succ);
|
||||
if (it == loops_by_header.end()) {
|
||||
auto new_loop = std::make_unique<Loop>();
|
||||
new_loop->header = succ;
|
||||
loop = new_loop.get();
|
||||
loops_.push_back(std::move(new_loop));
|
||||
loops_by_header.emplace(succ, loop);
|
||||
} else {
|
||||
loop = it->second;
|
||||
}
|
||||
|
||||
if (std::find(loop->latches.begin(), loop->latches.end(), block) ==
|
||||
loop->latches.end()) {
|
||||
loop->latches.push_back(block);
|
||||
}
|
||||
for (auto* natural_block : CollectNaturalLoopBlocks(succ, block)) {
|
||||
loop->blocks.insert(natural_block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<BasicBlock*, std::size_t> function_order;
|
||||
for (std::size_t i = 0; i < function_->GetBlocks().size(); ++i) {
|
||||
function_order.emplace(function_->GetBlocks()[i].get(), i);
|
||||
}
|
||||
|
||||
for (const auto& loop_ptr : loops_) {
|
||||
auto& loop = *loop_ptr;
|
||||
loop.block_list.clear();
|
||||
loop.exiting_blocks.clear();
|
||||
loop.exit_blocks.clear();
|
||||
loop.subloops.clear();
|
||||
loop.parent = nullptr;
|
||||
|
||||
for (const auto& block_ptr : function_->GetBlocks()) {
|
||||
if (loop.Contains(block_ptr.get())) {
|
||||
loop.block_list.push_back(block_ptr.get());
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(loop.latches.begin(), loop.latches.end(),
|
||||
[&](BasicBlock* lhs, BasicBlock* rhs) {
|
||||
return function_order[lhs] < function_order[rhs];
|
||||
});
|
||||
|
||||
std::vector<BasicBlock*> outside_preds;
|
||||
for (auto* pred : loop.header->GetPredecessors()) {
|
||||
if (!loop.Contains(pred)) {
|
||||
outside_preds.push_back(pred);
|
||||
}
|
||||
}
|
||||
if (outside_preds.size() == 1 &&
|
||||
outside_preds.front()->GetSuccessors().size() == 1) {
|
||||
loop.preheader = outside_preds.front();
|
||||
} else {
|
||||
loop.preheader = nullptr;
|
||||
}
|
||||
|
||||
std::unordered_set<BasicBlock*> exiting_seen;
|
||||
std::unordered_set<BasicBlock*> exit_seen;
|
||||
for (auto* block : loop.block_list) {
|
||||
for (auto* succ : block->GetSuccessors()) {
|
||||
if (loop.Contains(succ)) {
|
||||
continue;
|
||||
}
|
||||
if (exiting_seen.insert(block).second) {
|
||||
loop.exiting_blocks.push_back(block);
|
||||
}
|
||||
if (exit_seen.insert(succ).second) {
|
||||
loop.exit_blocks.push_back(succ);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(loop.exiting_blocks.begin(), loop.exiting_blocks.end(),
|
||||
[&](BasicBlock* lhs, BasicBlock* rhs) {
|
||||
return function_order[lhs] < function_order[rhs];
|
||||
});
|
||||
std::sort(loop.exit_blocks.begin(), loop.exit_blocks.end(),
|
||||
[&](BasicBlock* lhs, BasicBlock* rhs) {
|
||||
return function_order[lhs] < function_order[rhs];
|
||||
});
|
||||
}
|
||||
|
||||
for (const auto& loop_ptr : loops_) {
|
||||
auto* loop = loop_ptr.get();
|
||||
Loop* parent = nullptr;
|
||||
for (const auto& candidate_ptr : loops_) {
|
||||
auto* candidate = candidate_ptr.get();
|
||||
if (candidate == loop || !candidate->Contains(loop)) {
|
||||
continue;
|
||||
}
|
||||
if (!parent || candidate->blocks.size() < parent->blocks.size()) {
|
||||
parent = candidate;
|
||||
}
|
||||
}
|
||||
loop->parent = parent;
|
||||
if (parent) {
|
||||
parent->subloops.push_back(loop);
|
||||
} else {
|
||||
top_level_loops_.push_back(loop);
|
||||
}
|
||||
}
|
||||
|
||||
auto loop_order = [&](Loop* lhs, Loop* rhs) {
|
||||
return function_order[lhs->header] < function_order[rhs->header];
|
||||
};
|
||||
std::sort(top_level_loops_.begin(), top_level_loops_.end(), loop_order);
|
||||
for (const auto& loop_ptr : loops_) {
|
||||
std::sort(loop_ptr->subloops.begin(), loop_ptr->subloops.end(), loop_order);
|
||||
}
|
||||
|
||||
for (const auto& block_ptr : function_->GetBlocks()) {
|
||||
Loop* innermost = nullptr;
|
||||
for (const auto& loop_ptr : loops_) {
|
||||
auto* loop = loop_ptr.get();
|
||||
if (!loop->Contains(block_ptr.get())) {
|
||||
continue;
|
||||
}
|
||||
if (!innermost || loop->blocks.size() < innermost->blocks.size()) {
|
||||
innermost = loop;
|
||||
}
|
||||
}
|
||||
if (innermost) {
|
||||
block_to_loop_.emplace(block_ptr.get(), innermost);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Loop*> LoopInfo::GetTopLevelLoops() const { return top_level_loops_; }
|
||||
|
||||
std::vector<Loop*> LoopInfo::GetLoopsInPostOrder() const {
|
||||
std::vector<Loop*> ordered;
|
||||
std::function<void(Loop*)> dfs = [&](Loop* loop) {
|
||||
for (auto* subloop : loop->subloops) {
|
||||
dfs(subloop);
|
||||
}
|
||||
ordered.push_back(loop);
|
||||
};
|
||||
for (auto* loop : top_level_loops_) {
|
||||
dfs(loop);
|
||||
}
|
||||
return ordered;
|
||||
}
|
||||
|
||||
Loop* LoopInfo::GetLoopFor(BasicBlock* block) const {
|
||||
auto it = block_to_loop_.find(block);
|
||||
return it == block_to_loop_.end() ? nullptr : it->second;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
@ -1,4 +1,107 @@
|
||||
// CFG 简化:
|
||||
// - 删除不可达块、合并空块、简化分支等
|
||||
// - 改善 IR 结构,便于后续优化与后端生成
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "PassUtils.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
bool TryGetConstBranchTarget(CondBrInst* br, BasicBlock*& target, BasicBlock*& removed) {
|
||||
if (!br) {
|
||||
return false;
|
||||
}
|
||||
auto* then_block = br->GetThenBlock();
|
||||
auto* else_block = br->GetElseBlock();
|
||||
if (then_block == else_block) {
|
||||
target = then_block;
|
||||
removed = nullptr;
|
||||
return true;
|
||||
}
|
||||
if (auto* cond = dyncast<ConstantI1>(br->GetCondition())) {
|
||||
target = cond->GetValue() ? then_block : else_block;
|
||||
removed = cond->GetValue() ? else_block : then_block;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool SimplifyBlockTerminator(BasicBlock* block) {
|
||||
if (!block || block->GetInstructions().empty()) {
|
||||
return false;
|
||||
}
|
||||
auto* term = block->GetInstructions().back().get();
|
||||
auto* condbr = dyncast<CondBrInst>(term);
|
||||
if (!condbr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
BasicBlock* target = nullptr;
|
||||
BasicBlock* removed = nullptr;
|
||||
if (!TryGetConstBranchTarget(condbr, target, removed)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (removed) {
|
||||
passutils::RemoveIncomingFromSuccessor(removed, block);
|
||||
removed->RemovePredecessor(block);
|
||||
block->RemoveSuccessor(removed);
|
||||
}
|
||||
passutils::ReplaceTerminatorWithBr(block, target);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SimplifyPhiNodes(Function& function) {
|
||||
bool changed = false;
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
bool local_changed = true;
|
||||
while (local_changed) {
|
||||
local_changed = false;
|
||||
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
if (!passutils::SimplifyPhiInst(phi)) {
|
||||
continue;
|
||||
}
|
||||
local_changed = true;
|
||||
changed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RunCFGSimplifyOnFunction(Function& function) {
|
||||
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
changed |= SimplifyBlockTerminator(block_ptr.get());
|
||||
}
|
||||
|
||||
changed |= passutils::RemoveUnreachableBlocks(function);
|
||||
changed |= SimplifyPhiNodes(function);
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunCFGSimplify(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunCFGSimplifyOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
@ -1,4 +1,469 @@
|
||||
// IR 常量折叠:
|
||||
// - 折叠可判定的常量表达式
|
||||
// - 简化常量控制流分支(按实现范围裁剪)
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "PassUtils.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
Value* GetInt32Const(Context& ctx, std::int32_t value) {
|
||||
return ctx.GetConstInt(static_cast<int>(value));
|
||||
}
|
||||
|
||||
Value* GetBoolConst(Context& ctx, bool value) { return ctx.GetConstBool(value); }
|
||||
|
||||
Value* GetFloatConst(float value) {
|
||||
return new ConstantFloat(Type::GetFloatType(), value);
|
||||
}
|
||||
|
||||
bool TryGetInt32(Value* value, std::int32_t& out) {
|
||||
if (auto* ci = dyncast<ConstantInt>(value)) {
|
||||
out = static_cast<std::int32_t>(ci->GetValue());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool TryGetBool(Value* value, bool& out) {
|
||||
if (auto* cb = dyncast<ConstantI1>(value)) {
|
||||
out = cb->GetValue();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool TryGetFloat(Value* value, float& out) {
|
||||
if (auto* cf = dyncast<ConstantFloat>(value)) {
|
||||
out = cf->GetValue();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsZeroValue(Value* value) {
|
||||
std::int32_t i32 = 0;
|
||||
bool i1 = false;
|
||||
float f32 = 0.0f;
|
||||
return (TryGetInt32(value, i32) && i32 == 0) || (TryGetBool(value, i1) && !i1) ||
|
||||
(TryGetFloat(value, f32) && passutils::FloatBits(f32) == 0);
|
||||
}
|
||||
|
||||
bool IsOneValue(Value* value) {
|
||||
std::int32_t i32 = 0;
|
||||
bool i1 = false;
|
||||
float f32 = 0.0f;
|
||||
return (TryGetInt32(value, i32) && i32 == 1) || (TryGetBool(value, i1) && i1) ||
|
||||
(TryGetFloat(value, f32) &&
|
||||
passutils::FloatBits(f32) == passutils::FloatBits(1.0f));
|
||||
}
|
||||
|
||||
bool IsAllOnesInt(Value* value) {
|
||||
std::int32_t i32 = 0;
|
||||
return TryGetInt32(value, i32) && i32 == -1;
|
||||
}
|
||||
|
||||
std::int32_t WrapInt32(std::uint32_t value) {
|
||||
return static_cast<std::int32_t>(value);
|
||||
}
|
||||
|
||||
Value* FoldBinary(Context& ctx, BinaryInst* inst) {
|
||||
const auto opcode = inst->GetOpcode();
|
||||
auto* lhs = inst->GetLhs();
|
||||
auto* rhs = inst->GetRhs();
|
||||
|
||||
std::int32_t lhs_i32 = 0;
|
||||
std::int32_t rhs_i32 = 0;
|
||||
bool lhs_i1 = false;
|
||||
bool rhs_i1 = false;
|
||||
float lhs_f32 = 0.0f;
|
||||
float rhs_f32 = 0.0f;
|
||||
|
||||
const bool has_lhs_i32 = TryGetInt32(lhs, lhs_i32);
|
||||
const bool has_rhs_i32 = TryGetInt32(rhs, rhs_i32);
|
||||
const bool has_lhs_i1 = TryGetBool(lhs, lhs_i1);
|
||||
const bool has_rhs_i1 = TryGetBool(rhs, rhs_i1);
|
||||
const bool has_lhs_f32 = TryGetFloat(lhs, lhs_f32);
|
||||
const bool has_rhs_f32 = TryGetFloat(rhs, rhs_f32);
|
||||
|
||||
if (has_lhs_i32 && has_rhs_i32) {
|
||||
switch (opcode) {
|
||||
case Opcode::Add:
|
||||
return GetInt32Const(
|
||||
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) +
|
||||
static_cast<std::uint32_t>(rhs_i32)));
|
||||
case Opcode::Sub:
|
||||
return GetInt32Const(
|
||||
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) -
|
||||
static_cast<std::uint32_t>(rhs_i32)));
|
||||
case Opcode::Mul:
|
||||
return GetInt32Const(
|
||||
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) *
|
||||
static_cast<std::uint32_t>(rhs_i32)));
|
||||
case Opcode::Div:
|
||||
if (rhs_i32 == 0 ||
|
||||
(lhs_i32 == std::numeric_limits<std::int32_t>::min() && rhs_i32 == -1)) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetInt32Const(ctx, lhs_i32 / rhs_i32);
|
||||
case Opcode::Rem:
|
||||
if (rhs_i32 == 0 ||
|
||||
(lhs_i32 == std::numeric_limits<std::int32_t>::min() && rhs_i32 == -1)) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetInt32Const(ctx, lhs_i32 % rhs_i32);
|
||||
case Opcode::And:
|
||||
return GetInt32Const(ctx, lhs_i32 & rhs_i32);
|
||||
case Opcode::Or:
|
||||
return GetInt32Const(ctx, lhs_i32 | rhs_i32);
|
||||
case Opcode::Xor:
|
||||
return GetInt32Const(ctx, lhs_i32 ^ rhs_i32);
|
||||
case Opcode::Shl:
|
||||
if (rhs_i32 < 0 || rhs_i32 >= 32) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetInt32Const(ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32)
|
||||
<< rhs_i32));
|
||||
case Opcode::AShr:
|
||||
if (rhs_i32 < 0 || rhs_i32 >= 32) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetInt32Const(ctx, lhs_i32 >> rhs_i32);
|
||||
case Opcode::LShr:
|
||||
if (rhs_i32 < 0 || rhs_i32 >= 32) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetInt32Const(
|
||||
ctx,
|
||||
WrapInt32(static_cast<std::uint32_t>(lhs_i32) >> rhs_i32));
|
||||
case Opcode::ICmpEQ:
|
||||
return GetBoolConst(ctx, lhs_i32 == rhs_i32);
|
||||
case Opcode::ICmpNE:
|
||||
return GetBoolConst(ctx, lhs_i32 != rhs_i32);
|
||||
case Opcode::ICmpLT:
|
||||
return GetBoolConst(ctx, lhs_i32 < rhs_i32);
|
||||
case Opcode::ICmpGT:
|
||||
return GetBoolConst(ctx, lhs_i32 > rhs_i32);
|
||||
case Opcode::ICmpLE:
|
||||
return GetBoolConst(ctx, lhs_i32 <= rhs_i32);
|
||||
case Opcode::ICmpGE:
|
||||
return GetBoolConst(ctx, lhs_i32 >= rhs_i32);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_lhs_i1 && has_rhs_i1) {
|
||||
switch (opcode) {
|
||||
case Opcode::And:
|
||||
return GetBoolConst(ctx, lhs_i1 && rhs_i1);
|
||||
case Opcode::Or:
|
||||
return GetBoolConst(ctx, lhs_i1 || rhs_i1);
|
||||
case Opcode::Xor:
|
||||
return GetBoolConst(ctx, lhs_i1 != rhs_i1);
|
||||
case Opcode::ICmpEQ:
|
||||
return GetBoolConst(ctx, lhs_i1 == rhs_i1);
|
||||
case Opcode::ICmpNE:
|
||||
return GetBoolConst(ctx, lhs_i1 != rhs_i1);
|
||||
case Opcode::ICmpLT:
|
||||
return GetBoolConst(ctx, static_cast<int>(lhs_i1) < static_cast<int>(rhs_i1));
|
||||
case Opcode::ICmpGT:
|
||||
return GetBoolConst(ctx, static_cast<int>(lhs_i1) > static_cast<int>(rhs_i1));
|
||||
case Opcode::ICmpLE:
|
||||
return GetBoolConst(ctx, static_cast<int>(lhs_i1) <= static_cast<int>(rhs_i1));
|
||||
case Opcode::ICmpGE:
|
||||
return GetBoolConst(ctx, static_cast<int>(lhs_i1) >= static_cast<int>(rhs_i1));
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_lhs_f32 && has_rhs_f32) {
|
||||
switch (opcode) {
|
||||
case Opcode::FAdd:
|
||||
return GetFloatConst(lhs_f32 + rhs_f32);
|
||||
case Opcode::FSub:
|
||||
return GetFloatConst(lhs_f32 - rhs_f32);
|
||||
case Opcode::FMul:
|
||||
return GetFloatConst(lhs_f32 * rhs_f32);
|
||||
case Opcode::FDiv:
|
||||
return GetFloatConst(lhs_f32 / rhs_f32);
|
||||
case Opcode::FRem:
|
||||
return GetFloatConst(std::fmod(lhs_f32, rhs_f32));
|
||||
case Opcode::FCmpEQ:
|
||||
return GetBoolConst(
|
||||
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 == rhs_f32);
|
||||
case Opcode::FCmpNE:
|
||||
return GetBoolConst(
|
||||
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 != rhs_f32);
|
||||
case Opcode::FCmpLT:
|
||||
return GetBoolConst(
|
||||
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 < rhs_f32);
|
||||
case Opcode::FCmpGT:
|
||||
return GetBoolConst(
|
||||
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 > rhs_f32);
|
||||
case Opcode::FCmpLE:
|
||||
return GetBoolConst(
|
||||
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 <= rhs_f32);
|
||||
case Opcode::FCmpGE:
|
||||
return GetBoolConst(
|
||||
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 >= rhs_f32);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
switch (opcode) {
|
||||
case Opcode::Add:
|
||||
if (IsZeroValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
if (IsZeroValue(lhs)) {
|
||||
return rhs;
|
||||
}
|
||||
break;
|
||||
case Opcode::Sub:
|
||||
if (IsZeroValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
break;
|
||||
case Opcode::Mul:
|
||||
if (IsOneValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
if (IsOneValue(lhs)) {
|
||||
return rhs;
|
||||
}
|
||||
if (IsZeroValue(lhs) || IsZeroValue(rhs)) {
|
||||
return GetInt32Const(ctx, 0);
|
||||
}
|
||||
break;
|
||||
case Opcode::Div:
|
||||
if (IsOneValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
if (IsZeroValue(lhs) && !IsZeroValue(rhs)) {
|
||||
return GetInt32Const(ctx, 0);
|
||||
}
|
||||
break;
|
||||
case Opcode::Rem:
|
||||
if ((has_rhs_i32 && (rhs_i32 == 1 || rhs_i32 == -1)) ||
|
||||
(has_rhs_i1 && rhs_i1)) {
|
||||
return GetInt32Const(ctx, 0);
|
||||
}
|
||||
if (IsZeroValue(lhs) && !IsZeroValue(rhs)) {
|
||||
return GetInt32Const(ctx, 0);
|
||||
}
|
||||
break;
|
||||
case Opcode::And:
|
||||
if (IsZeroValue(lhs) || IsZeroValue(rhs)) {
|
||||
return inst->GetType()->IsInt1() ? GetBoolConst(ctx, false)
|
||||
: GetInt32Const(ctx, 0);
|
||||
}
|
||||
if (has_lhs_i1 && lhs_i1) {
|
||||
return rhs;
|
||||
}
|
||||
if (has_rhs_i1 && rhs_i1) {
|
||||
return lhs;
|
||||
}
|
||||
if (IsAllOnesInt(lhs)) {
|
||||
return rhs;
|
||||
}
|
||||
if (IsAllOnesInt(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
if (passutils::AreEquivalentValues(lhs, rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
break;
|
||||
case Opcode::Or:
|
||||
if (IsZeroValue(lhs)) {
|
||||
return rhs;
|
||||
}
|
||||
if (IsZeroValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
if (has_lhs_i1 && lhs_i1) {
|
||||
return GetBoolConst(ctx, true);
|
||||
}
|
||||
if (has_rhs_i1 && rhs_i1) {
|
||||
return GetBoolConst(ctx, true);
|
||||
}
|
||||
if (passutils::AreEquivalentValues(lhs, rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
break;
|
||||
case Opcode::Xor:
|
||||
if (IsZeroValue(lhs)) {
|
||||
return rhs;
|
||||
}
|
||||
if (IsZeroValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
if (passutils::AreEquivalentValues(lhs, rhs)) {
|
||||
return inst->GetType()->IsInt1() ? GetBoolConst(ctx, false)
|
||||
: GetInt32Const(ctx, 0);
|
||||
}
|
||||
break;
|
||||
case Opcode::Shl:
|
||||
case Opcode::AShr:
|
||||
case Opcode::LShr:
|
||||
if (IsZeroValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
if (IsZeroValue(lhs)) {
|
||||
return GetInt32Const(ctx, 0);
|
||||
}
|
||||
break;
|
||||
case Opcode::FAdd:
|
||||
if (IsZeroValue(lhs)) {
|
||||
return rhs;
|
||||
}
|
||||
if (IsZeroValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
break;
|
||||
case Opcode::FSub:
|
||||
if (IsZeroValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
break;
|
||||
case Opcode::FMul:
|
||||
if (IsOneValue(lhs)) {
|
||||
return rhs;
|
||||
}
|
||||
if (IsOneValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
break;
|
||||
case Opcode::FDiv:
|
||||
if (IsOneValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Value* FoldUnary(Context& ctx, UnaryInst* inst) {
|
||||
auto* operand = inst->GetOprd();
|
||||
std::int32_t i32 = 0;
|
||||
bool i1 = false;
|
||||
float f32 = 0.0f;
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Neg:
|
||||
if (TryGetInt32(operand, i32)) {
|
||||
return GetInt32Const(ctx, WrapInt32(0u - static_cast<std::uint32_t>(i32)));
|
||||
}
|
||||
break;
|
||||
case Opcode::Not:
|
||||
if (TryGetBool(operand, i1)) {
|
||||
return GetBoolConst(ctx, !i1);
|
||||
}
|
||||
if (TryGetInt32(operand, i32)) {
|
||||
return GetInt32Const(ctx, i32 ^ 1);
|
||||
}
|
||||
break;
|
||||
case Opcode::FNeg:
|
||||
if (TryGetFloat(operand, f32)) {
|
||||
return GetFloatConst(-f32);
|
||||
}
|
||||
break;
|
||||
case Opcode::FtoI:
|
||||
if (TryGetFloat(operand, f32)) {
|
||||
return GetInt32Const(ctx, static_cast<std::int32_t>(f32));
|
||||
}
|
||||
break;
|
||||
case Opcode::IToF:
|
||||
if (TryGetInt32(operand, i32)) {
|
||||
return GetFloatConst(static_cast<float>(i32));
|
||||
}
|
||||
if (TryGetBool(operand, i1)) {
|
||||
return GetFloatConst(i1 ? 1.0f : 0.0f);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Value* FoldZext(Context& ctx, ZextInst* inst) {
|
||||
auto* value = inst->GetValue();
|
||||
bool i1 = false;
|
||||
std::int32_t i32 = 0;
|
||||
if (inst->GetType()->IsInt1()) {
|
||||
if (TryGetBool(value, i1)) {
|
||||
return GetBoolConst(ctx, i1);
|
||||
}
|
||||
if (TryGetInt32(value, i32)) {
|
||||
return GetBoolConst(ctx, i32 != 0);
|
||||
}
|
||||
}
|
||||
if (inst->GetType()->IsInt32()) {
|
||||
if (TryGetBool(value, i1)) {
|
||||
return GetInt32Const(ctx, i1 ? 1 : 0);
|
||||
}
|
||||
if (TryGetInt32(value, i32)) {
|
||||
return GetInt32Const(ctx, i32);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool FoldFunction(Function& function, Context& ctx) {
|
||||
if (function.IsExternal()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
std::vector<Instruction*> to_remove;
|
||||
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
Value* replacement = nullptr;
|
||||
if (auto* binary = dyncast<BinaryInst>(inst)) {
|
||||
replacement = FoldBinary(ctx, binary);
|
||||
} else if (auto* unary = dyncast<UnaryInst>(inst)) {
|
||||
replacement = FoldUnary(ctx, unary);
|
||||
} else if (auto* zext = dyncast<ZextInst>(inst)) {
|
||||
replacement = FoldZext(ctx, zext);
|
||||
}
|
||||
|
||||
if (!replacement || replacement == inst) {
|
||||
continue;
|
||||
}
|
||||
inst->ReplaceAllUsesWith(replacement);
|
||||
to_remove.push_back(inst);
|
||||
changed = true;
|
||||
}
|
||||
|
||||
for (auto* inst : to_remove) {
|
||||
block_ptr->EraseInstruction(inst);
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunConstFold(Module& module) {
|
||||
bool changed = false;
|
||||
auto& ctx = module.GetContext();
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= FoldFunction(*function, ctx);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
@ -0,0 +1,196 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "MemoryUtils.h"
|
||||
#include "PassUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct ExprKey {
|
||||
Opcode opcode = Opcode::Add;
|
||||
std::uintptr_t result_type = 0;
|
||||
std::uintptr_t aux_type = 0;
|
||||
std::vector<std::uintptr_t> operands;
|
||||
|
||||
bool operator==(const ExprKey& rhs) const {
|
||||
return opcode == rhs.opcode && result_type == rhs.result_type &&
|
||||
aux_type == rhs.aux_type && operands == rhs.operands;
|
||||
}
|
||||
};
|
||||
|
||||
struct ExprKeyHash {
|
||||
std::size_t operator()(const ExprKey& key) const {
|
||||
std::size_t h = static_cast<std::size_t>(key.opcode);
|
||||
h ^= std::hash<std::uintptr_t>{}(key.result_type) + 0x9e3779b9 + (h << 6) +
|
||||
(h >> 2);
|
||||
h ^= std::hash<std::uintptr_t>{}(key.aux_type) + 0x9e3779b9 + (h << 6) +
|
||||
(h >> 2);
|
||||
for (auto operand : key.operands) {
|
||||
h ^= std::hash<std::uintptr_t>{}(operand) + 0x9e3779b9 + (h << 6) + (h >> 2);
|
||||
}
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
struct ScopedExpr {
|
||||
ExprKey key;
|
||||
Value* previous = nullptr;
|
||||
bool had_previous = false;
|
||||
};
|
||||
|
||||
bool IsSupportedGVNInstruction(Instruction* inst) {
|
||||
if (!inst || inst->IsVoid()) {
|
||||
return false;
|
||||
}
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::Div:
|
||||
case Opcode::Rem:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FSub:
|
||||
case Opcode::FMul:
|
||||
case Opcode::FDiv:
|
||||
case Opcode::FRem:
|
||||
case Opcode::And:
|
||||
case Opcode::Or:
|
||||
case Opcode::Xor:
|
||||
case Opcode::Shl:
|
||||
case Opcode::AShr:
|
||||
case Opcode::LShr:
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGE:
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
case Opcode::FCmpLT:
|
||||
case Opcode::FCmpGT:
|
||||
case Opcode::FCmpLE:
|
||||
case Opcode::FCmpGE:
|
||||
case Opcode::Neg:
|
||||
case Opcode::Not:
|
||||
case Opcode::FNeg:
|
||||
case Opcode::FtoI:
|
||||
case Opcode::IToF:
|
||||
case Opcode::GetElementPtr:
|
||||
case Opcode::Zext:
|
||||
return true;
|
||||
case Opcode::Call:
|
||||
return memutils::IsPureCall(dyncast<CallInst>(inst));
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
ExprKey BuildExprKey(Instruction* inst) {
|
||||
ExprKey key;
|
||||
key.opcode = inst->GetOpcode();
|
||||
key.result_type =
|
||||
reinterpret_cast<std::uintptr_t>(inst->GetType().get());
|
||||
if (auto* gep = dyncast<GetElementPtrInst>(inst)) {
|
||||
key.aux_type = reinterpret_cast<std::uintptr_t>(gep->GetSourceType().get());
|
||||
}
|
||||
key.operands.reserve(inst->GetNumOperands());
|
||||
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
|
||||
key.operands.push_back(
|
||||
reinterpret_cast<std::uintptr_t>(inst->GetOperand(i)));
|
||||
}
|
||||
if (inst->GetNumOperands() == 2 &&
|
||||
passutils::IsCommutativeOpcode(inst->GetOpcode()) &&
|
||||
key.operands[1] < key.operands[0]) {
|
||||
std::swap(key.operands[0], key.operands[1]);
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
bool RunGVNInDomSubtree(
|
||||
BasicBlock* block, const DominatorTree& dom_tree,
|
||||
std::unordered_map<ExprKey, Value*, ExprKeyHash>& available) {
|
||||
if (!block) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
std::vector<ScopedExpr> scope;
|
||||
std::vector<Instruction*> to_remove;
|
||||
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (!IsSupportedGVNInstruction(inst)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto key = BuildExprKey(inst);
|
||||
auto it = available.find(key);
|
||||
if (it != available.end()) {
|
||||
inst->ReplaceAllUsesWith(it->second);
|
||||
to_remove.push_back(inst);
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
ScopedExpr scoped{key, nullptr, false};
|
||||
auto existing = available.find(key);
|
||||
if (existing != available.end()) {
|
||||
scoped.previous = existing->second;
|
||||
scoped.had_previous = true;
|
||||
existing->second = inst;
|
||||
} else {
|
||||
available.emplace(key, inst);
|
||||
}
|
||||
scope.push_back(std::move(scoped));
|
||||
}
|
||||
|
||||
for (auto* inst : to_remove) {
|
||||
block->EraseInstruction(inst);
|
||||
}
|
||||
|
||||
for (auto* child : dom_tree.GetChildren(block)) {
|
||||
changed |= RunGVNInDomSubtree(child, dom_tree, available);
|
||||
}
|
||||
|
||||
for (auto it = scope.rbegin(); it != scope.rend(); ++it) {
|
||||
if (it->had_previous) {
|
||||
available[it->key] = it->previous;
|
||||
} else {
|
||||
available.erase(it->key);
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RunGVNOnFunction(Function& function) {
|
||||
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
DominatorTree dom_tree(function);
|
||||
std::unordered_map<ExprKey, Value*, ExprKeyHash> available;
|
||||
return RunGVNInDomSubtree(function.GetEntryBlock(), dom_tree, available);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunGVN(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunGVNOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,403 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct InlineCandidateInfo {
|
||||
bool valid = false;
|
||||
int cost = 0;
|
||||
bool has_nested_call = false;
|
||||
};
|
||||
|
||||
bool IsInlineableInstruction(const Instruction* inst) {
|
||||
if (!inst) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::Div:
|
||||
case Opcode::Rem:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FSub:
|
||||
case Opcode::FMul:
|
||||
case Opcode::FDiv:
|
||||
case Opcode::FRem:
|
||||
case Opcode::And:
|
||||
case Opcode::Or:
|
||||
case Opcode::Xor:
|
||||
case Opcode::Shl:
|
||||
case Opcode::AShr:
|
||||
case Opcode::LShr:
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGE:
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
case Opcode::FCmpLT:
|
||||
case Opcode::FCmpGT:
|
||||
case Opcode::FCmpLE:
|
||||
case Opcode::FCmpGE:
|
||||
case Opcode::Neg:
|
||||
case Opcode::Not:
|
||||
case Opcode::FNeg:
|
||||
case Opcode::FtoI:
|
||||
case Opcode::IToF:
|
||||
case Opcode::Load:
|
||||
case Opcode::Store:
|
||||
case Opcode::GetElementPtr:
|
||||
case Opcode::Zext:
|
||||
case Opcode::Memset:
|
||||
case Opcode::Call:
|
||||
case Opcode::Return:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
int EstimateInstructionCost(const Instruction* inst) {
|
||||
if (!inst) {
|
||||
return 0;
|
||||
}
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Return:
|
||||
return 0;
|
||||
case Opcode::Load:
|
||||
case Opcode::Store:
|
||||
case Opcode::Memset:
|
||||
return 3;
|
||||
case Opcode::Call:
|
||||
return 8;
|
||||
case Opcode::GetElementPtr:
|
||||
return 2;
|
||||
default:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
InlineCandidateInfo AnalyzeInlineCandidate(const Function& function) {
|
||||
InlineCandidateInfo info;
|
||||
if (function.IsExternal() || function.IsRecursive() || function.GetBlocks().size() != 1) {
|
||||
return info;
|
||||
}
|
||||
|
||||
const auto& block = function.GetBlocks().front();
|
||||
if (!block || block->GetInstructions().empty()) {
|
||||
return info;
|
||||
}
|
||||
|
||||
bool saw_return = false;
|
||||
for (std::size_t i = 0; i < block->GetInstructions().size(); ++i) {
|
||||
auto* inst = block->GetInstructions()[i].get();
|
||||
if (!IsInlineableInstruction(inst)) {
|
||||
return {};
|
||||
}
|
||||
if (dyncast<ReturnInst>(inst)) {
|
||||
if (i + 1 != block->GetInstructions().size()) {
|
||||
return {};
|
||||
}
|
||||
saw_return = true;
|
||||
continue;
|
||||
}
|
||||
if (dyncast<CallInst>(inst)) {
|
||||
info.has_nested_call = true;
|
||||
}
|
||||
info.cost += EstimateInstructionCost(inst);
|
||||
}
|
||||
|
||||
if (!saw_return) {
|
||||
return {};
|
||||
}
|
||||
|
||||
info.valid = true;
|
||||
return info;
|
||||
}
|
||||
|
||||
std::unordered_map<Function*, int> CountDirectCalls(Module& module) {
|
||||
std::unordered_map<Function*, int> counts;
|
||||
for (const auto& function_ptr : module.GetFunctions()) {
|
||||
if (!function_ptr) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& block_ptr : function_ptr->GetBlocks()) {
|
||||
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
|
||||
if (auto* call = dyncast<CallInst>(inst_ptr.get())) {
|
||||
if (auto* callee = call->GetCallee()) {
|
||||
++counts[callee];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return counts;
|
||||
}
|
||||
|
||||
bool ShouldInlineCallSite(const Function& caller, const CallInst& call,
|
||||
const InlineCandidateInfo& callee_info, int call_count) {
|
||||
auto* callee = call.GetCallee();
|
||||
if (!callee || callee == &caller || !callee_info.valid) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int budget = callee->CanDiscardUnusedCall() ? 40 : 24;
|
||||
if (call_count <= 1) {
|
||||
budget += 12;
|
||||
}
|
||||
if (callee_info.has_nested_call) {
|
||||
budget -= 8;
|
||||
}
|
||||
if (callee->MayWriteMemory()) {
|
||||
budget -= 4;
|
||||
}
|
||||
return callee_info.cost <= budget;
|
||||
}
|
||||
|
||||
Instruction* CloneInstructionAt(Function& function, Instruction* inst, BasicBlock* dest,
|
||||
std::size_t insert_index,
|
||||
std::unordered_map<Value*, Value*>& remap) {
|
||||
if (!inst || !dest) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto name = inst->IsVoid() ? std::string()
|
||||
: looputils::NextSyntheticName(function, "inline.");
|
||||
auto remap_operand = [&](Value* value) { return looputils::RemapValue(remap, value); };
|
||||
auto remember = [&](Instruction* clone) {
|
||||
if (clone && !inst->IsVoid()) {
|
||||
remap[inst] = clone;
|
||||
}
|
||||
return clone;
|
||||
};
|
||||
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::Div:
|
||||
case Opcode::Rem:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FSub:
|
||||
case Opcode::FMul:
|
||||
case Opcode::FDiv:
|
||||
case Opcode::FRem:
|
||||
case Opcode::And:
|
||||
case Opcode::Or:
|
||||
case Opcode::Xor:
|
||||
case Opcode::Shl:
|
||||
case Opcode::AShr:
|
||||
case Opcode::LShr:
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGE:
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
case Opcode::FCmpLT:
|
||||
case Opcode::FCmpGT:
|
||||
case Opcode::FCmpLE:
|
||||
case Opcode::FCmpGE: {
|
||||
auto* bin = static_cast<BinaryInst*>(inst);
|
||||
return remember(dest->Insert<BinaryInst>(insert_index, inst->GetOpcode(), inst->GetType(),
|
||||
remap_operand(bin->GetLhs()),
|
||||
remap_operand(bin->GetRhs()), nullptr, name));
|
||||
}
|
||||
case Opcode::Neg:
|
||||
case Opcode::Not:
|
||||
case Opcode::FNeg:
|
||||
case Opcode::FtoI:
|
||||
case Opcode::IToF: {
|
||||
auto* un = static_cast<UnaryInst*>(inst);
|
||||
return remember(dest->Insert<UnaryInst>(insert_index, inst->GetOpcode(), inst->GetType(),
|
||||
remap_operand(un->GetOprd()), nullptr, name));
|
||||
}
|
||||
case Opcode::Load: {
|
||||
auto* load = static_cast<LoadInst*>(inst);
|
||||
return remember(dest->Insert<LoadInst>(insert_index, inst->GetType(),
|
||||
remap_operand(load->GetPtr()), nullptr, name));
|
||||
}
|
||||
case Opcode::Store: {
|
||||
auto* store = static_cast<StoreInst*>(inst);
|
||||
return dest->Insert<StoreInst>(insert_index, remap_operand(store->GetValue()),
|
||||
remap_operand(store->GetPtr()), nullptr);
|
||||
}
|
||||
case Opcode::Memset: {
|
||||
auto* memset = static_cast<MemsetInst*>(inst);
|
||||
return dest->Insert<MemsetInst>(insert_index, remap_operand(memset->GetDest()),
|
||||
remap_operand(memset->GetValue()),
|
||||
remap_operand(memset->GetLength()),
|
||||
remap_operand(memset->GetIsVolatile()), nullptr);
|
||||
}
|
||||
case Opcode::GetElementPtr: {
|
||||
auto* gep = static_cast<GetElementPtrInst*>(inst);
|
||||
std::vector<Value*> indices;
|
||||
indices.reserve(gep->GetNumIndices());
|
||||
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
|
||||
indices.push_back(remap_operand(gep->GetIndex(i)));
|
||||
}
|
||||
return remember(dest->Insert<GetElementPtrInst>(
|
||||
insert_index, gep->GetSourceType(), remap_operand(gep->GetPointer()), indices, nullptr,
|
||||
name));
|
||||
}
|
||||
case Opcode::Zext: {
|
||||
auto* zext = static_cast<ZextInst*>(inst);
|
||||
return remember(dest->Insert<ZextInst>(insert_index, remap_operand(zext->GetValue()),
|
||||
inst->GetType(), nullptr, name));
|
||||
}
|
||||
case Opcode::Call: {
|
||||
auto* call = static_cast<CallInst*>(inst);
|
||||
std::vector<Value*> args;
|
||||
const auto original_args = call->GetArguments();
|
||||
args.reserve(original_args.size());
|
||||
for (auto* arg : original_args) {
|
||||
args.push_back(remap_operand(arg));
|
||||
}
|
||||
return remember(
|
||||
dest->Insert<CallInst>(insert_index, call->GetCallee(), args, nullptr, name));
|
||||
}
|
||||
case Opcode::Return:
|
||||
case Opcode::Alloca:
|
||||
case Opcode::Phi:
|
||||
case Opcode::Br:
|
||||
case Opcode::CondBr:
|
||||
case Opcode::Unreachable:
|
||||
break;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool InlineCallSite(Function& caller, CallInst* call) {
|
||||
if (!call) {
|
||||
return false;
|
||||
}
|
||||
auto* callee = call->GetCallee();
|
||||
if (!callee || callee->GetBlocks().size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& callee_args = callee->GetArguments();
|
||||
const auto call_args = call->GetArguments();
|
||||
if (callee_args.size() != call_args.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* block = call->GetParent();
|
||||
if (!block) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto& instructions = block->GetInstructions();
|
||||
auto call_it = std::find_if(instructions.begin(), instructions.end(),
|
||||
[&](const std::unique_ptr<Instruction>& current) {
|
||||
return current.get() == call;
|
||||
});
|
||||
if (call_it == instructions.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::size_t insert_index = static_cast<std::size_t>(call_it - instructions.begin());
|
||||
std::unordered_map<Value*, Value*> remap;
|
||||
for (std::size_t i = 0; i < call_args.size(); ++i) {
|
||||
remap[callee_args[i].get()] = call_args[i];
|
||||
}
|
||||
|
||||
Value* return_value = nullptr;
|
||||
for (const auto& inst_ptr : callee->GetBlocks().front()->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (auto* ret = dyncast<ReturnInst>(inst)) {
|
||||
if (ret->HasReturnValue()) {
|
||||
return_value = looputils::RemapValue(remap, ret->GetReturnValue());
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (!CloneInstructionAt(caller, inst, block, insert_index, remap)) {
|
||||
return false;
|
||||
}
|
||||
++insert_index;
|
||||
}
|
||||
|
||||
if (!call->GetType()->IsVoid()) {
|
||||
if (!return_value) {
|
||||
return false;
|
||||
}
|
||||
call->ReplaceAllUsesWith(return_value);
|
||||
}
|
||||
|
||||
block->EraseInstruction(call);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RunFunctionInliningOnFunction(
|
||||
Function& function,
|
||||
const std::unordered_map<Function*, InlineCandidateInfo>& callee_info,
|
||||
const std::unordered_map<Function*, int>& call_counts) {
|
||||
if (function.IsExternal()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
for (auto& block_ptr : function.GetBlocks()) {
|
||||
std::vector<CallInst*> calls;
|
||||
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
|
||||
if (auto* call = dyncast<CallInst>(inst_ptr.get())) {
|
||||
calls.push_back(call);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* call : calls) {
|
||||
auto* callee = call->GetCallee();
|
||||
if (!callee) {
|
||||
continue;
|
||||
}
|
||||
auto info_it = callee_info.find(callee);
|
||||
if (info_it == callee_info.end()) {
|
||||
continue;
|
||||
}
|
||||
const int call_count =
|
||||
call_counts.count(callee) != 0 ? call_counts.at(callee) : 0;
|
||||
if (!ShouldInlineCallSite(function, *call, info_it->second, call_count)) {
|
||||
continue;
|
||||
}
|
||||
changed |= InlineCallSite(function, call);
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunFunctionInlining(Module& module) {
|
||||
std::unordered_map<Function*, InlineCandidateInfo> callee_info;
|
||||
for (const auto& function_ptr : module.GetFunctions()) {
|
||||
if (function_ptr) {
|
||||
callee_info.emplace(function_ptr.get(), AnalyzeInlineCandidate(*function_ptr));
|
||||
}
|
||||
}
|
||||
|
||||
const auto call_counts = CountDirectCalls(module);
|
||||
bool changed = false;
|
||||
for (const auto& function_ptr : module.GetFunctions()) {
|
||||
if (function_ptr) {
|
||||
changed |= RunFunctionInliningOnFunction(*function_ptr, callee_info, call_counts);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,236 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "LoopMemoryUtils.h"
|
||||
#include "LoopPassUtils.h"
|
||||
#include "MemoryUtils.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct HoistedLoadKey {
|
||||
memutils::AddressKey address;
|
||||
std::uintptr_t type_id = 0;
|
||||
|
||||
bool operator==(const HoistedLoadKey& rhs) const {
|
||||
return type_id == rhs.type_id && address == rhs.address;
|
||||
}
|
||||
};
|
||||
|
||||
struct HoistedLoadKeyHash {
|
||||
std::size_t operator()(const HoistedLoadKey& key) const {
|
||||
std::size_t h = memutils::AddressKeyHash{}(key.address);
|
||||
h ^= std::hash<std::uintptr_t>{}(key.type_id) + 0x9e3779b9 + (h << 6) + (h >> 2);
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
bool IsHoistableInstruction(const Instruction* inst) {
|
||||
if (!inst || inst->IsTerminator() || inst->IsVoid()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FSub:
|
||||
case Opcode::FMul:
|
||||
case Opcode::FDiv:
|
||||
case Opcode::And:
|
||||
case Opcode::Or:
|
||||
case Opcode::Xor:
|
||||
case Opcode::Shl:
|
||||
case Opcode::AShr:
|
||||
case Opcode::LShr:
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGE:
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
case Opcode::FCmpLT:
|
||||
case Opcode::FCmpGT:
|
||||
case Opcode::FCmpLE:
|
||||
case Opcode::FCmpGE:
|
||||
case Opcode::Neg:
|
||||
case Opcode::Not:
|
||||
case Opcode::FNeg:
|
||||
case Opcode::FtoI:
|
||||
case Opcode::IToF:
|
||||
case Opcode::GetElementPtr:
|
||||
case Opcode::Zext:
|
||||
case Opcode::Load:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsLoopInvariantInstruction(
|
||||
const Loop& loop, Instruction* inst,
|
||||
const std::unordered_set<Instruction*>& invariant_insts,
|
||||
PhiInst* iv, int iv_stride,
|
||||
const std::vector<loopmem::MemoryAccessInfo>& accesses,
|
||||
const memutils::EscapeSummary& escapes) {
|
||||
if (!IsHoistableInstruction(inst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
|
||||
auto* operand = inst->GetOperand(i);
|
||||
auto* operand_inst = dyncast<Instruction>(operand);
|
||||
if (!operand_inst) {
|
||||
continue;
|
||||
}
|
||||
if (!loop.Contains(operand_inst->GetParent())) {
|
||||
continue;
|
||||
}
|
||||
if (invariant_insts.find(operand_inst) == invariant_insts.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto* load = dyncast<LoadInst>(inst)) {
|
||||
return loopmem::IsSafeInvariantLoadToHoist(loop, load, iv, iv_stride, accesses, &escapes);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HoistLoopInvariants(Function& function, const Loop& loop,
|
||||
BasicBlock* preheader) {
|
||||
if (!preheader) {
|
||||
return false;
|
||||
}
|
||||
|
||||
loopmem::SimpleInductionVar induction_var;
|
||||
PhiInst* iv = nullptr;
|
||||
int iv_stride = 1;
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
if (loopmem::MatchSimpleInductionVariable(loop, preheader, phi, induction_var)) {
|
||||
iv = induction_var.phi;
|
||||
iv_stride = induction_var.stride;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const auto escapes = memutils::AnalyzeEscapes(function);
|
||||
const auto accesses = loopmem::CollectMemoryAccesses(loop, iv, &escapes);
|
||||
|
||||
std::unordered_set<Instruction*> invariant_insts;
|
||||
std::vector<Instruction*> hoist_list;
|
||||
|
||||
bool progress = true;
|
||||
while (progress) {
|
||||
progress = false;
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
auto* block = block_ptr.get();
|
||||
if (!loop.Contains(block) || block == preheader) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (invariant_insts.find(inst) != invariant_insts.end()) {
|
||||
continue;
|
||||
}
|
||||
if (!IsLoopInvariantInstruction(loop, inst, invariant_insts, iv, iv_stride,
|
||||
accesses, escapes)) {
|
||||
continue;
|
||||
}
|
||||
invariant_insts.insert(inst);
|
||||
hoist_list.push_back(inst);
|
||||
progress = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
std::unordered_map<HoistedLoadKey, LoadInst*, HoistedLoadKeyHash> hoisted_loads;
|
||||
for (auto* inst : hoist_list) {
|
||||
if (auto* load = dyncast<LoadInst>(inst)) {
|
||||
auto ptr = loopmem::AnalyzePointer(load->GetPtr(), iv, loop,
|
||||
load->GetType()->GetSize(), &escapes);
|
||||
if (ptr.exact_key_valid) {
|
||||
HoistedLoadKey key{ptr.exact_key,
|
||||
reinterpret_cast<std::uintptr_t>(load->GetType().get())};
|
||||
auto it = hoisted_loads.find(key);
|
||||
if (it != hoisted_loads.end()) {
|
||||
load->ReplaceAllUsesWith(it->second);
|
||||
load->GetParent()->EraseInstruction(load);
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
auto* moved = dyncast<LoadInst>(
|
||||
looputils::MoveInstructionBeforeTerminator(load, preheader));
|
||||
if (moved) {
|
||||
hoisted_loads.emplace(std::move(key), moved);
|
||||
changed = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (looputils::MoveInstructionBeforeTerminator(inst, preheader)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RunLICMOnFunction(Function& function) {
|
||||
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
while (true) {
|
||||
DominatorTree dom_tree(function);
|
||||
LoopInfo loop_info(function, dom_tree);
|
||||
bool local_changed = false;
|
||||
|
||||
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
|
||||
auto* old_preheader = loop->preheader;
|
||||
auto* preheader = looputils::EnsurePreheader(function, *loop);
|
||||
bool loop_changed = preheader != old_preheader;
|
||||
loop_changed |= HoistLoopInvariants(function, *loop, preheader);
|
||||
if (!loop_changed) {
|
||||
continue;
|
||||
}
|
||||
changed = true;
|
||||
local_changed = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!local_changed) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLICM(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunLICMOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,319 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "MemoryUtils.h"
|
||||
#include "PassUtils.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct AvailableValue {
|
||||
Value* value = nullptr;
|
||||
|
||||
bool operator==(const AvailableValue& rhs) const {
|
||||
return passutils::AreEquivalentValues(value, rhs.value) || value == rhs.value;
|
||||
}
|
||||
};
|
||||
|
||||
using MemoryState =
|
||||
std::unordered_map<memutils::AddressKey, AvailableValue,
|
||||
memutils::AddressKeyHash>;
|
||||
|
||||
bool SameAvailableValue(const AvailableValue& lhs, const AvailableValue& rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
|
||||
bool SameMemoryState(const MemoryState& lhs, const MemoryState& rhs) {
|
||||
if (lhs.size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
for (const auto& [key, value] : lhs) {
|
||||
auto it = rhs.find(key);
|
||||
if (it == rhs.end() || !SameAvailableValue(value, it->second)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MemoryState MeetMemoryStates(const std::vector<MemoryState*>& predecessors) {
|
||||
if (predecessors.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
MemoryState in = *predecessors.front();
|
||||
for (auto it = in.begin(); it != in.end();) {
|
||||
bool keep = true;
|
||||
for (std::size_t i = 1; i < predecessors.size(); ++i) {
|
||||
auto pred_it = predecessors[i]->find(it->first);
|
||||
if (pred_it == predecessors[i]->end() ||
|
||||
!SameAvailableValue(it->second, pred_it->second)) {
|
||||
keep = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!keep) {
|
||||
it = in.erase(it);
|
||||
continue;
|
||||
}
|
||||
++it;
|
||||
}
|
||||
return in;
|
||||
}
|
||||
|
||||
void InvalidateAliasStates(MemoryState& state,
|
||||
const memutils::AddressKey& key) {
|
||||
for (auto it = state.begin(); it != state.end();) {
|
||||
if (memutils::MayAliasConservatively(it->first, key)) {
|
||||
it = state.erase(it);
|
||||
continue;
|
||||
}
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
void InvalidateStatesForCall(MemoryState& state, Function* callee) {
|
||||
for (auto it = state.begin(); it != state.end();) {
|
||||
if (memutils::CallMayWriteRoot(callee, it->first.kind)) {
|
||||
it = state.erase(it);
|
||||
continue;
|
||||
}
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
void SimulateInstruction(const memutils::EscapeSummary& escapes, Instruction* inst,
|
||||
MemoryState& state) {
|
||||
if (!inst) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto* load = dyncast<LoadInst>(inst)) {
|
||||
memutils::AddressKey key;
|
||||
if (!memutils::BuildExactAddressKey(load->GetPtr(), &escapes, key)) {
|
||||
state.clear();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto* store = dyncast<StoreInst>(inst)) {
|
||||
memutils::AddressKey key;
|
||||
if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) {
|
||||
state.clear();
|
||||
return;
|
||||
}
|
||||
InvalidateAliasStates(state, key);
|
||||
state[key] = {store->GetValue()};
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto* call = dyncast<CallInst>(inst)) {
|
||||
InvalidateStatesForCall(state, call->GetCallee());
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto* memset = dyncast<MemsetInst>(inst)) {
|
||||
memutils::AddressKey key;
|
||||
if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key)) {
|
||||
state.clear();
|
||||
return;
|
||||
}
|
||||
InvalidateAliasStates(state, key);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
MemoryState SimulateBlock(const memutils::EscapeSummary& escapes, BasicBlock* block,
|
||||
const MemoryState& in_state) {
|
||||
MemoryState state = in_state;
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
SimulateInstruction(escapes, inst_ptr.get(), state);
|
||||
}
|
||||
return state;
|
||||
}
|
||||
|
||||
bool MarkLoadObserved(
|
||||
const memutils::AddressKey& key,
|
||||
std::unordered_map<memutils::AddressKey, StoreInst*, memutils::AddressKeyHash>&
|
||||
pending_stores) {
|
||||
bool changed = false;
|
||||
for (auto it = pending_stores.begin(); it != pending_stores.end();) {
|
||||
if (memutils::MayAliasConservatively(it->first, key)) {
|
||||
it = pending_stores.erase(it);
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
++it;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
void InvalidatePendingForCall(
|
||||
std::unordered_map<memutils::AddressKey, StoreInst*, memutils::AddressKeyHash>&
|
||||
pending_stores,
|
||||
Function* callee) {
|
||||
for (auto it = pending_stores.begin(); it != pending_stores.end();) {
|
||||
if (memutils::CallMayReadRoot(callee, it->first.kind) ||
|
||||
memutils::CallMayWriteRoot(callee, it->first.kind)) {
|
||||
it = pending_stores.erase(it);
|
||||
continue;
|
||||
}
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
bool OptimizeBlock(
|
||||
const memutils::EscapeSummary& escapes, BasicBlock* block,
|
||||
const MemoryState& in_state) {
|
||||
bool changed = false;
|
||||
MemoryState state = in_state;
|
||||
std::unordered_map<memutils::AddressKey, StoreInst*, memutils::AddressKeyHash>
|
||||
pending_stores;
|
||||
std::vector<Instruction*> to_remove;
|
||||
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
|
||||
if (auto* load = dyncast<LoadInst>(inst)) {
|
||||
memutils::AddressKey key;
|
||||
if (!memutils::BuildExactAddressKey(load->GetPtr(), &escapes, key)) {
|
||||
state.clear();
|
||||
pending_stores.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
MarkLoadObserved(key, pending_stores);
|
||||
auto it = state.find(key);
|
||||
if (it != state.end() && it->second.value != load) {
|
||||
load->ReplaceAllUsesWith(it->second.value);
|
||||
to_remove.push_back(load);
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
// Keep block-local load reuse, but do not expose load results to cross-block
|
||||
// dataflow because the defining load itself may be removed later.
|
||||
state[key] = {load};
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto* store = dyncast<StoreInst>(inst)) {
|
||||
memutils::AddressKey key;
|
||||
if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) {
|
||||
state.clear();
|
||||
pending_stores.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto it = pending_stores.begin(); it != pending_stores.end();) {
|
||||
if (!memutils::MayAliasConservatively(it->first, key)) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
if (it->first == key) {
|
||||
to_remove.push_back(it->second);
|
||||
changed = true;
|
||||
}
|
||||
it = pending_stores.erase(it);
|
||||
}
|
||||
|
||||
pending_stores.emplace(key, store);
|
||||
InvalidateAliasStates(state, key);
|
||||
state[key] = {store->GetValue()};
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto* call = dyncast<CallInst>(inst)) {
|
||||
InvalidateStatesForCall(state, call->GetCallee());
|
||||
InvalidatePendingForCall(pending_stores, call->GetCallee());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto* memset = dyncast<MemsetInst>(inst)) {
|
||||
memutils::AddressKey key;
|
||||
if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key)) {
|
||||
state.clear();
|
||||
pending_stores.clear();
|
||||
continue;
|
||||
}
|
||||
InvalidateAliasStates(state, key);
|
||||
MarkLoadObserved(key, pending_stores);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* inst : to_remove) {
|
||||
if (inst->GetParent() == block) {
|
||||
block->EraseInstruction(inst);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RunLoadStoreElimOnFunction(Function& function) {
|
||||
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto escapes = memutils::AnalyzeEscapes(function);
|
||||
const auto reachable_blocks = passutils::CollectReachableBlocks(function);
|
||||
if (reachable_blocks.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_map<BasicBlock*, MemoryState> in_states;
|
||||
std::unordered_map<BasicBlock*, MemoryState> out_states;
|
||||
|
||||
bool dataflow_changed = true;
|
||||
while (dataflow_changed) {
|
||||
dataflow_changed = false;
|
||||
for (auto* block : reachable_blocks) {
|
||||
MemoryState in_state;
|
||||
if (block != function.GetEntryBlock()) {
|
||||
std::vector<MemoryState*> predecessors;
|
||||
for (auto* pred : block->GetPredecessors()) {
|
||||
auto it = out_states.find(pred);
|
||||
if (it != out_states.end()) {
|
||||
predecessors.push_back(&it->second);
|
||||
}
|
||||
}
|
||||
in_state = MeetMemoryStates(predecessors);
|
||||
}
|
||||
|
||||
auto out_state = SimulateBlock(escapes, block, in_state);
|
||||
auto in_it = in_states.find(block);
|
||||
if (in_it == in_states.end() || !SameMemoryState(in_it->second, in_state)) {
|
||||
in_states[block] = in_state;
|
||||
dataflow_changed = true;
|
||||
}
|
||||
auto out_it = out_states.find(block);
|
||||
if (out_it == out_states.end() || !SameMemoryState(out_it->second, out_state)) {
|
||||
out_states[block] = std::move(out_state);
|
||||
dataflow_changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
for (auto* block : reachable_blocks) {
|
||||
changed |= OptimizeBlock(escapes, block, in_states[block]);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoadStoreElim(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunLoadStoreElimOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,326 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "LoopMemoryUtils.h"
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct FissionLoopInfo {
|
||||
Loop* loop = nullptr;
|
||||
BasicBlock* preheader = nullptr;
|
||||
BasicBlock* header = nullptr;
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
CondBrInst* branch = nullptr;
|
||||
BinaryInst* compare = nullptr;
|
||||
Opcode compare_opcode = Opcode::ICmpLT;
|
||||
Value* bound = nullptr;
|
||||
loopmem::SimpleInductionVar induction_var;
|
||||
PhiInst* iv = nullptr;
|
||||
BinaryInst* step_inst = nullptr;
|
||||
};
|
||||
|
||||
bool HasSyntheticLoopTag(const std::string& name) {
|
||||
return name.find("unroll.") != std::string::npos ||
|
||||
name.find("fission.") != std::string::npos;
|
||||
}
|
||||
|
||||
bool IsAlreadyTransformedLoop(const Loop& loop, BasicBlock* body) {
|
||||
if (!loop.preheader || !loop.header || !body) {
|
||||
return true;
|
||||
}
|
||||
return HasSyntheticLoopTag(loop.preheader->GetName()) ||
|
||||
HasSyntheticLoopTag(loop.header->GetName()) ||
|
||||
HasSyntheticLoopTag(body->GetName());
|
||||
}
|
||||
|
||||
Opcode SwapCompareOpcode(Opcode opcode) {
|
||||
switch (opcode) {
|
||||
case Opcode::ICmpLT:
|
||||
return Opcode::ICmpGT;
|
||||
case Opcode::ICmpLE:
|
||||
return Opcode::ICmpGE;
|
||||
case Opcode::ICmpGT:
|
||||
return Opcode::ICmpLT;
|
||||
case Opcode::ICmpGE:
|
||||
return Opcode::ICmpLE;
|
||||
default:
|
||||
return opcode;
|
||||
}
|
||||
}
|
||||
|
||||
bool MatchFissionLoop(Loop& loop, FissionLoopInfo& info) {
|
||||
if (!loop.preheader || !loop.header || !loop.IsInnermost()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) {
|
||||
return false;
|
||||
}
|
||||
if (IsAlreadyTransformedLoop(loop, body)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<PhiInst*> phis;
|
||||
loopmem::SimpleInductionVar induction_var;
|
||||
bool found_iv = false;
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
phis.push_back(phi);
|
||||
if (!found_iv &&
|
||||
loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) {
|
||||
found_iv = true;
|
||||
}
|
||||
}
|
||||
if (!found_iv || phis.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* branch = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
|
||||
auto* compare = branch ? dyncast<BinaryInst>(branch->GetCondition()) : nullptr;
|
||||
if (!branch || branch->GetThenBlock() != body || !compare) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Opcode compare_opcode = compare->GetOpcode();
|
||||
Value* bound = nullptr;
|
||||
if (compare->GetLhs() == induction_var.phi &&
|
||||
looputils::IsLoopInvariantValue(loop, compare->GetRhs())) {
|
||||
bound = compare->GetRhs();
|
||||
} else if (compare->GetRhs() == induction_var.phi &&
|
||||
looputils::IsLoopInvariantValue(loop, compare->GetLhs())) {
|
||||
bound = compare->GetLhs();
|
||||
compare_opcode = SwapCompareOpcode(compare_opcode);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* step_inst = dyncast<BinaryInst>(induction_var.latch_value);
|
||||
if (!step_inst || step_inst->GetParent() != body) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto& inst_ptr : body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator() || inst == step_inst) {
|
||||
continue;
|
||||
}
|
||||
if (!looputils::IsCloneableInstruction(inst) || dyncast<CallInst>(inst) ||
|
||||
dyncast<MemsetInst>(inst) || dyncast<AllocaInst>(inst)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
info.loop = &loop;
|
||||
info.preheader = loop.preheader;
|
||||
info.header = loop.header;
|
||||
info.body = body;
|
||||
info.exit = exit;
|
||||
info.branch = branch;
|
||||
info.compare = compare;
|
||||
info.compare_opcode = compare_opcode;
|
||||
info.bound = bound;
|
||||
info.induction_var = induction_var;
|
||||
info.iv = induction_var.phi;
|
||||
info.step_inst = step_inst;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ContainsInterestingPayload(const std::vector<Instruction*>& group) {
|
||||
bool has_memory = false;
|
||||
for (auto* inst : group) {
|
||||
if (dyncast<LoadInst>(inst) || dyncast<StoreInst>(inst)) {
|
||||
has_memory = true;
|
||||
}
|
||||
}
|
||||
return has_memory;
|
||||
}
|
||||
|
||||
Value* RemapExitValue(Value* value, PhiInst* old_iv, PhiInst* new_iv) {
|
||||
if (value == old_iv) {
|
||||
return new_iv;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
bool BuildSecondLoop(Function& function, const FissionLoopInfo& info,
|
||||
const std::vector<Instruction*>& second_group) {
|
||||
auto* second_header =
|
||||
function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.header"));
|
||||
auto* second_body =
|
||||
function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.body"));
|
||||
|
||||
const int preheader_index = looputils::GetPhiIncomingIndex(info.iv, info.preheader);
|
||||
if (preheader_index < 0) {
|
||||
return false;
|
||||
}
|
||||
auto* second_iv = second_header->Append<PhiInst>(
|
||||
info.iv->GetType(), nullptr,
|
||||
looputils::NextSyntheticName(function, "fission.iv."));
|
||||
second_iv->AddIncoming(info.iv->GetIncomingValue(preheader_index), info.header);
|
||||
|
||||
auto* second_cmp = second_header->Append<BinaryInst>(
|
||||
info.compare_opcode, Type::GetBoolType(), second_iv, info.bound, nullptr,
|
||||
looputils::NextSyntheticName(function, "fission.cmp."));
|
||||
second_header->Append<CondBrInst>(second_cmp, second_body, info.exit, nullptr);
|
||||
second_header->AddPredecessor(info.header);
|
||||
second_header->AddSuccessor(second_body);
|
||||
second_header->AddSuccessor(info.exit);
|
||||
|
||||
std::unordered_map<Value*, Value*> remap;
|
||||
remap[info.iv] = second_iv;
|
||||
std::unordered_set<Instruction*> selected(second_group.begin(), second_group.end());
|
||||
selected.insert(info.step_inst);
|
||||
for (const auto& inst_ptr : info.body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator() || selected.find(inst) == selected.end()) {
|
||||
continue;
|
||||
}
|
||||
looputils::CloneInstruction(function, inst, second_body, remap, "fission.");
|
||||
}
|
||||
auto* cloned_step_value = looputils::RemapValue(remap, info.step_inst);
|
||||
if (!cloned_step_value) {
|
||||
return false;
|
||||
}
|
||||
second_iv->AddIncoming(cloned_step_value, second_body);
|
||||
second_body->Append<UncondBrInst>(second_header, nullptr);
|
||||
second_body->AddPredecessor(second_header);
|
||||
second_body->AddSuccessor(second_header);
|
||||
second_header->AddPredecessor(second_body);
|
||||
|
||||
if (!looputils::RedirectSuccessorEdge(info.header, info.exit, second_header)) {
|
||||
return false;
|
||||
}
|
||||
info.exit->RemovePredecessor(info.header);
|
||||
info.exit->AddPredecessor(second_header);
|
||||
|
||||
for (const auto& inst_ptr : info.exit->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
const int incoming = looputils::GetPhiIncomingIndex(phi, info.header);
|
||||
if (incoming < 0) {
|
||||
continue;
|
||||
}
|
||||
phi->SetOperand(static_cast<std::size_t>(2 * incoming),
|
||||
RemapExitValue(phi->GetIncomingValue(incoming), info.iv, second_iv));
|
||||
phi->SetOperand(static_cast<std::size_t>(2 * incoming + 1), second_header);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RunLoopFissionOnFunction(Function& function) {
|
||||
if (function.IsExternal() || !function.GetEntryBlock()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
while (true) {
|
||||
DominatorTree dom_tree(function);
|
||||
LoopInfo loop_info(function, dom_tree);
|
||||
bool local_changed = false;
|
||||
|
||||
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
|
||||
FissionLoopInfo info;
|
||||
if (!MatchFissionLoop(*loop, info)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto accesses = loopmem::CollectMemoryAccesses(*loop, info.iv);
|
||||
std::vector<Instruction*> payload;
|
||||
for (const auto& inst_ptr : info.body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator() || inst == info.step_inst) {
|
||||
continue;
|
||||
}
|
||||
payload.push_back(inst);
|
||||
}
|
||||
if (payload.size() < 2) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int chosen_cut = -1;
|
||||
std::vector<Instruction*> first_group;
|
||||
std::vector<Instruction*> second_group;
|
||||
for (std::size_t cut = 1; cut < payload.size(); ++cut) {
|
||||
std::vector<Instruction*> first(payload.begin(), payload.begin() + static_cast<long long>(cut));
|
||||
std::vector<Instruction*> second(payload.begin() + static_cast<long long>(cut),
|
||||
payload.end());
|
||||
if (!ContainsInterestingPayload(first) || !ContainsInterestingPayload(second)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unordered_set<Instruction*> first_set(first.begin(), first.end());
|
||||
std::unordered_set<Instruction*> second_set(second.begin(), second.end());
|
||||
if (loopmem::HasScalarDependenceAcrossCut(first, second_set) ||
|
||||
loopmem::HasMemoryDependenceAcrossCut(accesses, first_set, second_set,
|
||||
info.induction_var.stride)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
chosen_cut = static_cast<int>(cut);
|
||||
first_group = std::move(first);
|
||||
second_group = std::move(second);
|
||||
break;
|
||||
}
|
||||
|
||||
if (chosen_cut < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unordered_set<Instruction*> keep(first_group.begin(), first_group.end());
|
||||
keep.insert(info.step_inst);
|
||||
std::vector<Instruction*> to_remove;
|
||||
for (const auto& inst_ptr : info.body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator() || keep.find(inst) != keep.end()) {
|
||||
continue;
|
||||
}
|
||||
to_remove.push_back(inst);
|
||||
}
|
||||
if (!BuildSecondLoop(function, info, second_group)) {
|
||||
continue;
|
||||
}
|
||||
for (auto* inst : to_remove) {
|
||||
info.body->EraseInstruction(inst);
|
||||
}
|
||||
|
||||
changed = true;
|
||||
local_changed = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!local_changed) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoopFission(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunLoopFissionOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,506 @@
|
||||
#pragma once
|
||||
|
||||
#include "LoopPassUtils.h"
|
||||
#include "MemoryUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace ir::loopmem {
|
||||
|
||||
struct SimpleInductionVar {
|
||||
PhiInst* phi = nullptr;
|
||||
Value* start = nullptr;
|
||||
Value* latch_value = nullptr;
|
||||
BasicBlock* latch = nullptr;
|
||||
int stride = 0;
|
||||
};
|
||||
|
||||
inline bool MatchSimpleInductionVariable(const Loop& loop, BasicBlock* preheader,
|
||||
PhiInst* phi, SimpleInductionVar& info) {
|
||||
if (!phi || !preheader || phi->GetParent() != loop.header ||
|
||||
!phi->GetType()->IsInt32() || phi->GetNumIncomings() != 2 ||
|
||||
loop.latches.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* latch = loop.latches.front();
|
||||
const int preheader_index = looputils::GetPhiIncomingIndex(phi, preheader);
|
||||
const int latch_index = looputils::GetPhiIncomingIndex(phi, latch);
|
||||
if (preheader_index < 0 || latch_index < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* step_inst = dyncast<BinaryInst>(phi->GetIncomingValue(latch_index));
|
||||
if (!step_inst || step_inst->GetParent() != latch) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int stride = 0;
|
||||
if (step_inst->GetOpcode() == Opcode::Add) {
|
||||
if (step_inst->GetLhs() == phi) {
|
||||
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
|
||||
if (!delta) {
|
||||
return false;
|
||||
}
|
||||
stride = delta->GetValue();
|
||||
} else if (step_inst->GetRhs() == phi) {
|
||||
auto* delta = dyncast<ConstantInt>(step_inst->GetLhs());
|
||||
if (!delta) {
|
||||
return false;
|
||||
}
|
||||
stride = delta->GetValue();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else if (step_inst->GetOpcode() == Opcode::Sub) {
|
||||
if (step_inst->GetLhs() != phi) {
|
||||
return false;
|
||||
}
|
||||
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
|
||||
if (!delta) {
|
||||
return false;
|
||||
}
|
||||
stride = -delta->GetValue();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (stride == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
info.phi = phi;
|
||||
info.start = phi->GetIncomingValue(preheader_index);
|
||||
info.latch_value = phi->GetIncomingValue(latch_index);
|
||||
info.latch = latch;
|
||||
info.stride = stride;
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool GetCanonicalLoopBlocks(const Loop& loop, BasicBlock*& body,
|
||||
BasicBlock*& exit) {
|
||||
body = nullptr;
|
||||
exit = nullptr;
|
||||
if (!loop.header || loop.latches.size() != 1 || loop.block_list.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* condbr = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
|
||||
if (!condbr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* then_block = condbr->GetThenBlock();
|
||||
auto* else_block = condbr->GetElseBlock();
|
||||
const bool then_in_loop = loop.Contains(then_block);
|
||||
const bool else_in_loop = loop.Contains(else_block);
|
||||
if (then_in_loop == else_in_loop) {
|
||||
return false;
|
||||
}
|
||||
|
||||
body = then_in_loop ? then_block : else_block;
|
||||
exit = then_in_loop ? else_block : then_block;
|
||||
if (!body || !exit || body != loop.latches.front() ||
|
||||
body->GetSuccessors().size() != 1 || body->GetSuccessors().front() != loop.header) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
struct AffineExpr {
|
||||
bool valid = false;
|
||||
Value* var = nullptr;
|
||||
std::int64_t coeff = 0;
|
||||
std::int64_t constant = 0;
|
||||
};
|
||||
|
||||
inline AffineExpr MakeConst(std::int64_t value) {
|
||||
return {true, nullptr, 0, value};
|
||||
}
|
||||
|
||||
inline AffineExpr Scale(const AffineExpr& expr, std::int64_t factor) {
|
||||
if (!expr.valid) {
|
||||
return {};
|
||||
}
|
||||
return {true, expr.var, expr.coeff * factor, expr.constant * factor};
|
||||
}
|
||||
|
||||
inline AffineExpr Combine(const AffineExpr& lhs, const AffineExpr& rhs, int sign) {
|
||||
if (!lhs.valid || !rhs.valid) {
|
||||
return {};
|
||||
}
|
||||
if (lhs.var != nullptr && rhs.var != nullptr && lhs.var != rhs.var) {
|
||||
return {};
|
||||
}
|
||||
AffineExpr out;
|
||||
out.valid = true;
|
||||
out.var = lhs.var ? lhs.var : rhs.var;
|
||||
out.coeff = lhs.coeff + sign * rhs.coeff;
|
||||
out.constant = lhs.constant + sign * rhs.constant;
|
||||
return out;
|
||||
}
|
||||
|
||||
inline AffineExpr AnalyzeAffine(Value* value, PhiInst* iv, const Loop& loop) {
|
||||
if (!value) {
|
||||
return {};
|
||||
}
|
||||
if (auto* ci = dyncast<ConstantInt>(value)) {
|
||||
return MakeConst(ci->GetValue());
|
||||
}
|
||||
if (value == iv) {
|
||||
return {true, iv, 1, 0};
|
||||
}
|
||||
if (looputils::IsLoopInvariantValue(loop, value)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
if (auto* zext = dyncast<ZextInst>(value)) {
|
||||
return AnalyzeAffine(zext->GetValue(), iv, loop);
|
||||
}
|
||||
auto* inst = dyncast<Instruction>(value);
|
||||
if (!inst) {
|
||||
return {};
|
||||
}
|
||||
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
return Combine(AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetLhs(), iv, loop),
|
||||
AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetRhs(), iv, loop), +1);
|
||||
case Opcode::Sub:
|
||||
return Combine(AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetLhs(), iv, loop),
|
||||
AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetRhs(), iv, loop), -1);
|
||||
case Opcode::Mul: {
|
||||
auto* bin = static_cast<BinaryInst*>(inst);
|
||||
auto lhs = AnalyzeAffine(bin->GetLhs(), iv, loop);
|
||||
auto rhs = AnalyzeAffine(bin->GetRhs(), iv, loop);
|
||||
if (lhs.valid && lhs.var == nullptr && rhs.valid) {
|
||||
return Scale(rhs, lhs.constant);
|
||||
}
|
||||
if (rhs.valid && rhs.var == nullptr && lhs.valid) {
|
||||
return Scale(lhs, rhs.constant);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
case Opcode::Neg:
|
||||
return Scale(AnalyzeAffine(static_cast<UnaryInst*>(inst)->GetOprd(), iv, loop), -1);
|
||||
default:
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
struct PointerInfo {
|
||||
Value* base = nullptr;
|
||||
AffineExpr byte_offset;
|
||||
bool invariant_address = false;
|
||||
bool distinct_root = false;
|
||||
bool argument_root = false;
|
||||
bool readonly_root = false;
|
||||
bool exact_key_valid = false;
|
||||
memutils::PointerRootKind root_kind = memutils::PointerRootKind::Unknown;
|
||||
memutils::AddressKey exact_key;
|
||||
int access_size = 0;
|
||||
};
|
||||
|
||||
inline Value* StripPointerBase(Value* pointer) {
|
||||
auto* value = pointer;
|
||||
while (auto* gep = dyncast<GetElementPtrInst>(value)) {
|
||||
value = gep->GetPointer();
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
inline std::shared_ptr<Type> AdvanceGEPType(std::shared_ptr<Type> current) {
|
||||
if (current && current->IsArray()) {
|
||||
return current->GetElementType();
|
||||
}
|
||||
return current;
|
||||
}
|
||||
|
||||
inline PointerInfo AnalyzePointer(Value* pointer, PhiInst* iv, const Loop& loop,
|
||||
int access_size,
|
||||
const memutils::EscapeSummary* escapes = nullptr) {
|
||||
PointerInfo info;
|
||||
info.access_size = access_size;
|
||||
info.base = StripPointerBase(pointer);
|
||||
info.root_kind = memutils::ClassifyRoot(info.base, escapes);
|
||||
info.argument_root = info.root_kind == memutils::PointerRootKind::Param;
|
||||
info.readonly_root = info.root_kind == memutils::PointerRootKind::ReadonlyGlobal;
|
||||
info.distinct_root = info.root_kind == memutils::PointerRootKind::Local ||
|
||||
info.root_kind == memutils::PointerRootKind::Global ||
|
||||
info.root_kind == memutils::PointerRootKind::ReadonlyGlobal;
|
||||
info.exact_key_valid =
|
||||
escapes != nullptr && memutils::BuildExactAddressKey(pointer, escapes, info.exact_key);
|
||||
|
||||
info.invariant_address = looputils::IsLoopInvariantValue(loop, pointer);
|
||||
if (!dyncast<GetElementPtrInst>(pointer)) {
|
||||
info.byte_offset = MakeConst(0);
|
||||
return info;
|
||||
}
|
||||
|
||||
auto* gep = static_cast<GetElementPtrInst*>(pointer);
|
||||
std::shared_ptr<Type> current = gep->GetSourceType();
|
||||
AffineExpr total = MakeConst(0);
|
||||
bool all_indices_loop_invariant = looputils::IsLoopInvariantValue(loop, gep->GetPointer());
|
||||
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
|
||||
auto* index = gep->GetIndex(i);
|
||||
all_indices_loop_invariant &= looputils::IsLoopInvariantValue(loop, index);
|
||||
const std::int64_t stride = current ? current->GetSize() : 0;
|
||||
auto term = AnalyzeAffine(index, iv, loop);
|
||||
if (!term.valid) {
|
||||
total = {};
|
||||
} else if (total.valid) {
|
||||
total = Combine(total, Scale(term, stride), +1);
|
||||
}
|
||||
current = AdvanceGEPType(current);
|
||||
}
|
||||
info.invariant_address = all_indices_loop_invariant;
|
||||
info.byte_offset = total;
|
||||
return info;
|
||||
}
|
||||
|
||||
struct MemoryAccessInfo {
|
||||
Instruction* inst = nullptr;
|
||||
Value* pointer = nullptr;
|
||||
PointerInfo ptr;
|
||||
bool is_read = false;
|
||||
bool is_write = false;
|
||||
};
|
||||
|
||||
inline std::vector<MemoryAccessInfo> CollectMemoryAccesses(const Loop& loop,
|
||||
PhiInst* iv,
|
||||
const memutils::EscapeSummary* escapes =
|
||||
nullptr) {
|
||||
std::vector<MemoryAccessInfo> accesses;
|
||||
for (auto* block : loop.block_list) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (auto* load = dyncast<LoadInst>(inst)) {
|
||||
accesses.push_back(
|
||||
{inst, load->GetPtr(),
|
||||
AnalyzePointer(load->GetPtr(), iv, loop, load->GetType()->GetSize(), escapes),
|
||||
true,
|
||||
false});
|
||||
} else if (auto* store = dyncast<StoreInst>(inst)) {
|
||||
accesses.push_back({inst, store->GetPtr(),
|
||||
AnalyzePointer(store->GetPtr(), iv, loop,
|
||||
store->GetValue()->GetType()->GetSize(), escapes),
|
||||
false, true});
|
||||
} else if (auto* memset = dyncast<MemsetInst>(inst)) {
|
||||
accesses.push_back(
|
||||
{inst, memset->GetDest(),
|
||||
AnalyzePointer(memset->GetDest(), iv, loop, 1, escapes), false, true});
|
||||
}
|
||||
}
|
||||
}
|
||||
return accesses;
|
||||
}
|
||||
|
||||
inline bool SameAffineAddress(const PointerInfo& lhs, const PointerInfo& rhs) {
|
||||
return lhs.base == rhs.base && lhs.byte_offset.valid && rhs.byte_offset.valid &&
|
||||
lhs.byte_offset.var == rhs.byte_offset.var &&
|
||||
lhs.byte_offset.coeff == rhs.byte_offset.coeff &&
|
||||
lhs.byte_offset.constant == rhs.byte_offset.constant;
|
||||
}
|
||||
|
||||
inline bool MayAliasSameIteration(const PointerInfo& lhs, const PointerInfo& rhs) {
|
||||
if (lhs.exact_key_valid && rhs.exact_key_valid) {
|
||||
return memutils::MayAliasConservatively(lhs.exact_key, rhs.exact_key);
|
||||
}
|
||||
if (!lhs.base || !rhs.base) {
|
||||
return true;
|
||||
}
|
||||
if (lhs.base != rhs.base) {
|
||||
if (lhs.distinct_root && rhs.distinct_root && !lhs.argument_root && !rhs.argument_root) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (!lhs.byte_offset.valid || !rhs.byte_offset.valid) {
|
||||
return true;
|
||||
}
|
||||
if (lhs.byte_offset.var != rhs.byte_offset.var) {
|
||||
return true;
|
||||
}
|
||||
if (lhs.byte_offset.coeff != rhs.byte_offset.coeff) {
|
||||
return true;
|
||||
}
|
||||
const auto diff = std::llabs(lhs.byte_offset.constant - rhs.byte_offset.constant);
|
||||
const auto overlap = std::min(lhs.access_size, rhs.access_size);
|
||||
return diff < overlap;
|
||||
}
|
||||
|
||||
inline bool HasCrossIterationDependence(const PointerInfo& lhs, const PointerInfo& rhs,
|
||||
int iv_stride) {
|
||||
if (lhs.exact_key_valid && rhs.exact_key_valid &&
|
||||
!memutils::MayAliasConservatively(lhs.exact_key, rhs.exact_key)) {
|
||||
return false;
|
||||
}
|
||||
if (!lhs.base || !rhs.base) {
|
||||
return true;
|
||||
}
|
||||
if (lhs.base != rhs.base) {
|
||||
if (lhs.distinct_root && rhs.distinct_root && !lhs.argument_root && !rhs.argument_root) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (!lhs.byte_offset.valid || !rhs.byte_offset.valid) {
|
||||
return true;
|
||||
}
|
||||
if (lhs.byte_offset.var != rhs.byte_offset.var) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto lhs_step = lhs.byte_offset.coeff * iv_stride;
|
||||
const auto rhs_step = rhs.byte_offset.coeff * iv_stride;
|
||||
if (lhs_step == 0 && rhs_step == 0) {
|
||||
return MayAliasSameIteration(lhs, rhs);
|
||||
}
|
||||
if (lhs_step == rhs_step && lhs_step != 0) {
|
||||
const auto diff = rhs.byte_offset.constant - lhs.byte_offset.constant;
|
||||
return diff != 0 && diff % std::llabs(lhs_step) == 0;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool CallMayWritePointer(Function* callee, const PointerInfo& ptr) {
|
||||
if (ptr.readonly_root) {
|
||||
return false;
|
||||
}
|
||||
return memutils::CallMayWriteRoot(callee, ptr.root_kind);
|
||||
}
|
||||
|
||||
inline bool IsSafeInvariantLoadToHoist(const Loop& loop, LoadInst* load, PhiInst* iv,
|
||||
int iv_stride,
|
||||
const std::vector<MemoryAccessInfo>& accesses,
|
||||
const memutils::EscapeSummary* escapes = nullptr) {
|
||||
if (!load) {
|
||||
return false;
|
||||
}
|
||||
auto ptr = AnalyzePointer(load->GetPtr(), iv, loop, load->GetType()->GetSize(), escapes);
|
||||
if (!ptr.invariant_address) {
|
||||
return false;
|
||||
}
|
||||
if (ptr.readonly_root) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (auto* block : loop.block_list) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst == load) {
|
||||
continue;
|
||||
}
|
||||
if (auto* call = dyncast<CallInst>(inst)) {
|
||||
if (CallMayWritePointer(call->GetCallee(), ptr)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& access : accesses) {
|
||||
if (access.inst == load || !access.is_write) {
|
||||
continue;
|
||||
}
|
||||
if (MayAliasSameIteration(ptr, access.ptr)) {
|
||||
return false;
|
||||
}
|
||||
if (HasCrossIterationDependence(ptr, access.ptr, iv_stride)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool HasScalarDependenceAcrossCut(const std::vector<Instruction*>& first_group,
|
||||
const std::unordered_set<Instruction*>& second_set) {
|
||||
for (auto* inst : first_group) {
|
||||
if (!inst || inst->IsVoid()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& use : inst->GetUses()) {
|
||||
auto* user = dyncast<Instruction>(use.GetUser());
|
||||
if (user && second_set.find(user) != second_set.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool HasMemoryDependenceAcrossCut(const std::vector<MemoryAccessInfo>& accesses,
|
||||
const std::unordered_set<Instruction*>& first_set,
|
||||
const std::unordered_set<Instruction*>& second_set,
|
||||
int iv_stride) {
|
||||
for (const auto& lhs : accesses) {
|
||||
if (first_set.find(lhs.inst) == first_set.end()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& rhs : accesses) {
|
||||
if (second_set.find(rhs.inst) == second_set.end()) {
|
||||
continue;
|
||||
}
|
||||
if (!lhs.is_write && !rhs.is_write) {
|
||||
continue;
|
||||
}
|
||||
if (MayAliasSameIteration(lhs.ptr, rhs.ptr) ||
|
||||
HasCrossIterationDependence(lhs.ptr, rhs.ptr, iv_stride)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool IsLoopParallelizable(const Loop& loop, PhiInst* iv, int iv_stride,
|
||||
const std::vector<MemoryAccessInfo>& accesses) {
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
if (phi != iv) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* block : loop.block_list) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (auto* call = dyncast<CallInst>(inst)) {
|
||||
auto* callee = call->GetCallee();
|
||||
if (callee == nullptr || callee->HasObservableSideEffects() || callee->IsRecursive()) {
|
||||
return false;
|
||||
}
|
||||
for (const auto& access : accesses) {
|
||||
if (CallMayWritePointer(callee, access.ptr)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (dyncast<MemsetInst>(inst)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (std::size_t i = 0; i < accesses.size(); ++i) {
|
||||
for (std::size_t j = i + 1; j < accesses.size(); ++j) {
|
||||
if (!accesses[i].is_write && !accesses[j].is_write) {
|
||||
continue;
|
||||
}
|
||||
if (HasCrossIterationDependence(accesses[i].ptr, accesses[j].ptr, iv_stride)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace ir::loopmem
|
||||
@ -0,0 +1,440 @@
|
||||
#pragma once
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir::looputils {
|
||||
|
||||
inline Instruction* GetTerminator(BasicBlock* block) {
|
||||
if (!block || block->GetInstructions().empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto* inst = block->GetInstructions().back().get();
|
||||
return inst && inst->IsTerminator() ? inst : nullptr;
|
||||
}
|
||||
|
||||
inline std::size_t GetTerminatorIndex(BasicBlock* block) {
|
||||
if (!block) {
|
||||
return 0;
|
||||
}
|
||||
const auto size = block->GetInstructions().size();
|
||||
if (!block->HasTerminator()) {
|
||||
return size;
|
||||
}
|
||||
return size == 0 ? 0 : size - 1;
|
||||
}
|
||||
|
||||
inline std::size_t GetFirstNonPhiIndex(BasicBlock* block) {
|
||||
if (!block) {
|
||||
return 0;
|
||||
}
|
||||
std::size_t index = 0;
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
if (!dyncast<PhiInst>(inst_ptr.get())) {
|
||||
break;
|
||||
}
|
||||
++index;
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
inline std::string NextSyntheticName(Function& function, const std::string& prefix) {
|
||||
static std::unordered_map<Function*, int> counters;
|
||||
const int id = ++counters[&function];
|
||||
return "%" + prefix + std::to_string(id);
|
||||
}
|
||||
|
||||
inline std::string NextSyntheticBlockName(Function& function,
|
||||
const std::string& prefix) {
|
||||
static std::unordered_map<Function*, int> counters;
|
||||
const int id = ++counters[&function];
|
||||
return prefix + "." + std::to_string(id);
|
||||
}
|
||||
|
||||
inline ConstantInt* ConstInt(int value) {
|
||||
return new ConstantInt(Type::GetInt32Type(), value);
|
||||
}
|
||||
|
||||
inline int GetPhiIncomingIndex(PhiInst* phi, BasicBlock* block) {
|
||||
if (!phi || !block) {
|
||||
return -1;
|
||||
}
|
||||
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
|
||||
if (phi->GetIncomingBlock(i) == block) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
inline bool ReplacePhiIncoming(PhiInst* phi, BasicBlock* old_block,
|
||||
Value* new_value, BasicBlock* new_block) {
|
||||
if (!phi || !old_block || !new_value || !new_block) {
|
||||
return false;
|
||||
}
|
||||
const int index = GetPhiIncomingIndex(phi, old_block);
|
||||
if (index < 0) {
|
||||
return false;
|
||||
}
|
||||
phi->SetOperand(static_cast<std::size_t>(2 * index), new_value);
|
||||
phi->SetOperand(static_cast<std::size_t>(2 * index + 1), new_block);
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool RedirectSuccessorEdge(BasicBlock* pred, BasicBlock* old_succ,
|
||||
BasicBlock* new_succ) {
|
||||
auto* terminator = GetTerminator(pred);
|
||||
if (!terminator || !old_succ || !new_succ) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (auto* br = dyncast<UncondBrInst>(terminator)) {
|
||||
if (br->GetDest() != old_succ) {
|
||||
return false;
|
||||
}
|
||||
br->SetOperand(0, new_succ);
|
||||
} else if (auto* condbr = dyncast<CondBrInst>(terminator)) {
|
||||
bool changed = false;
|
||||
if (condbr->GetThenBlock() == old_succ) {
|
||||
condbr->SetOperand(1, new_succ);
|
||||
changed = true;
|
||||
}
|
||||
if (condbr->GetElseBlock() == old_succ) {
|
||||
condbr->SetOperand(2, new_succ);
|
||||
changed = true;
|
||||
}
|
||||
if (!changed) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
pred->RemoveSuccessor(old_succ);
|
||||
pred->AddSuccessor(new_succ);
|
||||
return true;
|
||||
}
|
||||
|
||||
inline Instruction* MoveInstructionBeforeTerminator(Instruction* inst,
|
||||
BasicBlock* dest) {
|
||||
if (!inst || !dest) {
|
||||
return nullptr;
|
||||
}
|
||||
auto* src = inst->GetParent();
|
||||
if (!src || src == dest) {
|
||||
return inst;
|
||||
}
|
||||
|
||||
auto& src_insts = src->GetInstructions();
|
||||
auto src_it = std::find_if(src_insts.begin(), src_insts.end(),
|
||||
[&](const std::unique_ptr<Instruction>& current) {
|
||||
return current.get() == inst;
|
||||
});
|
||||
if (src_it == src_insts.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto moved = std::move(*src_it);
|
||||
src_insts.erase(src_it);
|
||||
moved->SetParent(dest);
|
||||
|
||||
auto& dest_insts = dest->GetInstructions();
|
||||
auto insert_it = dest_insts.begin() +
|
||||
static_cast<long long>(GetTerminatorIndex(dest));
|
||||
auto* ptr = moved.get();
|
||||
dest_insts.insert(insert_it, std::move(moved));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
inline bool IsLoopInvariantValue(const Loop& loop, Value* value) {
|
||||
auto* inst = dyncast<Instruction>(value);
|
||||
return inst == nullptr || !loop.Contains(inst->GetParent());
|
||||
}
|
||||
|
||||
inline Value* RemapValue(const std::unordered_map<Value*, Value*>& remap,
|
||||
Value* value) {
|
||||
auto it = remap.find(value);
|
||||
return it == remap.end() ? value : it->second;
|
||||
}
|
||||
|
||||
inline bool IsCloneableInstruction(const Instruction* inst) {
|
||||
if (!inst || inst->IsTerminator() || inst->GetOpcode() == Opcode::Phi) {
|
||||
return false;
|
||||
}
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::Div:
|
||||
case Opcode::Rem:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FSub:
|
||||
case Opcode::FMul:
|
||||
case Opcode::FDiv:
|
||||
case Opcode::FRem:
|
||||
case Opcode::And:
|
||||
case Opcode::Or:
|
||||
case Opcode::Xor:
|
||||
case Opcode::Shl:
|
||||
case Opcode::AShr:
|
||||
case Opcode::LShr:
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGE:
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
case Opcode::FCmpLT:
|
||||
case Opcode::FCmpGT:
|
||||
case Opcode::FCmpLE:
|
||||
case Opcode::FCmpGE:
|
||||
case Opcode::Neg:
|
||||
case Opcode::Not:
|
||||
case Opcode::FNeg:
|
||||
case Opcode::FtoI:
|
||||
case Opcode::IToF:
|
||||
case Opcode::Alloca:
|
||||
case Opcode::Load:
|
||||
case Opcode::Store:
|
||||
case Opcode::Memset:
|
||||
case Opcode::GetElementPtr:
|
||||
case Opcode::Zext:
|
||||
case Opcode::Call:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline Instruction* CloneInstruction(Function& function, Instruction* inst,
|
||||
BasicBlock* dest,
|
||||
std::unordered_map<Value*, Value*>& remap,
|
||||
const std::string& prefix) {
|
||||
if (!inst || !dest || !IsCloneableInstruction(inst)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto insert_index = GetTerminatorIndex(dest);
|
||||
const auto name = inst->IsVoid() ? std::string()
|
||||
: NextSyntheticName(function, prefix);
|
||||
|
||||
auto remap_operand = [&](Value* value) { return RemapValue(remap, value); };
|
||||
auto remember = [&](Instruction* clone) {
|
||||
if (clone && !inst->IsVoid()) {
|
||||
remap[inst] = clone;
|
||||
}
|
||||
return clone;
|
||||
};
|
||||
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::Div:
|
||||
case Opcode::Rem:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FSub:
|
||||
case Opcode::FMul:
|
||||
case Opcode::FDiv:
|
||||
case Opcode::FRem:
|
||||
case Opcode::And:
|
||||
case Opcode::Or:
|
||||
case Opcode::Xor:
|
||||
case Opcode::Shl:
|
||||
case Opcode::AShr:
|
||||
case Opcode::LShr:
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGE:
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
case Opcode::FCmpLT:
|
||||
case Opcode::FCmpGT:
|
||||
case Opcode::FCmpLE:
|
||||
case Opcode::FCmpGE: {
|
||||
auto* bin = static_cast<BinaryInst*>(inst);
|
||||
return remember(dest->Insert<BinaryInst>(
|
||||
insert_index, inst->GetOpcode(), inst->GetType(),
|
||||
remap_operand(bin->GetLhs()), remap_operand(bin->GetRhs()), nullptr,
|
||||
name));
|
||||
}
|
||||
case Opcode::Neg:
|
||||
case Opcode::Not:
|
||||
case Opcode::FNeg:
|
||||
case Opcode::FtoI:
|
||||
case Opcode::IToF: {
|
||||
auto* un = static_cast<UnaryInst*>(inst);
|
||||
return remember(dest->Insert<UnaryInst>(insert_index, inst->GetOpcode(),
|
||||
inst->GetType(),
|
||||
remap_operand(un->GetOprd()),
|
||||
nullptr, name));
|
||||
}
|
||||
case Opcode::Alloca: {
|
||||
auto* alloca = static_cast<AllocaInst*>(inst);
|
||||
return remember(dest->Insert<AllocaInst>(insert_index,
|
||||
alloca->GetAllocatedType(),
|
||||
nullptr, name));
|
||||
}
|
||||
case Opcode::Load: {
|
||||
auto* load = static_cast<LoadInst*>(inst);
|
||||
return remember(dest->Insert<LoadInst>(insert_index, inst->GetType(),
|
||||
remap_operand(load->GetPtr()),
|
||||
nullptr, name));
|
||||
}
|
||||
case Opcode::Store: {
|
||||
auto* store = static_cast<StoreInst*>(inst);
|
||||
return dest->Insert<StoreInst>(insert_index,
|
||||
remap_operand(store->GetValue()),
|
||||
remap_operand(store->GetPtr()), nullptr);
|
||||
}
|
||||
case Opcode::Memset: {
|
||||
auto* memset = static_cast<MemsetInst*>(inst);
|
||||
return dest->Insert<MemsetInst>(insert_index,
|
||||
remap_operand(memset->GetDest()),
|
||||
remap_operand(memset->GetValue()),
|
||||
remap_operand(memset->GetLength()),
|
||||
remap_operand(memset->GetIsVolatile()),
|
||||
nullptr);
|
||||
}
|
||||
case Opcode::GetElementPtr: {
|
||||
auto* gep = static_cast<GetElementPtrInst*>(inst);
|
||||
std::vector<Value*> indices;
|
||||
indices.reserve(gep->GetNumIndices());
|
||||
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
|
||||
indices.push_back(remap_operand(gep->GetIndex(i)));
|
||||
}
|
||||
return remember(dest->Insert<GetElementPtrInst>(
|
||||
insert_index, gep->GetSourceType(), remap_operand(gep->GetPointer()),
|
||||
indices, nullptr, name));
|
||||
}
|
||||
case Opcode::Zext: {
|
||||
auto* zext = static_cast<ZextInst*>(inst);
|
||||
return remember(dest->Insert<ZextInst>(insert_index,
|
||||
remap_operand(zext->GetValue()),
|
||||
inst->GetType(), nullptr, name));
|
||||
}
|
||||
case Opcode::Call: {
|
||||
auto* call = static_cast<CallInst*>(inst);
|
||||
std::vector<Value*> args;
|
||||
const auto original_args = call->GetArguments();
|
||||
args.reserve(original_args.size());
|
||||
for (auto* arg : original_args) {
|
||||
args.push_back(remap_operand(arg));
|
||||
}
|
||||
return remember(dest->Insert<CallInst>(insert_index, call->GetCallee(),
|
||||
args, nullptr, name));
|
||||
}
|
||||
case Opcode::Phi:
|
||||
case Opcode::Br:
|
||||
case Opcode::CondBr:
|
||||
case Opcode::Return:
|
||||
case Opcode::Unreachable:
|
||||
break;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline BasicBlock* EnsurePreheader(Function& function, Loop& loop) {
|
||||
if (loop.preheader) {
|
||||
return loop.preheader;
|
||||
}
|
||||
|
||||
auto* header = loop.header;
|
||||
if (!header) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<BasicBlock*> outside_preds;
|
||||
for (auto* pred : header->GetPredecessors()) {
|
||||
if (!loop.Contains(pred)) {
|
||||
outside_preds.push_back(pred);
|
||||
}
|
||||
}
|
||||
if (outside_preds.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (outside_preds.size() == 1 &&
|
||||
outside_preds.front()->GetSuccessors().size() == 1) {
|
||||
loop.preheader = outside_preds.front();
|
||||
return loop.preheader;
|
||||
}
|
||||
|
||||
auto* preheader = function.CreateBlock(
|
||||
NextSyntheticBlockName(function, header->GetName() + ".preheader"));
|
||||
|
||||
std::size_t phi_insert_index = 0;
|
||||
for (const auto& inst_ptr : header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
|
||||
std::vector<int> outside_incomings;
|
||||
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
|
||||
if (!loop.Contains(phi->GetIncomingBlock(i))) {
|
||||
outside_incomings.push_back(i);
|
||||
}
|
||||
}
|
||||
if (outside_incomings.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Value* merged_value = nullptr;
|
||||
if (outside_incomings.size() == 1) {
|
||||
merged_value = phi->GetIncomingValue(outside_incomings.front());
|
||||
} else {
|
||||
auto new_phi = std::make_unique<PhiInst>(
|
||||
phi->GetType(), nullptr,
|
||||
NextSyntheticName(function, "preheader.phi."));
|
||||
auto* new_phi_ptr = new_phi.get();
|
||||
new_phi_ptr->SetParent(preheader);
|
||||
auto& preheader_insts = preheader->GetInstructions();
|
||||
preheader_insts.insert(preheader_insts.begin() +
|
||||
static_cast<long long>(phi_insert_index),
|
||||
std::move(new_phi));
|
||||
++phi_insert_index;
|
||||
|
||||
for (int incoming_index : outside_incomings) {
|
||||
new_phi_ptr->AddIncoming(phi->GetIncomingValue(incoming_index),
|
||||
phi->GetIncomingBlock(incoming_index));
|
||||
}
|
||||
merged_value = new_phi_ptr;
|
||||
}
|
||||
|
||||
for (auto it = outside_incomings.rbegin(); it != outside_incomings.rend();
|
||||
++it) {
|
||||
phi->RemoveOperand(static_cast<std::size_t>(2 * *it + 1));
|
||||
phi->RemoveOperand(static_cast<std::size_t>(2 * *it));
|
||||
}
|
||||
phi->AddIncoming(merged_value, preheader);
|
||||
}
|
||||
|
||||
preheader->Append<UncondBrInst>(header, nullptr);
|
||||
preheader->AddSuccessor(header);
|
||||
header->AddPredecessor(preheader);
|
||||
|
||||
for (auto* pred : outside_preds) {
|
||||
if (RedirectSuccessorEdge(pred, header, preheader)) {
|
||||
preheader->AddPredecessor(pred);
|
||||
header->RemovePredecessor(pred);
|
||||
}
|
||||
}
|
||||
|
||||
loop.preheader = preheader;
|
||||
return preheader;
|
||||
}
|
||||
|
||||
} // namespace ir::looputils
|
||||
@ -0,0 +1,295 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct InductionVarInfo {
|
||||
PhiInst* phi = nullptr;
|
||||
Value* start = nullptr;
|
||||
BasicBlock* latch = nullptr;
|
||||
int stride = 0;
|
||||
};
|
||||
|
||||
Value* BuildMulValue(Function& function, BasicBlock* block, Value* lhs, Value* rhs,
|
||||
const std::string& prefix) {
|
||||
if (auto* lhs_const = dyncast<ConstantInt>(lhs)) {
|
||||
if (lhs_const->GetValue() == 0) {
|
||||
return looputils::ConstInt(0);
|
||||
}
|
||||
if (lhs_const->GetValue() == 1) {
|
||||
return rhs;
|
||||
}
|
||||
}
|
||||
if (auto* rhs_const = dyncast<ConstantInt>(rhs)) {
|
||||
if (rhs_const->GetValue() == 0) {
|
||||
return looputils::ConstInt(0);
|
||||
}
|
||||
if (rhs_const->GetValue() == 1) {
|
||||
return lhs;
|
||||
}
|
||||
}
|
||||
if (auto* lhs_const = dyncast<ConstantInt>(lhs)) {
|
||||
if (auto* rhs_const = dyncast<ConstantInt>(rhs)) {
|
||||
return looputils::ConstInt(lhs_const->GetValue() * rhs_const->GetValue());
|
||||
}
|
||||
}
|
||||
return block->Insert<BinaryInst>(looputils::GetTerminatorIndex(block), Opcode::Mul,
|
||||
Type::GetInt32Type(), lhs, rhs, nullptr,
|
||||
looputils::NextSyntheticName(function, prefix));
|
||||
}
|
||||
|
||||
Value* BuildScaledValue(Function& function, BasicBlock* block, Value* base,
|
||||
int factor, const std::string& prefix) {
|
||||
if (factor == 0) {
|
||||
return looputils::ConstInt(0);
|
||||
}
|
||||
if (factor == 1) {
|
||||
return base;
|
||||
}
|
||||
if (auto* base_const = dyncast<ConstantInt>(base)) {
|
||||
return looputils::ConstInt(base_const->GetValue() * factor);
|
||||
}
|
||||
if (factor == -1) {
|
||||
return block->Insert<UnaryInst>(looputils::GetTerminatorIndex(block), Opcode::Neg,
|
||||
base->GetType(), base, nullptr,
|
||||
looputils::NextSyntheticName(function, prefix));
|
||||
}
|
||||
return BuildMulValue(function, block, base, looputils::ConstInt(factor), prefix);
|
||||
}
|
||||
|
||||
bool MatchSimpleInductionVariable(const Loop& loop, BasicBlock* preheader,
|
||||
PhiInst* phi, InductionVarInfo& info) {
|
||||
if (!phi || !phi->GetType() || !phi->GetType()->IsInt32() ||
|
||||
phi->GetParent() != loop.header || phi->GetNumIncomings() != 2 ||
|
||||
loop.latches.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* latch = loop.latches.front();
|
||||
int preheader_index = -1;
|
||||
int latch_index = -1;
|
||||
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
|
||||
if (phi->GetIncomingBlock(i) == preheader) {
|
||||
preheader_index = i;
|
||||
} else if (phi->GetIncomingBlock(i) == latch) {
|
||||
latch_index = i;
|
||||
}
|
||||
}
|
||||
if (preheader_index < 0 || latch_index < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* step_inst = dyncast<BinaryInst>(phi->GetIncomingValue(latch_index));
|
||||
if (!step_inst || step_inst->GetParent() != latch) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int stride = 0;
|
||||
if (step_inst->GetOpcode() == Opcode::Add) {
|
||||
if (step_inst->GetLhs() == phi) {
|
||||
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
|
||||
if (!delta) {
|
||||
return false;
|
||||
}
|
||||
stride = delta->GetValue();
|
||||
} else if (step_inst->GetRhs() == phi) {
|
||||
auto* delta = dyncast<ConstantInt>(step_inst->GetLhs());
|
||||
if (!delta) {
|
||||
return false;
|
||||
}
|
||||
stride = delta->GetValue();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else if (step_inst->GetOpcode() == Opcode::Sub) {
|
||||
if (step_inst->GetLhs() != phi) {
|
||||
return false;
|
||||
}
|
||||
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
|
||||
if (!delta) {
|
||||
return false;
|
||||
}
|
||||
stride = -delta->GetValue();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (stride == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
info.phi = phi;
|
||||
info.start = phi->GetIncomingValue(preheader_index);
|
||||
info.latch = latch;
|
||||
info.stride = stride;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsMulCandidate(const Loop& loop, Instruction* inst, PhiInst* phi, Value*& factor) {
|
||||
auto* mul = dyncast<BinaryInst>(inst);
|
||||
if (!mul || mul->GetOpcode() != Opcode::Mul || !mul->GetType()->IsInt32()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (mul->GetLhs() == phi && looputils::IsLoopInvariantValue(loop, mul->GetRhs())) {
|
||||
factor = mul->GetRhs();
|
||||
return true;
|
||||
}
|
||||
if (mul->GetRhs() == phi && looputils::IsLoopInvariantValue(loop, mul->GetLhs())) {
|
||||
factor = mul->GetLhs();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Value* CreateReducedPhi(Function& function, BasicBlock* header, BasicBlock* preheader,
|
||||
const InductionVarInfo& iv, Value* factor) {
|
||||
auto* reduced_phi = header->Insert<PhiInst>(
|
||||
looputils::GetFirstNonPhiIndex(header), Type::GetInt32Type(), nullptr,
|
||||
looputils::NextSyntheticName(function, "lsr.phi."));
|
||||
|
||||
Value* init = BuildMulValue(function, preheader, iv.start, factor, "lsr.init.");
|
||||
reduced_phi->AddIncoming(init, preheader);
|
||||
|
||||
Value* step = BuildScaledValue(function, preheader, factor, std::abs(iv.stride),
|
||||
"lsr.step.");
|
||||
Instruction* next = nullptr;
|
||||
if (iv.stride > 0) {
|
||||
next = iv.latch->Insert<BinaryInst>(
|
||||
looputils::GetTerminatorIndex(iv.latch), Opcode::Add, Type::GetInt32Type(),
|
||||
reduced_phi, step, nullptr,
|
||||
looputils::NextSyntheticName(function, "lsr.next."));
|
||||
} else {
|
||||
next = iv.latch->Insert<BinaryInst>(
|
||||
looputils::GetTerminatorIndex(iv.latch), Opcode::Sub, Type::GetInt32Type(),
|
||||
reduced_phi, step, nullptr,
|
||||
looputils::NextSyntheticName(function, "lsr.next."));
|
||||
}
|
||||
reduced_phi->AddIncoming(next, iv.latch);
|
||||
return reduced_phi;
|
||||
}
|
||||
|
||||
bool ReduceLoopMultiplications(Function& function, const Loop& loop,
|
||||
BasicBlock* preheader) {
|
||||
if (!preheader || loop.latches.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<InductionVarInfo> induction_vars;
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
InductionVarInfo info;
|
||||
if (MatchSimpleInductionVariable(loop, preheader, phi, info)) {
|
||||
induction_vars.push_back(info);
|
||||
}
|
||||
}
|
||||
if (induction_vars.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
std::vector<Instruction*> to_remove;
|
||||
for (const auto& iv : induction_vars) {
|
||||
std::vector<std::pair<Instruction*, Value*>> candidates;
|
||||
for (auto* block : loop.block_list) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst == iv.phi) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Value* factor = nullptr;
|
||||
if (IsMulCandidate(loop, inst, iv.phi, factor)) {
|
||||
candidates.push_back({inst, factor});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (candidates.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unordered_map<Value*, Value*> reduced_cache;
|
||||
for (const auto& candidate : candidates) {
|
||||
auto* inst = candidate.first;
|
||||
auto* factor = candidate.second;
|
||||
|
||||
auto cache_it = reduced_cache.find(factor);
|
||||
Value* replacement = nullptr;
|
||||
if (cache_it != reduced_cache.end()) {
|
||||
replacement = cache_it->second;
|
||||
} else {
|
||||
replacement = CreateReducedPhi(function, loop.header, preheader, iv, factor);
|
||||
reduced_cache.emplace(factor, replacement);
|
||||
}
|
||||
|
||||
inst->ReplaceAllUsesWith(replacement);
|
||||
to_remove.push_back(inst);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* inst : to_remove) {
|
||||
if (inst && inst->GetParent()) {
|
||||
inst->GetParent()->EraseInstruction(inst);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RunLoopStrengthReductionOnFunction(Function& function) {
|
||||
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
while (true) {
|
||||
DominatorTree dom_tree(function);
|
||||
LoopInfo loop_info(function, dom_tree);
|
||||
bool local_changed = false;
|
||||
|
||||
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
|
||||
auto* old_preheader = loop->preheader;
|
||||
auto* preheader = looputils::EnsurePreheader(function, *loop);
|
||||
bool loop_changed = preheader != old_preheader;
|
||||
loop_changed |= ReduceLoopMultiplications(function, *loop, preheader);
|
||||
if (!loop_changed) {
|
||||
continue;
|
||||
}
|
||||
changed = true;
|
||||
local_changed = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!local_changed) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoopStrengthReduction(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunLoopStrengthReductionOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,400 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "LoopMemoryUtils.h"
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct CountedLoopInfo {
|
||||
Loop* loop = nullptr;
|
||||
BasicBlock* preheader = nullptr;
|
||||
BasicBlock* header = nullptr;
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
CondBrInst* branch = nullptr;
|
||||
BinaryInst* compare = nullptr;
|
||||
Opcode compare_opcode = Opcode::ICmpLT;
|
||||
Value* bound = nullptr;
|
||||
loopmem::SimpleInductionVar induction_var;
|
||||
std::vector<PhiInst*> phis;
|
||||
};
|
||||
|
||||
bool HasSyntheticLoopTag(const std::string& name) {
|
||||
return name.find("unroll.") != std::string::npos;
|
||||
}
|
||||
|
||||
bool IsAlreadyTransformedLoop(const Loop& loop, BasicBlock* body) {
|
||||
if (!loop.preheader || !loop.header || !body) {
|
||||
return true;
|
||||
}
|
||||
if (HasSyntheticLoopTag(loop.preheader->GetName()) ||
|
||||
HasSyntheticLoopTag(loop.header->GetName()) ||
|
||||
HasSyntheticLoopTag(body->GetName())) {
|
||||
return true;
|
||||
}
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
const int incoming = looputils::GetPhiIncomingIndex(phi, loop.preheader);
|
||||
if (incoming < 0) {
|
||||
continue;
|
||||
}
|
||||
auto* incoming_phi = dyncast<PhiInst>(phi->GetIncomingValue(incoming));
|
||||
if (incoming_phi && incoming_phi->GetParent() &&
|
||||
HasSyntheticLoopTag(incoming_phi->GetParent()->GetName())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsSupportedCompareOpcode(Opcode opcode) {
|
||||
switch (opcode) {
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpGE:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Opcode SwapCompareOpcode(Opcode opcode) {
|
||||
switch (opcode) {
|
||||
case Opcode::ICmpLT:
|
||||
return Opcode::ICmpGT;
|
||||
case Opcode::ICmpLE:
|
||||
return Opcode::ICmpGE;
|
||||
case Opcode::ICmpGT:
|
||||
return Opcode::ICmpLT;
|
||||
case Opcode::ICmpGE:
|
||||
return Opcode::ICmpLE;
|
||||
default:
|
||||
return opcode;
|
||||
}
|
||||
}
|
||||
|
||||
int CountPayloadInstructions(BasicBlock* block) {
|
||||
int count = 0;
|
||||
if (!block) {
|
||||
return 0;
|
||||
}
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator()) {
|
||||
break;
|
||||
}
|
||||
++count;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
int ChooseUnrollFactor(BasicBlock* body) {
|
||||
const int inst_count = CountPayloadInstructions(body);
|
||||
int mem_ops = 0;
|
||||
for (const auto& inst_ptr : body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator()) {
|
||||
break;
|
||||
}
|
||||
if (dyncast<LoadInst>(inst) || dyncast<StoreInst>(inst)) {
|
||||
++mem_ops;
|
||||
}
|
||||
}
|
||||
if (inst_count >= 2 && inst_count <= 6 && mem_ops <= 2) {
|
||||
return 4;
|
||||
}
|
||||
if (inst_count >= 2 && inst_count <= 18) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
bool HasUnsafeLoopCarriedMemoryDependence(
|
||||
const std::vector<loopmem::MemoryAccessInfo>& accesses, int iv_stride) {
|
||||
for (std::size_t i = 0; i < accesses.size(); ++i) {
|
||||
if (accesses[i].is_write &&
|
||||
loopmem::HasCrossIterationDependence(accesses[i].ptr, accesses[i].ptr,
|
||||
iv_stride)) {
|
||||
return true;
|
||||
}
|
||||
for (std::size_t j = i + 1; j < accesses.size(); ++j) {
|
||||
if (!accesses[i].is_write && !accesses[j].is_write) {
|
||||
continue;
|
||||
}
|
||||
if (loopmem::HasCrossIterationDependence(accesses[i].ptr, accesses[j].ptr,
|
||||
iv_stride)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool MatchCountedLoop(Loop& loop, CountedLoopInfo& info) {
|
||||
if (!loop.preheader || !loop.header || !loop.IsInnermost()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) {
|
||||
return false;
|
||||
}
|
||||
if (IsAlreadyTransformedLoop(loop, body)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* branch = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
|
||||
if (!branch || branch->GetThenBlock() != body) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* compare = dyncast<BinaryInst>(branch->GetCondition());
|
||||
if (!compare || !compare->GetType()->IsBool() ||
|
||||
!IsSupportedCompareOpcode(compare->GetOpcode())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool found_iv = false;
|
||||
loopmem::SimpleInductionVar induction_var;
|
||||
std::vector<PhiInst*> phis;
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
phis.push_back(phi);
|
||||
if (!found_iv &&
|
||||
loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) {
|
||||
found_iv = true;
|
||||
}
|
||||
}
|
||||
if (!found_iv) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Opcode compare_opcode = compare->GetOpcode();
|
||||
Value* bound = nullptr;
|
||||
if (compare->GetLhs() == induction_var.phi &&
|
||||
looputils::IsLoopInvariantValue(loop, compare->GetRhs())) {
|
||||
bound = compare->GetRhs();
|
||||
} else if (compare->GetRhs() == induction_var.phi &&
|
||||
looputils::IsLoopInvariantValue(loop, compare->GetLhs())) {
|
||||
bound = compare->GetLhs();
|
||||
compare_opcode = SwapCompareOpcode(compare_opcode);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!bound) {
|
||||
return false;
|
||||
}
|
||||
if ((induction_var.stride > 0 &&
|
||||
!(compare_opcode == Opcode::ICmpLT || compare_opcode == Opcode::ICmpLE)) ||
|
||||
(induction_var.stride < 0 &&
|
||||
!(compare_opcode == Opcode::ICmpGT || compare_opcode == Opcode::ICmpGE))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto accesses = loopmem::CollectMemoryAccesses(loop, induction_var.phi);
|
||||
if (HasUnsafeLoopCarriedMemoryDependence(accesses, induction_var.stride)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (dyncast<PhiInst>(inst) || inst == compare || inst->IsTerminator()) {
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto& inst_ptr : body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator()) {
|
||||
continue;
|
||||
}
|
||||
if (!looputils::IsCloneableInstruction(inst) || dyncast<CallInst>(inst) ||
|
||||
dyncast<MemsetInst>(inst) || dyncast<AllocaInst>(inst)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
info.loop = &loop;
|
||||
info.preheader = loop.preheader;
|
||||
info.header = loop.header;
|
||||
info.body = body;
|
||||
info.exit = exit;
|
||||
info.branch = branch;
|
||||
info.compare = compare;
|
||||
info.compare_opcode = compare_opcode;
|
||||
info.bound = bound;
|
||||
info.induction_var = induction_var;
|
||||
info.phis = std::move(phis);
|
||||
return true;
|
||||
}
|
||||
|
||||
Value* BuildAdjustedBound(Function& function, BasicBlock* preheader, Value* bound,
|
||||
int stride, int factor) {
|
||||
const int delta = std::abs(stride) * (factor - 1);
|
||||
if (delta == 0) {
|
||||
return bound;
|
||||
}
|
||||
if (auto* ci = dyncast<ConstantInt>(bound)) {
|
||||
return looputils::ConstInt(stride > 0 ? ci->GetValue() - delta : ci->GetValue() + delta);
|
||||
}
|
||||
return preheader->Insert<BinaryInst>(
|
||||
looputils::GetTerminatorIndex(preheader),
|
||||
stride > 0 ? Opcode::Sub : Opcode::Add, Type::GetInt32Type(), bound,
|
||||
looputils::ConstInt(delta), nullptr,
|
||||
looputils::NextSyntheticName(function, "unroll.bound."));
|
||||
}
|
||||
|
||||
bool RunLoopUnrollOnFunction(Function& function) {
|
||||
if (function.IsExternal() || !function.GetEntryBlock()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
while (true) {
|
||||
DominatorTree dom_tree(function);
|
||||
LoopInfo loop_info(function, dom_tree);
|
||||
bool local_changed = false;
|
||||
|
||||
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
|
||||
CountedLoopInfo info;
|
||||
if (!MatchCountedLoop(*loop, info)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int factor = ChooseUnrollFactor(info.body);
|
||||
if (factor <= 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* unrolled_header =
|
||||
function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.header"));
|
||||
auto* unrolled_body =
|
||||
function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.body"));
|
||||
auto* unrolled_exit =
|
||||
function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.exit"));
|
||||
|
||||
std::unordered_map<Value*, Value*> remap;
|
||||
std::unordered_map<PhiInst*, PhiInst*> unrolled_phis;
|
||||
std::unordered_map<PhiInst*, PhiInst*> exit_phis;
|
||||
std::unordered_map<PhiInst*, Value*> current_phi_values;
|
||||
std::unordered_map<PhiInst*, Value*> latch_values;
|
||||
|
||||
for (auto* phi : info.phis) {
|
||||
auto* cloned_phi = unrolled_header->Append<PhiInst>(
|
||||
phi->GetType(), nullptr,
|
||||
looputils::NextSyntheticName(function, "unroll.phi."));
|
||||
const int preheader_index = looputils::GetPhiIncomingIndex(phi, info.preheader);
|
||||
const int latch_index = looputils::GetPhiIncomingIndex(phi, info.body);
|
||||
if (preheader_index < 0 || latch_index < 0) {
|
||||
continue;
|
||||
}
|
||||
cloned_phi->AddIncoming(phi->GetIncomingValue(preheader_index), info.preheader);
|
||||
remap[phi] = cloned_phi;
|
||||
unrolled_phis.emplace(phi, cloned_phi);
|
||||
current_phi_values.emplace(phi, cloned_phi);
|
||||
latch_values.emplace(phi, phi->GetIncomingValue(latch_index));
|
||||
}
|
||||
|
||||
auto* adjusted_bound = BuildAdjustedBound(function, info.preheader, info.bound,
|
||||
info.induction_var.stride, factor);
|
||||
auto* unrolled_cond = unrolled_header->Append<BinaryInst>(
|
||||
info.compare_opcode, Type::GetBoolType(), unrolled_phis[info.induction_var.phi],
|
||||
adjusted_bound, nullptr,
|
||||
looputils::NextSyntheticName(function, "unroll.cmp."));
|
||||
unrolled_header->Append<CondBrInst>(unrolled_cond, unrolled_body, unrolled_exit, nullptr);
|
||||
unrolled_header->AddPredecessor(info.preheader);
|
||||
unrolled_header->AddSuccessor(unrolled_body);
|
||||
unrolled_header->AddSuccessor(unrolled_exit);
|
||||
|
||||
for (int lane = 0; lane < factor; ++lane) {
|
||||
for (const auto& inst_ptr : info.body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator()) {
|
||||
continue;
|
||||
}
|
||||
looputils::CloneInstruction(function, inst, unrolled_body, remap,
|
||||
"unroll." + std::to_string(lane) + ".");
|
||||
}
|
||||
|
||||
std::unordered_map<PhiInst*, Value*> next_phi_values;
|
||||
for (const auto& entry : latch_values) {
|
||||
next_phi_values.emplace(entry.first,
|
||||
looputils::RemapValue(remap, entry.second));
|
||||
}
|
||||
for (const auto& entry : next_phi_values) {
|
||||
remap[entry.first] = entry.second;
|
||||
current_phi_values[entry.first] = entry.second;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& entry : unrolled_phis) {
|
||||
entry.second->AddIncoming(current_phi_values[entry.first], unrolled_body);
|
||||
}
|
||||
unrolled_body->Append<UncondBrInst>(unrolled_header, nullptr);
|
||||
unrolled_body->AddPredecessor(unrolled_header);
|
||||
unrolled_body->AddSuccessor(unrolled_header);
|
||||
unrolled_header->AddPredecessor(unrolled_body);
|
||||
|
||||
for (const auto& entry : unrolled_phis) {
|
||||
auto* exit_phi = unrolled_exit->Append<PhiInst>(
|
||||
entry.first->GetType(), nullptr,
|
||||
looputils::NextSyntheticName(function, "unroll.exit."));
|
||||
exit_phi->AddIncoming(entry.second, unrolled_header);
|
||||
exit_phis.emplace(entry.first, exit_phi);
|
||||
}
|
||||
unrolled_exit->Append<UncondBrInst>(info.header, nullptr);
|
||||
unrolled_exit->AddPredecessor(unrolled_header);
|
||||
unrolled_exit->AddSuccessor(info.header);
|
||||
|
||||
if (!looputils::RedirectSuccessorEdge(info.preheader, info.header, unrolled_header)) {
|
||||
continue;
|
||||
}
|
||||
info.header->RemovePredecessor(info.preheader);
|
||||
info.header->AddPredecessor(unrolled_exit);
|
||||
|
||||
for (auto* phi : info.phis) {
|
||||
looputils::ReplacePhiIncoming(phi, info.preheader, exit_phis[phi], unrolled_exit);
|
||||
}
|
||||
|
||||
changed = true;
|
||||
local_changed = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!local_changed) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoopUnroll(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunLoopUnrollOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,260 @@
|
||||
#pragma once
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "PassUtils.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir::memutils {
|
||||
|
||||
enum class PointerRootKind {
|
||||
Local,
|
||||
Global,
|
||||
ReadonlyGlobal,
|
||||
Param,
|
||||
Unknown,
|
||||
};
|
||||
|
||||
struct AddressComponent {
|
||||
bool is_constant = false;
|
||||
std::int64_t constant = 0;
|
||||
Value* value = nullptr;
|
||||
|
||||
bool operator==(const AddressComponent& rhs) const {
|
||||
return is_constant == rhs.is_constant && constant == rhs.constant &&
|
||||
value == rhs.value;
|
||||
}
|
||||
};
|
||||
|
||||
struct AddressKey {
|
||||
PointerRootKind kind = PointerRootKind::Unknown;
|
||||
Value* root = nullptr;
|
||||
std::vector<AddressComponent> components;
|
||||
|
||||
bool operator==(const AddressKey& rhs) const {
|
||||
return kind == rhs.kind && root == rhs.root && components == rhs.components;
|
||||
}
|
||||
};
|
||||
|
||||
struct AddressKeyHash {
|
||||
std::size_t operator()(const AddressKey& key) const {
|
||||
std::size_t h = static_cast<std::size_t>(key.kind);
|
||||
h ^= std::hash<Value*>{}(key.root) + 0x9e3779b9 + (h << 6) + (h >> 2);
|
||||
for (const auto& component : key.components) {
|
||||
h ^= std::hash<bool>{}(component.is_constant) + 0x9e3779b9 + (h << 6) + (h >> 2);
|
||||
if (component.is_constant) {
|
||||
h ^= std::hash<std::int64_t>{}(component.constant) + 0x9e3779b9 + (h << 6) +
|
||||
(h >> 2);
|
||||
} else {
|
||||
h ^= std::hash<Value*>{}(component.value) + 0x9e3779b9 + (h << 6) + (h >> 2);
|
||||
}
|
||||
}
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
struct EscapeSummary {
|
||||
std::unordered_set<Value*> escaped_locals;
|
||||
|
||||
bool IsEscaped(Value* value) const {
|
||||
return value != nullptr && escaped_locals.find(value) != escaped_locals.end();
|
||||
}
|
||||
};
|
||||
|
||||
inline bool IsNoEscapePointerUse(Value* current, Instruction* user) {
|
||||
if (!current || !user) {
|
||||
return false;
|
||||
}
|
||||
if (auto* load = dyncast<LoadInst>(user)) {
|
||||
return load->GetPtr() == current;
|
||||
}
|
||||
if (auto* store = dyncast<StoreInst>(user)) {
|
||||
return store->GetPtr() == current;
|
||||
}
|
||||
if (auto* memset = dyncast<MemsetInst>(user)) {
|
||||
return memset->GetDest() == current;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool PointerValueEscapes(Value* current, Value* root,
|
||||
std::unordered_set<Value*>& visiting) {
|
||||
if (!current || !root || !visiting.insert(current).second) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto& use : current->GetUses()) {
|
||||
auto* user = dyncast<Instruction>(use.GetUser());
|
||||
if (!user) {
|
||||
return true;
|
||||
}
|
||||
if (auto* gep = dyncast<GetElementPtrInst>(user)) {
|
||||
if (gep->GetPointer() == current &&
|
||||
PointerValueEscapes(gep, root, visiting)) {
|
||||
return true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (IsNoEscapePointerUse(current, user)) {
|
||||
continue;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
inline EscapeSummary AnalyzeEscapes(Function& function) {
|
||||
EscapeSummary summary;
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
|
||||
auto* alloca = dyncast<AllocaInst>(inst_ptr.get());
|
||||
if (!alloca) {
|
||||
continue;
|
||||
}
|
||||
std::unordered_set<Value*> visiting;
|
||||
if (PointerValueEscapes(alloca, alloca, visiting)) {
|
||||
summary.escaped_locals.insert(alloca);
|
||||
}
|
||||
}
|
||||
}
|
||||
return summary;
|
||||
}
|
||||
|
||||
inline PointerRootKind ClassifyRoot(Value* root, const EscapeSummary* summary) {
|
||||
if (root == nullptr) {
|
||||
return PointerRootKind::Unknown;
|
||||
}
|
||||
if (auto* global = dyncast<GlobalValue>(root)) {
|
||||
return global->IsConstant() ? PointerRootKind::ReadonlyGlobal
|
||||
: PointerRootKind::Global;
|
||||
}
|
||||
if (isa<Argument>(root)) {
|
||||
return PointerRootKind::Param;
|
||||
}
|
||||
if (isa<AllocaInst>(root)) {
|
||||
if (summary != nullptr && summary->IsEscaped(root)) {
|
||||
return PointerRootKind::Unknown;
|
||||
}
|
||||
return PointerRootKind::Local;
|
||||
}
|
||||
return PointerRootKind::Unknown;
|
||||
}
|
||||
|
||||
inline Value* StripPointerRoot(Value* pointer) {
|
||||
auto* current = pointer;
|
||||
while (auto* gep = dyncast<GetElementPtrInst>(current)) {
|
||||
current = gep->GetPointer();
|
||||
}
|
||||
return current;
|
||||
}
|
||||
|
||||
inline AddressComponent MakeAddressComponent(Value* value) {
|
||||
if (auto* ci = dyncast<ConstantInt>(value)) {
|
||||
return {true, ci->GetValue(), nullptr};
|
||||
}
|
||||
if (auto* cb = dyncast<ConstantI1>(value)) {
|
||||
return {true, cb->GetValue() ? 1 : 0, nullptr};
|
||||
}
|
||||
return {false, 0, value};
|
||||
}
|
||||
|
||||
inline bool BuildExactAddressKey(Value* pointer, const EscapeSummary* summary,
|
||||
AddressKey& key) {
|
||||
if (!pointer) {
|
||||
return false;
|
||||
}
|
||||
if (auto* gep = dyncast<GetElementPtrInst>(pointer)) {
|
||||
if (!BuildExactAddressKey(gep->GetPointer(), summary, key)) {
|
||||
return false;
|
||||
}
|
||||
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
|
||||
key.components.push_back(MakeAddressComponent(gep->GetIndex(i)));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
key.kind = ClassifyRoot(pointer, summary);
|
||||
key.root = pointer;
|
||||
key.components.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool HasOnlyConstantComponents(const AddressKey& key) {
|
||||
for (const auto& component : key.components) {
|
||||
if (!component.is_constant) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool MayAliasConservatively(const AddressKey& lhs, const AddressKey& rhs) {
|
||||
if (lhs.kind == PointerRootKind::Unknown || rhs.kind == PointerRootKind::Unknown) {
|
||||
return true;
|
||||
}
|
||||
if (lhs.kind != rhs.kind || lhs.root != rhs.root) {
|
||||
return false;
|
||||
}
|
||||
if (lhs.components == rhs.components) {
|
||||
return true;
|
||||
}
|
||||
if (HasOnlyConstantComponents(lhs) && HasOnlyConstantComponents(rhs)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool CallMayReadRoot(Function* callee, PointerRootKind kind) {
|
||||
if (!callee) {
|
||||
return true;
|
||||
}
|
||||
if (callee->HasUnknownEffects()) {
|
||||
return true;
|
||||
}
|
||||
switch (kind) {
|
||||
case PointerRootKind::ReadonlyGlobal:
|
||||
return callee->ReadsGlobalMemory();
|
||||
case PointerRootKind::Global:
|
||||
return callee->ReadsGlobalMemory() || callee->WritesGlobalMemory();
|
||||
case PointerRootKind::Param:
|
||||
return callee->ReadsParamMemory() || callee->WritesParamMemory();
|
||||
case PointerRootKind::Local:
|
||||
return false;
|
||||
case PointerRootKind::Unknown:
|
||||
return callee->MayReadMemory();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool CallMayWriteRoot(Function* callee, PointerRootKind kind) {
|
||||
if (!callee) {
|
||||
return true;
|
||||
}
|
||||
if (callee->HasUnknownEffects()) {
|
||||
return true;
|
||||
}
|
||||
switch (kind) {
|
||||
case PointerRootKind::ReadonlyGlobal:
|
||||
return false;
|
||||
case PointerRootKind::Global:
|
||||
return callee->WritesGlobalMemory();
|
||||
case PointerRootKind::Param:
|
||||
return callee->WritesParamMemory();
|
||||
case PointerRootKind::Local:
|
||||
return false;
|
||||
case PointerRootKind::Unknown:
|
||||
return callee->MayWriteMemory();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool IsPureCall(const CallInst* call) {
|
||||
auto* callee = call == nullptr ? nullptr : call->GetCallee();
|
||||
return callee != nullptr && callee->CanDiscardUnusedCall() &&
|
||||
!callee->MayReadMemory();
|
||||
}
|
||||
|
||||
} // namespace ir::memutils
|
||||
@ -0,0 +1,234 @@
|
||||
#pragma once
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir::passutils {
|
||||
|
||||
inline std::uint32_t FloatBits(float value) {
|
||||
std::uint32_t bits = 0;
|
||||
std::memcpy(&bits, &value, sizeof(bits));
|
||||
return bits;
|
||||
}
|
||||
|
||||
inline bool AreEquivalentValues(Value* lhs, Value* rhs) {
|
||||
if (lhs == rhs) {
|
||||
return true;
|
||||
}
|
||||
auto* lhs_i32 = dyncast<ConstantInt>(lhs);
|
||||
auto* rhs_i32 = dyncast<ConstantInt>(rhs);
|
||||
if (lhs_i32 && rhs_i32) {
|
||||
return lhs_i32->GetValue() == rhs_i32->GetValue();
|
||||
}
|
||||
auto* lhs_i1 = dyncast<ConstantI1>(lhs);
|
||||
auto* rhs_i1 = dyncast<ConstantI1>(rhs);
|
||||
if (lhs_i1 && rhs_i1) {
|
||||
return lhs_i1->GetValue() == rhs_i1->GetValue();
|
||||
}
|
||||
auto* lhs_f32 = dyncast<ConstantFloat>(lhs);
|
||||
auto* rhs_f32 = dyncast<ConstantFloat>(rhs);
|
||||
if (lhs_f32 && rhs_f32) {
|
||||
return FloatBits(lhs_f32->GetValue()) == FloatBits(rhs_f32->GetValue());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline std::vector<BasicBlock*> CollectReachableBlocks(Function& function) {
|
||||
std::vector<BasicBlock*> order;
|
||||
auto* entry = function.GetEntryBlock();
|
||||
if (!entry) {
|
||||
return order;
|
||||
}
|
||||
|
||||
std::unordered_set<BasicBlock*> visited;
|
||||
std::vector<BasicBlock*> stack{entry};
|
||||
while (!stack.empty()) {
|
||||
auto* block = stack.back();
|
||||
stack.pop_back();
|
||||
if (!block || !visited.insert(block).second) {
|
||||
continue;
|
||||
}
|
||||
order.push_back(block);
|
||||
const auto& succs = block->GetSuccessors();
|
||||
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
|
||||
if (*it != nullptr) {
|
||||
stack.push_back(*it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return order;
|
||||
}
|
||||
|
||||
inline bool IsSideEffectingInstruction(const Instruction* inst) {
|
||||
if (!inst) {
|
||||
return false;
|
||||
}
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Store:
|
||||
case Opcode::Memset:
|
||||
case Opcode::Br:
|
||||
case Opcode::CondBr:
|
||||
case Opcode::Return:
|
||||
case Opcode::Unreachable:
|
||||
return true;
|
||||
case Opcode::Call: {
|
||||
auto* call = dyncast<const CallInst>(inst);
|
||||
auto* callee = call == nullptr ? nullptr : call->GetCallee();
|
||||
return callee == nullptr || !callee->CanDiscardUnusedCall();
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool IsTriviallyDead(Instruction* inst) {
|
||||
return inst != nullptr && !IsSideEffectingInstruction(inst) &&
|
||||
inst->GetUses().empty();
|
||||
}
|
||||
|
||||
inline void RemoveIncomingForBlock(PhiInst* phi, BasicBlock* block) {
|
||||
if (!phi || !block) {
|
||||
return;
|
||||
}
|
||||
for (int i = phi->GetNumIncomings() - 1; i >= 0; --i) {
|
||||
if (phi->GetIncomingBlock(i) != block) {
|
||||
continue;
|
||||
}
|
||||
phi->RemoveOperand(static_cast<size_t>(2 * i + 1));
|
||||
phi->RemoveOperand(static_cast<size_t>(2 * i));
|
||||
}
|
||||
}
|
||||
|
||||
inline void RemoveIncomingFromSuccessor(BasicBlock* succ, BasicBlock* pred) {
|
||||
if (!succ || !pred) {
|
||||
return;
|
||||
}
|
||||
for (const auto& inst_ptr : succ->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
RemoveIncomingForBlock(phi, pred);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ReplaceTerminatorWithBr(BasicBlock* block, BasicBlock* dest) {
|
||||
auto& instructions = block->GetInstructions();
|
||||
if (instructions.empty() || !instructions.back()->IsTerminator()) {
|
||||
return;
|
||||
}
|
||||
instructions.back()->ClearAllOperands();
|
||||
auto branch = std::make_unique<UncondBrInst>(dest, nullptr);
|
||||
branch->SetParent(block);
|
||||
instructions.back() = std::move(branch);
|
||||
}
|
||||
|
||||
inline bool SimplifyPhiInst(PhiInst* phi) {
|
||||
if (!phi) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Value* unique_value = nullptr;
|
||||
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
|
||||
auto* incoming = phi->GetIncomingValue(i);
|
||||
if (incoming == phi) {
|
||||
continue;
|
||||
}
|
||||
if (unique_value == nullptr) {
|
||||
unique_value = incoming;
|
||||
continue;
|
||||
}
|
||||
if (!AreEquivalentValues(unique_value, incoming)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (unique_value == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* parent = phi->GetParent();
|
||||
phi->ReplaceAllUsesWith(unique_value);
|
||||
parent->EraseInstruction(phi);
|
||||
return true;
|
||||
}
|
||||
|
||||
inline void EraseBlock(Function& function, BasicBlock* block) {
|
||||
if (!block) {
|
||||
return;
|
||||
}
|
||||
auto& blocks = function.GetBlocks();
|
||||
blocks.erase(std::remove_if(blocks.begin(), blocks.end(),
|
||||
[&](const std::unique_ptr<BasicBlock>& current) {
|
||||
return current.get() == block;
|
||||
}),
|
||||
blocks.end());
|
||||
}
|
||||
|
||||
inline bool RemoveUnreachableBlocks(Function& function) {
|
||||
auto reachable = CollectReachableBlocks(function);
|
||||
std::unordered_set<BasicBlock*> reachable_set(reachable.begin(), reachable.end());
|
||||
std::vector<BasicBlock*> dead_blocks;
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
auto* block = block_ptr.get();
|
||||
if (reachable_set.find(block) == reachable_set.end()) {
|
||||
dead_blocks.push_back(block);
|
||||
}
|
||||
}
|
||||
|
||||
if (dead_blocks.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto* block : dead_blocks) {
|
||||
auto preds = block->GetPredecessors();
|
||||
auto succs = block->GetSuccessors();
|
||||
for (auto* succ : succs) {
|
||||
RemoveIncomingFromSuccessor(succ, block);
|
||||
succ->RemovePredecessor(block);
|
||||
}
|
||||
for (auto* pred : preds) {
|
||||
pred->RemoveSuccessor(block);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* block : dead_blocks) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
inst_ptr->ClearAllOperands();
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* block : dead_blocks) {
|
||||
EraseBlock(function, block);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool IsCommutativeOpcode(Opcode opcode) {
|
||||
switch (opcode) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Mul:
|
||||
case Opcode::And:
|
||||
case Opcode::Or:
|
||||
case Opcode::Xor:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FMul:
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ir::passutils
|
||||
Loading…
Reference in new issue