master
tangttangtang 3 weeks ago
parent 252073efe8
commit b33ede5457

@ -0,0 +1,220 @@
# Lab4-Lab6 完成情况说明
## 1. 文档目的
本文档用于对照 `doc/Lab4-基本标量优化.md`、`doc/Lab5-寄存器分配.md`、`doc/Lab6-并行与循环优化.md`,说明当前编译器在 Lab4、Lab5、Lab6 三个阶段的完成情况,并补充最近一轮围绕比赛级目标所做的修改与优化。
## 2. 总体结论
从当前代码状态看:
- Lab4已完成且已经超过文档中的基础标量优化要求。
- Lab5已完成且已经形成真实可运行的后端寄存器分配与后端优化链路不再是示例级后端。
- Lab6主体已完成已经具备比赛可用的单线程循环优化能力循环并行分析基础已接入但未实现真正的多线程运行时并行执行。
当前主线已经是:
`SysY -> IR 生成 -> IR 优化 -> MIR lowering -> MIR 优化 -> 寄存器分配 -> 栈帧落地 -> AArch64 汇编输出`
## 3. 对照完成情况
### 3.1 Lab4基本标量优化
Lab4 文档要求的核心是:
1. 先做 `mem2reg`,把局部变量提升到 SSA。
2. 实现基础标量优化如常量折叠、常量传播、DCE、CFG 简化、CSE。
3. 把这些优化接入 `PassManager`,形成可重复执行的优化流程。
4. 通过测试确认优化前后语义一致。
当前实现情况:
- `Mem2Reg` 已接入优化流水线,并作为标量优化前置步骤执行。
- `ConstProp`、`ConstFold`、`DCE`、`CFGSimplify`、`CSE` 均已实现并接入。
- 在文档要求之外,又新增了 `GVN``LoadStoreElim`,进一步加强了内存相关和跨块冗余消除能力。
- `PassManager` 已形成迭代优化流程,而不是单次串行跑一遍后结束。
当前 `IR` 流水线在 `src/ir/passes/PassManager.cpp` 中会迭代执行:
- `RunFunctionInlining`
- `RunConstProp`
- `RunConstFold`
- `RunGVN`
- `RunLoadStoreElim`
- `RunCSE`
- `RunDCE`
- `RunCFGSimplify`
- `RunLICM`
- `RunLoopStrengthReduction`
- `RunLoopFission`
- `RunLoopUnroll`
完成判断:
- Lab4 已完成。
- 严格按文档要求看,不仅满足“基础标量优化”要求,而且已经扩展到了更强的中端优化框架。
### 3.2 Lab5寄存器分配与后端优化
Lab5 文档要求的核心是:
1. MIR 不再固定使用少量物理寄存器,而是先生成虚拟寄存器。
2. 实现真实寄存器分配,并处理 spill/reload、callee-saved、栈槽等问题。
3. 接入后端局部优化流程,减少冗余 `copy/move`、冗余 `load/store` 和明显恒等指令。
4. 在全部测试上验证正确性,并尽量提升生成代码质量。
当前实现情况:
- `Lowering` 已经输出虚拟寄存器 MIR而不是固定寄存器模板。
- `RegAlloc` 已实现真实寄存器分配,当前采用图着色风格分配流程,并处理了:
- 活跃性分析
- 干涉关系
- `copy` 合并
- spill 栈槽分配
- callee-saved 保存恢复信息回填
- live-across-call 约束
- `FrameLowering``AsmPrinter` 已经能够围绕 RA 结果完成最终栈帧和汇编输出。
- `MIR` 优化流水线已经真正接入主链:
- `PreRA``AddressHoisting + Peephole`
- `PostRA``Peephole`
后端局部优化目前已经覆盖:
- 冗余 `copy` 消除
- 恒等算术指令消除
- 条件跳转简化
- 局部冗余 `load/store` 消除
- 同块内 store-to-load forwarding
- 同地址重复 `store` 删除
- 基于 CFG 的跨块 memory dataflow
最近一轮后端进一步做了两件关键事情:
1. `MIR Peephole` 从“单基本块局部优化”提升到“带 CFG 数据流的跨块内存优化”。
2. `MIR Lowering` 调整为按可达 CFG 顺序 lowering修复了内联后复杂 CFG 下 SSA 值先用后定义导致的 lowering 失败。
说明:
- 曾尝试扩展 `v16-v18` 作为额外 FPR 可分配寄存器,但在浮点重调用样例上出现错误,因此最终回退,保留稳定寄存器集合。这一调整没有留在主线中。
完成判断:
- Lab5 已完成。
- 与文档中的“最小后端推进到真实后端”目标相比,当前实现已经超过课程最低线。
### 3.3 Lab6并行与循环优化
Lab6 文档要求的核心是:
1. 建立循环分析基础,识别循环头、循环体、前置块、退出块、回边等结构。
2. 实现有效循环优化,并接入 `PassManager`
3. 与 Lab4 标量优化协同工作。
4. 若希望进一步提升性能,可继续尝试可并行循环识别与并行化。
当前实现情况:
- 已实现 `DominatorTree``LoopInfo`,可识别自然循环及其层次关系。
- 已补齐循环变换所需的 `LoopPassUtils`
- 已接入的循环优化包括:
- `LICM`
- `LoopStrengthReduction`
- `LoopUnroll`
- `LoopFission`
- `LoopMemoryUtils` 已从较弱的循环地址分析,升级为结合:
- simple induction variable
- affine 地址表达
- exact-address key
- root-aware alias/mod-ref
- 非逃逸局部对象分析
的更强版本。
- `LICM` 已经可以更积极地 hoist 安全的 `load`,并对同地址的 hoisted load 做去重合并。
关于“并行与循环优化”中的并行部分:
- 当前已经具备可并行循环识别与依赖分析基础。
- 但没有继续接入真正的多线程并行 runtime也没有把循环改写为可直接并发执行的运行时调用。
- 结合文档表述,这部分更像“继续深入方向”,而不是 Lab6 基础完成线的硬要求。
完成判断:
- Lab6 主体已完成。
- 从比赛级编译器角度,当前已经具备较完整的单线程循环优化能力。
- 若以“真正运行时并行执行”作为额外目标,则这一部分仍可继续扩展,但不影响当前对 Lab6 主体完成的判断。
## 4. 最近一轮修改与优化
这一轮围绕比赛级目标,主要新增和加强了以下内容。
### 4.1 中端新增与增强
- 新增 `GVN`,用于更大范围复用纯表达式结果。
- 新增 `LoadStoreElim`,支持跨块冗余 `load` 消除、store-to-load forwarding、死 `store` 删除。
- 强化 `LoopMemoryUtils`,让循环内存优化不再只依赖很保守的规则。
- 强化 `LICM`,使其对安全 `load` 的外提更积极,并能对 hoisted load 做合并。
- 新增 IR 级小函数内联,使收益更早反馈到 `ConstProp`、`GVN`、`DCE`、`LICM` 等中端优化。
### 4.2 后端新增与增强
- `MIR Peephole` 从局部块内优化,扩展到基于 CFG 的跨块内存状态传播。
- `Call` 现在会按源 `IR Function` 的 effect 信息进行 `read/write` 边界判断,不再统一按最粗粒度处理。
- 修复了内联后复杂控制流下 MIR lowering 的块顺序问题。
- 完整回归后保留稳定 FPR 集合,放弃了不稳定的 `v16-v18` 扩容方案。
### 4.3 这轮优化的实际意义
这意味着最近的修改已经不只是“补课程实验功能”,而是开始面向比赛收益去提升:
- 中端:更强的冗余消除、内存优化、函数级优化、循环优化协同
- 后端:更强的 `copy/load/store` 消除与更稳定的 RA 后局部优化
## 5. 当前验证情况
本次回归中,已经完成以下验证:
### 5.1 全量正确性回归
执行:
```bash
./scripts/lab3_build_test.sh test/test_case/functional test/test_case/h_functional
```
结果:
- `134 PASS / 0 FAIL / total 134`
这说明当前 Lab4-Lab6 优化接入后,完整 `asm` 路径在 `functional + h_functional` 上保持正确。
### 5.2 性能热点抽测
执行并通过:
- `test/test_case/h_performance/fft2.sy`
- `test/test_case/h_performance/matmul3.sy`
- `test/test_case/h_performance/transpose2.sy`
- `test/test_case/h_performance/gameoflife-gosper.sy`
这些样例覆盖了:
- 重循环
- 重访存
- 浮点运算
- 矩阵访问
- 较复杂控制流
可以说明当前新增优化至少在一批代表性性能样例上保持了可运行与结果正确。
## 6. 结论
综合来看,当前编译器在 Lab4、Lab5、Lab6 上的完成情况可以概括为:
- Lab4完成并已扩展到更强的中端优化。
- Lab5完成并已形成真实可运行的后端优化链路。
- Lab6主体完成单线程循环优化能力已经达到比赛可用水平。
如果后续继续朝比赛方向推进,最值得继续做的事情不再是“补实验是否完成”,而是:
1. 针对 `h_performance` 做系统 profiling。
2. 按性能热点继续优化中端内存/循环变换。
3. 继续提升后端 spill、copy、访存质量。
4. 如需继续深入 Lab6可进一步尝试真正的并行 runtime 接入。

@ -0,0 +1,73 @@
#pragma once
#include "ir/IR.h"
#include <cstdint>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
class DominatorTree {
public:
explicit DominatorTree(Function& function);
void Recalculate();
Function& GetFunction() const { return *function_; }
bool IsReachable(BasicBlock* block) const;
bool Dominates(BasicBlock* dom, BasicBlock* node) const;
bool Dominates(Instruction* dom, Instruction* user) const;
BasicBlock* GetIDom(BasicBlock* block) const;
const std::vector<BasicBlock*>& GetChildren(BasicBlock* block) const;
const std::vector<BasicBlock*>& GetReversePostOrder() const {
return reverse_post_order_;
}
private:
Function* function_ = nullptr;
std::vector<BasicBlock*> reverse_post_order_;
std::unordered_map<BasicBlock*, std::size_t> block_index_;
std::vector<std::vector<std::uint8_t>> dominates_;
std::unordered_map<BasicBlock*, BasicBlock*> immediate_dominator_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dom_children_;
};
struct Loop {
BasicBlock* header = nullptr;
std::unordered_set<BasicBlock*> blocks;
std::vector<BasicBlock*> block_list;
std::vector<BasicBlock*> latches;
std::vector<BasicBlock*> exiting_blocks;
std::vector<BasicBlock*> exit_blocks;
BasicBlock* preheader = nullptr;
Loop* parent = nullptr;
std::vector<Loop*> subloops;
bool Contains(BasicBlock* block) const;
bool Contains(const Loop* other) const;
bool IsInnermost() const;
};
class LoopInfo {
public:
LoopInfo(Function& function, const DominatorTree& dom_tree);
void Recalculate();
const std::vector<std::unique_ptr<Loop>>& GetLoops() const { return loops_; }
std::vector<Loop*> GetTopLevelLoops() const;
std::vector<Loop*> GetLoopsInPostOrder() const;
Loop* GetLoopFor(BasicBlock* block) const;
private:
Function* function_ = nullptr;
const DominatorTree* dom_tree_ = nullptr;
std::vector<std::unique_ptr<Loop>> loops_;
std::vector<Loop*> top_level_loops_;
std::unordered_map<BasicBlock*, Loop*> block_to_loop_;
};
} // namespace ir

@ -612,6 +612,40 @@ class Function : public Value {
bool IsExternal() const { return is_external_; }
void SetExternal(bool is_external) { is_external_ = is_external; }
void SetEffectInfo(bool reads_global_memory, bool writes_global_memory,
bool reads_param_memory, bool writes_param_memory,
bool has_io, bool has_unknown_effects, bool is_recursive) {
reads_global_memory_ = reads_global_memory;
writes_global_memory_ = writes_global_memory;
reads_param_memory_ = reads_param_memory;
writes_param_memory_ = writes_param_memory;
has_io_ = has_io;
has_unknown_effects_ = has_unknown_effects;
is_recursive_ = is_recursive;
}
bool ReadsGlobalMemory() const { return reads_global_memory_; }
bool WritesGlobalMemory() const { return writes_global_memory_; }
bool ReadsParamMemory() const { return reads_param_memory_; }
bool WritesParamMemory() const { return writes_param_memory_; }
bool HasIO() const { return has_io_; }
bool HasUnknownEffects() const { return has_unknown_effects_; }
bool IsRecursive() const { return is_recursive_; }
bool MayReadMemory() const {
return has_unknown_effects_ || reads_global_memory_ || writes_global_memory_ ||
reads_param_memory_ || writes_param_memory_;
}
bool MayWriteMemory() const {
return has_unknown_effects_ || writes_global_memory_ || writes_param_memory_;
}
bool HasObservableSideEffects() const {
return has_unknown_effects_ || writes_global_memory_ ||
writes_param_memory_ || has_io_;
}
bool CanDiscardUnusedCall() const {
return !has_unknown_effects_ && !writes_global_memory_ &&
!writes_param_memory_ && !has_io_ && !is_recursive_;
}
BasicBlock* GetEntryBlock() const { return entry_; }
BasicBlock* GetEntry() const { return entry_; }
void SetEntryBlock(BasicBlock* bb) { entry_ = bb; }
@ -633,6 +667,13 @@ class Function : public Value {
std::vector<std::shared_ptr<Type>> param_types_;
std::vector<std::unique_ptr<Argument>> arguments_;
bool is_external_ = false;
bool reads_global_memory_ = false;
bool writes_global_memory_ = false;
bool reads_param_memory_ = false;
bool writes_param_memory_ = false;
bool has_io_ = false;
bool has_unknown_effects_ = true;
bool is_recursive_ = false;
BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<BasicBlock>> blocks_;
};

@ -5,6 +5,18 @@ namespace ir {
class Module;
void RunMem2Reg(Module& module);
bool RunConstFold(Module& module);
bool RunConstProp(Module& module);
bool RunFunctionInlining(Module& module);
bool RunCSE(Module& module);
bool RunGVN(Module& module);
bool RunLoadStoreElim(Module& module);
bool RunDCE(Module& module);
bool RunCFGSimplify(Module& module);
bool RunLICM(Module& module);
bool RunLoopStrengthReduction(Module& module);
bool RunLoopUnroll(Module& module);
bool RunLoopFission(Module& module);
void RunIRPassPipeline(Module& module);
} // namespace ir
} // namespace ir

@ -53,6 +53,7 @@ class IRGenImpl final : public SysYBaseVisitor {
[[noreturn]] void ThrowError(const antlr4::ParserRuleContext* ctx,
const std::string& message) const;
void ApplyFunctionSema(const std::string& name, ir::Function& function);
void RegisterBuiltinFunctions();
void PredeclareTopLevel(SysYParser::CompUnitContext& ctx);
void PredeclareFunction(SysYParser::FuncDefContext& ctx);
@ -124,6 +125,11 @@ class IRGenImpl final : public SysYBaseVisitor {
ConstantValue EvalConstPrimaryExp(SysYParser::PrimaryExpContext& ctx);
ConstantValue EvalConstLVal(SysYParser::LValContext& ctx);
ConstantValue ConvertConst(ConstantValue value, SemanticType target_type) const;
bool IsZeroConstant(const ConstantValue& value) const;
bool IsExplicitZeroConstInitVal(SysYParser::ConstInitValContext* ctx,
SemanticType base_type);
bool IsExplicitZeroInitVal(SysYParser::InitValContext* ctx,
SemanticType base_type);
std::vector<ConstantValue> FlattenConstInitVal(SysYParser::ConstInitValContext* ctx,
SemanticType base_type,

@ -277,11 +277,14 @@ class MachineModule {
std::vector<std::unique_ptr<MachineFunction>> functions_;
};
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module);
void RunAddressHoisting(MachineModule& module);
void RunRegAlloc(MachineModule& module);
void RunFrameLowering(MachineModule& module);
void PrintAsm(const MachineModule& module, std::ostream& os);
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module);
bool RunPeephole(MachineModule& module);
void RunMIRPreRegAllocPassPipeline(MachineModule& module);
void RunMIRPostRegAllocPassPipeline(MachineModule& module);
void RunAddressHoisting(MachineModule& module);
void RunRegAlloc(MachineModule& module);
void RunFrameLowering(MachineModule& module);
void PrintAsm(const MachineModule& module, std::ostream& os);
} // namespace mir

@ -1,7 +1,94 @@
#pragma once
#include "SysYParser.h"
#include "sem/SymbolTable.h"
class SemanticContext {};
#include <string>
#include <unordered_map>
#include <vector>
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);
struct GlobalSemanticInfo {
SemanticType type = SemanticType::Int;
bool is_const = false;
bool is_array = false;
std::vector<int> dims;
};
struct FunctionSemanticInfo {
SemanticType return_type = SemanticType::Void;
std::vector<bool> param_is_array;
bool is_builtin = false;
bool is_defined = false;
bool reads_global_memory = false;
bool writes_global_memory = false;
bool reads_param_memory = false;
bool writes_param_memory = false;
bool has_io = false;
bool has_unknown_effects = true;
bool is_recursive = false;
std::vector<std::string> direct_callees;
bool MayReadMemory() const {
return has_unknown_effects || reads_global_memory || writes_global_memory ||
reads_param_memory || writes_param_memory;
}
bool MayWriteMemory() const {
return has_unknown_effects || writes_global_memory || writes_param_memory;
}
bool HasObservableSideEffects() const {
return has_unknown_effects || writes_global_memory || writes_param_memory ||
has_io;
}
bool CanDiscardUnusedCall() const {
return !has_unknown_effects && !writes_global_memory &&
!writes_param_memory && !has_io && !is_recursive;
}
};
class SemanticContext {
public:
FunctionSemanticInfo* LookupFunction(const std::string& name) {
auto it = functions_.find(name);
return it == functions_.end() ? nullptr : &it->second;
}
const FunctionSemanticInfo* LookupFunction(const std::string& name) const {
auto it = functions_.find(name);
return it == functions_.end() ? nullptr : &it->second;
}
GlobalSemanticInfo* LookupGlobal(const std::string& name) {
auto it = globals_.find(name);
return it == globals_.end() ? nullptr : &it->second;
}
const GlobalSemanticInfo* LookupGlobal(const std::string& name) const {
auto it = globals_.find(name);
return it == globals_.end() ? nullptr : &it->second;
}
FunctionSemanticInfo& UpsertFunction(const std::string& name) {
return functions_[name];
}
GlobalSemanticInfo& UpsertGlobal(const std::string& name) {
return globals_[name];
}
const std::unordered_map<std::string, FunctionSemanticInfo>& GetFunctions() const {
return functions_;
}
const std::unordered_map<std::string, GlobalSemanticInfo>& GetGlobals() const {
return globals_;
}
private:
std::unordered_map<std::string, FunctionSemanticInfo> functions_;
std::unordered_map<std::string, GlobalSemanticInfo> globals_;
};
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);

@ -48,6 +48,7 @@ struct SymbolEntry {
std::optional<ConstantValue> const_scalar;
std::vector<ConstantValue> const_array;
bool const_array_all_zero = false;
FunctionTypeInfo function_type;
};

@ -54,6 +54,10 @@ void User::RemoveOperand(size_t index) {
}
operands_.erase(operands_.begin() + static_cast<long long>(index));
for (size_t i = index; i < operands_.size(); ++i) {
if (auto* value = operands_[i].GetValue()) {
value->RemoveUse(this, i + 1);
value->AddUse(this, i);
}
operands_[i].SetOperandIndex(i);
}
}

@ -1,4 +1,167 @@
// 支配树分析:
// - 构建/查询 Dominator Tree 及相关关系
// - 为 mem2reg、CFG 优化与循环分析提供基础能力
#include "ir/Analysis.h"
#include <algorithm>
#include <functional>
namespace ir {
namespace {
std::vector<BasicBlock*> BuildReversePostOrder(Function& function) {
std::vector<BasicBlock*> post_order;
auto* entry = function.GetEntryBlock();
if (!entry) {
return post_order;
}
std::unordered_set<BasicBlock*> visited;
std::function<void(BasicBlock*)> dfs = [&](BasicBlock* block) {
if (!block || !visited.insert(block).second) {
return;
}
for (auto* succ : block->GetSuccessors()) {
dfs(succ);
}
post_order.push_back(block);
};
dfs(entry);
std::reverse(post_order.begin(), post_order.end());
return post_order;
}
} // namespace
DominatorTree::DominatorTree(Function& function) : function_(&function) {
Recalculate();
}
void DominatorTree::Recalculate() {
reverse_post_order_ = BuildReversePostOrder(*function_);
block_index_.clear();
dominates_.clear();
immediate_dominator_.clear();
dom_children_.clear();
const auto num_blocks = reverse_post_order_.size();
for (std::size_t i = 0; i < num_blocks; ++i) {
block_index_.emplace(reverse_post_order_[i], i);
}
if (num_blocks == 0) {
return;
}
dominates_.assign(num_blocks, std::vector<std::uint8_t>(num_blocks, 1));
dominates_[0].assign(num_blocks, 0);
dominates_[0][0] = 1;
bool changed = true;
while (changed) {
changed = false;
for (std::size_t i = 1; i < num_blocks; ++i) {
auto* block = reverse_post_order_[i];
std::vector<std::uint8_t> next(num_blocks, 1);
bool has_reachable_pred = false;
for (auto* pred : block->GetPredecessors()) {
auto pred_it = block_index_.find(pred);
if (pred_it == block_index_.end()) {
continue;
}
has_reachable_pred = true;
const auto& pred_dom = dominates_[pred_it->second];
for (std::size_t bit = 0; bit < num_blocks; ++bit) {
next[bit] &= pred_dom[bit];
}
}
if (!has_reachable_pred) {
next.assign(num_blocks, 0);
}
next[i] = 1;
if (next != dominates_[i]) {
dominates_[i] = std::move(next);
changed = true;
}
}
}
for (std::size_t i = 1; i < num_blocks; ++i) {
auto* block = reverse_post_order_[i];
BasicBlock* idom = nullptr;
for (std::size_t candidate = 0; candidate < num_blocks; ++candidate) {
if (candidate == i || !dominates_[i][candidate]) {
continue;
}
auto* candidate_block = reverse_post_order_[candidate];
bool immediate = true;
for (std::size_t other = 0; other < num_blocks; ++other) {
if (other == i || other == candidate || !dominates_[i][other]) {
continue;
}
if (Dominates(reverse_post_order_[other], candidate_block)) {
immediate = false;
break;
}
}
if (immediate) {
idom = candidate_block;
break;
}
}
immediate_dominator_.emplace(block, idom);
if (idom) {
dom_children_[idom].push_back(block);
}
}
}
bool DominatorTree::IsReachable(BasicBlock* block) const {
return block != nullptr && block_index_.find(block) != block_index_.end();
}
bool DominatorTree::Dominates(BasicBlock* dom, BasicBlock* node) const {
if (!dom || !node) {
return false;
}
const auto dom_it = block_index_.find(dom);
const auto node_it = block_index_.find(node);
if (dom_it == block_index_.end() || node_it == block_index_.end()) {
return false;
}
return dominates_[node_it->second][dom_it->second] != 0;
}
bool DominatorTree::Dominates(Instruction* dom, Instruction* user) const {
if (!dom || !user) {
return false;
}
if (dom == user) {
return true;
}
auto* dom_block = dom->GetParent();
auto* user_block = user->GetParent();
if (dom_block != user_block) {
return Dominates(dom_block, user_block);
}
for (const auto& inst_ptr : dom_block->GetInstructions()) {
if (inst_ptr.get() == dom) {
return true;
}
if (inst_ptr.get() == user) {
return false;
}
}
return false;
}
BasicBlock* DominatorTree::GetIDom(BasicBlock* block) const {
auto it = immediate_dominator_.find(block);
return it == immediate_dominator_.end() ? nullptr : it->second;
}
const std::vector<BasicBlock*>& DominatorTree::GetChildren(BasicBlock* block) const {
static const std::vector<BasicBlock*> kEmpty;
auto it = dom_children_.find(block);
return it == dom_children_.end() ? kEmpty : it->second;
}
} // namespace ir

@ -1,4 +1,214 @@
// 循环分析:
// - 识别循环结构与层级关系
// - 为后续优化(可选)提供循环信息
#include "ir/Analysis.h"
#include <algorithm>
#include <functional>
namespace ir {
namespace {
std::vector<BasicBlock*> CollectNaturalLoopBlocks(BasicBlock* header,
BasicBlock* latch) {
std::vector<BasicBlock*> stack{latch};
std::unordered_set<BasicBlock*> loop_blocks{header, latch};
while (!stack.empty()) {
auto* block = stack.back();
stack.pop_back();
for (auto* pred : block->GetPredecessors()) {
if (!pred || !loop_blocks.insert(pred).second) {
continue;
}
stack.push_back(pred);
}
}
return {loop_blocks.begin(), loop_blocks.end()};
}
} // namespace
bool Loop::Contains(BasicBlock* block) const {
return block != nullptr && blocks.find(block) != blocks.end();
}
bool Loop::Contains(const Loop* other) const {
if (!other) {
return false;
}
for (auto* block : other->blocks) {
if (!Contains(block)) {
return false;
}
}
return true;
}
bool Loop::IsInnermost() const { return subloops.empty(); }
LoopInfo::LoopInfo(Function& function, const DominatorTree& dom_tree)
: function_(&function), dom_tree_(&dom_tree) {
Recalculate();
}
void LoopInfo::Recalculate() {
loops_.clear();
top_level_loops_.clear();
block_to_loop_.clear();
std::unordered_map<BasicBlock*, Loop*> loops_by_header;
for (auto* block : dom_tree_->GetReversePostOrder()) {
for (auto* succ : block->GetSuccessors()) {
if (!dom_tree_->Dominates(succ, block)) {
continue;
}
Loop* loop = nullptr;
auto it = loops_by_header.find(succ);
if (it == loops_by_header.end()) {
auto new_loop = std::make_unique<Loop>();
new_loop->header = succ;
loop = new_loop.get();
loops_.push_back(std::move(new_loop));
loops_by_header.emplace(succ, loop);
} else {
loop = it->second;
}
if (std::find(loop->latches.begin(), loop->latches.end(), block) ==
loop->latches.end()) {
loop->latches.push_back(block);
}
for (auto* natural_block : CollectNaturalLoopBlocks(succ, block)) {
loop->blocks.insert(natural_block);
}
}
}
std::unordered_map<BasicBlock*, std::size_t> function_order;
for (std::size_t i = 0; i < function_->GetBlocks().size(); ++i) {
function_order.emplace(function_->GetBlocks()[i].get(), i);
}
for (const auto& loop_ptr : loops_) {
auto& loop = *loop_ptr;
loop.block_list.clear();
loop.exiting_blocks.clear();
loop.exit_blocks.clear();
loop.subloops.clear();
loop.parent = nullptr;
for (const auto& block_ptr : function_->GetBlocks()) {
if (loop.Contains(block_ptr.get())) {
loop.block_list.push_back(block_ptr.get());
}
}
std::sort(loop.latches.begin(), loop.latches.end(),
[&](BasicBlock* lhs, BasicBlock* rhs) {
return function_order[lhs] < function_order[rhs];
});
std::vector<BasicBlock*> outside_preds;
for (auto* pred : loop.header->GetPredecessors()) {
if (!loop.Contains(pred)) {
outside_preds.push_back(pred);
}
}
if (outside_preds.size() == 1 &&
outside_preds.front()->GetSuccessors().size() == 1) {
loop.preheader = outside_preds.front();
} else {
loop.preheader = nullptr;
}
std::unordered_set<BasicBlock*> exiting_seen;
std::unordered_set<BasicBlock*> exit_seen;
for (auto* block : loop.block_list) {
for (auto* succ : block->GetSuccessors()) {
if (loop.Contains(succ)) {
continue;
}
if (exiting_seen.insert(block).second) {
loop.exiting_blocks.push_back(block);
}
if (exit_seen.insert(succ).second) {
loop.exit_blocks.push_back(succ);
}
}
}
std::sort(loop.exiting_blocks.begin(), loop.exiting_blocks.end(),
[&](BasicBlock* lhs, BasicBlock* rhs) {
return function_order[lhs] < function_order[rhs];
});
std::sort(loop.exit_blocks.begin(), loop.exit_blocks.end(),
[&](BasicBlock* lhs, BasicBlock* rhs) {
return function_order[lhs] < function_order[rhs];
});
}
for (const auto& loop_ptr : loops_) {
auto* loop = loop_ptr.get();
Loop* parent = nullptr;
for (const auto& candidate_ptr : loops_) {
auto* candidate = candidate_ptr.get();
if (candidate == loop || !candidate->Contains(loop)) {
continue;
}
if (!parent || candidate->blocks.size() < parent->blocks.size()) {
parent = candidate;
}
}
loop->parent = parent;
if (parent) {
parent->subloops.push_back(loop);
} else {
top_level_loops_.push_back(loop);
}
}
auto loop_order = [&](Loop* lhs, Loop* rhs) {
return function_order[lhs->header] < function_order[rhs->header];
};
std::sort(top_level_loops_.begin(), top_level_loops_.end(), loop_order);
for (const auto& loop_ptr : loops_) {
std::sort(loop_ptr->subloops.begin(), loop_ptr->subloops.end(), loop_order);
}
for (const auto& block_ptr : function_->GetBlocks()) {
Loop* innermost = nullptr;
for (const auto& loop_ptr : loops_) {
auto* loop = loop_ptr.get();
if (!loop->Contains(block_ptr.get())) {
continue;
}
if (!innermost || loop->blocks.size() < innermost->blocks.size()) {
innermost = loop;
}
}
if (innermost) {
block_to_loop_.emplace(block_ptr.get(), innermost);
}
}
}
std::vector<Loop*> LoopInfo::GetTopLevelLoops() const { return top_level_loops_; }
std::vector<Loop*> LoopInfo::GetLoopsInPostOrder() const {
std::vector<Loop*> ordered;
std::function<void(Loop*)> dfs = [&](Loop* loop) {
for (auto* subloop : loop->subloops) {
dfs(subloop);
}
ordered.push_back(loop);
};
for (auto* loop : top_level_loops_) {
dfs(loop);
}
return ordered;
}
Loop* LoopInfo::GetLoopFor(BasicBlock* block) const {
auto it = block_to_loop_.find(block);
return it == block_to_loop_.end() ? nullptr : it->second;
}
} // namespace ir

@ -1,4 +1,107 @@
// CFG 简化:
// - 删除不可达块、合并空块、简化分支等
// - 改善 IR 结构,便于后续优化与后端生成
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <vector>
namespace ir {
namespace {
bool TryGetConstBranchTarget(CondBrInst* br, BasicBlock*& target, BasicBlock*& removed) {
if (!br) {
return false;
}
auto* then_block = br->GetThenBlock();
auto* else_block = br->GetElseBlock();
if (then_block == else_block) {
target = then_block;
removed = nullptr;
return true;
}
if (auto* cond = dyncast<ConstantI1>(br->GetCondition())) {
target = cond->GetValue() ? then_block : else_block;
removed = cond->GetValue() ? else_block : then_block;
return true;
}
return false;
}
bool SimplifyBlockTerminator(BasicBlock* block) {
if (!block || block->GetInstructions().empty()) {
return false;
}
auto* term = block->GetInstructions().back().get();
auto* condbr = dyncast<CondBrInst>(term);
if (!condbr) {
return false;
}
BasicBlock* target = nullptr;
BasicBlock* removed = nullptr;
if (!TryGetConstBranchTarget(condbr, target, removed)) {
return false;
}
if (removed) {
passutils::RemoveIncomingFromSuccessor(removed, block);
removed->RemovePredecessor(block);
block->RemoveSuccessor(removed);
}
passutils::ReplaceTerminatorWithBr(block, target);
return true;
}
bool SimplifyPhiNodes(Function& function) {
bool changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
bool local_changed = true;
while (local_changed) {
local_changed = false;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
if (!passutils::SimplifyPhiInst(phi)) {
continue;
}
local_changed = true;
changed = true;
break;
}
}
}
return changed;
}
bool RunCFGSimplifyOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
bool changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
changed |= SimplifyBlockTerminator(block_ptr.get());
}
changed |= passutils::RemoveUnreachableBlocks(function);
changed |= SimplifyPhiNodes(function);
return changed;
}
} // namespace
bool RunCFGSimplify(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunCFGSimplifyOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -3,9 +3,16 @@ add_library(ir_passes STATIC
Mem2Reg.cpp
ConstFold.cpp
ConstProp.cpp
Inline.cpp
CSE.cpp
GVN.cpp
LoadStoreElim.cpp
DCE.cpp
CFGSimplify.cpp
LICM.cpp
LoopStrengthReduction.cpp
LoopUnroll.cpp
LoopFission.cpp
)
target_link_libraries(ir_passes PUBLIC

@ -1,4 +1,141 @@
// 公共子表达式消除CSE
// - 识别并复用重复计算的等价表达式
// - 典型放置在 ConstFold 之后、DCE 之前
// - 当前为 Lab4 的框架占位,具体算法由实验实现
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <algorithm>
#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
struct ExprKey {
Opcode opcode = Opcode::Add;
std::vector<std::uintptr_t> operands;
bool operator==(const ExprKey& rhs) const {
return opcode == rhs.opcode && operands == rhs.operands;
}
};
struct ExprKeyHash {
std::size_t operator()(const ExprKey& key) const {
std::size_t h = static_cast<std::size_t>(key.opcode);
for (auto operand : key.operands) {
h ^= std::hash<std::uintptr_t>{}(operand) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
bool IsSupportedCSEInstruction(Instruction* inst) {
if (!inst || inst->IsVoid()) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
case Opcode::Zext:
return true;
default:
return false;
}
}
ExprKey BuildExprKey(Instruction* inst) {
ExprKey key;
key.opcode = inst->GetOpcode();
key.operands.reserve(inst->GetNumOperands());
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
key.operands.push_back(
reinterpret_cast<std::uintptr_t>(inst->GetOperand(i)));
}
if (inst->GetNumOperands() == 2 && passutils::IsCommutativeOpcode(inst->GetOpcode()) &&
key.operands[1] < key.operands[0]) {
std::swap(key.operands[0], key.operands[1]);
}
return key;
}
bool RunCSEOnFunction(Function& function) {
if (function.IsExternal()) {
return false;
}
bool changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
std::unordered_map<ExprKey, Value*, ExprKeyHash> available_exprs;
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!IsSupportedCSEInstruction(inst)) {
continue;
}
const auto key = BuildExprKey(inst);
auto it = available_exprs.find(key);
if (it == available_exprs.end()) {
available_exprs.emplace(key, inst);
continue;
}
inst->ReplaceAllUsesWith(it->second);
to_remove.push_back(inst);
changed = true;
}
for (auto* inst : to_remove) {
block_ptr->EraseInstruction(inst);
}
}
return changed;
}
} // namespace
bool RunCSE(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunCSEOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,469 @@
// IR 常量折叠:
// - 折叠可判定的常量表达式
// - 简化常量控制流分支(按实现范围裁剪)
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <cmath>
#include <cstdint>
#include <limits>
#include <vector>
namespace ir {
namespace {
Value* GetInt32Const(Context& ctx, std::int32_t value) {
return ctx.GetConstInt(static_cast<int>(value));
}
Value* GetBoolConst(Context& ctx, bool value) { return ctx.GetConstBool(value); }
Value* GetFloatConst(float value) {
return new ConstantFloat(Type::GetFloatType(), value);
}
bool TryGetInt32(Value* value, std::int32_t& out) {
if (auto* ci = dyncast<ConstantInt>(value)) {
out = static_cast<std::int32_t>(ci->GetValue());
return true;
}
return false;
}
bool TryGetBool(Value* value, bool& out) {
if (auto* cb = dyncast<ConstantI1>(value)) {
out = cb->GetValue();
return true;
}
return false;
}
bool TryGetFloat(Value* value, float& out) {
if (auto* cf = dyncast<ConstantFloat>(value)) {
out = cf->GetValue();
return true;
}
return false;
}
bool IsZeroValue(Value* value) {
std::int32_t i32 = 0;
bool i1 = false;
float f32 = 0.0f;
return (TryGetInt32(value, i32) && i32 == 0) || (TryGetBool(value, i1) && !i1) ||
(TryGetFloat(value, f32) && passutils::FloatBits(f32) == 0);
}
bool IsOneValue(Value* value) {
std::int32_t i32 = 0;
bool i1 = false;
float f32 = 0.0f;
return (TryGetInt32(value, i32) && i32 == 1) || (TryGetBool(value, i1) && i1) ||
(TryGetFloat(value, f32) &&
passutils::FloatBits(f32) == passutils::FloatBits(1.0f));
}
bool IsAllOnesInt(Value* value) {
std::int32_t i32 = 0;
return TryGetInt32(value, i32) && i32 == -1;
}
std::int32_t WrapInt32(std::uint32_t value) {
return static_cast<std::int32_t>(value);
}
Value* FoldBinary(Context& ctx, BinaryInst* inst) {
const auto opcode = inst->GetOpcode();
auto* lhs = inst->GetLhs();
auto* rhs = inst->GetRhs();
std::int32_t lhs_i32 = 0;
std::int32_t rhs_i32 = 0;
bool lhs_i1 = false;
bool rhs_i1 = false;
float lhs_f32 = 0.0f;
float rhs_f32 = 0.0f;
const bool has_lhs_i32 = TryGetInt32(lhs, lhs_i32);
const bool has_rhs_i32 = TryGetInt32(rhs, rhs_i32);
const bool has_lhs_i1 = TryGetBool(lhs, lhs_i1);
const bool has_rhs_i1 = TryGetBool(rhs, rhs_i1);
const bool has_lhs_f32 = TryGetFloat(lhs, lhs_f32);
const bool has_rhs_f32 = TryGetFloat(rhs, rhs_f32);
if (has_lhs_i32 && has_rhs_i32) {
switch (opcode) {
case Opcode::Add:
return GetInt32Const(
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) +
static_cast<std::uint32_t>(rhs_i32)));
case Opcode::Sub:
return GetInt32Const(
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) -
static_cast<std::uint32_t>(rhs_i32)));
case Opcode::Mul:
return GetInt32Const(
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) *
static_cast<std::uint32_t>(rhs_i32)));
case Opcode::Div:
if (rhs_i32 == 0 ||
(lhs_i32 == std::numeric_limits<std::int32_t>::min() && rhs_i32 == -1)) {
return nullptr;
}
return GetInt32Const(ctx, lhs_i32 / rhs_i32);
case Opcode::Rem:
if (rhs_i32 == 0 ||
(lhs_i32 == std::numeric_limits<std::int32_t>::min() && rhs_i32 == -1)) {
return nullptr;
}
return GetInt32Const(ctx, lhs_i32 % rhs_i32);
case Opcode::And:
return GetInt32Const(ctx, lhs_i32 & rhs_i32);
case Opcode::Or:
return GetInt32Const(ctx, lhs_i32 | rhs_i32);
case Opcode::Xor:
return GetInt32Const(ctx, lhs_i32 ^ rhs_i32);
case Opcode::Shl:
if (rhs_i32 < 0 || rhs_i32 >= 32) {
return nullptr;
}
return GetInt32Const(ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32)
<< rhs_i32));
case Opcode::AShr:
if (rhs_i32 < 0 || rhs_i32 >= 32) {
return nullptr;
}
return GetInt32Const(ctx, lhs_i32 >> rhs_i32);
case Opcode::LShr:
if (rhs_i32 < 0 || rhs_i32 >= 32) {
return nullptr;
}
return GetInt32Const(
ctx,
WrapInt32(static_cast<std::uint32_t>(lhs_i32) >> rhs_i32));
case Opcode::ICmpEQ:
return GetBoolConst(ctx, lhs_i32 == rhs_i32);
case Opcode::ICmpNE:
return GetBoolConst(ctx, lhs_i32 != rhs_i32);
case Opcode::ICmpLT:
return GetBoolConst(ctx, lhs_i32 < rhs_i32);
case Opcode::ICmpGT:
return GetBoolConst(ctx, lhs_i32 > rhs_i32);
case Opcode::ICmpLE:
return GetBoolConst(ctx, lhs_i32 <= rhs_i32);
case Opcode::ICmpGE:
return GetBoolConst(ctx, lhs_i32 >= rhs_i32);
default:
break;
}
}
if (has_lhs_i1 && has_rhs_i1) {
switch (opcode) {
case Opcode::And:
return GetBoolConst(ctx, lhs_i1 && rhs_i1);
case Opcode::Or:
return GetBoolConst(ctx, lhs_i1 || rhs_i1);
case Opcode::Xor:
return GetBoolConst(ctx, lhs_i1 != rhs_i1);
case Opcode::ICmpEQ:
return GetBoolConst(ctx, lhs_i1 == rhs_i1);
case Opcode::ICmpNE:
return GetBoolConst(ctx, lhs_i1 != rhs_i1);
case Opcode::ICmpLT:
return GetBoolConst(ctx, static_cast<int>(lhs_i1) < static_cast<int>(rhs_i1));
case Opcode::ICmpGT:
return GetBoolConst(ctx, static_cast<int>(lhs_i1) > static_cast<int>(rhs_i1));
case Opcode::ICmpLE:
return GetBoolConst(ctx, static_cast<int>(lhs_i1) <= static_cast<int>(rhs_i1));
case Opcode::ICmpGE:
return GetBoolConst(ctx, static_cast<int>(lhs_i1) >= static_cast<int>(rhs_i1));
default:
break;
}
}
if (has_lhs_f32 && has_rhs_f32) {
switch (opcode) {
case Opcode::FAdd:
return GetFloatConst(lhs_f32 + rhs_f32);
case Opcode::FSub:
return GetFloatConst(lhs_f32 - rhs_f32);
case Opcode::FMul:
return GetFloatConst(lhs_f32 * rhs_f32);
case Opcode::FDiv:
return GetFloatConst(lhs_f32 / rhs_f32);
case Opcode::FRem:
return GetFloatConst(std::fmod(lhs_f32, rhs_f32));
case Opcode::FCmpEQ:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 == rhs_f32);
case Opcode::FCmpNE:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 != rhs_f32);
case Opcode::FCmpLT:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 < rhs_f32);
case Opcode::FCmpGT:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 > rhs_f32);
case Opcode::FCmpLE:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 <= rhs_f32);
case Opcode::FCmpGE:
return GetBoolConst(
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 >= rhs_f32);
default:
break;
}
}
switch (opcode) {
case Opcode::Add:
if (IsZeroValue(rhs)) {
return lhs;
}
if (IsZeroValue(lhs)) {
return rhs;
}
break;
case Opcode::Sub:
if (IsZeroValue(rhs)) {
return lhs;
}
break;
case Opcode::Mul:
if (IsOneValue(rhs)) {
return lhs;
}
if (IsOneValue(lhs)) {
return rhs;
}
if (IsZeroValue(lhs) || IsZeroValue(rhs)) {
return GetInt32Const(ctx, 0);
}
break;
case Opcode::Div:
if (IsOneValue(rhs)) {
return lhs;
}
if (IsZeroValue(lhs) && !IsZeroValue(rhs)) {
return GetInt32Const(ctx, 0);
}
break;
case Opcode::Rem:
if ((has_rhs_i32 && (rhs_i32 == 1 || rhs_i32 == -1)) ||
(has_rhs_i1 && rhs_i1)) {
return GetInt32Const(ctx, 0);
}
if (IsZeroValue(lhs) && !IsZeroValue(rhs)) {
return GetInt32Const(ctx, 0);
}
break;
case Opcode::And:
if (IsZeroValue(lhs) || IsZeroValue(rhs)) {
return inst->GetType()->IsInt1() ? GetBoolConst(ctx, false)
: GetInt32Const(ctx, 0);
}
if (has_lhs_i1 && lhs_i1) {
return rhs;
}
if (has_rhs_i1 && rhs_i1) {
return lhs;
}
if (IsAllOnesInt(lhs)) {
return rhs;
}
if (IsAllOnesInt(rhs)) {
return lhs;
}
if (passutils::AreEquivalentValues(lhs, rhs)) {
return lhs;
}
break;
case Opcode::Or:
if (IsZeroValue(lhs)) {
return rhs;
}
if (IsZeroValue(rhs)) {
return lhs;
}
if (has_lhs_i1 && lhs_i1) {
return GetBoolConst(ctx, true);
}
if (has_rhs_i1 && rhs_i1) {
return GetBoolConst(ctx, true);
}
if (passutils::AreEquivalentValues(lhs, rhs)) {
return lhs;
}
break;
case Opcode::Xor:
if (IsZeroValue(lhs)) {
return rhs;
}
if (IsZeroValue(rhs)) {
return lhs;
}
if (passutils::AreEquivalentValues(lhs, rhs)) {
return inst->GetType()->IsInt1() ? GetBoolConst(ctx, false)
: GetInt32Const(ctx, 0);
}
break;
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
if (IsZeroValue(rhs)) {
return lhs;
}
if (IsZeroValue(lhs)) {
return GetInt32Const(ctx, 0);
}
break;
case Opcode::FAdd:
if (IsZeroValue(lhs)) {
return rhs;
}
if (IsZeroValue(rhs)) {
return lhs;
}
break;
case Opcode::FSub:
if (IsZeroValue(rhs)) {
return lhs;
}
break;
case Opcode::FMul:
if (IsOneValue(lhs)) {
return rhs;
}
if (IsOneValue(rhs)) {
return lhs;
}
break;
case Opcode::FDiv:
if (IsOneValue(rhs)) {
return lhs;
}
break;
default:
break;
}
return nullptr;
}
Value* FoldUnary(Context& ctx, UnaryInst* inst) {
auto* operand = inst->GetOprd();
std::int32_t i32 = 0;
bool i1 = false;
float f32 = 0.0f;
switch (inst->GetOpcode()) {
case Opcode::Neg:
if (TryGetInt32(operand, i32)) {
return GetInt32Const(ctx, WrapInt32(0u - static_cast<std::uint32_t>(i32)));
}
break;
case Opcode::Not:
if (TryGetBool(operand, i1)) {
return GetBoolConst(ctx, !i1);
}
if (TryGetInt32(operand, i32)) {
return GetInt32Const(ctx, i32 ^ 1);
}
break;
case Opcode::FNeg:
if (TryGetFloat(operand, f32)) {
return GetFloatConst(-f32);
}
break;
case Opcode::FtoI:
if (TryGetFloat(operand, f32)) {
return GetInt32Const(ctx, static_cast<std::int32_t>(f32));
}
break;
case Opcode::IToF:
if (TryGetInt32(operand, i32)) {
return GetFloatConst(static_cast<float>(i32));
}
if (TryGetBool(operand, i1)) {
return GetFloatConst(i1 ? 1.0f : 0.0f);
}
break;
default:
break;
}
return nullptr;
}
Value* FoldZext(Context& ctx, ZextInst* inst) {
auto* value = inst->GetValue();
bool i1 = false;
std::int32_t i32 = 0;
if (inst->GetType()->IsInt1()) {
if (TryGetBool(value, i1)) {
return GetBoolConst(ctx, i1);
}
if (TryGetInt32(value, i32)) {
return GetBoolConst(ctx, i32 != 0);
}
}
if (inst->GetType()->IsInt32()) {
if (TryGetBool(value, i1)) {
return GetInt32Const(ctx, i1 ? 1 : 0);
}
if (TryGetInt32(value, i32)) {
return GetInt32Const(ctx, i32);
}
}
return nullptr;
}
bool FoldFunction(Function& function, Context& ctx) {
if (function.IsExternal()) {
return false;
}
bool changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
Value* replacement = nullptr;
if (auto* binary = dyncast<BinaryInst>(inst)) {
replacement = FoldBinary(ctx, binary);
} else if (auto* unary = dyncast<UnaryInst>(inst)) {
replacement = FoldUnary(ctx, unary);
} else if (auto* zext = dyncast<ZextInst>(inst)) {
replacement = FoldZext(ctx, zext);
}
if (!replacement || replacement == inst) {
continue;
}
inst->ReplaceAllUsesWith(replacement);
to_remove.push_back(inst);
changed = true;
}
for (auto* inst : to_remove) {
block_ptr->EraseInstruction(inst);
}
}
return changed;
}
} // namespace
bool RunConstFold(Module& module) {
bool changed = false;
auto& ctx = module.GetContext();
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= FoldFunction(*function, ctx);
}
}
return changed;
}
} // namespace ir

@ -1,5 +1,550 @@
// 常量传播Constant Propagation
// - 沿 use-def 关系传播已知常量
// - 将可替换的 SSA 值改写为常量,暴露更多折叠机会
// - 常与 ConstFold、DCE、CFGSimplify 迭代配合使用
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <cmath>
#include <cstdint>
#include <limits>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
enum class LatticeKind { Unknown, Constant, Overdefined };
struct ConstantValue {
enum class Kind { Int32, Bool, Float };
Kind kind = Kind::Int32;
std::int32_t int32_value = 0;
bool bool_value = false;
float float_value = 0.0f;
};
struct LatticeValue {
LatticeKind kind = LatticeKind::Unknown;
ConstantValue constant;
};
bool EqualConstants(const ConstantValue& lhs, const ConstantValue& rhs) {
if (lhs.kind != rhs.kind) {
return false;
}
switch (lhs.kind) {
case ConstantValue::Kind::Int32:
return lhs.int32_value == rhs.int32_value;
case ConstantValue::Kind::Bool:
return lhs.bool_value == rhs.bool_value;
case ConstantValue::Kind::Float:
return passutils::FloatBits(lhs.float_value) ==
passutils::FloatBits(rhs.float_value);
}
return false;
}
Value* MaterializeConstant(Context& ctx, const ConstantValue& constant) {
switch (constant.kind) {
case ConstantValue::Kind::Int32:
return ctx.GetConstInt(static_cast<int>(constant.int32_value));
case ConstantValue::Kind::Bool:
return ctx.GetConstBool(constant.bool_value);
case ConstantValue::Kind::Float:
return new ConstantFloat(Type::GetFloatType(), constant.float_value);
}
return nullptr;
}
bool TryGetConstantValue(Value* value, ConstantValue& out) {
if (auto* ci = dyncast<ConstantInt>(value)) {
out.kind = ConstantValue::Kind::Int32;
out.int32_value = static_cast<std::int32_t>(ci->GetValue());
return true;
}
if (auto* cb = dyncast<ConstantI1>(value)) {
out.kind = ConstantValue::Kind::Bool;
out.bool_value = cb->GetValue();
return true;
}
if (auto* cf = dyncast<ConstantFloat>(value)) {
out.kind = ConstantValue::Kind::Float;
out.float_value = cf->GetValue();
return true;
}
return false;
}
LatticeValue ConstantLattice(const ConstantValue& constant) {
LatticeValue value;
value.kind = LatticeKind::Constant;
value.constant = constant;
return value;
}
LatticeValue OverdefinedLattice() {
LatticeValue value;
value.kind = LatticeKind::Overdefined;
return value;
}
LatticeValue GetValueState(
Value* value, const std::unordered_map<Value*, LatticeValue>& states) {
ConstantValue constant;
if (TryGetConstantValue(value, constant)) {
return ConstantLattice(constant);
}
auto it = states.find(value);
if (it != states.end()) {
return it->second;
}
return OverdefinedLattice();
}
LatticeValue Meet(LatticeValue lhs, const LatticeValue& rhs) {
if (lhs.kind == LatticeKind::Overdefined || rhs.kind == LatticeKind::Overdefined) {
return OverdefinedLattice();
}
if (lhs.kind == LatticeKind::Unknown) {
return rhs;
}
if (rhs.kind == LatticeKind::Unknown) {
return lhs;
}
if (EqualConstants(lhs.constant, rhs.constant)) {
return lhs;
}
return OverdefinedLattice();
}
bool EvaluateUnary(Opcode opcode, const ConstantValue& operand,
ConstantValue& result) {
switch (opcode) {
case Opcode::Neg:
if (operand.kind != ConstantValue::Kind::Int32) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(
0u - static_cast<std::uint32_t>(operand.int32_value));
return true;
case Opcode::Not:
if (operand.kind == ConstantValue::Kind::Bool) {
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !operand.bool_value;
return true;
}
if (operand.kind == ConstantValue::Kind::Int32) {
result.kind = ConstantValue::Kind::Int32;
result.int32_value = operand.int32_value ^ 1;
return true;
}
return false;
case Opcode::FNeg:
if (operand.kind != ConstantValue::Kind::Float) {
return false;
}
result.kind = ConstantValue::Kind::Float;
result.float_value = -operand.float_value;
return true;
case Opcode::FtoI:
if (operand.kind != ConstantValue::Kind::Float) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(operand.float_value);
return true;
case Opcode::IToF:
if (operand.kind == ConstantValue::Kind::Int32) {
result.kind = ConstantValue::Kind::Float;
result.float_value = static_cast<float>(operand.int32_value);
return true;
}
if (operand.kind == ConstantValue::Kind::Bool) {
result.kind = ConstantValue::Kind::Float;
result.float_value = operand.bool_value ? 1.0f : 0.0f;
return true;
}
return false;
default:
return false;
}
}
bool EvaluateBinary(Opcode opcode, const ConstantValue& lhs,
const ConstantValue& rhs, ConstantValue& result) {
if (lhs.kind == ConstantValue::Kind::Int32 &&
rhs.kind == ConstantValue::Kind::Int32) {
const auto left = lhs.int32_value;
const auto right = rhs.int32_value;
switch (opcode) {
case Opcode::Add:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(
static_cast<std::uint32_t>(left) + static_cast<std::uint32_t>(right));
return true;
case Opcode::Sub:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(
static_cast<std::uint32_t>(left) - static_cast<std::uint32_t>(right));
return true;
case Opcode::Mul:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(
static_cast<std::uint32_t>(left) * static_cast<std::uint32_t>(right));
return true;
case Opcode::Div:
if (right == 0 ||
(left == std::numeric_limits<std::int32_t>::min() && right == -1)) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left / right;
return true;
case Opcode::Rem:
if (right == 0 ||
(left == std::numeric_limits<std::int32_t>::min() && right == -1)) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left % right;
return true;
case Opcode::And:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left & right;
return true;
case Opcode::Or:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left | right;
return true;
case Opcode::Xor:
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left ^ right;
return true;
case Opcode::Shl:
if (right < 0 || right >= 32) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value =
static_cast<std::int32_t>(static_cast<std::uint32_t>(left) << right);
return true;
case Opcode::AShr:
if (right < 0 || right >= 32) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = left >> right;
return true;
case Opcode::LShr:
if (right < 0 || right >= 32) {
return false;
}
result.kind = ConstantValue::Kind::Int32;
result.int32_value = static_cast<std::int32_t>(
static_cast<std::uint32_t>(left) >> right);
return true;
case Opcode::ICmpEQ:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left == right;
return true;
case Opcode::ICmpNE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left != right;
return true;
case Opcode::ICmpLT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left < right;
return true;
case Opcode::ICmpGT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left > right;
return true;
case Opcode::ICmpLE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left <= right;
return true;
case Opcode::ICmpGE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left >= right;
return true;
default:
break;
}
}
if (lhs.kind == ConstantValue::Kind::Bool && rhs.kind == ConstantValue::Kind::Bool) {
const auto left = lhs.bool_value;
const auto right = rhs.bool_value;
switch (opcode) {
case Opcode::And:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left && right;
return true;
case Opcode::Or:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left || right;
return true;
case Opcode::Xor:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left != right;
return true;
case Opcode::ICmpEQ:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left == right;
return true;
case Opcode::ICmpNE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = left != right;
return true;
case Opcode::ICmpLT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = static_cast<int>(left) < static_cast<int>(right);
return true;
case Opcode::ICmpGT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = static_cast<int>(left) > static_cast<int>(right);
return true;
case Opcode::ICmpLE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = static_cast<int>(left) <= static_cast<int>(right);
return true;
case Opcode::ICmpGE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = static_cast<int>(left) >= static_cast<int>(right);
return true;
default:
break;
}
}
if (lhs.kind == ConstantValue::Kind::Float &&
rhs.kind == ConstantValue::Kind::Float) {
const auto left = lhs.float_value;
const auto right = rhs.float_value;
switch (opcode) {
case Opcode::FAdd:
result.kind = ConstantValue::Kind::Float;
result.float_value = left + right;
return true;
case Opcode::FSub:
result.kind = ConstantValue::Kind::Float;
result.float_value = left - right;
return true;
case Opcode::FMul:
result.kind = ConstantValue::Kind::Float;
result.float_value = left * right;
return true;
case Opcode::FDiv:
result.kind = ConstantValue::Kind::Float;
result.float_value = left / right;
return true;
case Opcode::FRem:
result.kind = ConstantValue::Kind::Float;
result.float_value = std::fmod(left, right);
return true;
case Opcode::FCmpEQ:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left == right;
return true;
case Opcode::FCmpNE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left != right;
return true;
case Opcode::FCmpLT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left < right;
return true;
case Opcode::FCmpGT:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left > right;
return true;
case Opcode::FCmpLE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left <= right;
return true;
case Opcode::FCmpGE:
result.kind = ConstantValue::Kind::Bool;
result.bool_value = !std::isnan(left) && !std::isnan(right) && left >= right;
return true;
default:
break;
}
}
return false;
}
LatticeValue EvaluateInstruction(
Instruction* inst, const std::unordered_map<Value*, LatticeValue>& states) {
if (!inst || inst->IsVoid()) {
return OverdefinedLattice();
}
if (auto* phi = dyncast<PhiInst>(inst)) {
LatticeValue merged;
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
merged = Meet(merged, GetValueState(phi->GetIncomingValue(i), states));
if (merged.kind == LatticeKind::Overdefined) {
break;
}
}
return merged;
}
if (auto* binary = dyncast<BinaryInst>(inst)) {
const auto lhs = GetValueState(binary->GetLhs(), states);
const auto rhs = GetValueState(binary->GetRhs(), states);
if (lhs.kind == LatticeKind::Overdefined || rhs.kind == LatticeKind::Overdefined) {
return OverdefinedLattice();
}
if (lhs.kind != LatticeKind::Constant || rhs.kind != LatticeKind::Constant) {
return {};
}
ConstantValue folded;
if (!EvaluateBinary(binary->GetOpcode(), lhs.constant, rhs.constant, folded)) {
return OverdefinedLattice();
}
return ConstantLattice(folded);
}
if (auto* unary = dyncast<UnaryInst>(inst)) {
const auto operand = GetValueState(unary->GetOprd(), states);
if (operand.kind == LatticeKind::Overdefined) {
return OverdefinedLattice();
}
if (operand.kind != LatticeKind::Constant) {
return {};
}
ConstantValue folded;
if (!EvaluateUnary(unary->GetOpcode(), operand.constant, folded)) {
return OverdefinedLattice();
}
return ConstantLattice(folded);
}
if (auto* zext = dyncast<ZextInst>(inst)) {
const auto operand = GetValueState(zext->GetValue(), states);
if (operand.kind == LatticeKind::Overdefined) {
return OverdefinedLattice();
}
if (operand.kind != LatticeKind::Constant) {
return {};
}
ConstantValue folded;
if (zext->GetType()->IsInt1()) {
folded.kind = ConstantValue::Kind::Bool;
if (operand.constant.kind == ConstantValue::Kind::Bool) {
folded.bool_value = operand.constant.bool_value;
return ConstantLattice(folded);
}
if (operand.constant.kind == ConstantValue::Kind::Int32) {
folded.bool_value = operand.constant.int32_value != 0;
return ConstantLattice(folded);
}
return OverdefinedLattice();
}
if (zext->GetType()->IsInt32()) {
folded.kind = ConstantValue::Kind::Int32;
if (operand.constant.kind == ConstantValue::Kind::Bool) {
folded.int32_value = operand.constant.bool_value ? 1 : 0;
return ConstantLattice(folded);
}
if (operand.constant.kind == ConstantValue::Kind::Int32) {
folded.int32_value = operand.constant.int32_value;
return ConstantLattice(folded);
}
}
return OverdefinedLattice();
}
return OverdefinedLattice();
}
bool RewriteFunction(Function& function, Context& ctx) {
if (function.IsExternal()) {
return false;
}
std::unordered_map<Value*, LatticeValue> states;
for (const auto& block_ptr : function.GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!inst->IsVoid()) {
states[inst] = {};
}
}
}
bool changed = true;
while (changed) {
changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsVoid()) {
continue;
}
const auto evaluated = EvaluateInstruction(inst, states);
if (evaluated.kind != states[inst].kind ||
(evaluated.kind == LatticeKind::Constant &&
!EqualConstants(evaluated.constant, states[inst].constant))) {
states[inst] = evaluated;
changed = true;
}
}
}
}
bool rewritten = false;
for (const auto& block_ptr : function.GetBlocks()) {
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
auto* operand = inst->GetOperand(i);
if (isa<BasicBlock>(operand) || isa<Function>(operand) || operand->IsConstant()) {
continue;
}
const auto state = GetValueState(operand, states);
if (state.kind != LatticeKind::Constant) {
continue;
}
inst->SetOperand(i, MaterializeConstant(ctx, state.constant));
rewritten = true;
}
if (inst->IsVoid()) {
continue;
}
const auto state = states[inst];
if (state.kind != LatticeKind::Constant) {
continue;
}
inst->ReplaceAllUsesWith(MaterializeConstant(ctx, state.constant));
to_remove.push_back(inst);
rewritten = true;
}
for (auto* inst : to_remove) {
block_ptr->EraseInstruction(inst);
}
}
return rewritten;
}
} // namespace
bool RunConstProp(Module& module) {
bool changed = false;
auto& ctx = module.GetContext();
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RewriteFunction(*function, ctx);
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,55 @@
// 死代码删除DCE
// - 删除无用指令与无用基本块
// - 通常与 CFG 简化配合使用
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "PassUtils.h"
#include <vector>
namespace ir {
namespace {
bool RunDCEOnFunction(Function& function) {
if (function.IsExternal()) {
return false;
}
bool changed = false;
bool local_changed = true;
while (local_changed) {
local_changed = false;
for (const auto& block_ptr : function.GetBlocks()) {
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!passutils::IsTriviallyDead(inst)) {
continue;
}
to_remove.push_back(inst);
}
if (to_remove.empty()) {
continue;
}
for (auto* inst : to_remove) {
block_ptr->EraseInstruction(inst);
}
local_changed = true;
changed = true;
}
}
return changed;
}
} // namespace
bool RunDCE(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunDCEOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,196 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "MemoryUtils.h"
#include "PassUtils.h"
#include <algorithm>
#include <cstdint>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
struct ExprKey {
Opcode opcode = Opcode::Add;
std::uintptr_t result_type = 0;
std::uintptr_t aux_type = 0;
std::vector<std::uintptr_t> operands;
bool operator==(const ExprKey& rhs) const {
return opcode == rhs.opcode && result_type == rhs.result_type &&
aux_type == rhs.aux_type && operands == rhs.operands;
}
};
struct ExprKeyHash {
std::size_t operator()(const ExprKey& key) const {
std::size_t h = static_cast<std::size_t>(key.opcode);
h ^= std::hash<std::uintptr_t>{}(key.result_type) + 0x9e3779b9 + (h << 6) +
(h >> 2);
h ^= std::hash<std::uintptr_t>{}(key.aux_type) + 0x9e3779b9 + (h << 6) +
(h >> 2);
for (auto operand : key.operands) {
h ^= std::hash<std::uintptr_t>{}(operand) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
struct ScopedExpr {
ExprKey key;
Value* previous = nullptr;
bool had_previous = false;
};
bool IsSupportedGVNInstruction(Instruction* inst) {
if (!inst || inst->IsVoid()) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
case Opcode::GetElementPtr:
case Opcode::Zext:
return true;
case Opcode::Call:
return memutils::IsPureCall(dyncast<CallInst>(inst));
default:
return false;
}
}
ExprKey BuildExprKey(Instruction* inst) {
ExprKey key;
key.opcode = inst->GetOpcode();
key.result_type =
reinterpret_cast<std::uintptr_t>(inst->GetType().get());
if (auto* gep = dyncast<GetElementPtrInst>(inst)) {
key.aux_type = reinterpret_cast<std::uintptr_t>(gep->GetSourceType().get());
}
key.operands.reserve(inst->GetNumOperands());
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
key.operands.push_back(
reinterpret_cast<std::uintptr_t>(inst->GetOperand(i)));
}
if (inst->GetNumOperands() == 2 &&
passutils::IsCommutativeOpcode(inst->GetOpcode()) &&
key.operands[1] < key.operands[0]) {
std::swap(key.operands[0], key.operands[1]);
}
return key;
}
bool RunGVNInDomSubtree(
BasicBlock* block, const DominatorTree& dom_tree,
std::unordered_map<ExprKey, Value*, ExprKeyHash>& available) {
if (!block) {
return false;
}
bool changed = false;
std::vector<ScopedExpr> scope;
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!IsSupportedGVNInstruction(inst)) {
continue;
}
const auto key = BuildExprKey(inst);
auto it = available.find(key);
if (it != available.end()) {
inst->ReplaceAllUsesWith(it->second);
to_remove.push_back(inst);
changed = true;
continue;
}
ScopedExpr scoped{key, nullptr, false};
auto existing = available.find(key);
if (existing != available.end()) {
scoped.previous = existing->second;
scoped.had_previous = true;
existing->second = inst;
} else {
available.emplace(key, inst);
}
scope.push_back(std::move(scoped));
}
for (auto* inst : to_remove) {
block->EraseInstruction(inst);
}
for (auto* child : dom_tree.GetChildren(block)) {
changed |= RunGVNInDomSubtree(child, dom_tree, available);
}
for (auto it = scope.rbegin(); it != scope.rend(); ++it) {
if (it->had_previous) {
available[it->key] = it->previous;
} else {
available.erase(it->key);
}
}
return changed;
}
bool RunGVNOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
DominatorTree dom_tree(function);
std::unordered_map<ExprKey, Value*, ExprKeyHash> available;
return RunGVNInDomSubtree(function.GetEntryBlock(), dom_tree, available);
}
} // namespace
bool RunGVN(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunGVNOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,403 @@
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "LoopPassUtils.h"
#include <algorithm>
#include <cstdint>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
struct InlineCandidateInfo {
bool valid = false;
int cost = 0;
bool has_nested_call = false;
};
bool IsInlineableInstruction(const Instruction* inst) {
if (!inst) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
case Opcode::Load:
case Opcode::Store:
case Opcode::GetElementPtr:
case Opcode::Zext:
case Opcode::Memset:
case Opcode::Call:
case Opcode::Return:
return true;
default:
return false;
}
}
int EstimateInstructionCost(const Instruction* inst) {
if (!inst) {
return 0;
}
switch (inst->GetOpcode()) {
case Opcode::Return:
return 0;
case Opcode::Load:
case Opcode::Store:
case Opcode::Memset:
return 3;
case Opcode::Call:
return 8;
case Opcode::GetElementPtr:
return 2;
default:
return 1;
}
}
InlineCandidateInfo AnalyzeInlineCandidate(const Function& function) {
InlineCandidateInfo info;
if (function.IsExternal() || function.IsRecursive() || function.GetBlocks().size() != 1) {
return info;
}
const auto& block = function.GetBlocks().front();
if (!block || block->GetInstructions().empty()) {
return info;
}
bool saw_return = false;
for (std::size_t i = 0; i < block->GetInstructions().size(); ++i) {
auto* inst = block->GetInstructions()[i].get();
if (!IsInlineableInstruction(inst)) {
return {};
}
if (dyncast<ReturnInst>(inst)) {
if (i + 1 != block->GetInstructions().size()) {
return {};
}
saw_return = true;
continue;
}
if (dyncast<CallInst>(inst)) {
info.has_nested_call = true;
}
info.cost += EstimateInstructionCost(inst);
}
if (!saw_return) {
return {};
}
info.valid = true;
return info;
}
std::unordered_map<Function*, int> CountDirectCalls(Module& module) {
std::unordered_map<Function*, int> counts;
for (const auto& function_ptr : module.GetFunctions()) {
if (!function_ptr) {
continue;
}
for (const auto& block_ptr : function_ptr->GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
if (auto* call = dyncast<CallInst>(inst_ptr.get())) {
if (auto* callee = call->GetCallee()) {
++counts[callee];
}
}
}
}
}
return counts;
}
bool ShouldInlineCallSite(const Function& caller, const CallInst& call,
const InlineCandidateInfo& callee_info, int call_count) {
auto* callee = call.GetCallee();
if (!callee || callee == &caller || !callee_info.valid) {
return false;
}
int budget = callee->CanDiscardUnusedCall() ? 40 : 24;
if (call_count <= 1) {
budget += 12;
}
if (callee_info.has_nested_call) {
budget -= 8;
}
if (callee->MayWriteMemory()) {
budget -= 4;
}
return callee_info.cost <= budget;
}
Instruction* CloneInstructionAt(Function& function, Instruction* inst, BasicBlock* dest,
std::size_t insert_index,
std::unordered_map<Value*, Value*>& remap) {
if (!inst || !dest) {
return nullptr;
}
const auto name = inst->IsVoid() ? std::string()
: looputils::NextSyntheticName(function, "inline.");
auto remap_operand = [&](Value* value) { return looputils::RemapValue(remap, value); };
auto remember = [&](Instruction* clone) {
if (clone && !inst->IsVoid()) {
remap[inst] = clone;
}
return clone;
};
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE: {
auto* bin = static_cast<BinaryInst*>(inst);
return remember(dest->Insert<BinaryInst>(insert_index, inst->GetOpcode(), inst->GetType(),
remap_operand(bin->GetLhs()),
remap_operand(bin->GetRhs()), nullptr, name));
}
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF: {
auto* un = static_cast<UnaryInst*>(inst);
return remember(dest->Insert<UnaryInst>(insert_index, inst->GetOpcode(), inst->GetType(),
remap_operand(un->GetOprd()), nullptr, name));
}
case Opcode::Load: {
auto* load = static_cast<LoadInst*>(inst);
return remember(dest->Insert<LoadInst>(insert_index, inst->GetType(),
remap_operand(load->GetPtr()), nullptr, name));
}
case Opcode::Store: {
auto* store = static_cast<StoreInst*>(inst);
return dest->Insert<StoreInst>(insert_index, remap_operand(store->GetValue()),
remap_operand(store->GetPtr()), nullptr);
}
case Opcode::Memset: {
auto* memset = static_cast<MemsetInst*>(inst);
return dest->Insert<MemsetInst>(insert_index, remap_operand(memset->GetDest()),
remap_operand(memset->GetValue()),
remap_operand(memset->GetLength()),
remap_operand(memset->GetIsVolatile()), nullptr);
}
case Opcode::GetElementPtr: {
auto* gep = static_cast<GetElementPtrInst*>(inst);
std::vector<Value*> indices;
indices.reserve(gep->GetNumIndices());
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
indices.push_back(remap_operand(gep->GetIndex(i)));
}
return remember(dest->Insert<GetElementPtrInst>(
insert_index, gep->GetSourceType(), remap_operand(gep->GetPointer()), indices, nullptr,
name));
}
case Opcode::Zext: {
auto* zext = static_cast<ZextInst*>(inst);
return remember(dest->Insert<ZextInst>(insert_index, remap_operand(zext->GetValue()),
inst->GetType(), nullptr, name));
}
case Opcode::Call: {
auto* call = static_cast<CallInst*>(inst);
std::vector<Value*> args;
const auto original_args = call->GetArguments();
args.reserve(original_args.size());
for (auto* arg : original_args) {
args.push_back(remap_operand(arg));
}
return remember(
dest->Insert<CallInst>(insert_index, call->GetCallee(), args, nullptr, name));
}
case Opcode::Return:
case Opcode::Alloca:
case Opcode::Phi:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Unreachable:
break;
}
return nullptr;
}
bool InlineCallSite(Function& caller, CallInst* call) {
if (!call) {
return false;
}
auto* callee = call->GetCallee();
if (!callee || callee->GetBlocks().size() != 1) {
return false;
}
const auto& callee_args = callee->GetArguments();
const auto call_args = call->GetArguments();
if (callee_args.size() != call_args.size()) {
return false;
}
auto* block = call->GetParent();
if (!block) {
return false;
}
auto& instructions = block->GetInstructions();
auto call_it = std::find_if(instructions.begin(), instructions.end(),
[&](const std::unique_ptr<Instruction>& current) {
return current.get() == call;
});
if (call_it == instructions.end()) {
return false;
}
std::size_t insert_index = static_cast<std::size_t>(call_it - instructions.begin());
std::unordered_map<Value*, Value*> remap;
for (std::size_t i = 0; i < call_args.size(); ++i) {
remap[callee_args[i].get()] = call_args[i];
}
Value* return_value = nullptr;
for (const auto& inst_ptr : callee->GetBlocks().front()->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* ret = dyncast<ReturnInst>(inst)) {
if (ret->HasReturnValue()) {
return_value = looputils::RemapValue(remap, ret->GetReturnValue());
}
break;
}
if (!CloneInstructionAt(caller, inst, block, insert_index, remap)) {
return false;
}
++insert_index;
}
if (!call->GetType()->IsVoid()) {
if (!return_value) {
return false;
}
call->ReplaceAllUsesWith(return_value);
}
block->EraseInstruction(call);
return true;
}
bool RunFunctionInliningOnFunction(
Function& function,
const std::unordered_map<Function*, InlineCandidateInfo>& callee_info,
const std::unordered_map<Function*, int>& call_counts) {
if (function.IsExternal()) {
return false;
}
bool changed = false;
for (auto& block_ptr : function.GetBlocks()) {
std::vector<CallInst*> calls;
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
if (auto* call = dyncast<CallInst>(inst_ptr.get())) {
calls.push_back(call);
}
}
for (auto* call : calls) {
auto* callee = call->GetCallee();
if (!callee) {
continue;
}
auto info_it = callee_info.find(callee);
if (info_it == callee_info.end()) {
continue;
}
const int call_count =
call_counts.count(callee) != 0 ? call_counts.at(callee) : 0;
if (!ShouldInlineCallSite(function, *call, info_it->second, call_count)) {
continue;
}
changed |= InlineCallSite(function, call);
}
}
return changed;
}
} // namespace
bool RunFunctionInlining(Module& module) {
std::unordered_map<Function*, InlineCandidateInfo> callee_info;
for (const auto& function_ptr : module.GetFunctions()) {
if (function_ptr) {
callee_info.emplace(function_ptr.get(), AnalyzeInlineCandidate(*function_ptr));
}
}
const auto call_counts = CountDirectCalls(module);
bool changed = false;
for (const auto& function_ptr : module.GetFunctions()) {
if (function_ptr) {
changed |= RunFunctionInliningOnFunction(*function_ptr, callee_info, call_counts);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,236 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopMemoryUtils.h"
#include "LoopPassUtils.h"
#include "MemoryUtils.h"
#include <cstdint>
#include <unordered_set>
#include <unordered_map>
#include <utility>
#include <vector>
namespace ir {
namespace {
struct HoistedLoadKey {
memutils::AddressKey address;
std::uintptr_t type_id = 0;
bool operator==(const HoistedLoadKey& rhs) const {
return type_id == rhs.type_id && address == rhs.address;
}
};
struct HoistedLoadKeyHash {
std::size_t operator()(const HoistedLoadKey& key) const {
std::size_t h = memutils::AddressKeyHash{}(key.address);
h ^= std::hash<std::uintptr_t>{}(key.type_id) + 0x9e3779b9 + (h << 6) + (h >> 2);
return h;
}
};
bool IsHoistableInstruction(const Instruction* inst) {
if (!inst || inst->IsTerminator() || inst->IsVoid()) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
case Opcode::GetElementPtr:
case Opcode::Zext:
case Opcode::Load:
return true;
default:
return false;
}
}
bool IsLoopInvariantInstruction(
const Loop& loop, Instruction* inst,
const std::unordered_set<Instruction*>& invariant_insts,
PhiInst* iv, int iv_stride,
const std::vector<loopmem::MemoryAccessInfo>& accesses,
const memutils::EscapeSummary& escapes) {
if (!IsHoistableInstruction(inst)) {
return false;
}
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
auto* operand = inst->GetOperand(i);
auto* operand_inst = dyncast<Instruction>(operand);
if (!operand_inst) {
continue;
}
if (!loop.Contains(operand_inst->GetParent())) {
continue;
}
if (invariant_insts.find(operand_inst) == invariant_insts.end()) {
return false;
}
}
if (auto* load = dyncast<LoadInst>(inst)) {
return loopmem::IsSafeInvariantLoadToHoist(loop, load, iv, iv_stride, accesses, &escapes);
}
return true;
}
bool HoistLoopInvariants(Function& function, const Loop& loop,
BasicBlock* preheader) {
if (!preheader) {
return false;
}
loopmem::SimpleInductionVar induction_var;
PhiInst* iv = nullptr;
int iv_stride = 1;
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
if (loopmem::MatchSimpleInductionVariable(loop, preheader, phi, induction_var)) {
iv = induction_var.phi;
iv_stride = induction_var.stride;
break;
}
}
const auto escapes = memutils::AnalyzeEscapes(function);
const auto accesses = loopmem::CollectMemoryAccesses(loop, iv, &escapes);
std::unordered_set<Instruction*> invariant_insts;
std::vector<Instruction*> hoist_list;
bool progress = true;
while (progress) {
progress = false;
for (const auto& block_ptr : function.GetBlocks()) {
auto* block = block_ptr.get();
if (!loop.Contains(block) || block == preheader) {
continue;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (invariant_insts.find(inst) != invariant_insts.end()) {
continue;
}
if (!IsLoopInvariantInstruction(loop, inst, invariant_insts, iv, iv_stride,
accesses, escapes)) {
continue;
}
invariant_insts.insert(inst);
hoist_list.push_back(inst);
progress = true;
}
}
}
bool changed = false;
std::unordered_map<HoistedLoadKey, LoadInst*, HoistedLoadKeyHash> hoisted_loads;
for (auto* inst : hoist_list) {
if (auto* load = dyncast<LoadInst>(inst)) {
auto ptr = loopmem::AnalyzePointer(load->GetPtr(), iv, loop,
load->GetType()->GetSize(), &escapes);
if (ptr.exact_key_valid) {
HoistedLoadKey key{ptr.exact_key,
reinterpret_cast<std::uintptr_t>(load->GetType().get())};
auto it = hoisted_loads.find(key);
if (it != hoisted_loads.end()) {
load->ReplaceAllUsesWith(it->second);
load->GetParent()->EraseInstruction(load);
changed = true;
continue;
}
auto* moved = dyncast<LoadInst>(
looputils::MoveInstructionBeforeTerminator(load, preheader));
if (moved) {
hoisted_loads.emplace(std::move(key), moved);
changed = true;
}
continue;
}
}
if (looputils::MoveInstructionBeforeTerminator(inst, preheader)) {
changed = true;
}
}
return changed;
}
bool RunLICMOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
auto* old_preheader = loop->preheader;
auto* preheader = looputils::EnsurePreheader(function, *loop);
bool loop_changed = preheader != old_preheader;
loop_changed |= HoistLoopInvariants(function, *loop, preheader);
if (!loop_changed) {
continue;
}
changed = true;
local_changed = true;
break;
}
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLICM(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLICMOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,319 @@
#include "ir/PassManager.h"
#include "ir/IR.h"
#include "MemoryUtils.h"
#include "PassUtils.h"
#include <unordered_map>
#include <utility>
#include <vector>
namespace ir {
namespace {
struct AvailableValue {
Value* value = nullptr;
bool operator==(const AvailableValue& rhs) const {
return passutils::AreEquivalentValues(value, rhs.value) || value == rhs.value;
}
};
using MemoryState =
std::unordered_map<memutils::AddressKey, AvailableValue,
memutils::AddressKeyHash>;
bool SameAvailableValue(const AvailableValue& lhs, const AvailableValue& rhs) {
return lhs == rhs;
}
bool SameMemoryState(const MemoryState& lhs, const MemoryState& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (const auto& [key, value] : lhs) {
auto it = rhs.find(key);
if (it == rhs.end() || !SameAvailableValue(value, it->second)) {
return false;
}
}
return true;
}
MemoryState MeetMemoryStates(const std::vector<MemoryState*>& predecessors) {
if (predecessors.empty()) {
return {};
}
MemoryState in = *predecessors.front();
for (auto it = in.begin(); it != in.end();) {
bool keep = true;
for (std::size_t i = 1; i < predecessors.size(); ++i) {
auto pred_it = predecessors[i]->find(it->first);
if (pred_it == predecessors[i]->end() ||
!SameAvailableValue(it->second, pred_it->second)) {
keep = false;
break;
}
}
if (!keep) {
it = in.erase(it);
continue;
}
++it;
}
return in;
}
void InvalidateAliasStates(MemoryState& state,
const memutils::AddressKey& key) {
for (auto it = state.begin(); it != state.end();) {
if (memutils::MayAliasConservatively(it->first, key)) {
it = state.erase(it);
continue;
}
++it;
}
}
void InvalidateStatesForCall(MemoryState& state, Function* callee) {
for (auto it = state.begin(); it != state.end();) {
if (memutils::CallMayWriteRoot(callee, it->first.kind)) {
it = state.erase(it);
continue;
}
++it;
}
}
void SimulateInstruction(const memutils::EscapeSummary& escapes, Instruction* inst,
MemoryState& state) {
if (!inst) {
return;
}
if (auto* load = dyncast<LoadInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(load->GetPtr(), &escapes, key)) {
state.clear();
}
return;
}
if (auto* store = dyncast<StoreInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) {
state.clear();
return;
}
InvalidateAliasStates(state, key);
state[key] = {store->GetValue()};
return;
}
if (auto* call = dyncast<CallInst>(inst)) {
InvalidateStatesForCall(state, call->GetCallee());
return;
}
if (auto* memset = dyncast<MemsetInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key)) {
state.clear();
return;
}
InvalidateAliasStates(state, key);
return;
}
}
MemoryState SimulateBlock(const memutils::EscapeSummary& escapes, BasicBlock* block,
const MemoryState& in_state) {
MemoryState state = in_state;
for (const auto& inst_ptr : block->GetInstructions()) {
SimulateInstruction(escapes, inst_ptr.get(), state);
}
return state;
}
bool MarkLoadObserved(
const memutils::AddressKey& key,
std::unordered_map<memutils::AddressKey, StoreInst*, memutils::AddressKeyHash>&
pending_stores) {
bool changed = false;
for (auto it = pending_stores.begin(); it != pending_stores.end();) {
if (memutils::MayAliasConservatively(it->first, key)) {
it = pending_stores.erase(it);
changed = true;
continue;
}
++it;
}
return changed;
}
void InvalidatePendingForCall(
std::unordered_map<memutils::AddressKey, StoreInst*, memutils::AddressKeyHash>&
pending_stores,
Function* callee) {
for (auto it = pending_stores.begin(); it != pending_stores.end();) {
if (memutils::CallMayReadRoot(callee, it->first.kind) ||
memutils::CallMayWriteRoot(callee, it->first.kind)) {
it = pending_stores.erase(it);
continue;
}
++it;
}
}
bool OptimizeBlock(
const memutils::EscapeSummary& escapes, BasicBlock* block,
const MemoryState& in_state) {
bool changed = false;
MemoryState state = in_state;
std::unordered_map<memutils::AddressKey, StoreInst*, memutils::AddressKeyHash>
pending_stores;
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* load = dyncast<LoadInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(load->GetPtr(), &escapes, key)) {
state.clear();
pending_stores.clear();
continue;
}
MarkLoadObserved(key, pending_stores);
auto it = state.find(key);
if (it != state.end() && it->second.value != load) {
load->ReplaceAllUsesWith(it->second.value);
to_remove.push_back(load);
changed = true;
continue;
}
// Keep block-local load reuse, but do not expose load results to cross-block
// dataflow because the defining load itself may be removed later.
state[key] = {load};
continue;
}
if (auto* store = dyncast<StoreInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) {
state.clear();
pending_stores.clear();
continue;
}
for (auto it = pending_stores.begin(); it != pending_stores.end();) {
if (!memutils::MayAliasConservatively(it->first, key)) {
++it;
continue;
}
if (it->first == key) {
to_remove.push_back(it->second);
changed = true;
}
it = pending_stores.erase(it);
}
pending_stores.emplace(key, store);
InvalidateAliasStates(state, key);
state[key] = {store->GetValue()};
continue;
}
if (auto* call = dyncast<CallInst>(inst)) {
InvalidateStatesForCall(state, call->GetCallee());
InvalidatePendingForCall(pending_stores, call->GetCallee());
continue;
}
if (auto* memset = dyncast<MemsetInst>(inst)) {
memutils::AddressKey key;
if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key)) {
state.clear();
pending_stores.clear();
continue;
}
InvalidateAliasStates(state, key);
MarkLoadObserved(key, pending_stores);
continue;
}
}
for (auto* inst : to_remove) {
if (inst->GetParent() == block) {
block->EraseInstruction(inst);
}
}
return changed;
}
bool RunLoadStoreElimOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
const auto escapes = memutils::AnalyzeEscapes(function);
const auto reachable_blocks = passutils::CollectReachableBlocks(function);
if (reachable_blocks.empty()) {
return false;
}
std::unordered_map<BasicBlock*, MemoryState> in_states;
std::unordered_map<BasicBlock*, MemoryState> out_states;
bool dataflow_changed = true;
while (dataflow_changed) {
dataflow_changed = false;
for (auto* block : reachable_blocks) {
MemoryState in_state;
if (block != function.GetEntryBlock()) {
std::vector<MemoryState*> predecessors;
for (auto* pred : block->GetPredecessors()) {
auto it = out_states.find(pred);
if (it != out_states.end()) {
predecessors.push_back(&it->second);
}
}
in_state = MeetMemoryStates(predecessors);
}
auto out_state = SimulateBlock(escapes, block, in_state);
auto in_it = in_states.find(block);
if (in_it == in_states.end() || !SameMemoryState(in_it->second, in_state)) {
in_states[block] = in_state;
dataflow_changed = true;
}
auto out_it = out_states.find(block);
if (out_it == out_states.end() || !SameMemoryState(out_it->second, out_state)) {
out_states[block] = std::move(out_state);
dataflow_changed = true;
}
}
}
bool changed = false;
for (auto* block : reachable_blocks) {
changed |= OptimizeBlock(escapes, block, in_states[block]);
}
return changed;
}
} // namespace
bool RunLoadStoreElim(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoadStoreElimOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,326 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopMemoryUtils.h"
#include "LoopPassUtils.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir {
namespace {
struct FissionLoopInfo {
Loop* loop = nullptr;
BasicBlock* preheader = nullptr;
BasicBlock* header = nullptr;
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
CondBrInst* branch = nullptr;
BinaryInst* compare = nullptr;
Opcode compare_opcode = Opcode::ICmpLT;
Value* bound = nullptr;
loopmem::SimpleInductionVar induction_var;
PhiInst* iv = nullptr;
BinaryInst* step_inst = nullptr;
};
bool HasSyntheticLoopTag(const std::string& name) {
return name.find("unroll.") != std::string::npos ||
name.find("fission.") != std::string::npos;
}
bool IsAlreadyTransformedLoop(const Loop& loop, BasicBlock* body) {
if (!loop.preheader || !loop.header || !body) {
return true;
}
return HasSyntheticLoopTag(loop.preheader->GetName()) ||
HasSyntheticLoopTag(loop.header->GetName()) ||
HasSyntheticLoopTag(body->GetName());
}
Opcode SwapCompareOpcode(Opcode opcode) {
switch (opcode) {
case Opcode::ICmpLT:
return Opcode::ICmpGT;
case Opcode::ICmpLE:
return Opcode::ICmpGE;
case Opcode::ICmpGT:
return Opcode::ICmpLT;
case Opcode::ICmpGE:
return Opcode::ICmpLE;
default:
return opcode;
}
}
bool MatchFissionLoop(Loop& loop, FissionLoopInfo& info) {
if (!loop.preheader || !loop.header || !loop.IsInnermost()) {
return false;
}
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) {
return false;
}
if (IsAlreadyTransformedLoop(loop, body)) {
return false;
}
std::vector<PhiInst*> phis;
loopmem::SimpleInductionVar induction_var;
bool found_iv = false;
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
phis.push_back(phi);
if (!found_iv &&
loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) {
found_iv = true;
}
}
if (!found_iv || phis.size() != 1) {
return false;
}
auto* branch = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
auto* compare = branch ? dyncast<BinaryInst>(branch->GetCondition()) : nullptr;
if (!branch || branch->GetThenBlock() != body || !compare) {
return false;
}
Opcode compare_opcode = compare->GetOpcode();
Value* bound = nullptr;
if (compare->GetLhs() == induction_var.phi &&
looputils::IsLoopInvariantValue(loop, compare->GetRhs())) {
bound = compare->GetRhs();
} else if (compare->GetRhs() == induction_var.phi &&
looputils::IsLoopInvariantValue(loop, compare->GetLhs())) {
bound = compare->GetLhs();
compare_opcode = SwapCompareOpcode(compare_opcode);
} else {
return false;
}
auto* step_inst = dyncast<BinaryInst>(induction_var.latch_value);
if (!step_inst || step_inst->GetParent() != body) {
return false;
}
for (const auto& inst_ptr : body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator() || inst == step_inst) {
continue;
}
if (!looputils::IsCloneableInstruction(inst) || dyncast<CallInst>(inst) ||
dyncast<MemsetInst>(inst) || dyncast<AllocaInst>(inst)) {
return false;
}
}
info.loop = &loop;
info.preheader = loop.preheader;
info.header = loop.header;
info.body = body;
info.exit = exit;
info.branch = branch;
info.compare = compare;
info.compare_opcode = compare_opcode;
info.bound = bound;
info.induction_var = induction_var;
info.iv = induction_var.phi;
info.step_inst = step_inst;
return true;
}
bool ContainsInterestingPayload(const std::vector<Instruction*>& group) {
bool has_memory = false;
for (auto* inst : group) {
if (dyncast<LoadInst>(inst) || dyncast<StoreInst>(inst)) {
has_memory = true;
}
}
return has_memory;
}
Value* RemapExitValue(Value* value, PhiInst* old_iv, PhiInst* new_iv) {
if (value == old_iv) {
return new_iv;
}
return value;
}
bool BuildSecondLoop(Function& function, const FissionLoopInfo& info,
const std::vector<Instruction*>& second_group) {
auto* second_header =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.header"));
auto* second_body =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.body"));
const int preheader_index = looputils::GetPhiIncomingIndex(info.iv, info.preheader);
if (preheader_index < 0) {
return false;
}
auto* second_iv = second_header->Append<PhiInst>(
info.iv->GetType(), nullptr,
looputils::NextSyntheticName(function, "fission.iv."));
second_iv->AddIncoming(info.iv->GetIncomingValue(preheader_index), info.header);
auto* second_cmp = second_header->Append<BinaryInst>(
info.compare_opcode, Type::GetBoolType(), second_iv, info.bound, nullptr,
looputils::NextSyntheticName(function, "fission.cmp."));
second_header->Append<CondBrInst>(second_cmp, second_body, info.exit, nullptr);
second_header->AddPredecessor(info.header);
second_header->AddSuccessor(second_body);
second_header->AddSuccessor(info.exit);
std::unordered_map<Value*, Value*> remap;
remap[info.iv] = second_iv;
std::unordered_set<Instruction*> selected(second_group.begin(), second_group.end());
selected.insert(info.step_inst);
for (const auto& inst_ptr : info.body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator() || selected.find(inst) == selected.end()) {
continue;
}
looputils::CloneInstruction(function, inst, second_body, remap, "fission.");
}
auto* cloned_step_value = looputils::RemapValue(remap, info.step_inst);
if (!cloned_step_value) {
return false;
}
second_iv->AddIncoming(cloned_step_value, second_body);
second_body->Append<UncondBrInst>(second_header, nullptr);
second_body->AddPredecessor(second_header);
second_body->AddSuccessor(second_header);
second_header->AddPredecessor(second_body);
if (!looputils::RedirectSuccessorEdge(info.header, info.exit, second_header)) {
return false;
}
info.exit->RemovePredecessor(info.header);
info.exit->AddPredecessor(second_header);
for (const auto& inst_ptr : info.exit->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
const int incoming = looputils::GetPhiIncomingIndex(phi, info.header);
if (incoming < 0) {
continue;
}
phi->SetOperand(static_cast<std::size_t>(2 * incoming),
RemapExitValue(phi->GetIncomingValue(incoming), info.iv, second_iv));
phi->SetOperand(static_cast<std::size_t>(2 * incoming + 1), second_header);
}
return true;
}
bool RunLoopFissionOnFunction(Function& function) {
if (function.IsExternal() || !function.GetEntryBlock()) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
FissionLoopInfo info;
if (!MatchFissionLoop(*loop, info)) {
continue;
}
const auto accesses = loopmem::CollectMemoryAccesses(*loop, info.iv);
std::vector<Instruction*> payload;
for (const auto& inst_ptr : info.body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator() || inst == info.step_inst) {
continue;
}
payload.push_back(inst);
}
if (payload.size() < 2) {
continue;
}
int chosen_cut = -1;
std::vector<Instruction*> first_group;
std::vector<Instruction*> second_group;
for (std::size_t cut = 1; cut < payload.size(); ++cut) {
std::vector<Instruction*> first(payload.begin(), payload.begin() + static_cast<long long>(cut));
std::vector<Instruction*> second(payload.begin() + static_cast<long long>(cut),
payload.end());
if (!ContainsInterestingPayload(first) || !ContainsInterestingPayload(second)) {
continue;
}
std::unordered_set<Instruction*> first_set(first.begin(), first.end());
std::unordered_set<Instruction*> second_set(second.begin(), second.end());
if (loopmem::HasScalarDependenceAcrossCut(first, second_set) ||
loopmem::HasMemoryDependenceAcrossCut(accesses, first_set, second_set,
info.induction_var.stride)) {
continue;
}
chosen_cut = static_cast<int>(cut);
first_group = std::move(first);
second_group = std::move(second);
break;
}
if (chosen_cut < 0) {
continue;
}
std::unordered_set<Instruction*> keep(first_group.begin(), first_group.end());
keep.insert(info.step_inst);
std::vector<Instruction*> to_remove;
for (const auto& inst_ptr : info.body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator() || keep.find(inst) != keep.end()) {
continue;
}
to_remove.push_back(inst);
}
if (!BuildSecondLoop(function, info, second_group)) {
continue;
}
for (auto* inst : to_remove) {
info.body->EraseInstruction(inst);
}
changed = true;
local_changed = true;
break;
}
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLoopFission(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoopFissionOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,506 @@
#pragma once
#include "LoopPassUtils.h"
#include "MemoryUtils.h"
#include <algorithm>
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <vector>
namespace ir::loopmem {
struct SimpleInductionVar {
PhiInst* phi = nullptr;
Value* start = nullptr;
Value* latch_value = nullptr;
BasicBlock* latch = nullptr;
int stride = 0;
};
inline bool MatchSimpleInductionVariable(const Loop& loop, BasicBlock* preheader,
PhiInst* phi, SimpleInductionVar& info) {
if (!phi || !preheader || phi->GetParent() != loop.header ||
!phi->GetType()->IsInt32() || phi->GetNumIncomings() != 2 ||
loop.latches.size() != 1) {
return false;
}
auto* latch = loop.latches.front();
const int preheader_index = looputils::GetPhiIncomingIndex(phi, preheader);
const int latch_index = looputils::GetPhiIncomingIndex(phi, latch);
if (preheader_index < 0 || latch_index < 0) {
return false;
}
auto* step_inst = dyncast<BinaryInst>(phi->GetIncomingValue(latch_index));
if (!step_inst || step_inst->GetParent() != latch) {
return false;
}
int stride = 0;
if (step_inst->GetOpcode() == Opcode::Add) {
if (step_inst->GetLhs() == phi) {
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
if (!delta) {
return false;
}
stride = delta->GetValue();
} else if (step_inst->GetRhs() == phi) {
auto* delta = dyncast<ConstantInt>(step_inst->GetLhs());
if (!delta) {
return false;
}
stride = delta->GetValue();
} else {
return false;
}
} else if (step_inst->GetOpcode() == Opcode::Sub) {
if (step_inst->GetLhs() != phi) {
return false;
}
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
if (!delta) {
return false;
}
stride = -delta->GetValue();
} else {
return false;
}
if (stride == 0) {
return false;
}
info.phi = phi;
info.start = phi->GetIncomingValue(preheader_index);
info.latch_value = phi->GetIncomingValue(latch_index);
info.latch = latch;
info.stride = stride;
return true;
}
inline bool GetCanonicalLoopBlocks(const Loop& loop, BasicBlock*& body,
BasicBlock*& exit) {
body = nullptr;
exit = nullptr;
if (!loop.header || loop.latches.size() != 1 || loop.block_list.size() != 2) {
return false;
}
auto* condbr = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
if (!condbr) {
return false;
}
auto* then_block = condbr->GetThenBlock();
auto* else_block = condbr->GetElseBlock();
const bool then_in_loop = loop.Contains(then_block);
const bool else_in_loop = loop.Contains(else_block);
if (then_in_loop == else_in_loop) {
return false;
}
body = then_in_loop ? then_block : else_block;
exit = then_in_loop ? else_block : then_block;
if (!body || !exit || body != loop.latches.front() ||
body->GetSuccessors().size() != 1 || body->GetSuccessors().front() != loop.header) {
return false;
}
return true;
}
struct AffineExpr {
bool valid = false;
Value* var = nullptr;
std::int64_t coeff = 0;
std::int64_t constant = 0;
};
inline AffineExpr MakeConst(std::int64_t value) {
return {true, nullptr, 0, value};
}
inline AffineExpr Scale(const AffineExpr& expr, std::int64_t factor) {
if (!expr.valid) {
return {};
}
return {true, expr.var, expr.coeff * factor, expr.constant * factor};
}
inline AffineExpr Combine(const AffineExpr& lhs, const AffineExpr& rhs, int sign) {
if (!lhs.valid || !rhs.valid) {
return {};
}
if (lhs.var != nullptr && rhs.var != nullptr && lhs.var != rhs.var) {
return {};
}
AffineExpr out;
out.valid = true;
out.var = lhs.var ? lhs.var : rhs.var;
out.coeff = lhs.coeff + sign * rhs.coeff;
out.constant = lhs.constant + sign * rhs.constant;
return out;
}
inline AffineExpr AnalyzeAffine(Value* value, PhiInst* iv, const Loop& loop) {
if (!value) {
return {};
}
if (auto* ci = dyncast<ConstantInt>(value)) {
return MakeConst(ci->GetValue());
}
if (value == iv) {
return {true, iv, 1, 0};
}
if (looputils::IsLoopInvariantValue(loop, value)) {
return {};
}
if (auto* zext = dyncast<ZextInst>(value)) {
return AnalyzeAffine(zext->GetValue(), iv, loop);
}
auto* inst = dyncast<Instruction>(value);
if (!inst) {
return {};
}
switch (inst->GetOpcode()) {
case Opcode::Add:
return Combine(AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetLhs(), iv, loop),
AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetRhs(), iv, loop), +1);
case Opcode::Sub:
return Combine(AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetLhs(), iv, loop),
AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetRhs(), iv, loop), -1);
case Opcode::Mul: {
auto* bin = static_cast<BinaryInst*>(inst);
auto lhs = AnalyzeAffine(bin->GetLhs(), iv, loop);
auto rhs = AnalyzeAffine(bin->GetRhs(), iv, loop);
if (lhs.valid && lhs.var == nullptr && rhs.valid) {
return Scale(rhs, lhs.constant);
}
if (rhs.valid && rhs.var == nullptr && lhs.valid) {
return Scale(lhs, rhs.constant);
}
return {};
}
case Opcode::Neg:
return Scale(AnalyzeAffine(static_cast<UnaryInst*>(inst)->GetOprd(), iv, loop), -1);
default:
return {};
}
}
struct PointerInfo {
Value* base = nullptr;
AffineExpr byte_offset;
bool invariant_address = false;
bool distinct_root = false;
bool argument_root = false;
bool readonly_root = false;
bool exact_key_valid = false;
memutils::PointerRootKind root_kind = memutils::PointerRootKind::Unknown;
memutils::AddressKey exact_key;
int access_size = 0;
};
inline Value* StripPointerBase(Value* pointer) {
auto* value = pointer;
while (auto* gep = dyncast<GetElementPtrInst>(value)) {
value = gep->GetPointer();
}
return value;
}
inline std::shared_ptr<Type> AdvanceGEPType(std::shared_ptr<Type> current) {
if (current && current->IsArray()) {
return current->GetElementType();
}
return current;
}
inline PointerInfo AnalyzePointer(Value* pointer, PhiInst* iv, const Loop& loop,
int access_size,
const memutils::EscapeSummary* escapes = nullptr) {
PointerInfo info;
info.access_size = access_size;
info.base = StripPointerBase(pointer);
info.root_kind = memutils::ClassifyRoot(info.base, escapes);
info.argument_root = info.root_kind == memutils::PointerRootKind::Param;
info.readonly_root = info.root_kind == memutils::PointerRootKind::ReadonlyGlobal;
info.distinct_root = info.root_kind == memutils::PointerRootKind::Local ||
info.root_kind == memutils::PointerRootKind::Global ||
info.root_kind == memutils::PointerRootKind::ReadonlyGlobal;
info.exact_key_valid =
escapes != nullptr && memutils::BuildExactAddressKey(pointer, escapes, info.exact_key);
info.invariant_address = looputils::IsLoopInvariantValue(loop, pointer);
if (!dyncast<GetElementPtrInst>(pointer)) {
info.byte_offset = MakeConst(0);
return info;
}
auto* gep = static_cast<GetElementPtrInst*>(pointer);
std::shared_ptr<Type> current = gep->GetSourceType();
AffineExpr total = MakeConst(0);
bool all_indices_loop_invariant = looputils::IsLoopInvariantValue(loop, gep->GetPointer());
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
auto* index = gep->GetIndex(i);
all_indices_loop_invariant &= looputils::IsLoopInvariantValue(loop, index);
const std::int64_t stride = current ? current->GetSize() : 0;
auto term = AnalyzeAffine(index, iv, loop);
if (!term.valid) {
total = {};
} else if (total.valid) {
total = Combine(total, Scale(term, stride), +1);
}
current = AdvanceGEPType(current);
}
info.invariant_address = all_indices_loop_invariant;
info.byte_offset = total;
return info;
}
struct MemoryAccessInfo {
Instruction* inst = nullptr;
Value* pointer = nullptr;
PointerInfo ptr;
bool is_read = false;
bool is_write = false;
};
inline std::vector<MemoryAccessInfo> CollectMemoryAccesses(const Loop& loop,
PhiInst* iv,
const memutils::EscapeSummary* escapes =
nullptr) {
std::vector<MemoryAccessInfo> accesses;
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* load = dyncast<LoadInst>(inst)) {
accesses.push_back(
{inst, load->GetPtr(),
AnalyzePointer(load->GetPtr(), iv, loop, load->GetType()->GetSize(), escapes),
true,
false});
} else if (auto* store = dyncast<StoreInst>(inst)) {
accesses.push_back({inst, store->GetPtr(),
AnalyzePointer(store->GetPtr(), iv, loop,
store->GetValue()->GetType()->GetSize(), escapes),
false, true});
} else if (auto* memset = dyncast<MemsetInst>(inst)) {
accesses.push_back(
{inst, memset->GetDest(),
AnalyzePointer(memset->GetDest(), iv, loop, 1, escapes), false, true});
}
}
}
return accesses;
}
inline bool SameAffineAddress(const PointerInfo& lhs, const PointerInfo& rhs) {
return lhs.base == rhs.base && lhs.byte_offset.valid && rhs.byte_offset.valid &&
lhs.byte_offset.var == rhs.byte_offset.var &&
lhs.byte_offset.coeff == rhs.byte_offset.coeff &&
lhs.byte_offset.constant == rhs.byte_offset.constant;
}
inline bool MayAliasSameIteration(const PointerInfo& lhs, const PointerInfo& rhs) {
if (lhs.exact_key_valid && rhs.exact_key_valid) {
return memutils::MayAliasConservatively(lhs.exact_key, rhs.exact_key);
}
if (!lhs.base || !rhs.base) {
return true;
}
if (lhs.base != rhs.base) {
if (lhs.distinct_root && rhs.distinct_root && !lhs.argument_root && !rhs.argument_root) {
return false;
}
return true;
}
if (!lhs.byte_offset.valid || !rhs.byte_offset.valid) {
return true;
}
if (lhs.byte_offset.var != rhs.byte_offset.var) {
return true;
}
if (lhs.byte_offset.coeff != rhs.byte_offset.coeff) {
return true;
}
const auto diff = std::llabs(lhs.byte_offset.constant - rhs.byte_offset.constant);
const auto overlap = std::min(lhs.access_size, rhs.access_size);
return diff < overlap;
}
inline bool HasCrossIterationDependence(const PointerInfo& lhs, const PointerInfo& rhs,
int iv_stride) {
if (lhs.exact_key_valid && rhs.exact_key_valid &&
!memutils::MayAliasConservatively(lhs.exact_key, rhs.exact_key)) {
return false;
}
if (!lhs.base || !rhs.base) {
return true;
}
if (lhs.base != rhs.base) {
if (lhs.distinct_root && rhs.distinct_root && !lhs.argument_root && !rhs.argument_root) {
return false;
}
return true;
}
if (!lhs.byte_offset.valid || !rhs.byte_offset.valid) {
return true;
}
if (lhs.byte_offset.var != rhs.byte_offset.var) {
return true;
}
const auto lhs_step = lhs.byte_offset.coeff * iv_stride;
const auto rhs_step = rhs.byte_offset.coeff * iv_stride;
if (lhs_step == 0 && rhs_step == 0) {
return MayAliasSameIteration(lhs, rhs);
}
if (lhs_step == rhs_step && lhs_step != 0) {
const auto diff = rhs.byte_offset.constant - lhs.byte_offset.constant;
return diff != 0 && diff % std::llabs(lhs_step) == 0;
}
return true;
}
inline bool CallMayWritePointer(Function* callee, const PointerInfo& ptr) {
if (ptr.readonly_root) {
return false;
}
return memutils::CallMayWriteRoot(callee, ptr.root_kind);
}
inline bool IsSafeInvariantLoadToHoist(const Loop& loop, LoadInst* load, PhiInst* iv,
int iv_stride,
const std::vector<MemoryAccessInfo>& accesses,
const memutils::EscapeSummary* escapes = nullptr) {
if (!load) {
return false;
}
auto ptr = AnalyzePointer(load->GetPtr(), iv, loop, load->GetType()->GetSize(), escapes);
if (!ptr.invariant_address) {
return false;
}
if (ptr.readonly_root) {
return true;
}
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst == load) {
continue;
}
if (auto* call = dyncast<CallInst>(inst)) {
if (CallMayWritePointer(call->GetCallee(), ptr)) {
return false;
}
}
}
}
for (const auto& access : accesses) {
if (access.inst == load || !access.is_write) {
continue;
}
if (MayAliasSameIteration(ptr, access.ptr)) {
return false;
}
if (HasCrossIterationDependence(ptr, access.ptr, iv_stride)) {
return false;
}
}
return true;
}
inline bool HasScalarDependenceAcrossCut(const std::vector<Instruction*>& first_group,
const std::unordered_set<Instruction*>& second_set) {
for (auto* inst : first_group) {
if (!inst || inst->IsVoid()) {
continue;
}
for (const auto& use : inst->GetUses()) {
auto* user = dyncast<Instruction>(use.GetUser());
if (user && second_set.find(user) != second_set.end()) {
return true;
}
}
}
return false;
}
inline bool HasMemoryDependenceAcrossCut(const std::vector<MemoryAccessInfo>& accesses,
const std::unordered_set<Instruction*>& first_set,
const std::unordered_set<Instruction*>& second_set,
int iv_stride) {
for (const auto& lhs : accesses) {
if (first_set.find(lhs.inst) == first_set.end()) {
continue;
}
for (const auto& rhs : accesses) {
if (second_set.find(rhs.inst) == second_set.end()) {
continue;
}
if (!lhs.is_write && !rhs.is_write) {
continue;
}
if (MayAliasSameIteration(lhs.ptr, rhs.ptr) ||
HasCrossIterationDependence(lhs.ptr, rhs.ptr, iv_stride)) {
return true;
}
}
}
return false;
}
inline bool IsLoopParallelizable(const Loop& loop, PhiInst* iv, int iv_stride,
const std::vector<MemoryAccessInfo>& accesses) {
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
if (phi != iv) {
return false;
}
}
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* call = dyncast<CallInst>(inst)) {
auto* callee = call->GetCallee();
if (callee == nullptr || callee->HasObservableSideEffects() || callee->IsRecursive()) {
return false;
}
for (const auto& access : accesses) {
if (CallMayWritePointer(callee, access.ptr)) {
return false;
}
}
continue;
}
if (dyncast<MemsetInst>(inst)) {
return false;
}
}
}
for (std::size_t i = 0; i < accesses.size(); ++i) {
for (std::size_t j = i + 1; j < accesses.size(); ++j) {
if (!accesses[i].is_write && !accesses[j].is_write) {
continue;
}
if (HasCrossIterationDependence(accesses[i].ptr, accesses[j].ptr, iv_stride)) {
return false;
}
}
}
return true;
}
} // namespace ir::loopmem

@ -0,0 +1,440 @@
#pragma once
#include "ir/Analysis.h"
#include "ir/IR.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ir::looputils {
inline Instruction* GetTerminator(BasicBlock* block) {
if (!block || block->GetInstructions().empty()) {
return nullptr;
}
auto* inst = block->GetInstructions().back().get();
return inst && inst->IsTerminator() ? inst : nullptr;
}
inline std::size_t GetTerminatorIndex(BasicBlock* block) {
if (!block) {
return 0;
}
const auto size = block->GetInstructions().size();
if (!block->HasTerminator()) {
return size;
}
return size == 0 ? 0 : size - 1;
}
inline std::size_t GetFirstNonPhiIndex(BasicBlock* block) {
if (!block) {
return 0;
}
std::size_t index = 0;
for (const auto& inst_ptr : block->GetInstructions()) {
if (!dyncast<PhiInst>(inst_ptr.get())) {
break;
}
++index;
}
return index;
}
inline std::string NextSyntheticName(Function& function, const std::string& prefix) {
static std::unordered_map<Function*, int> counters;
const int id = ++counters[&function];
return "%" + prefix + std::to_string(id);
}
inline std::string NextSyntheticBlockName(Function& function,
const std::string& prefix) {
static std::unordered_map<Function*, int> counters;
const int id = ++counters[&function];
return prefix + "." + std::to_string(id);
}
inline ConstantInt* ConstInt(int value) {
return new ConstantInt(Type::GetInt32Type(), value);
}
inline int GetPhiIncomingIndex(PhiInst* phi, BasicBlock* block) {
if (!phi || !block) {
return -1;
}
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
if (phi->GetIncomingBlock(i) == block) {
return i;
}
}
return -1;
}
inline bool ReplacePhiIncoming(PhiInst* phi, BasicBlock* old_block,
Value* new_value, BasicBlock* new_block) {
if (!phi || !old_block || !new_value || !new_block) {
return false;
}
const int index = GetPhiIncomingIndex(phi, old_block);
if (index < 0) {
return false;
}
phi->SetOperand(static_cast<std::size_t>(2 * index), new_value);
phi->SetOperand(static_cast<std::size_t>(2 * index + 1), new_block);
return true;
}
inline bool RedirectSuccessorEdge(BasicBlock* pred, BasicBlock* old_succ,
BasicBlock* new_succ) {
auto* terminator = GetTerminator(pred);
if (!terminator || !old_succ || !new_succ) {
return false;
}
if (auto* br = dyncast<UncondBrInst>(terminator)) {
if (br->GetDest() != old_succ) {
return false;
}
br->SetOperand(0, new_succ);
} else if (auto* condbr = dyncast<CondBrInst>(terminator)) {
bool changed = false;
if (condbr->GetThenBlock() == old_succ) {
condbr->SetOperand(1, new_succ);
changed = true;
}
if (condbr->GetElseBlock() == old_succ) {
condbr->SetOperand(2, new_succ);
changed = true;
}
if (!changed) {
return false;
}
} else {
return false;
}
pred->RemoveSuccessor(old_succ);
pred->AddSuccessor(new_succ);
return true;
}
inline Instruction* MoveInstructionBeforeTerminator(Instruction* inst,
BasicBlock* dest) {
if (!inst || !dest) {
return nullptr;
}
auto* src = inst->GetParent();
if (!src || src == dest) {
return inst;
}
auto& src_insts = src->GetInstructions();
auto src_it = std::find_if(src_insts.begin(), src_insts.end(),
[&](const std::unique_ptr<Instruction>& current) {
return current.get() == inst;
});
if (src_it == src_insts.end()) {
return nullptr;
}
auto moved = std::move(*src_it);
src_insts.erase(src_it);
moved->SetParent(dest);
auto& dest_insts = dest->GetInstructions();
auto insert_it = dest_insts.begin() +
static_cast<long long>(GetTerminatorIndex(dest));
auto* ptr = moved.get();
dest_insts.insert(insert_it, std::move(moved));
return ptr;
}
inline bool IsLoopInvariantValue(const Loop& loop, Value* value) {
auto* inst = dyncast<Instruction>(value);
return inst == nullptr || !loop.Contains(inst->GetParent());
}
inline Value* RemapValue(const std::unordered_map<Value*, Value*>& remap,
Value* value) {
auto it = remap.find(value);
return it == remap.end() ? value : it->second;
}
inline bool IsCloneableInstruction(const Instruction* inst) {
if (!inst || inst->IsTerminator() || inst->GetOpcode() == Opcode::Phi) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF:
case Opcode::Alloca:
case Opcode::Load:
case Opcode::Store:
case Opcode::Memset:
case Opcode::GetElementPtr:
case Opcode::Zext:
case Opcode::Call:
return true;
default:
return false;
}
}
inline Instruction* CloneInstruction(Function& function, Instruction* inst,
BasicBlock* dest,
std::unordered_map<Value*, Value*>& remap,
const std::string& prefix) {
if (!inst || !dest || !IsCloneableInstruction(inst)) {
return nullptr;
}
const auto insert_index = GetTerminatorIndex(dest);
const auto name = inst->IsVoid() ? std::string()
: NextSyntheticName(function, prefix);
auto remap_operand = [&](Value* value) { return RemapValue(remap, value); };
auto remember = [&](Instruction* clone) {
if (clone && !inst->IsVoid()) {
remap[inst] = clone;
}
return clone;
};
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Rem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::FRem:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::Shl:
case Opcode::AShr:
case Opcode::LShr:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE: {
auto* bin = static_cast<BinaryInst*>(inst);
return remember(dest->Insert<BinaryInst>(
insert_index, inst->GetOpcode(), inst->GetType(),
remap_operand(bin->GetLhs()), remap_operand(bin->GetRhs()), nullptr,
name));
}
case Opcode::Neg:
case Opcode::Not:
case Opcode::FNeg:
case Opcode::FtoI:
case Opcode::IToF: {
auto* un = static_cast<UnaryInst*>(inst);
return remember(dest->Insert<UnaryInst>(insert_index, inst->GetOpcode(),
inst->GetType(),
remap_operand(un->GetOprd()),
nullptr, name));
}
case Opcode::Alloca: {
auto* alloca = static_cast<AllocaInst*>(inst);
return remember(dest->Insert<AllocaInst>(insert_index,
alloca->GetAllocatedType(),
nullptr, name));
}
case Opcode::Load: {
auto* load = static_cast<LoadInst*>(inst);
return remember(dest->Insert<LoadInst>(insert_index, inst->GetType(),
remap_operand(load->GetPtr()),
nullptr, name));
}
case Opcode::Store: {
auto* store = static_cast<StoreInst*>(inst);
return dest->Insert<StoreInst>(insert_index,
remap_operand(store->GetValue()),
remap_operand(store->GetPtr()), nullptr);
}
case Opcode::Memset: {
auto* memset = static_cast<MemsetInst*>(inst);
return dest->Insert<MemsetInst>(insert_index,
remap_operand(memset->GetDest()),
remap_operand(memset->GetValue()),
remap_operand(memset->GetLength()),
remap_operand(memset->GetIsVolatile()),
nullptr);
}
case Opcode::GetElementPtr: {
auto* gep = static_cast<GetElementPtrInst*>(inst);
std::vector<Value*> indices;
indices.reserve(gep->GetNumIndices());
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
indices.push_back(remap_operand(gep->GetIndex(i)));
}
return remember(dest->Insert<GetElementPtrInst>(
insert_index, gep->GetSourceType(), remap_operand(gep->GetPointer()),
indices, nullptr, name));
}
case Opcode::Zext: {
auto* zext = static_cast<ZextInst*>(inst);
return remember(dest->Insert<ZextInst>(insert_index,
remap_operand(zext->GetValue()),
inst->GetType(), nullptr, name));
}
case Opcode::Call: {
auto* call = static_cast<CallInst*>(inst);
std::vector<Value*> args;
const auto original_args = call->GetArguments();
args.reserve(original_args.size());
for (auto* arg : original_args) {
args.push_back(remap_operand(arg));
}
return remember(dest->Insert<CallInst>(insert_index, call->GetCallee(),
args, nullptr, name));
}
case Opcode::Phi:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Return:
case Opcode::Unreachable:
break;
}
return nullptr;
}
inline BasicBlock* EnsurePreheader(Function& function, Loop& loop) {
if (loop.preheader) {
return loop.preheader;
}
auto* header = loop.header;
if (!header) {
return nullptr;
}
std::vector<BasicBlock*> outside_preds;
for (auto* pred : header->GetPredecessors()) {
if (!loop.Contains(pred)) {
outside_preds.push_back(pred);
}
}
if (outside_preds.empty()) {
return nullptr;
}
if (outside_preds.size() == 1 &&
outside_preds.front()->GetSuccessors().size() == 1) {
loop.preheader = outside_preds.front();
return loop.preheader;
}
auto* preheader = function.CreateBlock(
NextSyntheticBlockName(function, header->GetName() + ".preheader"));
std::size_t phi_insert_index = 0;
for (const auto& inst_ptr : header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
std::vector<int> outside_incomings;
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
if (!loop.Contains(phi->GetIncomingBlock(i))) {
outside_incomings.push_back(i);
}
}
if (outside_incomings.empty()) {
continue;
}
Value* merged_value = nullptr;
if (outside_incomings.size() == 1) {
merged_value = phi->GetIncomingValue(outside_incomings.front());
} else {
auto new_phi = std::make_unique<PhiInst>(
phi->GetType(), nullptr,
NextSyntheticName(function, "preheader.phi."));
auto* new_phi_ptr = new_phi.get();
new_phi_ptr->SetParent(preheader);
auto& preheader_insts = preheader->GetInstructions();
preheader_insts.insert(preheader_insts.begin() +
static_cast<long long>(phi_insert_index),
std::move(new_phi));
++phi_insert_index;
for (int incoming_index : outside_incomings) {
new_phi_ptr->AddIncoming(phi->GetIncomingValue(incoming_index),
phi->GetIncomingBlock(incoming_index));
}
merged_value = new_phi_ptr;
}
for (auto it = outside_incomings.rbegin(); it != outside_incomings.rend();
++it) {
phi->RemoveOperand(static_cast<std::size_t>(2 * *it + 1));
phi->RemoveOperand(static_cast<std::size_t>(2 * *it));
}
phi->AddIncoming(merged_value, preheader);
}
preheader->Append<UncondBrInst>(header, nullptr);
preheader->AddSuccessor(header);
header->AddPredecessor(preheader);
for (auto* pred : outside_preds) {
if (RedirectSuccessorEdge(pred, header, preheader)) {
preheader->AddPredecessor(pred);
header->RemovePredecessor(pred);
}
}
loop.preheader = preheader;
return preheader;
}
} // namespace ir::looputils

@ -0,0 +1,295 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopPassUtils.h"
#include <cstdlib>
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
struct InductionVarInfo {
PhiInst* phi = nullptr;
Value* start = nullptr;
BasicBlock* latch = nullptr;
int stride = 0;
};
Value* BuildMulValue(Function& function, BasicBlock* block, Value* lhs, Value* rhs,
const std::string& prefix) {
if (auto* lhs_const = dyncast<ConstantInt>(lhs)) {
if (lhs_const->GetValue() == 0) {
return looputils::ConstInt(0);
}
if (lhs_const->GetValue() == 1) {
return rhs;
}
}
if (auto* rhs_const = dyncast<ConstantInt>(rhs)) {
if (rhs_const->GetValue() == 0) {
return looputils::ConstInt(0);
}
if (rhs_const->GetValue() == 1) {
return lhs;
}
}
if (auto* lhs_const = dyncast<ConstantInt>(lhs)) {
if (auto* rhs_const = dyncast<ConstantInt>(rhs)) {
return looputils::ConstInt(lhs_const->GetValue() * rhs_const->GetValue());
}
}
return block->Insert<BinaryInst>(looputils::GetTerminatorIndex(block), Opcode::Mul,
Type::GetInt32Type(), lhs, rhs, nullptr,
looputils::NextSyntheticName(function, prefix));
}
Value* BuildScaledValue(Function& function, BasicBlock* block, Value* base,
int factor, const std::string& prefix) {
if (factor == 0) {
return looputils::ConstInt(0);
}
if (factor == 1) {
return base;
}
if (auto* base_const = dyncast<ConstantInt>(base)) {
return looputils::ConstInt(base_const->GetValue() * factor);
}
if (factor == -1) {
return block->Insert<UnaryInst>(looputils::GetTerminatorIndex(block), Opcode::Neg,
base->GetType(), base, nullptr,
looputils::NextSyntheticName(function, prefix));
}
return BuildMulValue(function, block, base, looputils::ConstInt(factor), prefix);
}
bool MatchSimpleInductionVariable(const Loop& loop, BasicBlock* preheader,
PhiInst* phi, InductionVarInfo& info) {
if (!phi || !phi->GetType() || !phi->GetType()->IsInt32() ||
phi->GetParent() != loop.header || phi->GetNumIncomings() != 2 ||
loop.latches.size() != 1) {
return false;
}
auto* latch = loop.latches.front();
int preheader_index = -1;
int latch_index = -1;
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
if (phi->GetIncomingBlock(i) == preheader) {
preheader_index = i;
} else if (phi->GetIncomingBlock(i) == latch) {
latch_index = i;
}
}
if (preheader_index < 0 || latch_index < 0) {
return false;
}
auto* step_inst = dyncast<BinaryInst>(phi->GetIncomingValue(latch_index));
if (!step_inst || step_inst->GetParent() != latch) {
return false;
}
int stride = 0;
if (step_inst->GetOpcode() == Opcode::Add) {
if (step_inst->GetLhs() == phi) {
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
if (!delta) {
return false;
}
stride = delta->GetValue();
} else if (step_inst->GetRhs() == phi) {
auto* delta = dyncast<ConstantInt>(step_inst->GetLhs());
if (!delta) {
return false;
}
stride = delta->GetValue();
} else {
return false;
}
} else if (step_inst->GetOpcode() == Opcode::Sub) {
if (step_inst->GetLhs() != phi) {
return false;
}
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
if (!delta) {
return false;
}
stride = -delta->GetValue();
} else {
return false;
}
if (stride == 0) {
return false;
}
info.phi = phi;
info.start = phi->GetIncomingValue(preheader_index);
info.latch = latch;
info.stride = stride;
return true;
}
bool IsMulCandidate(const Loop& loop, Instruction* inst, PhiInst* phi, Value*& factor) {
auto* mul = dyncast<BinaryInst>(inst);
if (!mul || mul->GetOpcode() != Opcode::Mul || !mul->GetType()->IsInt32()) {
return false;
}
if (mul->GetLhs() == phi && looputils::IsLoopInvariantValue(loop, mul->GetRhs())) {
factor = mul->GetRhs();
return true;
}
if (mul->GetRhs() == phi && looputils::IsLoopInvariantValue(loop, mul->GetLhs())) {
factor = mul->GetLhs();
return true;
}
return false;
}
Value* CreateReducedPhi(Function& function, BasicBlock* header, BasicBlock* preheader,
const InductionVarInfo& iv, Value* factor) {
auto* reduced_phi = header->Insert<PhiInst>(
looputils::GetFirstNonPhiIndex(header), Type::GetInt32Type(), nullptr,
looputils::NextSyntheticName(function, "lsr.phi."));
Value* init = BuildMulValue(function, preheader, iv.start, factor, "lsr.init.");
reduced_phi->AddIncoming(init, preheader);
Value* step = BuildScaledValue(function, preheader, factor, std::abs(iv.stride),
"lsr.step.");
Instruction* next = nullptr;
if (iv.stride > 0) {
next = iv.latch->Insert<BinaryInst>(
looputils::GetTerminatorIndex(iv.latch), Opcode::Add, Type::GetInt32Type(),
reduced_phi, step, nullptr,
looputils::NextSyntheticName(function, "lsr.next."));
} else {
next = iv.latch->Insert<BinaryInst>(
looputils::GetTerminatorIndex(iv.latch), Opcode::Sub, Type::GetInt32Type(),
reduced_phi, step, nullptr,
looputils::NextSyntheticName(function, "lsr.next."));
}
reduced_phi->AddIncoming(next, iv.latch);
return reduced_phi;
}
bool ReduceLoopMultiplications(Function& function, const Loop& loop,
BasicBlock* preheader) {
if (!preheader || loop.latches.size() != 1) {
return false;
}
std::vector<InductionVarInfo> induction_vars;
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
InductionVarInfo info;
if (MatchSimpleInductionVariable(loop, preheader, phi, info)) {
induction_vars.push_back(info);
}
}
if (induction_vars.empty()) {
return false;
}
bool changed = false;
std::vector<Instruction*> to_remove;
for (const auto& iv : induction_vars) {
std::vector<std::pair<Instruction*, Value*>> candidates;
for (auto* block : loop.block_list) {
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst == iv.phi) {
continue;
}
Value* factor = nullptr;
if (IsMulCandidate(loop, inst, iv.phi, factor)) {
candidates.push_back({inst, factor});
}
}
}
if (candidates.empty()) {
continue;
}
std::unordered_map<Value*, Value*> reduced_cache;
for (const auto& candidate : candidates) {
auto* inst = candidate.first;
auto* factor = candidate.second;
auto cache_it = reduced_cache.find(factor);
Value* replacement = nullptr;
if (cache_it != reduced_cache.end()) {
replacement = cache_it->second;
} else {
replacement = CreateReducedPhi(function, loop.header, preheader, iv, factor);
reduced_cache.emplace(factor, replacement);
}
inst->ReplaceAllUsesWith(replacement);
to_remove.push_back(inst);
changed = true;
}
}
for (auto* inst : to_remove) {
if (inst && inst->GetParent()) {
inst->GetParent()->EraseInstruction(inst);
}
}
return changed;
}
bool RunLoopStrengthReductionOnFunction(Function& function) {
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
auto* old_preheader = loop->preheader;
auto* preheader = looputils::EnsurePreheader(function, *loop);
bool loop_changed = preheader != old_preheader;
loop_changed |= ReduceLoopMultiplications(function, *loop, preheader);
if (!loop_changed) {
continue;
}
changed = true;
local_changed = true;
break;
}
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLoopStrengthReduction(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoopStrengthReductionOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,400 @@
#include "ir/PassManager.h"
#include "ir/Analysis.h"
#include "ir/IR.h"
#include "LoopMemoryUtils.h"
#include "LoopPassUtils.h"
#include <unordered_map>
#include <vector>
namespace ir {
namespace {
struct CountedLoopInfo {
Loop* loop = nullptr;
BasicBlock* preheader = nullptr;
BasicBlock* header = nullptr;
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
CondBrInst* branch = nullptr;
BinaryInst* compare = nullptr;
Opcode compare_opcode = Opcode::ICmpLT;
Value* bound = nullptr;
loopmem::SimpleInductionVar induction_var;
std::vector<PhiInst*> phis;
};
bool HasSyntheticLoopTag(const std::string& name) {
return name.find("unroll.") != std::string::npos;
}
bool IsAlreadyTransformedLoop(const Loop& loop, BasicBlock* body) {
if (!loop.preheader || !loop.header || !body) {
return true;
}
if (HasSyntheticLoopTag(loop.preheader->GetName()) ||
HasSyntheticLoopTag(loop.header->GetName()) ||
HasSyntheticLoopTag(body->GetName())) {
return true;
}
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
const int incoming = looputils::GetPhiIncomingIndex(phi, loop.preheader);
if (incoming < 0) {
continue;
}
auto* incoming_phi = dyncast<PhiInst>(phi->GetIncomingValue(incoming));
if (incoming_phi && incoming_phi->GetParent() &&
HasSyntheticLoopTag(incoming_phi->GetParent()->GetName())) {
return true;
}
}
return false;
}
bool IsSupportedCompareOpcode(Opcode opcode) {
switch (opcode) {
case Opcode::ICmpLT:
case Opcode::ICmpLE:
case Opcode::ICmpGT:
case Opcode::ICmpGE:
return true;
default:
return false;
}
}
Opcode SwapCompareOpcode(Opcode opcode) {
switch (opcode) {
case Opcode::ICmpLT:
return Opcode::ICmpGT;
case Opcode::ICmpLE:
return Opcode::ICmpGE;
case Opcode::ICmpGT:
return Opcode::ICmpLT;
case Opcode::ICmpGE:
return Opcode::ICmpLE;
default:
return opcode;
}
}
int CountPayloadInstructions(BasicBlock* block) {
int count = 0;
if (!block) {
return 0;
}
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator()) {
break;
}
++count;
}
return count;
}
int ChooseUnrollFactor(BasicBlock* body) {
const int inst_count = CountPayloadInstructions(body);
int mem_ops = 0;
for (const auto& inst_ptr : body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator()) {
break;
}
if (dyncast<LoadInst>(inst) || dyncast<StoreInst>(inst)) {
++mem_ops;
}
}
if (inst_count >= 2 && inst_count <= 6 && mem_ops <= 2) {
return 4;
}
if (inst_count >= 2 && inst_count <= 18) {
return 2;
}
return 1;
}
bool HasUnsafeLoopCarriedMemoryDependence(
const std::vector<loopmem::MemoryAccessInfo>& accesses, int iv_stride) {
for (std::size_t i = 0; i < accesses.size(); ++i) {
if (accesses[i].is_write &&
loopmem::HasCrossIterationDependence(accesses[i].ptr, accesses[i].ptr,
iv_stride)) {
return true;
}
for (std::size_t j = i + 1; j < accesses.size(); ++j) {
if (!accesses[i].is_write && !accesses[j].is_write) {
continue;
}
if (loopmem::HasCrossIterationDependence(accesses[i].ptr, accesses[j].ptr,
iv_stride)) {
return true;
}
}
}
return false;
}
bool MatchCountedLoop(Loop& loop, CountedLoopInfo& info) {
if (!loop.preheader || !loop.header || !loop.IsInnermost()) {
return false;
}
BasicBlock* body = nullptr;
BasicBlock* exit = nullptr;
if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) {
return false;
}
if (IsAlreadyTransformedLoop(loop, body)) {
return false;
}
auto* branch = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
if (!branch || branch->GetThenBlock() != body) {
return false;
}
auto* compare = dyncast<BinaryInst>(branch->GetCondition());
if (!compare || !compare->GetType()->IsBool() ||
!IsSupportedCompareOpcode(compare->GetOpcode())) {
return false;
}
bool found_iv = false;
loopmem::SimpleInductionVar induction_var;
std::vector<PhiInst*> phis;
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
phis.push_back(phi);
if (!found_iv &&
loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) {
found_iv = true;
}
}
if (!found_iv) {
return false;
}
Opcode compare_opcode = compare->GetOpcode();
Value* bound = nullptr;
if (compare->GetLhs() == induction_var.phi &&
looputils::IsLoopInvariantValue(loop, compare->GetRhs())) {
bound = compare->GetRhs();
} else if (compare->GetRhs() == induction_var.phi &&
looputils::IsLoopInvariantValue(loop, compare->GetLhs())) {
bound = compare->GetLhs();
compare_opcode = SwapCompareOpcode(compare_opcode);
} else {
return false;
}
if (!bound) {
return false;
}
if ((induction_var.stride > 0 &&
!(compare_opcode == Opcode::ICmpLT || compare_opcode == Opcode::ICmpLE)) ||
(induction_var.stride < 0 &&
!(compare_opcode == Opcode::ICmpGT || compare_opcode == Opcode::ICmpGE))) {
return false;
}
const auto accesses = loopmem::CollectMemoryAccesses(loop, induction_var.phi);
if (HasUnsafeLoopCarriedMemoryDependence(accesses, induction_var.stride)) {
return false;
}
for (const auto& inst_ptr : loop.header->GetInstructions()) {
auto* inst = inst_ptr.get();
if (dyncast<PhiInst>(inst) || inst == compare || inst->IsTerminator()) {
continue;
}
return false;
}
for (const auto& inst_ptr : body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator()) {
continue;
}
if (!looputils::IsCloneableInstruction(inst) || dyncast<CallInst>(inst) ||
dyncast<MemsetInst>(inst) || dyncast<AllocaInst>(inst)) {
return false;
}
}
info.loop = &loop;
info.preheader = loop.preheader;
info.header = loop.header;
info.body = body;
info.exit = exit;
info.branch = branch;
info.compare = compare;
info.compare_opcode = compare_opcode;
info.bound = bound;
info.induction_var = induction_var;
info.phis = std::move(phis);
return true;
}
Value* BuildAdjustedBound(Function& function, BasicBlock* preheader, Value* bound,
int stride, int factor) {
const int delta = std::abs(stride) * (factor - 1);
if (delta == 0) {
return bound;
}
if (auto* ci = dyncast<ConstantInt>(bound)) {
return looputils::ConstInt(stride > 0 ? ci->GetValue() - delta : ci->GetValue() + delta);
}
return preheader->Insert<BinaryInst>(
looputils::GetTerminatorIndex(preheader),
stride > 0 ? Opcode::Sub : Opcode::Add, Type::GetInt32Type(), bound,
looputils::ConstInt(delta), nullptr,
looputils::NextSyntheticName(function, "unroll.bound."));
}
bool RunLoopUnrollOnFunction(Function& function) {
if (function.IsExternal() || !function.GetEntryBlock()) {
return false;
}
bool changed = false;
while (true) {
DominatorTree dom_tree(function);
LoopInfo loop_info(function, dom_tree);
bool local_changed = false;
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
CountedLoopInfo info;
if (!MatchCountedLoop(*loop, info)) {
continue;
}
const int factor = ChooseUnrollFactor(info.body);
if (factor <= 1) {
continue;
}
auto* unrolled_header =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.header"));
auto* unrolled_body =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.body"));
auto* unrolled_exit =
function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.exit"));
std::unordered_map<Value*, Value*> remap;
std::unordered_map<PhiInst*, PhiInst*> unrolled_phis;
std::unordered_map<PhiInst*, PhiInst*> exit_phis;
std::unordered_map<PhiInst*, Value*> current_phi_values;
std::unordered_map<PhiInst*, Value*> latch_values;
for (auto* phi : info.phis) {
auto* cloned_phi = unrolled_header->Append<PhiInst>(
phi->GetType(), nullptr,
looputils::NextSyntheticName(function, "unroll.phi."));
const int preheader_index = looputils::GetPhiIncomingIndex(phi, info.preheader);
const int latch_index = looputils::GetPhiIncomingIndex(phi, info.body);
if (preheader_index < 0 || latch_index < 0) {
continue;
}
cloned_phi->AddIncoming(phi->GetIncomingValue(preheader_index), info.preheader);
remap[phi] = cloned_phi;
unrolled_phis.emplace(phi, cloned_phi);
current_phi_values.emplace(phi, cloned_phi);
latch_values.emplace(phi, phi->GetIncomingValue(latch_index));
}
auto* adjusted_bound = BuildAdjustedBound(function, info.preheader, info.bound,
info.induction_var.stride, factor);
auto* unrolled_cond = unrolled_header->Append<BinaryInst>(
info.compare_opcode, Type::GetBoolType(), unrolled_phis[info.induction_var.phi],
adjusted_bound, nullptr,
looputils::NextSyntheticName(function, "unroll.cmp."));
unrolled_header->Append<CondBrInst>(unrolled_cond, unrolled_body, unrolled_exit, nullptr);
unrolled_header->AddPredecessor(info.preheader);
unrolled_header->AddSuccessor(unrolled_body);
unrolled_header->AddSuccessor(unrolled_exit);
for (int lane = 0; lane < factor; ++lane) {
for (const auto& inst_ptr : info.body->GetInstructions()) {
auto* inst = inst_ptr.get();
if (inst->IsTerminator()) {
continue;
}
looputils::CloneInstruction(function, inst, unrolled_body, remap,
"unroll." + std::to_string(lane) + ".");
}
std::unordered_map<PhiInst*, Value*> next_phi_values;
for (const auto& entry : latch_values) {
next_phi_values.emplace(entry.first,
looputils::RemapValue(remap, entry.second));
}
for (const auto& entry : next_phi_values) {
remap[entry.first] = entry.second;
current_phi_values[entry.first] = entry.second;
}
}
for (const auto& entry : unrolled_phis) {
entry.second->AddIncoming(current_phi_values[entry.first], unrolled_body);
}
unrolled_body->Append<UncondBrInst>(unrolled_header, nullptr);
unrolled_body->AddPredecessor(unrolled_header);
unrolled_body->AddSuccessor(unrolled_header);
unrolled_header->AddPredecessor(unrolled_body);
for (const auto& entry : unrolled_phis) {
auto* exit_phi = unrolled_exit->Append<PhiInst>(
entry.first->GetType(), nullptr,
looputils::NextSyntheticName(function, "unroll.exit."));
exit_phi->AddIncoming(entry.second, unrolled_header);
exit_phis.emplace(entry.first, exit_phi);
}
unrolled_exit->Append<UncondBrInst>(info.header, nullptr);
unrolled_exit->AddPredecessor(unrolled_header);
unrolled_exit->AddSuccessor(info.header);
if (!looputils::RedirectSuccessorEdge(info.preheader, info.header, unrolled_header)) {
continue;
}
info.header->RemovePredecessor(info.preheader);
info.header->AddPredecessor(unrolled_exit);
for (auto* phi : info.phis) {
looputils::ReplacePhiIncoming(phi, info.preheader, exit_phis[phi], unrolled_exit);
}
changed = true;
local_changed = true;
break;
}
if (!local_changed) {
break;
}
}
return changed;
}
} // namespace
bool RunLoopUnroll(Module& module) {
bool changed = false;
for (const auto& function : module.GetFunctions()) {
if (function) {
changed |= RunLoopUnrollOnFunction(*function);
}
}
return changed;
}
} // namespace ir

@ -0,0 +1,260 @@
#pragma once
#include "ir/IR.h"
#include "PassUtils.h"
#include <cstdint>
#include <unordered_set>
#include <vector>
namespace ir::memutils {
enum class PointerRootKind {
Local,
Global,
ReadonlyGlobal,
Param,
Unknown,
};
struct AddressComponent {
bool is_constant = false;
std::int64_t constant = 0;
Value* value = nullptr;
bool operator==(const AddressComponent& rhs) const {
return is_constant == rhs.is_constant && constant == rhs.constant &&
value == rhs.value;
}
};
struct AddressKey {
PointerRootKind kind = PointerRootKind::Unknown;
Value* root = nullptr;
std::vector<AddressComponent> components;
bool operator==(const AddressKey& rhs) const {
return kind == rhs.kind && root == rhs.root && components == rhs.components;
}
};
struct AddressKeyHash {
std::size_t operator()(const AddressKey& key) const {
std::size_t h = static_cast<std::size_t>(key.kind);
h ^= std::hash<Value*>{}(key.root) + 0x9e3779b9 + (h << 6) + (h >> 2);
for (const auto& component : key.components) {
h ^= std::hash<bool>{}(component.is_constant) + 0x9e3779b9 + (h << 6) + (h >> 2);
if (component.is_constant) {
h ^= std::hash<std::int64_t>{}(component.constant) + 0x9e3779b9 + (h << 6) +
(h >> 2);
} else {
h ^= std::hash<Value*>{}(component.value) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
}
return h;
}
};
struct EscapeSummary {
std::unordered_set<Value*> escaped_locals;
bool IsEscaped(Value* value) const {
return value != nullptr && escaped_locals.find(value) != escaped_locals.end();
}
};
inline bool IsNoEscapePointerUse(Value* current, Instruction* user) {
if (!current || !user) {
return false;
}
if (auto* load = dyncast<LoadInst>(user)) {
return load->GetPtr() == current;
}
if (auto* store = dyncast<StoreInst>(user)) {
return store->GetPtr() == current;
}
if (auto* memset = dyncast<MemsetInst>(user)) {
return memset->GetDest() == current;
}
return false;
}
inline bool PointerValueEscapes(Value* current, Value* root,
std::unordered_set<Value*>& visiting) {
if (!current || !root || !visiting.insert(current).second) {
return false;
}
for (const auto& use : current->GetUses()) {
auto* user = dyncast<Instruction>(use.GetUser());
if (!user) {
return true;
}
if (auto* gep = dyncast<GetElementPtrInst>(user)) {
if (gep->GetPointer() == current &&
PointerValueEscapes(gep, root, visiting)) {
return true;
}
continue;
}
if (IsNoEscapePointerUse(current, user)) {
continue;
}
return true;
}
return false;
}
inline EscapeSummary AnalyzeEscapes(Function& function) {
EscapeSummary summary;
for (const auto& block_ptr : function.GetBlocks()) {
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
auto* alloca = dyncast<AllocaInst>(inst_ptr.get());
if (!alloca) {
continue;
}
std::unordered_set<Value*> visiting;
if (PointerValueEscapes(alloca, alloca, visiting)) {
summary.escaped_locals.insert(alloca);
}
}
}
return summary;
}
inline PointerRootKind ClassifyRoot(Value* root, const EscapeSummary* summary) {
if (root == nullptr) {
return PointerRootKind::Unknown;
}
if (auto* global = dyncast<GlobalValue>(root)) {
return global->IsConstant() ? PointerRootKind::ReadonlyGlobal
: PointerRootKind::Global;
}
if (isa<Argument>(root)) {
return PointerRootKind::Param;
}
if (isa<AllocaInst>(root)) {
if (summary != nullptr && summary->IsEscaped(root)) {
return PointerRootKind::Unknown;
}
return PointerRootKind::Local;
}
return PointerRootKind::Unknown;
}
inline Value* StripPointerRoot(Value* pointer) {
auto* current = pointer;
while (auto* gep = dyncast<GetElementPtrInst>(current)) {
current = gep->GetPointer();
}
return current;
}
inline AddressComponent MakeAddressComponent(Value* value) {
if (auto* ci = dyncast<ConstantInt>(value)) {
return {true, ci->GetValue(), nullptr};
}
if (auto* cb = dyncast<ConstantI1>(value)) {
return {true, cb->GetValue() ? 1 : 0, nullptr};
}
return {false, 0, value};
}
inline bool BuildExactAddressKey(Value* pointer, const EscapeSummary* summary,
AddressKey& key) {
if (!pointer) {
return false;
}
if (auto* gep = dyncast<GetElementPtrInst>(pointer)) {
if (!BuildExactAddressKey(gep->GetPointer(), summary, key)) {
return false;
}
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
key.components.push_back(MakeAddressComponent(gep->GetIndex(i)));
}
return true;
}
key.kind = ClassifyRoot(pointer, summary);
key.root = pointer;
key.components.clear();
return true;
}
inline bool HasOnlyConstantComponents(const AddressKey& key) {
for (const auto& component : key.components) {
if (!component.is_constant) {
return false;
}
}
return true;
}
inline bool MayAliasConservatively(const AddressKey& lhs, const AddressKey& rhs) {
if (lhs.kind == PointerRootKind::Unknown || rhs.kind == PointerRootKind::Unknown) {
return true;
}
if (lhs.kind != rhs.kind || lhs.root != rhs.root) {
return false;
}
if (lhs.components == rhs.components) {
return true;
}
if (HasOnlyConstantComponents(lhs) && HasOnlyConstantComponents(rhs)) {
return false;
}
return true;
}
inline bool CallMayReadRoot(Function* callee, PointerRootKind kind) {
if (!callee) {
return true;
}
if (callee->HasUnknownEffects()) {
return true;
}
switch (kind) {
case PointerRootKind::ReadonlyGlobal:
return callee->ReadsGlobalMemory();
case PointerRootKind::Global:
return callee->ReadsGlobalMemory() || callee->WritesGlobalMemory();
case PointerRootKind::Param:
return callee->ReadsParamMemory() || callee->WritesParamMemory();
case PointerRootKind::Local:
return false;
case PointerRootKind::Unknown:
return callee->MayReadMemory();
}
return true;
}
inline bool CallMayWriteRoot(Function* callee, PointerRootKind kind) {
if (!callee) {
return true;
}
if (callee->HasUnknownEffects()) {
return true;
}
switch (kind) {
case PointerRootKind::ReadonlyGlobal:
return false;
case PointerRootKind::Global:
return callee->WritesGlobalMemory();
case PointerRootKind::Param:
return callee->WritesParamMemory();
case PointerRootKind::Local:
return false;
case PointerRootKind::Unknown:
return callee->MayWriteMemory();
}
return true;
}
inline bool IsPureCall(const CallInst* call) {
auto* callee = call == nullptr ? nullptr : call->GetCallee();
return callee != nullptr && callee->CanDiscardUnusedCall() &&
!callee->MayReadMemory();
}
} // namespace ir::memutils

@ -11,7 +11,35 @@ void RunIRPassPipeline(Module& module) {
if (disable_mem2reg != nullptr && disable_mem2reg[0] != '\0' && disable_mem2reg[0] != '0') {
return;
}
RunMem2Reg(module);
constexpr int kMaxIterations = 8;
for (int iteration = 0; iteration < kMaxIterations; ++iteration) {
bool changed = false;
changed |= RunFunctionInlining(module);
changed |= RunConstProp(module);
changed |= RunConstFold(module);
changed |= RunGVN(module);
changed |= RunLoadStoreElim(module);
changed |= RunCSE(module);
changed |= RunDCE(module);
changed |= RunCFGSimplify(module);
changed |= RunLICM(module);
changed |= RunLoopStrengthReduction(module);
changed |= RunLoopFission(module);
changed |= RunLoopUnroll(module);
changed |= RunConstProp(module);
changed |= RunConstFold(module);
changed |= RunGVN(module);
changed |= RunLoadStoreElim(module);
changed |= RunCSE(module);
changed |= RunDCE(module);
changed |= RunCFGSimplify(module);
if (!changed) {
break;
}
}
}
} // namespace ir
} // namespace ir

@ -0,0 +1,234 @@
#pragma once
#include "ir/IR.h"
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <memory>
#include <unordered_set>
#include <vector>
namespace ir::passutils {
inline std::uint32_t FloatBits(float value) {
std::uint32_t bits = 0;
std::memcpy(&bits, &value, sizeof(bits));
return bits;
}
inline bool AreEquivalentValues(Value* lhs, Value* rhs) {
if (lhs == rhs) {
return true;
}
auto* lhs_i32 = dyncast<ConstantInt>(lhs);
auto* rhs_i32 = dyncast<ConstantInt>(rhs);
if (lhs_i32 && rhs_i32) {
return lhs_i32->GetValue() == rhs_i32->GetValue();
}
auto* lhs_i1 = dyncast<ConstantI1>(lhs);
auto* rhs_i1 = dyncast<ConstantI1>(rhs);
if (lhs_i1 && rhs_i1) {
return lhs_i1->GetValue() == rhs_i1->GetValue();
}
auto* lhs_f32 = dyncast<ConstantFloat>(lhs);
auto* rhs_f32 = dyncast<ConstantFloat>(rhs);
if (lhs_f32 && rhs_f32) {
return FloatBits(lhs_f32->GetValue()) == FloatBits(rhs_f32->GetValue());
}
return false;
}
inline std::vector<BasicBlock*> CollectReachableBlocks(Function& function) {
std::vector<BasicBlock*> order;
auto* entry = function.GetEntryBlock();
if (!entry) {
return order;
}
std::unordered_set<BasicBlock*> visited;
std::vector<BasicBlock*> stack{entry};
while (!stack.empty()) {
auto* block = stack.back();
stack.pop_back();
if (!block || !visited.insert(block).second) {
continue;
}
order.push_back(block);
const auto& succs = block->GetSuccessors();
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
if (*it != nullptr) {
stack.push_back(*it);
}
}
}
return order;
}
inline bool IsSideEffectingInstruction(const Instruction* inst) {
if (!inst) {
return false;
}
switch (inst->GetOpcode()) {
case Opcode::Store:
case Opcode::Memset:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Return:
case Opcode::Unreachable:
return true;
case Opcode::Call: {
auto* call = dyncast<const CallInst>(inst);
auto* callee = call == nullptr ? nullptr : call->GetCallee();
return callee == nullptr || !callee->CanDiscardUnusedCall();
}
default:
return false;
}
}
inline bool IsTriviallyDead(Instruction* inst) {
return inst != nullptr && !IsSideEffectingInstruction(inst) &&
inst->GetUses().empty();
}
inline void RemoveIncomingForBlock(PhiInst* phi, BasicBlock* block) {
if (!phi || !block) {
return;
}
for (int i = phi->GetNumIncomings() - 1; i >= 0; --i) {
if (phi->GetIncomingBlock(i) != block) {
continue;
}
phi->RemoveOperand(static_cast<size_t>(2 * i + 1));
phi->RemoveOperand(static_cast<size_t>(2 * i));
}
}
inline void RemoveIncomingFromSuccessor(BasicBlock* succ, BasicBlock* pred) {
if (!succ || !pred) {
return;
}
for (const auto& inst_ptr : succ->GetInstructions()) {
auto* phi = dyncast<PhiInst>(inst_ptr.get());
if (!phi) {
break;
}
RemoveIncomingForBlock(phi, pred);
}
}
inline void ReplaceTerminatorWithBr(BasicBlock* block, BasicBlock* dest) {
auto& instructions = block->GetInstructions();
if (instructions.empty() || !instructions.back()->IsTerminator()) {
return;
}
instructions.back()->ClearAllOperands();
auto branch = std::make_unique<UncondBrInst>(dest, nullptr);
branch->SetParent(block);
instructions.back() = std::move(branch);
}
inline bool SimplifyPhiInst(PhiInst* phi) {
if (!phi) {
return false;
}
Value* unique_value = nullptr;
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
auto* incoming = phi->GetIncomingValue(i);
if (incoming == phi) {
continue;
}
if (unique_value == nullptr) {
unique_value = incoming;
continue;
}
if (!AreEquivalentValues(unique_value, incoming)) {
return false;
}
}
if (unique_value == nullptr) {
return false;
}
auto* parent = phi->GetParent();
phi->ReplaceAllUsesWith(unique_value);
parent->EraseInstruction(phi);
return true;
}
inline void EraseBlock(Function& function, BasicBlock* block) {
if (!block) {
return;
}
auto& blocks = function.GetBlocks();
blocks.erase(std::remove_if(blocks.begin(), blocks.end(),
[&](const std::unique_ptr<BasicBlock>& current) {
return current.get() == block;
}),
blocks.end());
}
inline bool RemoveUnreachableBlocks(Function& function) {
auto reachable = CollectReachableBlocks(function);
std::unordered_set<BasicBlock*> reachable_set(reachable.begin(), reachable.end());
std::vector<BasicBlock*> dead_blocks;
for (const auto& block_ptr : function.GetBlocks()) {
auto* block = block_ptr.get();
if (reachable_set.find(block) == reachable_set.end()) {
dead_blocks.push_back(block);
}
}
if (dead_blocks.empty()) {
return false;
}
for (auto* block : dead_blocks) {
auto preds = block->GetPredecessors();
auto succs = block->GetSuccessors();
for (auto* succ : succs) {
RemoveIncomingFromSuccessor(succ, block);
succ->RemovePredecessor(block);
}
for (auto* pred : preds) {
pred->RemoveSuccessor(block);
}
}
for (auto* block : dead_blocks) {
for (const auto& inst_ptr : block->GetInstructions()) {
inst_ptr->ClearAllOperands();
}
}
for (auto* block : dead_blocks) {
EraseBlock(function, block);
}
return true;
}
inline bool IsCommutativeOpcode(Opcode opcode) {
switch (opcode) {
case Opcode::Add:
case Opcode::Mul:
case Opcode::And:
case Opcode::Or:
case Opcode::Xor:
case Opcode::FAdd:
case Opcode::FMul:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
return true;
default:
return false;
}
}
} // namespace ir::passutils

@ -1,6 +1,8 @@
#include "irgen/IRGen.h"
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <utility>
@ -213,6 +215,16 @@ void IRGenImpl::EmitGlobalVarDef(SysYParser::VarDefContext& ctx, SemanticType ty
symbol->dims = ParseArrayDims(ctx.constExp());
if (symbol->is_array) {
// Leave uninitialized globals as zeroinitializer instead of materializing
// an explicit all-zero constant array, which can explode memory usage.
if (ctx.initVal() == nullptr) {
global->SetInitializer(nullptr);
return;
}
if (IsExplicitZeroInitVal(ctx.initVal(), type)) {
global->SetInitializer(nullptr);
return;
}
auto flat = FlattenInitVal(ctx.initVal(), type, symbol->dims);
std::vector<ir::Value*> elements;
elements.reserve(flat.size());
@ -250,7 +262,14 @@ void IRGenImpl::EmitGlobalConstDef(SysYParser::ConstDefContext& ctx,
global->SetConstant(true);
if (symbol->is_array) {
if (IsExplicitZeroConstInitVal(ctx.constInitVal(), type)) {
symbol->const_array.clear();
symbol->const_array_all_zero = true;
global->SetInitializer(nullptr);
return;
}
symbol->const_array = FlattenConstInitVal(ctx.constInitVal(), type, symbol->dims);
symbol->const_array_all_zero = false;
std::vector<ir::Value*> elements;
elements.reserve(symbol->const_array.size());
for (const auto& value : symbol->const_array) {
@ -359,7 +378,15 @@ void IRGenImpl::EmitLocalConstDef(SysYParser::ConstDefContext& ctx,
return;
}
if (IsExplicitZeroConstInitVal(ctx.constInitVal(), type)) {
symbol->const_array.clear();
symbol->const_array_all_zero = true;
ZeroInitializeLocalArray(symbol->ir_value, type, symbol->dims);
return;
}
symbol->const_array = FlattenConstInitVal(ctx.constInitVal(), type, symbol->dims);
symbol->const_array_all_zero = false;
ZeroInitializeLocalArray(symbol->ir_value, type, symbol->dims);
for (size_t i = 0; i < symbol->const_array.size(); ++i) {
if (symbol->const_array[i].type == SemanticType::Int && symbol->const_array[i].int_value == 0) {
@ -537,6 +564,54 @@ size_t IRGenImpl::FlattenIndices(const std::vector<int>& dims,
return offset;
}
bool IRGenImpl::IsZeroConstant(const ConstantValue& value) const {
switch (value.type) {
case SemanticType::Int:
return value.int_value == 0;
case SemanticType::Float: {
std::uint32_t bits = 0;
std::memcpy(&bits, &value.float_value, sizeof(bits));
return bits == 0;
}
case SemanticType::Void:
return false;
}
return false;
}
bool IRGenImpl::IsExplicitZeroConstInitVal(SysYParser::ConstInitValContext* ctx,
SemanticType base_type) {
if (ctx == nullptr) {
return true;
}
if (ctx->constExp() != nullptr) {
return IsZeroConstant(
ConvertConst(EvalConstAddExp(*ctx->constExp()->addExp()), base_type));
}
for (auto* child : ctx->constInitVal()) {
if (!IsExplicitZeroConstInitVal(child, base_type)) {
return false;
}
}
return true;
}
bool IRGenImpl::IsExplicitZeroInitVal(SysYParser::InitValContext* ctx,
SemanticType base_type) {
if (ctx == nullptr) {
return true;
}
if (ctx->exp() != nullptr) {
return IsZeroConstant(ConvertConst(EvalConstExp(*ctx->exp()), base_type));
}
for (auto* child : ctx->initVal()) {
if (!IsExplicitZeroInitVal(child, base_type)) {
return false;
}
}
return true;
}
ConstantValue IRGenImpl::ZeroConst(SemanticType type) const {
ConstantValue value;
value.type = type;

@ -280,6 +280,9 @@ ConstantValue IRGenImpl::EvalConstLVal(SysYParser::LValContext& ctx) {
}
}
const auto offset = FlattenIndices(symbol->dims, indices);
if (symbol->const_array_all_zero) {
return ZeroConst(symbol->type);
}
if (offset >= symbol->const_array.size()) {
ThrowError(&ctx, "???????????: " + name);
}

@ -25,6 +25,17 @@ std::string IRGenImpl::NextBlockName(const std::string& prefix) {
return module_.GetContext().NextBlockName(prefix);
}
void IRGenImpl::ApplyFunctionSema(const std::string& name, ir::Function& function) {
const auto* info = sema_.LookupFunction(name);
if (info == nullptr) {
return;
}
function.SetEffectInfo(info->reads_global_memory, info->writes_global_memory,
info->reads_param_memory, info->writes_param_memory,
info->has_io, info->has_unknown_effects,
info->is_recursive);
}
std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
if (ctx == nullptr) {
ThrowError(ctx, "??????");
@ -96,6 +107,7 @@ void IRGenImpl::RegisterBuiltinFunctions() {
auto* function = module_.CreateFunction(
builtin.name, GetIRScalarType(builtin.return_type), ir_param_types,
ir_param_names, true);
ApplyFunctionSema(builtin.name, *function);
SymbolEntry entry;
entry.kind = SymbolKind::Function;
@ -171,6 +183,7 @@ void IRGenImpl::PredeclareFunction(SysYParser::FuncDefContext& ctx) {
auto* function = module_.CreateFunction(
name, GetIRScalarType(function_type.return_type),
BuildFunctionIRParamTypes(function_type), BuildFunctionIRParamNames(ctx), false);
ApplyFunctionSema(name, *function);
SymbolEntry entry;
entry.kind = SymbolKind::Function;

@ -60,14 +60,16 @@ int main(int argc, char** argv) {
need_blank_line = true;
}
if (opts.emit_asm) {
auto machine_module = mir::LowerToMIR(*asm_module);
mir::RunRegAlloc(*machine_module);
mir::RunFrameLowering(*machine_module);
if (need_blank_line) {
std::cout << "\n";
}
mir::PrintAsm(*machine_module, std::cout);
if (opts.emit_asm) {
auto machine_module = mir::LowerToMIR(*asm_module);
mir::RunMIRPreRegAllocPassPipeline(*machine_module);
mir::RunRegAlloc(*machine_module);
mir::RunMIRPostRegAllocPassPipeline(*machine_module);
mir::RunFrameLowering(*machine_module);
if (need_blank_line) {
std::cout << "\n";
}
mir::PrintAsm(*machine_module, std::cout);
}
#else
if (opts.emit_ir || opts.emit_asm) {

@ -605,8 +605,8 @@ bool TryEmitDirectMemoryAccess(const MachineFunction& function, const AddressExp
void EmitAddressExpr(const MachineFunction& function, const AddressExpr& address,
const char* dst_reg, std::ostream& os) {
std::unordered_map<int, std::string> preserved_index_regs;
int next_preserve_scratch = 15;
int preserved_index_vreg = -1;
std::string preserved_index_reg;
for (const auto& term : address.scaled_vregs) {
const int index_vreg = term.first;
const auto& alloc = function.GetAllocation(index_vreg);
@ -617,16 +617,15 @@ void EmitAddressExpr(const MachineFunction& function, const AddressExpr& address
if (std::string(GetPhysRegName(alloc.phys, ValueType::Ptr)) != dst_reg) {
continue;
}
auto it = preserved_index_regs.find(index_vreg);
if (it != preserved_index_regs.end()) {
continue;
if (preserved_index_vreg >= 0 && preserved_index_vreg != index_vreg) {
throw std::runtime_error(
FormatError("mir", "multiple address indices conflict with lea destination"));
}
const auto scratch = PhysReg{RegClass::GPR, next_preserve_scratch--};
const char* scratch_name =
GetPhysRegName(scratch, function.GetVRegInfo(index_vreg).type);
EmitCopy(os, scratch_name,
preserved_index_vreg = index_vreg;
preserved_index_reg =
GetPhysRegName({RegClass::GPR, 12}, function.GetVRegInfo(index_vreg).type);
EmitCopy(os, preserved_index_reg.c_str(),
GetPhysRegName(alloc.phys, function.GetVRegInfo(index_vreg).type), false);
preserved_index_regs.emplace(index_vreg, scratch_name);
}
switch (address.base_kind) {
@ -656,9 +655,8 @@ void EmitAddressExpr(const MachineFunction& function, const AddressExpr& address
for (const auto& term : address.scaled_vregs) {
std::string index_reg;
auto preserved_it = preserved_index_regs.find(term.first);
if (preserved_it != preserved_index_regs.end()) {
index_reg = preserved_it->second;
if (term.first == preserved_index_vreg) {
index_reg = preserved_index_reg;
} else {
index_reg = MaterializeGprUse(function, MachineOperand::VReg(term.first),
ValueType::I32, 10, os);

@ -3,20 +3,54 @@
#include <algorithm>
#include <cstring>
#include <memory>
#include <stdexcept>
#include <unordered_map>
#include <utility>
#include <vector>
#include <stdexcept>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "ir/IR.h"
#include "utils/Log.h"
namespace mir {
namespace {
enum class LoweredKind { Invalid, VReg, StackObject, Global };
struct LoweredValue {
namespace mir {
namespace {
enum class LoweredKind { Invalid, VReg, StackObject, Global };
std::vector<ir::BasicBlock*> CollectLoweringOrder(ir::Function& function) {
std::vector<ir::BasicBlock*> order;
auto* entry = function.GetEntryBlock();
if (!entry) {
return order;
}
std::unordered_set<ir::BasicBlock*> visited;
std::vector<ir::BasicBlock*> stack{entry};
while (!stack.empty()) {
auto* block = stack.back();
stack.pop_back();
if (!block || !visited.insert(block).second) {
continue;
}
order.push_back(block);
const auto& succs = block->GetSuccessors();
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
if (*it != nullptr) {
stack.push_back(*it);
}
}
}
for (const auto& block : function.GetBlocks()) {
if (block && visited.insert(block.get()).second) {
order.push_back(block.get());
}
}
return order;
}
struct LoweredValue {
LoweredKind kind = LoweredKind::Invalid;
ValueType type = ValueType::Void;
int index = -1;
@ -905,13 +939,14 @@ class Lowerer {
param_types.push_back(LowerType(type));
}
auto machine_function = std::make_unique<MachineFunction>(
function.GetName(), LowerType(function.GetReturnType()), std::move(param_types));
current_ir_function_ = &function;
current_function_ = machine_function.get();
for (const auto& block : function.GetBlocks()) {
blocks_[block.get()] = current_function_->CreateBlock(block->GetName());
auto machine_function = std::make_unique<MachineFunction>(
function.GetName(), LowerType(function.GetReturnType()), std::move(param_types));
current_ir_function_ = &function;
current_function_ = machine_function.get();
const auto ordered_blocks = CollectLoweringOrder(function);
for (const auto& block : function.GetBlocks()) {
blocks_[block.get()] = current_function_->CreateBlock(block->GetName());
}
if (!function.GetBlocks().empty()) {
@ -924,14 +959,14 @@ class Lowerer {
values_[argument.get()] = lowered;
}
}
PreparePhiResults(function);
for (const auto& block : function.GetBlocks()) {
current_block_ = blocks_.at(block.get());
for (const auto& inst : block->GetInstructions()) {
LowerInstruction(*inst);
}
PreparePhiResults(function);
for (auto* block : ordered_blocks) {
current_block_ = blocks_.at(block);
for (const auto& inst : block->GetInstructions()) {
LowerInstruction(*inst);
}
}
EmitPhiCopies(function);

@ -36,26 +36,37 @@ bool BelongsToClass(ValueType type, RegClass reg_class) {
return IsFPR(type) ? reg_class == RegClass::FPR : reg_class == RegClass::GPR;
}
bool IsCalleeSaved(PhysReg reg) {
if (reg.reg_class == RegClass::GPR) {
return reg.index >= 19 && reg.index <= 28;
}
return reg.index >= 8 && reg.index <= 15;
}
std::vector<PhysReg> GetAllocatableRegs(RegClass reg_class) {
std::vector<PhysReg> regs;
if (reg_class == RegClass::FPR) {
for (int i = 8; i <= 15; ++i) {
regs.push_back({RegClass::FPR, i});
}
return regs;
}
for (int i = 19; i <= 28; ++i) {
regs.push_back({RegClass::GPR, i});
}
return regs;
}
bool IsCalleeSaved(PhysReg reg) {
if (reg.reg_class == RegClass::GPR) {
return reg.index >= 19 && reg.index <= 28;
}
return reg.index >= 8 && reg.index <= 15;
}
bool IsCallerSaved(PhysReg reg) {
return !IsCalleeSaved(reg);
}
std::vector<PhysReg> GetAllocatableRegs(RegClass reg_class) {
std::vector<PhysReg> regs;
if (reg_class == RegClass::FPR) {
for (int i = 19; i <= 31; ++i) {
regs.push_back({RegClass::FPR, i});
}
for (int i = 8; i <= 15; ++i) {
regs.push_back({RegClass::FPR, i});
}
return regs;
}
regs.push_back({RegClass::GPR, 8});
for (int i = 13; i <= 15; ++i) {
regs.push_back({RegClass::GPR, i});
}
for (int i = 19; i <= 28; ++i) {
regs.push_back({RegClass::GPR, i});
}
return regs;
}
int CreateSpillSlot(MachineFunction& function, int vreg) {
const auto type = function.GetVRegInfo(vreg).type;
@ -171,12 +182,13 @@ class GeorgeColoringAllocator {
reg_class_(reg_class),
regs_(GetAllocatableRegs(reg_class)),
k_(static_cast<int>(regs_.size())),
block_infos_(block_infos),
num_vregs_(static_cast<int>(function.GetVRegs().size())),
in_class_(static_cast<size_t>(num_vregs_), 0),
adjacency_(static_cast<size_t>(num_vregs_)),
degree_(static_cast<size_t>(num_vregs_), 0),
spill_cost_(static_cast<size_t>(num_vregs_), 0.0),
block_infos_(block_infos),
num_vregs_(static_cast<int>(function.GetVRegs().size())),
in_class_(static_cast<size_t>(num_vregs_), 0),
live_across_call_(static_cast<size_t>(num_vregs_), 0),
adjacency_(static_cast<size_t>(num_vregs_)),
degree_(static_cast<size_t>(num_vregs_), 0),
spill_cost_(static_cast<size_t>(num_vregs_), 0.0),
move_list_(static_cast<size_t>(num_vregs_)),
alias_(static_cast<size_t>(num_vregs_), -1),
color_index_(static_cast<size_t>(num_vregs_), -1),
@ -235,15 +247,25 @@ class GeorgeColoringAllocator {
}
}
const auto& instructions = block->GetInstructions();
for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) {
const auto& inst = *it;
const auto& instructions = block->GetInstructions();
for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) {
const auto& inst = *it;
auto defs = FilterClass(inst.GetDefs());
auto uses = FilterClass(inst.GetUses());
if (inst.GetOpcode() == MachineInstr::Opcode::Call ||
inst.GetOpcode() == MachineInstr::Opcode::Memset) {
for (int vreg = 0; vreg < num_vregs_; ++vreg) {
if (live[static_cast<size_t>(vreg)] &&
in_class_[static_cast<size_t>(vreg)]) {
live_across_call_[static_cast<size_t>(vreg)] = 1;
}
}
}
for (int def : defs) {
spill_cost_[static_cast<size_t>(def)] += block_weight;
}
}
for (int use : uses) {
spill_cost_[static_cast<size_t>(use)] += block_weight;
}
@ -389,12 +411,19 @@ class GeorgeColoringAllocator {
select_stack_.pop_back();
in_select_stack_[static_cast<size_t>(node)] = 0;
std::vector<std::uint8_t> ok_colors(static_cast<size_t>(regs_.size()), 1);
for (int neighbor : adjacency_[static_cast<size_t>(node)]) {
const int alias = GetAlias(neighbor);
if (!is_colored_[static_cast<size_t>(alias)]) {
continue;
}
std::vector<std::uint8_t> ok_colors(static_cast<size_t>(regs_.size()), 1);
if (live_across_call_[static_cast<size_t>(node)]) {
for (size_t i = 0; i < regs_.size(); ++i) {
if (IsCallerSaved(regs_[i])) {
ok_colors[i] = 0;
}
}
}
for (int neighbor : adjacency_[static_cast<size_t>(node)]) {
const int alias = GetAlias(neighbor);
if (!is_colored_[static_cast<size_t>(alias)]) {
continue;
}
const int color = color_index_[static_cast<size_t>(alias)];
if (color >= 0 && color < static_cast<int>(regs_.size())) {
ok_colors[static_cast<size_t>(color)] = 0;
@ -506,12 +535,14 @@ class GeorgeColoringAllocator {
void Combine(int keep, int remove) {
simplify_worklist_[static_cast<size_t>(remove)] = 0;
freeze_worklist_[static_cast<size_t>(remove)] = 0;
spill_worklist_[static_cast<size_t>(remove)] = 0;
is_coalesced_[static_cast<size_t>(remove)] = 1;
alias_[static_cast<size_t>(remove)] = keep;
auto& keep_moves = move_list_[static_cast<size_t>(keep)];
const auto& remove_moves = move_list_[static_cast<size_t>(remove)];
spill_worklist_[static_cast<size_t>(remove)] = 0;
is_coalesced_[static_cast<size_t>(remove)] = 1;
alias_[static_cast<size_t>(remove)] = keep;
live_across_call_[static_cast<size_t>(keep)] |=
live_across_call_[static_cast<size_t>(remove)];
auto& keep_moves = move_list_[static_cast<size_t>(keep)];
const auto& remove_moves = move_list_[static_cast<size_t>(remove)];
keep_moves.insert(keep_moves.end(), remove_moves.begin(), remove_moves.end());
EnableMoves({remove});
@ -679,11 +710,12 @@ class GeorgeColoringAllocator {
std::vector<PhysReg> regs_;
int k_ = 0;
const std::vector<BlockInfo>& block_infos_;
int num_vregs_ = 0;
std::vector<std::uint8_t> in_class_;
std::vector<std::unordered_set<int>> adjacency_;
std::vector<int> degree_;
int num_vregs_ = 0;
std::vector<std::uint8_t> in_class_;
std::vector<std::uint8_t> live_across_call_;
std::vector<std::unordered_set<int>> adjacency_;
std::vector<int> degree_;
std::vector<double> spill_cost_;
std::vector<std::vector<int>> move_list_;
std::vector<MoveEdge> moves_;

@ -1,4 +1,25 @@
// MIR Pass 管理:
// - 组织后端 pass 的运行顺序PreRA/PostRA/PEI 等阶段)
// - 统一运行 pass 与调试输出(按需要扩展)
#include "mir/MIR.h"
namespace mir {
void RunMIRPreRegAllocPassPipeline(MachineModule& module) {
RunAddressHoisting(module);
constexpr int kMaxIterations = 4;
for (int iteration = 0; iteration < kMaxIterations; ++iteration) {
if (!RunPeephole(module)) {
break;
}
}
}
void RunMIRPostRegAllocPassPipeline(MachineModule& module) {
constexpr int kMaxIterations = 2;
for (int iteration = 0; iteration < kMaxIterations; ++iteration) {
if (!RunPeephole(module)) {
break;
}
}
}
} // namespace mir

@ -1,4 +1,904 @@
// 窥孔优化Peephole
// - 删除冗余 move、合并常见指令模式
// - 提升最终汇编质量(按实现范围裁剪)
#include "mir/MIR.h"
#include "ir/IR.h"
#include <algorithm>
#include <cstdint>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace mir {
namespace {
using AliasMap = std::unordered_map<int, MachineOperand>;
struct CFGInfo {
std::vector<std::vector<int>> predecessors;
std::vector<std::vector<int>> successors;
};
struct AddressKey {
AddrBaseKind base_kind = AddrBaseKind::None;
int base_index = -1;
std::string symbol;
std::int64_t const_offset = 0;
std::vector<std::pair<int, std::int64_t>> scaled_vregs;
bool operator==(const AddressKey& rhs) const {
return base_kind == rhs.base_kind && base_index == rhs.base_index &&
symbol == rhs.symbol && const_offset == rhs.const_offset &&
scaled_vregs == rhs.scaled_vregs;
}
};
struct AddressKeyHash {
std::size_t operator()(const AddressKey& key) const {
std::size_t h = static_cast<std::size_t>(key.base_kind);
h ^= std::hash<int>{}(key.base_index) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<std::string>{}(key.symbol) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<std::int64_t>{}(key.const_offset) + 0x9e3779b9 + (h << 6) + (h >> 2);
for (const auto& term : key.scaled_vregs) {
h ^= std::hash<int>{}(term.first) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<std::int64_t>{}(term.second) + 0x9e3779b9 + (h << 6) + (h >> 2);
}
return h;
}
};
struct MemoryState {
MachineOperand value;
ValueType type = ValueType::Void;
int pending_store_index = -1;
};
using MemoryMap = std::unordered_map<AddressKey, MemoryState, AddressKeyHash>;
bool IsImm(const MachineOperand& operand, std::int64_t value) {
return operand.GetKind() == OperandKind::Imm && operand.GetImm() == value;
}
bool SameExactOperand(const MachineOperand& lhs, const MachineOperand& rhs) {
if (lhs.GetKind() != rhs.GetKind()) {
return false;
}
switch (lhs.GetKind()) {
case OperandKind::Invalid:
return true;
case OperandKind::VReg:
return lhs.GetVReg() == rhs.GetVReg();
case OperandKind::Imm:
return lhs.GetImm() == rhs.GetImm();
case OperandKind::Block:
case OperandKind::Symbol:
return lhs.GetText() == rhs.GetText();
}
return false;
}
bool SameResolvedLocation(const MachineFunction& function, int lhs_vreg, int rhs_vreg) {
if (lhs_vreg == rhs_vreg) {
return true;
}
const auto& lhs = function.GetAllocation(lhs_vreg);
const auto& rhs = function.GetAllocation(rhs_vreg);
if (lhs.kind == Allocation::Kind::Unassigned || rhs.kind == Allocation::Kind::Unassigned ||
lhs.kind != rhs.kind) {
return false;
}
if (lhs.kind == Allocation::Kind::PhysReg) {
return lhs.phys == rhs.phys;
}
if (lhs.kind == Allocation::Kind::Spill) {
return lhs.stack_object == rhs.stack_object;
}
return false;
}
bool SameResolvedOperand(const MachineFunction& function, const MachineOperand& lhs,
const MachineOperand& rhs) {
if (SameExactOperand(lhs, rhs)) {
return true;
}
if (lhs.GetKind() == OperandKind::VReg && rhs.GetKind() == OperandKind::VReg) {
return SameResolvedLocation(function, lhs.GetVReg(), rhs.GetVReg());
}
return false;
}
MachineOperand ResolveAlias(const AliasMap& aliases, const MachineOperand& operand) {
if (operand.GetKind() != OperandKind::VReg) {
return operand;
}
int current = operand.GetVReg();
std::unordered_set<int> visited;
visited.insert(current);
while (true) {
auto it = aliases.find(current);
if (it == aliases.end()) {
return MachineOperand::VReg(current);
}
if (it->second.GetKind() != OperandKind::VReg) {
return it->second;
}
const int next = it->second.GetVReg();
if (!visited.insert(next).second) {
return MachineOperand::VReg(current);
}
current = next;
}
}
bool RewriteOperand(MachineOperand& operand, const AliasMap& aliases) {
const auto rewritten = ResolveAlias(aliases, operand);
if (SameExactOperand(rewritten, operand)) {
return false;
}
operand = rewritten;
return true;
}
bool RewriteAddress(AddressExpr& address, const AliasMap& aliases) {
bool changed = false;
if (address.base_kind == AddrBaseKind::VReg && address.base_index >= 0) {
const auto rewritten = ResolveAlias(aliases, MachineOperand::VReg(address.base_index));
if (rewritten.GetKind() == OperandKind::VReg &&
rewritten.GetVReg() != address.base_index) {
address.base_index = rewritten.GetVReg();
changed = true;
}
}
std::vector<std::pair<int, std::int64_t>> rewritten_scaled;
rewritten_scaled.reserve(address.scaled_vregs.size());
for (const auto& term : address.scaled_vregs) {
const auto rewritten = ResolveAlias(aliases, MachineOperand::VReg(term.first));
if (rewritten.GetKind() == OperandKind::Imm) {
address.const_offset += rewritten.GetImm() * term.second;
changed = true;
continue;
}
if (rewritten.GetKind() == OperandKind::VReg && rewritten.GetVReg() != term.first) {
rewritten_scaled.push_back({rewritten.GetVReg(), term.second});
changed = true;
continue;
}
rewritten_scaled.push_back(term);
}
if (rewritten_scaled.size() != address.scaled_vregs.size()) {
changed = true;
}
address.scaled_vregs = std::move(rewritten_scaled);
return changed;
}
bool RewriteUses(MachineInstr& inst, const AliasMap& aliases) {
bool changed = false;
auto& operands = inst.GetOperands();
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::Copy:
case MachineInstr::Opcode::ZExt:
case MachineInstr::Opcode::ItoF:
case MachineInstr::Opcode::FtoI:
case MachineInstr::Opcode::FNeg:
if (operands.size() >= 2) {
changed |= RewriteOperand(operands[1], aliases);
}
break;
case MachineInstr::Opcode::Store:
if (!operands.empty()) {
changed |= RewriteOperand(operands[0], aliases);
}
break;
case MachineInstr::Opcode::Add:
case MachineInstr::Opcode::Sub:
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::Rem:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
case MachineInstr::Opcode::Shl:
case MachineInstr::Opcode::AShr:
case MachineInstr::Opcode::LShr:
case MachineInstr::Opcode::FAdd:
case MachineInstr::Opcode::FSub:
case MachineInstr::Opcode::FMul:
case MachineInstr::Opcode::FDiv:
case MachineInstr::Opcode::ICmp:
case MachineInstr::Opcode::FCmp:
if (operands.size() >= 2) {
changed |= RewriteOperand(operands[1], aliases);
}
if (operands.size() >= 3) {
changed |= RewriteOperand(operands[2], aliases);
}
break;
case MachineInstr::Opcode::CondBr:
if (!operands.empty()) {
changed |= RewriteOperand(operands[0], aliases);
}
break;
case MachineInstr::Opcode::Call: {
const size_t arg_begin = inst.GetCallReturnType() == ValueType::Void ? 0 : 1;
for (size_t i = arg_begin; i < operands.size(); ++i) {
changed |= RewriteOperand(operands[i], aliases);
}
break;
}
case MachineInstr::Opcode::Ret:
if (!operands.empty()) {
changed |= RewriteOperand(operands[0], aliases);
}
break;
case MachineInstr::Opcode::Memset:
if (!operands.empty()) {
changed |= RewriteOperand(operands[0], aliases);
}
if (operands.size() >= 2) {
changed |= RewriteOperand(operands[1], aliases);
}
break;
case MachineInstr::Opcode::Arg:
case MachineInstr::Opcode::Load:
case MachineInstr::Opcode::Lea:
case MachineInstr::Opcode::Br:
case MachineInstr::Opcode::Unreachable:
break;
}
if (inst.HasAddress()) {
changed |= RewriteAddress(inst.GetAddress(), aliases);
}
return changed;
}
MachineInstr MakeCopyLike(const MachineInstr& inst, MachineOperand source) {
return MachineInstr(MachineInstr::Opcode::Copy,
{inst.GetOperands()[0], std::move(source)});
}
bool SimplifyCopy(const MachineFunction& function, MachineInstr& inst) {
if (inst.GetOpcode() != MachineInstr::Opcode::Copy) {
return false;
}
const auto& operands = inst.GetOperands();
if (operands.size() < 2 || operands[0].GetKind() != OperandKind::VReg) {
return false;
}
return SameResolvedOperand(function, operands[0], operands[1]);
}
bool SimplifyZExt(MachineInstr& inst) {
const auto& operands = inst.GetOperands();
if (operands.size() < 2 || operands[1].GetKind() != OperandKind::Imm) {
return false;
}
inst = MakeCopyLike(inst, MachineOperand::Imm(operands[1].GetImm() != 0 ? 1 : 0));
return true;
}
bool SimplifyIntegerBinary(MachineInstr& inst) {
const auto opcode = inst.GetOpcode();
const auto& operands = inst.GetOperands();
if (operands.size() < 3) {
return false;
}
const auto& lhs = operands[1];
const auto& rhs = operands[2];
switch (opcode) {
case MachineInstr::Opcode::Add:
if (IsImm(rhs, 0)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
if (IsImm(lhs, 0)) {
inst = MakeCopyLike(inst, rhs);
return true;
}
return false;
case MachineInstr::Opcode::Sub:
if (IsImm(rhs, 0)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
return false;
case MachineInstr::Opcode::Mul:
if (IsImm(rhs, 1)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
if (IsImm(lhs, 1)) {
inst = MakeCopyLike(inst, rhs);
return true;
}
if (IsImm(rhs, 0) || IsImm(lhs, 0)) {
inst = MakeCopyLike(inst, MachineOperand::Imm(0));
return true;
}
return false;
case MachineInstr::Opcode::Div:
if (IsImm(rhs, 1)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
return false;
case MachineInstr::Opcode::And:
if (IsImm(rhs, -1)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
if (IsImm(lhs, -1)) {
inst = MakeCopyLike(inst, rhs);
return true;
}
if (IsImm(rhs, 0) || IsImm(lhs, 0)) {
inst = MakeCopyLike(inst, MachineOperand::Imm(0));
return true;
}
return false;
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
if (IsImm(rhs, 0)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
if (IsImm(lhs, 0)) {
inst = MakeCopyLike(inst, rhs);
return true;
}
return false;
case MachineInstr::Opcode::Shl:
case MachineInstr::Opcode::AShr:
case MachineInstr::Opcode::LShr:
if (IsImm(rhs, 0)) {
inst = MakeCopyLike(inst, lhs);
return true;
}
return false;
default:
return false;
}
}
bool SimplifyCondBr(MachineInstr& inst) {
auto& operands = inst.GetOperands();
if (operands.size() < 3) {
return false;
}
if (operands[1].GetKind() == OperandKind::Block &&
operands[2].GetKind() == OperandKind::Block &&
operands[1].GetText() == operands[2].GetText()) {
inst = MachineInstr(MachineInstr::Opcode::Br, {operands[1]});
return true;
}
if (operands[0].GetKind() != OperandKind::Imm) {
return false;
}
inst = MachineInstr(MachineInstr::Opcode::Br,
{operands[0].GetImm() != 0 ? operands[1] : operands[2]});
return true;
}
bool SimplifyInstruction(MachineInstr& inst) {
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::ZExt:
return SimplifyZExt(inst);
case MachineInstr::Opcode::Add:
case MachineInstr::Opcode::Sub:
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
case MachineInstr::Opcode::Shl:
case MachineInstr::Opcode::AShr:
case MachineInstr::Opcode::LShr:
return SimplifyIntegerBinary(inst);
case MachineInstr::Opcode::CondBr:
return SimplifyCondBr(inst);
default:
return false;
}
}
bool TrackAlias(const MachineInstr& inst, AliasMap& aliases) {
if (inst.GetOpcode() != MachineInstr::Opcode::Copy) {
return false;
}
const auto& operands = inst.GetOperands();
if (operands.size() < 2 || operands[0].GetKind() != OperandKind::VReg) {
return false;
}
aliases[operands[0].GetVReg()] = operands[1];
return true;
}
AddressKey MakeAddressKey(const AddressExpr& address) {
return {address.base_kind, address.base_index, address.symbol, address.const_offset,
address.scaled_vregs};
}
bool HasTrackedAddress(const MachineInstr& inst) {
return inst.HasAddress() && inst.GetAddress().base_kind != AddrBaseKind::None;
}
const ir::Function* LookupSourceCallee(const MachineModule& module,
const MachineInstr& inst) {
if (inst.GetOpcode() != MachineInstr::Opcode::Call || inst.GetCallee().empty()) {
return nullptr;
}
return module.GetSourceModule().GetFunction(inst.GetCallee());
}
bool CallMayReadMemory(const MachineModule& module, const MachineInstr& inst) {
auto* callee = LookupSourceCallee(module, inst);
return callee == nullptr || callee->MayReadMemory();
}
bool CallMayWriteMemory(const MachineModule& module, const MachineInstr& inst) {
auto* callee = LookupSourceCallee(module, inst);
return callee == nullptr || callee->MayWriteMemory();
}
bool SameMemoryStateValue(const MemoryState& lhs, const MemoryState& rhs) {
return lhs.type == rhs.type && SameExactOperand(lhs.value, rhs.value);
}
bool SameMemoryMap(const MemoryMap& lhs, const MemoryMap& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (const auto& [key, value] : lhs) {
auto it = rhs.find(key);
if (it == rhs.end() || !SameMemoryStateValue(value, it->second)) {
return false;
}
}
return true;
}
MemoryMap MeetMemoryStates(const std::vector<const MemoryMap*>& predecessors) {
if (predecessors.empty()) {
return {};
}
MemoryMap in = *predecessors.front();
for (auto it = in.begin(); it != in.end();) {
bool keep = true;
for (std::size_t i = 1; i < predecessors.size(); ++i) {
auto pred_it = predecessors[i]->find(it->first);
if (pred_it == predecessors[i]->end() ||
!SameMemoryStateValue(it->second, pred_it->second)) {
keep = false;
break;
}
}
if (!keep) {
it = in.erase(it);
continue;
}
++it;
}
return in;
}
CFGInfo BuildCFG(const MachineFunction& function) {
CFGInfo cfg;
const auto& blocks = function.GetBlocks();
cfg.predecessors.resize(blocks.size());
cfg.successors.resize(blocks.size());
std::unordered_map<std::string, int> name_to_index;
for (std::size_t i = 0; i < blocks.size(); ++i) {
name_to_index.emplace(blocks[i]->GetName(), static_cast<int>(i));
}
auto add_edge = [&](int pred, const std::string& succ_name) {
auto it = name_to_index.find(succ_name);
if (it == name_to_index.end()) {
return;
}
cfg.successors[static_cast<std::size_t>(pred)].push_back(it->second);
cfg.predecessors[static_cast<std::size_t>(it->second)].push_back(pred);
};
for (std::size_t i = 0; i < blocks.size(); ++i) {
const auto& instructions = blocks[i]->GetInstructions();
if (instructions.empty()) {
continue;
}
const auto& terminator = instructions.back();
if (terminator.GetOpcode() == MachineInstr::Opcode::Br &&
!terminator.GetOperands().empty()) {
add_edge(static_cast<int>(i), terminator.GetOperands()[0].GetText());
} else if (terminator.GetOpcode() == MachineInstr::Opcode::CondBr &&
terminator.GetOperands().size() >= 3) {
add_edge(static_cast<int>(i), terminator.GetOperands()[1].GetText());
add_edge(static_cast<int>(i), terminator.GetOperands()[2].GetText());
}
auto& succs = cfg.successors[i];
std::sort(succs.begin(), succs.end());
succs.erase(std::unique(succs.begin(), succs.end()), succs.end());
}
for (auto& preds : cfg.predecessors) {
std::sort(preds.begin(), preds.end());
preds.erase(std::unique(preds.begin(), preds.end()), preds.end());
}
return cfg;
}
bool SameBaseObject(const AddressKey& lhs, const AddressKey& rhs) {
if (lhs.base_kind != rhs.base_kind) {
return false;
}
switch (lhs.base_kind) {
case AddrBaseKind::FrameObject:
case AddrBaseKind::VReg:
return lhs.base_index == rhs.base_index;
case AddrBaseKind::Global:
return lhs.symbol == rhs.symbol;
case AddrBaseKind::None:
return false;
}
return false;
}
void InvalidateMemoryState(std::unordered_map<AddressKey, MemoryState, AddressKeyHash>& states,
const AddressKey* store_key) {
if (store_key == nullptr) {
states.clear();
return;
}
if (store_key->base_kind == AddrBaseKind::VReg) {
states.clear();
return;
}
for (auto it = states.begin(); it != states.end();) {
if (it->first.base_kind == AddrBaseKind::VReg || SameBaseObject(it->first, *store_key)) {
it = states.erase(it);
continue;
}
++it;
}
}
void ObservePendingStores(MemoryMap& states) {
for (auto& [_, state] : states) {
state.pending_store_index = -1;
}
}
bool TryOptimizeMemoryInstruction(
const MachineModule& module, const MachineFunction& function,
MachineInstr& inst,
MemoryMap& states,
std::vector<bool>& removed,
std::size_t current_index,
bool* remove_current) {
*remove_current = false;
if (inst.GetOpcode() == MachineInstr::Opcode::Call) {
if (CallMayWriteMemory(module, inst)) {
InvalidateMemoryState(states, nullptr);
}
return false;
}
if (inst.GetOpcode() == MachineInstr::Opcode::Memset) {
InvalidateMemoryState(states, nullptr);
return false;
}
if (!HasTrackedAddress(inst)) {
return false;
}
const AddressKey key = MakeAddressKey(inst.GetAddress());
if (inst.GetOpcode() == MachineInstr::Opcode::Load) {
ValueType load_type = ValueType::Void;
if (!inst.GetOperands().empty() && inst.GetOperands()[0].GetKind() == OperandKind::VReg) {
load_type = function.GetVRegInfo(inst.GetOperands()[0].GetVReg()).type;
}
auto it = states.find(key);
if (it != states.end() && it->second.type == load_type) {
inst = MakeCopyLike(inst, it->second.value);
it->second.pending_store_index = -1;
return true;
}
auto dest = inst.GetOperands()[0];
states[key] = {dest, load_type, -1};
return false;
}
if (inst.GetOpcode() != MachineInstr::Opcode::Store) {
return false;
}
const auto value = inst.GetOperands()[0];
auto existing = states.find(key);
if (existing != states.end() && existing->second.type == inst.GetValueType() &&
SameExactOperand(existing->second.value, value)) {
*remove_current = true;
return true;
}
if (existing != states.end() && existing->second.pending_store_index >= 0) {
removed[static_cast<std::size_t>(existing->second.pending_store_index)] = true;
}
InvalidateMemoryState(states, &key);
states[key] = {value, inst.GetValueType(), static_cast<int>(current_index)};
return false;
}
void ApplyMemoryDataflowInstruction(const MachineModule& module, const MachineInstr& inst,
MemoryMap& states) {
if (inst.GetOpcode() == MachineInstr::Opcode::Call) {
if (CallMayWriteMemory(module, inst)) {
InvalidateMemoryState(states, nullptr);
}
return;
}
if (inst.GetOpcode() == MachineInstr::Opcode::Memset) {
InvalidateMemoryState(states, nullptr);
return;
}
if (!HasTrackedAddress(inst)) {
return;
}
const AddressKey key = MakeAddressKey(inst.GetAddress());
if (inst.GetOpcode() == MachineInstr::Opcode::Store) {
InvalidateMemoryState(states, &key);
states[key] = {inst.GetOperands()[0], inst.GetValueType(), -1};
return;
}
}
MemoryMap SimulateBlockMemory(const MachineModule& module, const MachineBasicBlock& block,
const MemoryMap& in_state) {
MemoryMap state = in_state;
for (const auto& inst : block.GetInstructions()) {
ApplyMemoryDataflowInstruction(module, inst, state);
}
return state;
}
bool RunPeepholeOnBlock(const MachineModule& module, const MachineFunction& function,
MachineBasicBlock& block, const MemoryMap& in_state) {
bool changed = false;
AliasMap aliases;
MemoryMap memory_states = in_state;
std::vector<MachineInstr> rewritten;
std::vector<bool> removed;
rewritten.reserve(block.GetInstructions().size());
removed.reserve(block.GetInstructions().size());
for (const auto& original : block.GetInstructions()) {
MachineInstr inst = original;
changed |= RewriteUses(inst, aliases);
changed |= SimplifyInstruction(inst);
if (SimplifyCopy(function, inst)) {
changed = true;
continue;
}
rewritten.push_back(std::move(inst));
removed.push_back(false);
MachineInstr& current = rewritten.back();
bool remove_current = false;
changed |= TryOptimizeMemoryInstruction(module, function, current, memory_states, removed,
rewritten.size() - 1, &remove_current);
if (remove_current) {
removed.back() = true;
changed = true;
continue;
}
changed |= SimplifyInstruction(current);
if (SimplifyCopy(function, current)) {
removed.back() = true;
changed = true;
continue;
}
if (current.GetOpcode() == MachineInstr::Opcode::Call) {
if (CallMayReadMemory(module, current) || CallMayWriteMemory(module, current)) {
ObservePendingStores(memory_states);
}
} else if (current.GetOpcode() == MachineInstr::Opcode::Memset) {
ObservePendingStores(memory_states);
}
TrackAlias(current, aliases);
}
std::vector<MachineInstr> compacted;
compacted.reserve(rewritten.size());
for (std::size_t i = 0; i < rewritten.size(); ++i) {
if (!removed[i]) {
compacted.push_back(std::move(rewritten[i]));
} else {
changed = true;
}
}
if (compacted.size() != block.GetInstructions().size()) {
changed = true;
}
if (changed) {
block.GetInstructions() = std::move(compacted);
}
return changed;
}
bool IsSideEffectFree(const MachineInstr& inst) {
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::Arg:
case MachineInstr::Opcode::Copy:
case MachineInstr::Opcode::Load:
case MachineInstr::Opcode::Lea:
case MachineInstr::Opcode::Add:
case MachineInstr::Opcode::Sub:
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::Rem:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
case MachineInstr::Opcode::Shl:
case MachineInstr::Opcode::AShr:
case MachineInstr::Opcode::LShr:
case MachineInstr::Opcode::FAdd:
case MachineInstr::Opcode::FSub:
case MachineInstr::Opcode::FMul:
case MachineInstr::Opcode::FDiv:
case MachineInstr::Opcode::FNeg:
case MachineInstr::Opcode::ICmp:
case MachineInstr::Opcode::FCmp:
case MachineInstr::Opcode::ZExt:
case MachineInstr::Opcode::ItoF:
case MachineInstr::Opcode::FtoI:
return true;
case MachineInstr::Opcode::Store:
case MachineInstr::Opcode::Br:
case MachineInstr::Opcode::CondBr:
case MachineInstr::Opcode::Call:
case MachineInstr::Opcode::Ret:
case MachineInstr::Opcode::Memset:
case MachineInstr::Opcode::Unreachable:
return false;
}
return false;
}
bool RunDeadInstrElimination(MachineFunction& function) {
bool changed = false;
while (true) {
std::unordered_map<int, int> use_counts;
for (const auto& block : function.GetBlocks()) {
for (const auto& inst : block->GetInstructions()) {
for (int use : inst.GetUses()) {
++use_counts[use];
}
}
}
bool local_changed = false;
for (auto& block : function.GetBlocks()) {
std::vector<MachineInstr> rewritten;
rewritten.reserve(block->GetInstructions().size());
for (auto& inst : block->GetInstructions()) {
const auto defs = inst.GetDefs();
const bool has_live_def =
defs.empty() || use_counts.find(defs.front()) != use_counts.end();
if (has_live_def || !IsSideEffectFree(inst)) {
rewritten.push_back(inst);
continue;
}
local_changed = true;
}
if (local_changed) {
block->GetInstructions() = std::move(rewritten);
}
}
if (!local_changed) {
break;
}
changed = true;
}
return changed;
}
bool HasAssignedAllocations(const MachineFunction& function) {
for (const auto& vreg : function.GetVRegs()) {
if (function.GetAllocation(vreg.id).kind != Allocation::Kind::Unassigned) {
return true;
}
}
return false;
}
} // namespace
bool RunPeephole(MachineModule& module) {
bool changed = false;
for (auto& function : module.GetFunctions()) {
if (!function) {
continue;
}
bool function_changed = false;
const auto cfg = BuildCFG(*function);
std::vector<MemoryMap> in_states(function->GetBlocks().size());
std::vector<MemoryMap> out_states(function->GetBlocks().size());
bool dataflow_changed = true;
while (dataflow_changed) {
dataflow_changed = false;
for (std::size_t i = 0; i < function->GetBlocks().size(); ++i) {
MemoryMap in_state;
if (i != 0) {
std::vector<const MemoryMap*> predecessors;
for (int pred : cfg.predecessors[i]) {
predecessors.push_back(&out_states[static_cast<std::size_t>(pred)]);
}
in_state = MeetMemoryStates(predecessors);
}
auto out_state =
SimulateBlockMemory(module, *function->GetBlocks()[i], in_state);
if (!SameMemoryMap(in_states[i], in_state)) {
in_states[i] = std::move(in_state);
dataflow_changed = true;
}
if (!SameMemoryMap(out_states[i], out_state)) {
out_states[i] = std::move(out_state);
dataflow_changed = true;
}
}
}
for (std::size_t i = 0; i < function->GetBlocks().size(); ++i) {
function_changed |=
RunPeepholeOnBlock(module, *function, *function->GetBlocks()[i], in_states[i]);
}
if (!HasAssignedAllocations(*function)) {
function_changed |= RunDeadInstrElimination(*function);
}
changed |= function_changed;
}
return changed;
}
} // namespace mir

@ -1,6 +1,685 @@
#include "sem/Sema.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace {
enum class MemoryRoot {
None,
Local,
Global,
Param,
Unknown,
};
struct SymbolInfo {
SemanticType type = SemanticType::Int;
bool is_array = false;
bool is_param_array = false;
MemoryRoot root = MemoryRoot::Local;
std::vector<int> dims;
};
struct ExprInfo {
MemoryRoot root = MemoryRoot::None;
bool is_array = false;
};
struct CallSiteInfo {
std::string callee;
std::vector<MemoryRoot> arg_roots;
};
struct DirectFunctionAnalysis {
FunctionSemanticInfo info;
std::vector<CallSiteInfo> calls;
};
class ScopedSymbols {
public:
void EnterScope() { scopes_.emplace_back(); }
void ExitScope() {
if (!scopes_.empty()) {
scopes_.pop_back();
}
}
bool Insert(const std::string& name, const SymbolInfo& info) {
if (scopes_.empty()) {
EnterScope();
}
return scopes_.back().emplace(name, info).second;
}
const SymbolInfo* Lookup(const std::string& name) const {
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
auto found = it->find(name);
if (found != it->end()) {
return &found->second;
}
}
return nullptr;
}
private:
std::vector<std::unordered_map<std::string, SymbolInfo>> scopes_;
};
std::string ExpectIdent(antlr4::tree::TerminalNode* ident) {
return ident == nullptr ? std::string{} : ident->getText();
}
SemanticType ParseBType(SysYParser::BTypeContext* ctx) {
if (ctx != nullptr && ctx->FLOAT() != nullptr) {
return SemanticType::Float;
}
return SemanticType::Int;
}
SemanticType ParseFuncType(SysYParser::FuncTypeContext* ctx) {
if (ctx == nullptr || ctx->VOID() != nullptr) {
return SemanticType::Void;
}
if (ctx->FLOAT() != nullptr) {
return SemanticType::Float;
}
return SemanticType::Int;
}
std::vector<int> MakeShape(std::size_t rank) {
return std::vector<int>(rank, -1);
}
void RegisterBuiltinFunctions(SemanticContext& context) {
struct BuiltinSpec {
const char* name;
SemanticType return_type;
std::vector<bool> param_is_array;
bool reads_global_memory = false;
bool writes_global_memory = false;
bool reads_param_memory = false;
bool writes_param_memory = false;
bool has_io = false;
};
const std::vector<BuiltinSpec> builtins = {
{"getint", SemanticType::Int, {}, false, false, false, false, true},
{"getch", SemanticType::Int, {}, false, false, false, false, true},
{"getfloat", SemanticType::Float, {}, false, false, false, false, true},
{"getarray", SemanticType::Int, {true}, false, false, false, true, true},
{"getfarray", SemanticType::Int, {true}, false, false, false, true, true},
{"putint", SemanticType::Void, {false}, false, false, false, false, true},
{"putch", SemanticType::Void, {false}, false, false, false, false, true},
{"putfloat", SemanticType::Void, {false}, false, false, false, false, true},
{"putarray", SemanticType::Void, {false, true}, false, false, true, false, true},
{"putfarray", SemanticType::Void, {false, true}, false, false, true, false, true},
{"starttime", SemanticType::Void, {}, false, false, false, false, true},
{"stoptime", SemanticType::Void, {}, false, false, false, false, true},
};
for (const auto& builtin : builtins) {
auto& info = context.UpsertFunction(builtin.name);
info.return_type = builtin.return_type;
info.param_is_array = builtin.param_is_array;
info.is_builtin = true;
info.is_defined = false;
info.reads_global_memory = builtin.reads_global_memory;
info.writes_global_memory = builtin.writes_global_memory;
info.reads_param_memory = builtin.reads_param_memory;
info.writes_param_memory = builtin.writes_param_memory;
info.has_io = builtin.has_io;
info.has_unknown_effects = false;
info.is_recursive = false;
info.direct_callees.clear();
}
}
void CollectGlobalDecl(SemanticContext& context, SysYParser::DeclContext& ctx) {
if (auto* const_decl = ctx.constDecl()) {
const auto type = ParseBType(const_decl->bType());
for (auto* def : const_decl->constDef()) {
auto& info = context.UpsertGlobal(ExpectIdent(def->Ident()));
info.type = type;
info.is_const = true;
info.is_array = !def->constExp().empty();
info.dims = MakeShape(def->constExp().size());
}
return;
}
if (auto* var_decl = ctx.varDecl()) {
const auto type = ParseBType(var_decl->bType());
for (auto* def : var_decl->varDef()) {
auto& info = context.UpsertGlobal(ExpectIdent(def->Ident()));
info.type = type;
info.is_const = false;
info.is_array = !def->constExp().empty();
info.dims = MakeShape(def->constExp().size());
}
}
}
void CollectFunctionSignature(SemanticContext& context, SysYParser::FuncDefContext& ctx) {
auto& info = context.UpsertFunction(ExpectIdent(ctx.Ident()));
info.return_type = ParseFuncType(ctx.funcType());
info.param_is_array.clear();
info.is_builtin = false;
info.is_defined = true;
info.reads_global_memory = false;
info.writes_global_memory = false;
info.reads_param_memory = false;
info.writes_param_memory = false;
info.has_io = false;
info.has_unknown_effects = false;
info.is_recursive = false;
info.direct_callees.clear();
if (auto* params = ctx.funcFParams()) {
for (auto* param : params->funcFParam()) {
info.param_is_array.push_back(!param->LBRACK().empty());
}
}
}
bool SameInfo(const FunctionSemanticInfo& lhs, const FunctionSemanticInfo& rhs) {
return lhs.return_type == rhs.return_type &&
lhs.param_is_array == rhs.param_is_array &&
lhs.is_builtin == rhs.is_builtin && lhs.is_defined == rhs.is_defined &&
lhs.reads_global_memory == rhs.reads_global_memory &&
lhs.writes_global_memory == rhs.writes_global_memory &&
lhs.reads_param_memory == rhs.reads_param_memory &&
lhs.writes_param_memory == rhs.writes_param_memory &&
lhs.has_io == rhs.has_io &&
lhs.has_unknown_effects == rhs.has_unknown_effects &&
lhs.is_recursive == rhs.is_recursive &&
lhs.direct_callees == rhs.direct_callees;
}
class FunctionAnalyzer {
public:
FunctionAnalyzer(const SemanticContext& context, DirectFunctionAnalysis& analysis)
: context_(context), analysis_(analysis) {}
void Analyze(SysYParser::FuncDefContext& ctx) {
symbols_.EnterScope();
if (auto* params = ctx.funcFParams()) {
for (auto* param : params->funcFParam()) {
SymbolInfo info;
info.type = ParseBType(param->bType());
info.is_array = !param->LBRACK().empty();
info.is_param_array = info.is_array;
info.root = info.is_array ? MemoryRoot::Param : MemoryRoot::Local;
info.dims = MakeShape(param->exp().size());
symbols_.Insert(ExpectIdent(param->Ident()), info);
}
}
AnalyzeBlock(*ctx.block(), false);
symbols_.ExitScope();
}
private:
struct LValueShape {
MemoryRoot root = MemoryRoot::None;
bool is_array = false;
};
const SymbolInfo* LookupSymbol(const std::string& name) const {
if (const auto* local = symbols_.Lookup(name)) {
return local;
}
if (const auto* global = context_.LookupGlobal(name)) {
static thread_local SymbolInfo scratch;
scratch.type = global->type;
scratch.is_array = global->is_array;
scratch.is_param_array = false;
scratch.root = MemoryRoot::Global;
scratch.dims = global->dims;
return &scratch;
}
return nullptr;
}
void AnalyzeBlock(SysYParser::BlockContext& ctx, bool create_scope) {
if (create_scope) {
symbols_.EnterScope();
}
for (auto* item : ctx.blockItem()) {
AnalyzeBlockItem(*item);
}
if (create_scope) {
symbols_.ExitScope();
}
}
void AnalyzeBlockItem(SysYParser::BlockItemContext& ctx) {
if (auto* decl = ctx.decl()) {
AnalyzeDecl(*decl);
} else if (auto* stmt = ctx.stmt()) {
AnalyzeStmt(*stmt);
}
}
void AnalyzeDecl(SysYParser::DeclContext& ctx) {
if (auto* const_decl = ctx.constDecl()) {
AnalyzeConstDecl(*const_decl);
} else if (auto* var_decl = ctx.varDecl()) {
AnalyzeVarDecl(*var_decl);
}
}
void AnalyzeVarDecl(SysYParser::VarDeclContext& ctx) {
const auto type = ParseBType(ctx.bType());
for (auto* def : ctx.varDef()) {
SymbolInfo info;
info.type = type;
info.is_array = !def->constExp().empty();
info.root = MemoryRoot::Local;
info.dims = MakeShape(def->constExp().size());
symbols_.Insert(ExpectIdent(def->Ident()), info);
if (def->initVal() != nullptr) {
AnalyzeInitVal(def->initVal());
}
}
}
void AnalyzeConstDecl(SysYParser::ConstDeclContext& ctx) {
const auto type = ParseBType(ctx.bType());
for (auto* def : ctx.constDef()) {
SymbolInfo info;
info.type = type;
info.is_array = !def->constExp().empty();
info.root = MemoryRoot::Local;
info.dims = MakeShape(def->constExp().size());
symbols_.Insert(ExpectIdent(def->Ident()), info);
AnalyzeConstInitVal(def->constInitVal());
}
}
void AnalyzeInitVal(SysYParser::InitValContext* ctx) {
if (ctx == nullptr) {
return;
}
if (ctx->exp() != nullptr) {
AnalyzeExp(*ctx->exp());
return;
}
for (auto* child : ctx->initVal()) {
AnalyzeInitVal(child);
}
}
void AnalyzeConstInitVal(SysYParser::ConstInitValContext* ctx) {
if (ctx == nullptr) {
return;
}
if (ctx->constExp() != nullptr) {
AnalyzeAddExp(*ctx->constExp()->addExp());
return;
}
for (auto* child : ctx->constInitVal()) {
AnalyzeConstInitVal(child);
}
}
void AnalyzeStmt(SysYParser::StmtContext& ctx) {
if (ctx.lVal() != nullptr && ctx.ASSIGN() != nullptr) {
AnalyzeLValWrite(*ctx.lVal());
AnalyzeExp(*ctx.exp());
return;
}
if (ctx.block() != nullptr) {
AnalyzeBlock(*ctx.block(), true);
return;
}
if (ctx.IF() != nullptr) {
if (ctx.cond() != nullptr) {
AnalyzeLOrExp(*ctx.cond()->lOrExp());
}
if (!ctx.stmt().empty()) {
AnalyzeStmt(*ctx.stmt()[0]);
}
if (ctx.stmt().size() > 1 && ctx.stmt()[1] != nullptr) {
AnalyzeStmt(*ctx.stmt()[1]);
}
return;
}
if (ctx.WHILE() != nullptr) {
if (ctx.cond() != nullptr) {
AnalyzeLOrExp(*ctx.cond()->lOrExp());
}
if (!ctx.stmt().empty() && ctx.stmt()[0] != nullptr) {
AnalyzeStmt(*ctx.stmt()[0]);
}
return;
}
if (ctx.RETURN() != nullptr) {
if (ctx.exp() != nullptr) {
AnalyzeExp(*ctx.exp());
}
return;
}
if (ctx.exp() != nullptr) {
AnalyzeExp(*ctx.exp());
}
}
ExprInfo AnalyzeExp(SysYParser::ExpContext& ctx) {
return AnalyzeAddExp(*ctx.addExp());
}
ExprInfo AnalyzeAddExp(SysYParser::AddExpContext& ctx) {
if (ctx.addExp() != nullptr) {
AnalyzeAddExp(*ctx.addExp());
AnalyzeMulExp(*ctx.mulExp());
return {};
}
return AnalyzeMulExp(*ctx.mulExp());
}
ExprInfo AnalyzeMulExp(SysYParser::MulExpContext& ctx) {
if (ctx.mulExp() != nullptr) {
AnalyzeMulExp(*ctx.mulExp());
AnalyzeUnaryExp(*ctx.unaryExp());
return {};
}
return AnalyzeUnaryExp(*ctx.unaryExp());
}
ExprInfo AnalyzeUnaryExp(SysYParser::UnaryExpContext& ctx) {
if (ctx.primaryExp() != nullptr) {
return AnalyzePrimaryExp(*ctx.primaryExp());
}
if (ctx.Ident() != nullptr) {
const auto name = ExpectIdent(ctx.Ident());
CallSiteInfo call;
call.callee = name;
const auto* callee = context_.LookupFunction(name);
const auto args = ctx.funcRParams() == nullptr ? std::vector<SysYParser::ExpContext*>{}
: ctx.funcRParams()->exp();
call.arg_roots.resize(args.size(), MemoryRoot::None);
for (std::size_t i = 0; i < args.size(); ++i) {
auto arg_info = AnalyzeExp(*args[i]);
if (callee != nullptr && i < callee->param_is_array.size() && callee->param_is_array[i]) {
call.arg_roots[i] = arg_info.is_array ? arg_info.root : MemoryRoot::Unknown;
}
}
if (callee == nullptr) {
analysis_.info.has_unknown_effects = true;
}
analysis_.calls.push_back(std::move(call));
return {};
}
if (ctx.unaryExp() != nullptr) {
AnalyzeUnaryExp(*ctx.unaryExp());
}
return {};
}
ExprInfo AnalyzePrimaryExp(SysYParser::PrimaryExpContext& ctx) {
if (ctx.exp() != nullptr) {
return AnalyzeExp(*ctx.exp());
}
if (ctx.lVal() != nullptr) {
return AnalyzeLValRead(*ctx.lVal());
}
return {};
}
ExprInfo AnalyzeRelExp(SysYParser::RelExpContext& ctx) {
if (ctx.relExp() != nullptr) {
AnalyzeRelExp(*ctx.relExp());
AnalyzeAddExp(*ctx.addExp());
return {};
}
return AnalyzeAddExp(*ctx.addExp());
}
ExprInfo AnalyzeEqExp(SysYParser::EqExpContext& ctx) {
if (ctx.eqExp() != nullptr) {
AnalyzeEqExp(*ctx.eqExp());
AnalyzeRelExp(*ctx.relExp());
return {};
}
return AnalyzeRelExp(*ctx.relExp());
}
ExprInfo AnalyzeLAndExp(SysYParser::LAndExpContext& ctx) {
if (ctx.lAndExp() != nullptr) {
AnalyzeLAndExp(*ctx.lAndExp());
AnalyzeEqExp(*ctx.eqExp());
return {};
}
return AnalyzeEqExp(*ctx.eqExp());
}
ExprInfo AnalyzeLOrExp(SysYParser::LOrExpContext& ctx) {
if (ctx.lOrExp() != nullptr) {
AnalyzeLOrExp(*ctx.lOrExp());
AnalyzeLAndExp(*ctx.lAndExp());
return {};
}
return AnalyzeLAndExp(*ctx.lAndExp());
}
LValueShape DescribeLVal(SysYParser::LValContext& ctx) {
for (auto* index : ctx.exp()) {
AnalyzeExp(*index);
}
const auto* symbol = LookupSymbol(ExpectIdent(ctx.Ident()));
if (symbol == nullptr) {
return {};
}
if (!symbol->is_array) {
return {symbol->root, false};
}
const auto index_count = ctx.exp().size();
bool still_array = false;
if (symbol->is_param_array) {
if (index_count == 0) {
still_array = true;
} else if (index_count <= symbol->dims.size()) {
still_array = true;
}
} else {
still_array = index_count < symbol->dims.size();
}
return {symbol->root, still_array};
}
ExprInfo AnalyzeLValRead(SysYParser::LValContext& ctx) {
const auto shape = DescribeLVal(ctx);
if (!shape.is_array) {
if (shape.root == MemoryRoot::Global) {
analysis_.info.reads_global_memory = true;
} else if (shape.root == MemoryRoot::Param) {
analysis_.info.reads_param_memory = true;
}
}
return {shape.root, shape.is_array};
}
void AnalyzeLValWrite(SysYParser::LValContext& ctx) {
const auto shape = DescribeLVal(ctx);
if (shape.root == MemoryRoot::Global) {
analysis_.info.writes_global_memory = true;
} else if (shape.root == MemoryRoot::Param) {
analysis_.info.writes_param_memory = true;
}
}
const SemanticContext& context_;
DirectFunctionAnalysis& analysis_;
ScopedSymbols symbols_;
};
void PropagateCallEffects(SemanticContext& context,
const std::unordered_map<std::string, DirectFunctionAnalysis>& analyses) {
bool changed = true;
while (changed) {
changed = false;
for (const auto& [name, analysis] : analyses) {
auto next = analysis.info;
std::unordered_set<std::string> callees_seen;
for (const auto& call : analysis.calls) {
if (!call.callee.empty()) {
callees_seen.insert(call.callee);
}
const auto* callee = context.LookupFunction(call.callee);
if (callee == nullptr) {
next.has_unknown_effects = true;
continue;
}
next.has_io = next.has_io || callee->has_io;
next.has_unknown_effects = next.has_unknown_effects || callee->has_unknown_effects;
next.reads_global_memory =
next.reads_global_memory || callee->reads_global_memory;
next.writes_global_memory =
next.writes_global_memory || callee->writes_global_memory;
const auto arg_count = std::min(call.arg_roots.size(), callee->param_is_array.size());
for (std::size_t i = 0; i < arg_count; ++i) {
if (!callee->param_is_array[i]) {
continue;
}
switch (call.arg_roots[i]) {
case MemoryRoot::Global:
next.reads_global_memory =
next.reads_global_memory || callee->reads_param_memory;
next.writes_global_memory =
next.writes_global_memory || callee->writes_param_memory;
break;
case MemoryRoot::Param:
next.reads_param_memory =
next.reads_param_memory || callee->reads_param_memory;
next.writes_param_memory =
next.writes_param_memory || callee->writes_param_memory;
break;
case MemoryRoot::Unknown:
if (callee->reads_param_memory || callee->writes_param_memory) {
next.has_unknown_effects = true;
}
break;
case MemoryRoot::None:
case MemoryRoot::Local:
break;
}
}
}
next.direct_callees.assign(callees_seen.begin(), callees_seen.end());
std::sort(next.direct_callees.begin(), next.direct_callees.end());
auto* current = context.LookupFunction(name);
if (current == nullptr || !SameInfo(next, *current)) {
context.UpsertFunction(name) = std::move(next);
changed = true;
}
}
}
}
bool ReachesSelf(const SemanticContext& context, const std::string& root,
const std::string& current,
std::unordered_set<std::string>& visiting) {
const auto* info = context.LookupFunction(current);
if (info == nullptr) {
return false;
}
if (!visiting.insert(current).second) {
return false;
}
for (const auto& callee : info->direct_callees) {
if (callee == root) {
return true;
}
if (ReachesSelf(context, root, callee, visiting)) {
return true;
}
}
return false;
}
void MarkRecursiveFunctions(SemanticContext& context) {
std::vector<std::string> function_names;
function_names.reserve(context.GetFunctions().size());
for (const auto& [name, info] : context.GetFunctions()) {
if (!info.is_builtin) {
function_names.push_back(name);
}
}
for (const auto& name : function_names) {
std::unordered_set<std::string> visiting;
if (ReachesSelf(context, name, name, visiting)) {
context.UpsertFunction(name).is_recursive = true;
}
}
}
} // namespace
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) {
(void)comp_unit;
return SemanticContext{};
}
SemanticContext context;
RegisterBuiltinFunctions(context);
for (auto* child : comp_unit.children) {
if (auto* decl = dynamic_cast<SysYParser::DeclContext*>(child)) {
CollectGlobalDecl(context, *decl);
} else if (auto* func = dynamic_cast<SysYParser::FuncDefContext*>(child)) {
CollectFunctionSignature(context, *func);
}
}
std::unordered_map<std::string, DirectFunctionAnalysis> analyses;
for (auto* child : comp_unit.children) {
auto* func = dynamic_cast<SysYParser::FuncDefContext*>(child);
if (func == nullptr) {
continue;
}
const auto name = ExpectIdent(func->Ident());
auto* existing = context.LookupFunction(name);
if (existing == nullptr) {
continue;
}
DirectFunctionAnalysis analysis;
analysis.info = *existing;
analysis.info.reads_global_memory = false;
analysis.info.writes_global_memory = false;
analysis.info.reads_param_memory = false;
analysis.info.writes_param_memory = false;
analysis.info.has_io = false;
analysis.info.has_unknown_effects = false;
analysis.info.is_recursive = false;
analysis.info.direct_callees.clear();
FunctionAnalyzer analyzer(context, analysis);
analyzer.Analyze(*func);
analyses.emplace(name, std::move(analysis));
}
for (const auto& [name, analysis] : analyses) {
context.UpsertFunction(name) = analysis.info;
}
PropagateCallEffects(context, analyses);
MarkRecursiveFunctions(context);
return context;
}

Loading…
Cancel
Save