diff --git a/Reference/lab03-code generation-2026.pdf b/Reference/lab3/lab03-code generation-2026.pdf similarity index 100% rename from Reference/lab03-code generation-2026.pdf rename to Reference/lab3/lab03-code generation-2026.pdf diff --git a/Reference/lecture05-instruction selection-169.pdf b/Reference/lab3/lecture05-instruction selection-169.pdf similarity index 100% rename from Reference/lecture05-instruction selection-169.pdf rename to Reference/lab3/lecture05-instruction selection-169.pdf diff --git a/Reference/lecture11-register allocation-part2-169.pdf b/Reference/lab3/lecture11-register allocation-part2-169.pdf similarity index 100% rename from Reference/lecture11-register allocation-part2-169.pdf rename to Reference/lab3/lecture11-register allocation-part2-169.pdf diff --git a/Reference/lab4-6/Lecture06-scalar-opt.pdf b/Reference/lab4-6/Lecture06-scalar-opt.pdf new file mode 100644 index 0000000..139b532 Binary files /dev/null and b/Reference/lab4-6/Lecture06-scalar-opt.pdf differ diff --git a/Reference/lab4-6/lecture08-dependence analysis-I-169.pdf b/Reference/lab4-6/lecture08-dependence analysis-I-169.pdf new file mode 100644 index 0000000..cb6f951 Binary files /dev/null and b/Reference/lab4-6/lecture08-dependence analysis-I-169.pdf differ diff --git a/Reference/lab4-6/lecture09-dependence analysis-II-169.pdf b/Reference/lab4-6/lecture09-dependence analysis-II-169.pdf new file mode 100644 index 0000000..e812947 Binary files /dev/null and b/Reference/lab4-6/lecture09-dependence analysis-II-169.pdf differ diff --git a/Reference/lab4-6/lecture13-loop transformation-I-169.pdf b/Reference/lab4-6/lecture13-loop transformation-I-169.pdf new file mode 100644 index 0000000..36b5068 Binary files /dev/null and b/Reference/lab4-6/lecture13-loop transformation-I-169.pdf differ diff --git a/Reference/lab4-6/lecture14-loop transformation-II-169.pdf b/Reference/lab4-6/lecture14-loop transformation-II-169.pdf new file mode 100644 index 0000000..32d0876 Binary files /dev/null and b/Reference/lab4-6/lecture14-loop transformation-II-169.pdf differ diff --git a/doc/lab4-6.md b/doc/lab4-6.md new file mode 100644 index 0000000..6aacdad --- /dev/null +++ b/doc/lab4-6.md @@ -0,0 +1,220 @@ +# Lab4-Lab6 完成情况说明 + +## 1. 文档目的 + +本文档用于对照 `doc/Lab4-基本标量优化.md`、`doc/Lab5-寄存器分配.md`、`doc/Lab6-并行与循环优化.md`,说明当前编译器在 Lab4、Lab5、Lab6 三个阶段的完成情况,并补充最近一轮围绕比赛级目标所做的修改与优化。 + +## 2. 总体结论 + +从当前代码状态看: + +- Lab4:已完成,且已经超过文档中的基础标量优化要求。 +- Lab5:已完成,且已经形成真实可运行的后端寄存器分配与后端优化链路,不再是示例级后端。 +- Lab6:主体已完成,已经具备比赛可用的单线程循环优化能力;循环并行分析基础已接入,但未实现真正的多线程运行时并行执行。 + +当前主线已经是: + +`SysY -> IR 生成 -> IR 优化 -> MIR lowering -> MIR 优化 -> 寄存器分配 -> 栈帧落地 -> AArch64 汇编输出` + +## 3. 对照完成情况 + +### 3.1 Lab4:基本标量优化 + +Lab4 文档要求的核心是: + +1. 先做 `mem2reg`,把局部变量提升到 SSA。 +2. 实现基础标量优化,如常量折叠、常量传播、DCE、CFG 简化、CSE。 +3. 把这些优化接入 `PassManager`,形成可重复执行的优化流程。 +4. 通过测试确认优化前后语义一致。 + +当前实现情况: + +- `Mem2Reg` 已接入优化流水线,并作为标量优化前置步骤执行。 +- `ConstProp`、`ConstFold`、`DCE`、`CFGSimplify`、`CSE` 均已实现并接入。 +- 在文档要求之外,又新增了 `GVN` 与 `LoadStoreElim`,进一步加强了内存相关和跨块冗余消除能力。 +- `PassManager` 已形成迭代优化流程,而不是单次串行跑一遍后结束。 + +当前 `IR` 流水线在 `src/ir/passes/PassManager.cpp` 中会迭代执行: + +- `RunFunctionInlining` +- `RunConstProp` +- `RunConstFold` +- `RunGVN` +- `RunLoadStoreElim` +- `RunCSE` +- `RunDCE` +- `RunCFGSimplify` +- `RunLICM` +- `RunLoopStrengthReduction` +- `RunLoopFission` +- `RunLoopUnroll` + +完成判断: + +- Lab4 已完成。 +- 严格按文档要求看,不仅满足“基础标量优化”要求,而且已经扩展到了更强的中端优化框架。 + +### 3.2 Lab5:寄存器分配与后端优化 + +Lab5 文档要求的核心是: + +1. MIR 不再固定使用少量物理寄存器,而是先生成虚拟寄存器。 +2. 实现真实寄存器分配,并处理 spill/reload、callee-saved、栈槽等问题。 +3. 接入后端局部优化流程,减少冗余 `copy/move`、冗余 `load/store` 和明显恒等指令。 +4. 在全部测试上验证正确性,并尽量提升生成代码质量。 + +当前实现情况: + +- `Lowering` 已经输出虚拟寄存器 MIR,而不是固定寄存器模板。 +- `RegAlloc` 已实现真实寄存器分配,当前采用图着色风格分配流程,并处理了: + - 活跃性分析 + - 干涉关系 + - `copy` 合并 + - spill 栈槽分配 + - callee-saved 保存恢复信息回填 + - live-across-call 约束 +- `FrameLowering` 与 `AsmPrinter` 已经能够围绕 RA 结果完成最终栈帧和汇编输出。 +- `MIR` 优化流水线已经真正接入主链: + - `PreRA`:`AddressHoisting + Peephole` + - `PostRA`:`Peephole` + +后端局部优化目前已经覆盖: + +- 冗余 `copy` 消除 +- 恒等算术指令消除 +- 条件跳转简化 +- 局部冗余 `load/store` 消除 +- 同块内 store-to-load forwarding +- 同地址重复 `store` 删除 +- 基于 CFG 的跨块 memory dataflow + +最近一轮后端进一步做了两件关键事情: + +1. `MIR Peephole` 从“单基本块局部优化”提升到“带 CFG 数据流的跨块内存优化”。 +2. `MIR Lowering` 调整为按可达 CFG 顺序 lowering,修复了内联后复杂 CFG 下 SSA 值先用后定义导致的 lowering 失败。 + +说明: + +- 曾尝试扩展 `v16-v18` 作为额外 FPR 可分配寄存器,但在浮点重调用样例上出现错误,因此最终回退,保留稳定寄存器集合。这一调整没有留在主线中。 + +完成判断: + +- Lab5 已完成。 +- 与文档中的“最小后端推进到真实后端”目标相比,当前实现已经超过课程最低线。 + +### 3.3 Lab6:并行与循环优化 + +Lab6 文档要求的核心是: + +1. 建立循环分析基础,识别循环头、循环体、前置块、退出块、回边等结构。 +2. 实现有效循环优化,并接入 `PassManager`。 +3. 与 Lab4 标量优化协同工作。 +4. 若希望进一步提升性能,可继续尝试可并行循环识别与并行化。 + +当前实现情况: + +- 已实现 `DominatorTree` 与 `LoopInfo`,可识别自然循环及其层次关系。 +- 已补齐循环变换所需的 `LoopPassUtils`。 +- 已接入的循环优化包括: + - `LICM` + - `LoopStrengthReduction` + - `LoopUnroll` + - `LoopFission` +- `LoopMemoryUtils` 已从较弱的循环地址分析,升级为结合: + - simple induction variable + - affine 地址表达 + - exact-address key + - root-aware alias/mod-ref + - 非逃逸局部对象分析 + 的更强版本。 +- `LICM` 已经可以更积极地 hoist 安全的 `load`,并对同地址的 hoisted load 做去重合并。 + +关于“并行与循环优化”中的并行部分: + +- 当前已经具备可并行循环识别与依赖分析基础。 +- 但没有继续接入真正的多线程并行 runtime,也没有把循环改写为可直接并发执行的运行时调用。 +- 结合文档表述,这部分更像“继续深入方向”,而不是 Lab6 基础完成线的硬要求。 + +完成判断: + +- Lab6 主体已完成。 +- 从比赛级编译器角度,当前已经具备较完整的单线程循环优化能力。 +- 若以“真正运行时并行执行”作为额外目标,则这一部分仍可继续扩展,但不影响当前对 Lab6 主体完成的判断。 + +## 4. 最近一轮修改与优化 + +这一轮围绕比赛级目标,主要新增和加强了以下内容。 + +### 4.1 中端新增与增强 + +- 新增 `GVN`,用于更大范围复用纯表达式结果。 +- 新增 `LoadStoreElim`,支持跨块冗余 `load` 消除、store-to-load forwarding、死 `store` 删除。 +- 强化 `LoopMemoryUtils`,让循环内存优化不再只依赖很保守的规则。 +- 强化 `LICM`,使其对安全 `load` 的外提更积极,并能对 hoisted load 做合并。 +- 新增 IR 级小函数内联,使收益更早反馈到 `ConstProp`、`GVN`、`DCE`、`LICM` 等中端优化。 + +### 4.2 后端新增与增强 + +- `MIR Peephole` 从局部块内优化,扩展到基于 CFG 的跨块内存状态传播。 +- `Call` 现在会按源 `IR Function` 的 effect 信息进行 `read/write` 边界判断,不再统一按最粗粒度处理。 +- 修复了内联后复杂控制流下 MIR lowering 的块顺序问题。 +- 完整回归后保留稳定 FPR 集合,放弃了不稳定的 `v16-v18` 扩容方案。 + +### 4.3 这轮优化的实际意义 + +这意味着最近的修改已经不只是“补课程实验功能”,而是开始面向比赛收益去提升: + +- 中端:更强的冗余消除、内存优化、函数级优化、循环优化协同 +- 后端:更强的 `copy/load/store` 消除与更稳定的 RA 后局部优化 + +## 5. 当前验证情况 + +本次回归中,已经完成以下验证: + +### 5.1 全量正确性回归 + +执行: + +```bash +./scripts/lab3_build_test.sh test/test_case/functional test/test_case/h_functional +``` + +结果: + +- `134 PASS / 0 FAIL / total 134` + +这说明当前 Lab4-Lab6 优化接入后,完整 `asm` 路径在 `functional + h_functional` 上保持正确。 + +### 5.2 性能热点抽测 + +执行并通过: + +- `test/test_case/h_performance/fft2.sy` +- `test/test_case/h_performance/matmul3.sy` +- `test/test_case/h_performance/transpose2.sy` +- `test/test_case/h_performance/gameoflife-gosper.sy` + +这些样例覆盖了: + +- 重循环 +- 重访存 +- 浮点运算 +- 矩阵访问 +- 较复杂控制流 + +可以说明当前新增优化至少在一批代表性性能样例上保持了可运行与结果正确。 + +## 6. 结论 + +综合来看,当前编译器在 Lab4、Lab5、Lab6 上的完成情况可以概括为: + +- Lab4:完成,并已扩展到更强的中端优化。 +- Lab5:完成,并已形成真实可运行的后端优化链路。 +- Lab6:主体完成,单线程循环优化能力已经达到比赛可用水平。 + +如果后续继续朝比赛方向推进,最值得继续做的事情不再是“补实验是否完成”,而是: + +1. 针对 `h_performance` 做系统 profiling。 +2. 按性能热点继续优化中端内存/循环变换。 +3. 继续提升后端 spill、copy、访存质量。 +4. 如需继续深入 Lab6,可进一步尝试真正的并行 runtime 接入。 diff --git a/include/ir/Analysis.h b/include/ir/Analysis.h new file mode 100644 index 0000000..4c0aadf --- /dev/null +++ b/include/ir/Analysis.h @@ -0,0 +1,73 @@ +#pragma once + +#include "ir/IR.h" + +#include +#include +#include +#include +#include + +namespace ir { + +class DominatorTree { + public: + explicit DominatorTree(Function& function); + + void Recalculate(); + + Function& GetFunction() const { return *function_; } + bool IsReachable(BasicBlock* block) const; + bool Dominates(BasicBlock* dom, BasicBlock* node) const; + bool Dominates(Instruction* dom, Instruction* user) const; + BasicBlock* GetIDom(BasicBlock* block) const; + const std::vector& GetChildren(BasicBlock* block) const; + const std::vector& GetReversePostOrder() const { + return reverse_post_order_; + } + + private: + Function* function_ = nullptr; + std::vector reverse_post_order_; + std::unordered_map block_index_; + std::vector> dominates_; + std::unordered_map immediate_dominator_; + std::unordered_map> dom_children_; +}; + +struct Loop { + BasicBlock* header = nullptr; + std::unordered_set blocks; + std::vector block_list; + std::vector latches; + std::vector exiting_blocks; + std::vector exit_blocks; + BasicBlock* preheader = nullptr; + Loop* parent = nullptr; + std::vector subloops; + + bool Contains(BasicBlock* block) const; + bool Contains(const Loop* other) const; + bool IsInnermost() const; +}; + +class LoopInfo { + public: + LoopInfo(Function& function, const DominatorTree& dom_tree); + + void Recalculate(); + + const std::vector>& GetLoops() const { return loops_; } + std::vector GetTopLevelLoops() const; + std::vector GetLoopsInPostOrder() const; + Loop* GetLoopFor(BasicBlock* block) const; + + private: + Function* function_ = nullptr; + const DominatorTree* dom_tree_ = nullptr; + std::vector> loops_; + std::vector top_level_loops_; + std::unordered_map block_to_loop_; +}; + +} // namespace ir diff --git a/include/ir/IR.h b/include/ir/IR.h index 7d04ba6..5ba3f0f 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -612,6 +612,40 @@ class Function : public Value { bool IsExternal() const { return is_external_; } void SetExternal(bool is_external) { is_external_ = is_external; } + void SetEffectInfo(bool reads_global_memory, bool writes_global_memory, + bool reads_param_memory, bool writes_param_memory, + bool has_io, bool has_unknown_effects, bool is_recursive) { + reads_global_memory_ = reads_global_memory; + writes_global_memory_ = writes_global_memory; + reads_param_memory_ = reads_param_memory; + writes_param_memory_ = writes_param_memory; + has_io_ = has_io; + has_unknown_effects_ = has_unknown_effects; + is_recursive_ = is_recursive; + } + bool ReadsGlobalMemory() const { return reads_global_memory_; } + bool WritesGlobalMemory() const { return writes_global_memory_; } + bool ReadsParamMemory() const { return reads_param_memory_; } + bool WritesParamMemory() const { return writes_param_memory_; } + bool HasIO() const { return has_io_; } + bool HasUnknownEffects() const { return has_unknown_effects_; } + bool IsRecursive() const { return is_recursive_; } + bool MayReadMemory() const { + return has_unknown_effects_ || reads_global_memory_ || writes_global_memory_ || + reads_param_memory_ || writes_param_memory_; + } + bool MayWriteMemory() const { + return has_unknown_effects_ || writes_global_memory_ || writes_param_memory_; + } + bool HasObservableSideEffects() const { + return has_unknown_effects_ || writes_global_memory_ || + writes_param_memory_ || has_io_; + } + bool CanDiscardUnusedCall() const { + return !has_unknown_effects_ && !writes_global_memory_ && + !writes_param_memory_ && !has_io_ && !is_recursive_; + } + BasicBlock* GetEntryBlock() const { return entry_; } BasicBlock* GetEntry() const { return entry_; } void SetEntryBlock(BasicBlock* bb) { entry_ = bb; } @@ -633,6 +667,13 @@ class Function : public Value { std::vector> param_types_; std::vector> arguments_; bool is_external_ = false; + bool reads_global_memory_ = false; + bool writes_global_memory_ = false; + bool reads_param_memory_ = false; + bool writes_param_memory_ = false; + bool has_io_ = false; + bool has_unknown_effects_ = true; + bool is_recursive_ = false; BasicBlock* entry_ = nullptr; std::vector> blocks_; }; diff --git a/include/ir/PassManager.h b/include/ir/PassManager.h index 71cabf0..54b8d27 100644 --- a/include/ir/PassManager.h +++ b/include/ir/PassManager.h @@ -5,6 +5,18 @@ namespace ir { class Module; void RunMem2Reg(Module& module); +bool RunConstFold(Module& module); +bool RunConstProp(Module& module); +bool RunFunctionInlining(Module& module); +bool RunCSE(Module& module); +bool RunGVN(Module& module); +bool RunLoadStoreElim(Module& module); +bool RunDCE(Module& module); +bool RunCFGSimplify(Module& module); +bool RunLICM(Module& module); +bool RunLoopStrengthReduction(Module& module); +bool RunLoopUnroll(Module& module); +bool RunLoopFission(Module& module); void RunIRPassPipeline(Module& module); -} // namespace ir \ No newline at end of file +} // namespace ir diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index 333c60b..193ea0a 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -53,6 +53,7 @@ class IRGenImpl final : public SysYBaseVisitor { [[noreturn]] void ThrowError(const antlr4::ParserRuleContext* ctx, const std::string& message) const; + void ApplyFunctionSema(const std::string& name, ir::Function& function); void RegisterBuiltinFunctions(); void PredeclareTopLevel(SysYParser::CompUnitContext& ctx); void PredeclareFunction(SysYParser::FuncDefContext& ctx); @@ -124,6 +125,11 @@ class IRGenImpl final : public SysYBaseVisitor { ConstantValue EvalConstPrimaryExp(SysYParser::PrimaryExpContext& ctx); ConstantValue EvalConstLVal(SysYParser::LValContext& ctx); ConstantValue ConvertConst(ConstantValue value, SemanticType target_type) const; + bool IsZeroConstant(const ConstantValue& value) const; + bool IsExplicitZeroConstInitVal(SysYParser::ConstInitValContext* ctx, + SemanticType base_type); + bool IsExplicitZeroInitVal(SysYParser::InitValContext* ctx, + SemanticType base_type); std::vector FlattenConstInitVal(SysYParser::ConstInitValContext* ctx, SemanticType base_type, diff --git a/include/mir/MIR.h b/include/mir/MIR.h index 2af3c8a..fffeb73 100644 --- a/include/mir/MIR.h +++ b/include/mir/MIR.h @@ -277,11 +277,14 @@ class MachineModule { std::vector> functions_; }; -std::unique_ptr LowerToMIR(const ir::Module& module); -void RunAddressHoisting(MachineModule& module); -void RunRegAlloc(MachineModule& module); -void RunFrameLowering(MachineModule& module); -void PrintAsm(const MachineModule& module, std::ostream& os); +std::unique_ptr LowerToMIR(const ir::Module& module); +bool RunPeephole(MachineModule& module); +void RunMIRPreRegAllocPassPipeline(MachineModule& module); +void RunMIRPostRegAllocPassPipeline(MachineModule& module); +void RunAddressHoisting(MachineModule& module); +void RunRegAlloc(MachineModule& module); +void RunFrameLowering(MachineModule& module); +void PrintAsm(const MachineModule& module, std::ostream& os); } // namespace mir diff --git a/include/sem/Sema.h b/include/sem/Sema.h index 37b299b..9bd0e3c 100644 --- a/include/sem/Sema.h +++ b/include/sem/Sema.h @@ -1,7 +1,94 @@ #pragma once #include "SysYParser.h" +#include "sem/SymbolTable.h" -class SemanticContext {}; +#include +#include +#include -SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); +struct GlobalSemanticInfo { + SemanticType type = SemanticType::Int; + bool is_const = false; + bool is_array = false; + std::vector dims; +}; + +struct FunctionSemanticInfo { + SemanticType return_type = SemanticType::Void; + std::vector param_is_array; + bool is_builtin = false; + bool is_defined = false; + bool reads_global_memory = false; + bool writes_global_memory = false; + bool reads_param_memory = false; + bool writes_param_memory = false; + bool has_io = false; + bool has_unknown_effects = true; + bool is_recursive = false; + std::vector direct_callees; + + bool MayReadMemory() const { + return has_unknown_effects || reads_global_memory || writes_global_memory || + reads_param_memory || writes_param_memory; + } + + bool MayWriteMemory() const { + return has_unknown_effects || writes_global_memory || writes_param_memory; + } + + bool HasObservableSideEffects() const { + return has_unknown_effects || writes_global_memory || writes_param_memory || + has_io; + } + + bool CanDiscardUnusedCall() const { + return !has_unknown_effects && !writes_global_memory && + !writes_param_memory && !has_io && !is_recursive; + } +}; + +class SemanticContext { + public: + FunctionSemanticInfo* LookupFunction(const std::string& name) { + auto it = functions_.find(name); + return it == functions_.end() ? nullptr : &it->second; + } + + const FunctionSemanticInfo* LookupFunction(const std::string& name) const { + auto it = functions_.find(name); + return it == functions_.end() ? nullptr : &it->second; + } + + GlobalSemanticInfo* LookupGlobal(const std::string& name) { + auto it = globals_.find(name); + return it == globals_.end() ? nullptr : &it->second; + } + + const GlobalSemanticInfo* LookupGlobal(const std::string& name) const { + auto it = globals_.find(name); + return it == globals_.end() ? nullptr : &it->second; + } + + FunctionSemanticInfo& UpsertFunction(const std::string& name) { + return functions_[name]; + } + + GlobalSemanticInfo& UpsertGlobal(const std::string& name) { + return globals_[name]; + } + + const std::unordered_map& GetFunctions() const { + return functions_; + } + + const std::unordered_map& GetGlobals() const { + return globals_; + } + + private: + std::unordered_map functions_; + std::unordered_map globals_; +}; + +SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); diff --git a/include/sem/SymbolTable.h b/include/sem/SymbolTable.h index 3c9e605..695372d 100644 --- a/include/sem/SymbolTable.h +++ b/include/sem/SymbolTable.h @@ -48,6 +48,7 @@ struct SymbolEntry { std::optional const_scalar; std::vector const_array; + bool const_array_all_zero = false; FunctionTypeInfo function_type; }; diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index 9f8d3da..c72c7fb 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -54,6 +54,10 @@ void User::RemoveOperand(size_t index) { } operands_.erase(operands_.begin() + static_cast(index)); for (size_t i = index; i < operands_.size(); ++i) { + if (auto* value = operands_[i].GetValue()) { + value->RemoveUse(this, i + 1); + value->AddUse(this, i); + } operands_[i].SetOperandIndex(i); } } diff --git a/src/ir/analysis/DominatorTree.cpp b/src/ir/analysis/DominatorTree.cpp index eaf7269..34771e6 100644 --- a/src/ir/analysis/DominatorTree.cpp +++ b/src/ir/analysis/DominatorTree.cpp @@ -1,4 +1,167 @@ -// 支配树分析: -// - 构建/查询 Dominator Tree 及相关关系 -// - 为 mem2reg、CFG 优化与循环分析提供基础能力 +#include "ir/Analysis.h" +#include +#include + +namespace ir { +namespace { + +std::vector BuildReversePostOrder(Function& function) { + std::vector post_order; + auto* entry = function.GetEntryBlock(); + if (!entry) { + return post_order; + } + + std::unordered_set visited; + std::function dfs = [&](BasicBlock* block) { + if (!block || !visited.insert(block).second) { + return; + } + for (auto* succ : block->GetSuccessors()) { + dfs(succ); + } + post_order.push_back(block); + }; + dfs(entry); + std::reverse(post_order.begin(), post_order.end()); + return post_order; +} + +} // namespace + +DominatorTree::DominatorTree(Function& function) : function_(&function) { + Recalculate(); +} + +void DominatorTree::Recalculate() { + reverse_post_order_ = BuildReversePostOrder(*function_); + block_index_.clear(); + dominates_.clear(); + immediate_dominator_.clear(); + dom_children_.clear(); + + const auto num_blocks = reverse_post_order_.size(); + for (std::size_t i = 0; i < num_blocks; ++i) { + block_index_.emplace(reverse_post_order_[i], i); + } + if (num_blocks == 0) { + return; + } + + dominates_.assign(num_blocks, std::vector(num_blocks, 1)); + dominates_[0].assign(num_blocks, 0); + dominates_[0][0] = 1; + + bool changed = true; + while (changed) { + changed = false; + for (std::size_t i = 1; i < num_blocks; ++i) { + auto* block = reverse_post_order_[i]; + std::vector next(num_blocks, 1); + bool has_reachable_pred = false; + for (auto* pred : block->GetPredecessors()) { + auto pred_it = block_index_.find(pred); + if (pred_it == block_index_.end()) { + continue; + } + has_reachable_pred = true; + const auto& pred_dom = dominates_[pred_it->second]; + for (std::size_t bit = 0; bit < num_blocks; ++bit) { + next[bit] &= pred_dom[bit]; + } + } + if (!has_reachable_pred) { + next.assign(num_blocks, 0); + } + next[i] = 1; + if (next != dominates_[i]) { + dominates_[i] = std::move(next); + changed = true; + } + } + } + + for (std::size_t i = 1; i < num_blocks; ++i) { + auto* block = reverse_post_order_[i]; + BasicBlock* idom = nullptr; + for (std::size_t candidate = 0; candidate < num_blocks; ++candidate) { + if (candidate == i || !dominates_[i][candidate]) { + continue; + } + auto* candidate_block = reverse_post_order_[candidate]; + bool immediate = true; + for (std::size_t other = 0; other < num_blocks; ++other) { + if (other == i || other == candidate || !dominates_[i][other]) { + continue; + } + if (Dominates(reverse_post_order_[other], candidate_block)) { + immediate = false; + break; + } + } + if (immediate) { + idom = candidate_block; + break; + } + } + immediate_dominator_.emplace(block, idom); + if (idom) { + dom_children_[idom].push_back(block); + } + } +} + +bool DominatorTree::IsReachable(BasicBlock* block) const { + return block != nullptr && block_index_.find(block) != block_index_.end(); +} + +bool DominatorTree::Dominates(BasicBlock* dom, BasicBlock* node) const { + if (!dom || !node) { + return false; + } + const auto dom_it = block_index_.find(dom); + const auto node_it = block_index_.find(node); + if (dom_it == block_index_.end() || node_it == block_index_.end()) { + return false; + } + return dominates_[node_it->second][dom_it->second] != 0; +} + +bool DominatorTree::Dominates(Instruction* dom, Instruction* user) const { + if (!dom || !user) { + return false; + } + if (dom == user) { + return true; + } + + auto* dom_block = dom->GetParent(); + auto* user_block = user->GetParent(); + if (dom_block != user_block) { + return Dominates(dom_block, user_block); + } + + for (const auto& inst_ptr : dom_block->GetInstructions()) { + if (inst_ptr.get() == dom) { + return true; + } + if (inst_ptr.get() == user) { + return false; + } + } + return false; +} + +BasicBlock* DominatorTree::GetIDom(BasicBlock* block) const { + auto it = immediate_dominator_.find(block); + return it == immediate_dominator_.end() ? nullptr : it->second; +} + +const std::vector& DominatorTree::GetChildren(BasicBlock* block) const { + static const std::vector kEmpty; + auto it = dom_children_.find(block); + return it == dom_children_.end() ? kEmpty : it->second; +} + +} // namespace ir diff --git a/src/ir/analysis/LoopInfo.cpp b/src/ir/analysis/LoopInfo.cpp index 9793dc6..f7612d8 100644 --- a/src/ir/analysis/LoopInfo.cpp +++ b/src/ir/analysis/LoopInfo.cpp @@ -1,4 +1,214 @@ -// 循环分析: -// - 识别循环结构与层级关系 -// - 为后续优化(可选)提供循环信息 +#include "ir/Analysis.h" +#include +#include + +namespace ir { +namespace { + +std::vector CollectNaturalLoopBlocks(BasicBlock* header, + BasicBlock* latch) { + std::vector stack{latch}; + std::unordered_set loop_blocks{header, latch}; + while (!stack.empty()) { + auto* block = stack.back(); + stack.pop_back(); + for (auto* pred : block->GetPredecessors()) { + if (!pred || !loop_blocks.insert(pred).second) { + continue; + } + stack.push_back(pred); + } + } + return {loop_blocks.begin(), loop_blocks.end()}; +} + +} // namespace + +bool Loop::Contains(BasicBlock* block) const { + return block != nullptr && blocks.find(block) != blocks.end(); +} + +bool Loop::Contains(const Loop* other) const { + if (!other) { + return false; + } + for (auto* block : other->blocks) { + if (!Contains(block)) { + return false; + } + } + return true; +} + +bool Loop::IsInnermost() const { return subloops.empty(); } + +LoopInfo::LoopInfo(Function& function, const DominatorTree& dom_tree) + : function_(&function), dom_tree_(&dom_tree) { + Recalculate(); +} + +void LoopInfo::Recalculate() { + loops_.clear(); + top_level_loops_.clear(); + block_to_loop_.clear(); + + std::unordered_map loops_by_header; + for (auto* block : dom_tree_->GetReversePostOrder()) { + for (auto* succ : block->GetSuccessors()) { + if (!dom_tree_->Dominates(succ, block)) { + continue; + } + + Loop* loop = nullptr; + auto it = loops_by_header.find(succ); + if (it == loops_by_header.end()) { + auto new_loop = std::make_unique(); + new_loop->header = succ; + loop = new_loop.get(); + loops_.push_back(std::move(new_loop)); + loops_by_header.emplace(succ, loop); + } else { + loop = it->second; + } + + if (std::find(loop->latches.begin(), loop->latches.end(), block) == + loop->latches.end()) { + loop->latches.push_back(block); + } + for (auto* natural_block : CollectNaturalLoopBlocks(succ, block)) { + loop->blocks.insert(natural_block); + } + } + } + + std::unordered_map function_order; + for (std::size_t i = 0; i < function_->GetBlocks().size(); ++i) { + function_order.emplace(function_->GetBlocks()[i].get(), i); + } + + for (const auto& loop_ptr : loops_) { + auto& loop = *loop_ptr; + loop.block_list.clear(); + loop.exiting_blocks.clear(); + loop.exit_blocks.clear(); + loop.subloops.clear(); + loop.parent = nullptr; + + for (const auto& block_ptr : function_->GetBlocks()) { + if (loop.Contains(block_ptr.get())) { + loop.block_list.push_back(block_ptr.get()); + } + } + + std::sort(loop.latches.begin(), loop.latches.end(), + [&](BasicBlock* lhs, BasicBlock* rhs) { + return function_order[lhs] < function_order[rhs]; + }); + + std::vector outside_preds; + for (auto* pred : loop.header->GetPredecessors()) { + if (!loop.Contains(pred)) { + outside_preds.push_back(pred); + } + } + if (outside_preds.size() == 1 && + outside_preds.front()->GetSuccessors().size() == 1) { + loop.preheader = outside_preds.front(); + } else { + loop.preheader = nullptr; + } + + std::unordered_set exiting_seen; + std::unordered_set exit_seen; + for (auto* block : loop.block_list) { + for (auto* succ : block->GetSuccessors()) { + if (loop.Contains(succ)) { + continue; + } + if (exiting_seen.insert(block).second) { + loop.exiting_blocks.push_back(block); + } + if (exit_seen.insert(succ).second) { + loop.exit_blocks.push_back(succ); + } + } + } + + std::sort(loop.exiting_blocks.begin(), loop.exiting_blocks.end(), + [&](BasicBlock* lhs, BasicBlock* rhs) { + return function_order[lhs] < function_order[rhs]; + }); + std::sort(loop.exit_blocks.begin(), loop.exit_blocks.end(), + [&](BasicBlock* lhs, BasicBlock* rhs) { + return function_order[lhs] < function_order[rhs]; + }); + } + + for (const auto& loop_ptr : loops_) { + auto* loop = loop_ptr.get(); + Loop* parent = nullptr; + for (const auto& candidate_ptr : loops_) { + auto* candidate = candidate_ptr.get(); + if (candidate == loop || !candidate->Contains(loop)) { + continue; + } + if (!parent || candidate->blocks.size() < parent->blocks.size()) { + parent = candidate; + } + } + loop->parent = parent; + if (parent) { + parent->subloops.push_back(loop); + } else { + top_level_loops_.push_back(loop); + } + } + + auto loop_order = [&](Loop* lhs, Loop* rhs) { + return function_order[lhs->header] < function_order[rhs->header]; + }; + std::sort(top_level_loops_.begin(), top_level_loops_.end(), loop_order); + for (const auto& loop_ptr : loops_) { + std::sort(loop_ptr->subloops.begin(), loop_ptr->subloops.end(), loop_order); + } + + for (const auto& block_ptr : function_->GetBlocks()) { + Loop* innermost = nullptr; + for (const auto& loop_ptr : loops_) { + auto* loop = loop_ptr.get(); + if (!loop->Contains(block_ptr.get())) { + continue; + } + if (!innermost || loop->blocks.size() < innermost->blocks.size()) { + innermost = loop; + } + } + if (innermost) { + block_to_loop_.emplace(block_ptr.get(), innermost); + } + } +} + +std::vector LoopInfo::GetTopLevelLoops() const { return top_level_loops_; } + +std::vector LoopInfo::GetLoopsInPostOrder() const { + std::vector ordered; + std::function dfs = [&](Loop* loop) { + for (auto* subloop : loop->subloops) { + dfs(subloop); + } + ordered.push_back(loop); + }; + for (auto* loop : top_level_loops_) { + dfs(loop); + } + return ordered; +} + +Loop* LoopInfo::GetLoopFor(BasicBlock* block) const { + auto it = block_to_loop_.find(block); + return it == block_to_loop_.end() ? nullptr : it->second; +} + +} // namespace ir diff --git a/src/ir/passes/CFGSimplify.cpp b/src/ir/passes/CFGSimplify.cpp index 3779397..66fa11b 100644 --- a/src/ir/passes/CFGSimplify.cpp +++ b/src/ir/passes/CFGSimplify.cpp @@ -1,4 +1,107 @@ -// CFG 简化: -// - 删除不可达块、合并空块、简化分支等 -// - 改善 IR 结构,便于后续优化与后端生成 +#include "ir/PassManager.h" +#include "ir/IR.h" +#include "PassUtils.h" + +#include + +namespace ir { +namespace { + +bool TryGetConstBranchTarget(CondBrInst* br, BasicBlock*& target, BasicBlock*& removed) { + if (!br) { + return false; + } + auto* then_block = br->GetThenBlock(); + auto* else_block = br->GetElseBlock(); + if (then_block == else_block) { + target = then_block; + removed = nullptr; + return true; + } + if (auto* cond = dyncast(br->GetCondition())) { + target = cond->GetValue() ? then_block : else_block; + removed = cond->GetValue() ? else_block : then_block; + return true; + } + return false; +} + +bool SimplifyBlockTerminator(BasicBlock* block) { + if (!block || block->GetInstructions().empty()) { + return false; + } + auto* term = block->GetInstructions().back().get(); + auto* condbr = dyncast(term); + if (!condbr) { + return false; + } + + BasicBlock* target = nullptr; + BasicBlock* removed = nullptr; + if (!TryGetConstBranchTarget(condbr, target, removed)) { + return false; + } + + if (removed) { + passutils::RemoveIncomingFromSuccessor(removed, block); + removed->RemovePredecessor(block); + block->RemoveSuccessor(removed); + } + passutils::ReplaceTerminatorWithBr(block, target); + return true; +} + +bool SimplifyPhiNodes(Function& function) { + bool changed = false; + for (const auto& block_ptr : function.GetBlocks()) { + bool local_changed = true; + while (local_changed) { + local_changed = false; + for (const auto& inst_ptr : block_ptr->GetInstructions()) { + auto* phi = dyncast(inst_ptr.get()); + if (!phi) { + break; + } + if (!passutils::SimplifyPhiInst(phi)) { + continue; + } + local_changed = true; + changed = true; + break; + } + } + } + return changed; +} + +bool RunCFGSimplifyOnFunction(Function& function) { + if (function.IsExternal() || function.GetEntryBlock() == nullptr) { + return false; + } + + bool changed = false; + + for (const auto& block_ptr : function.GetBlocks()) { + changed |= SimplifyBlockTerminator(block_ptr.get()); + } + + changed |= passutils::RemoveUnreachableBlocks(function); + changed |= SimplifyPhiNodes(function); + + return changed; +} + +} // namespace + +bool RunCFGSimplify(Module& module) { + bool changed = false; + for (const auto& function : module.GetFunctions()) { + if (function) { + changed |= RunCFGSimplifyOnFunction(*function); + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/CMakeLists.txt b/src/ir/passes/CMakeLists.txt index 98867f5..3466527 100644 --- a/src/ir/passes/CMakeLists.txt +++ b/src/ir/passes/CMakeLists.txt @@ -3,9 +3,16 @@ add_library(ir_passes STATIC Mem2Reg.cpp ConstFold.cpp ConstProp.cpp + Inline.cpp CSE.cpp + GVN.cpp + LoadStoreElim.cpp DCE.cpp CFGSimplify.cpp + LICM.cpp + LoopStrengthReduction.cpp + LoopUnroll.cpp + LoopFission.cpp ) target_link_libraries(ir_passes PUBLIC diff --git a/src/ir/passes/CSE.cpp b/src/ir/passes/CSE.cpp index 4b24dd0..c7c02f3 100644 --- a/src/ir/passes/CSE.cpp +++ b/src/ir/passes/CSE.cpp @@ -1,4 +1,141 @@ -// 公共子表达式消除(CSE): -// - 识别并复用重复计算的等价表达式 -// - 典型放置在 ConstFold 之后、DCE 之前 -// - 当前为 Lab4 的框架占位,具体算法由实验实现 +#include "ir/PassManager.h" + +#include "ir/IR.h" +#include "PassUtils.h" + +#include +#include +#include +#include +#include + +namespace ir { +namespace { + +struct ExprKey { + Opcode opcode = Opcode::Add; + std::vector operands; + + bool operator==(const ExprKey& rhs) const { + return opcode == rhs.opcode && operands == rhs.operands; + } +}; + +struct ExprKeyHash { + std::size_t operator()(const ExprKey& key) const { + std::size_t h = static_cast(key.opcode); + for (auto operand : key.operands) { + h ^= std::hash{}(operand) + 0x9e3779b9 + (h << 6) + (h >> 2); + } + return h; + } +}; + +bool IsSupportedCSEInstruction(Instruction* inst) { + if (!inst || inst->IsVoid()) { + return false; + } + switch (inst->GetOpcode()) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::Div: + case Opcode::Rem: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: + case Opcode::FRem: + case Opcode::And: + case Opcode::Or: + case Opcode::Xor: + case Opcode::Shl: + case Opcode::AShr: + case Opcode::LShr: + case Opcode::ICmpEQ: + case Opcode::ICmpNE: + case Opcode::ICmpLT: + case Opcode::ICmpGT: + case Opcode::ICmpLE: + case Opcode::ICmpGE: + case Opcode::FCmpEQ: + case Opcode::FCmpNE: + case Opcode::FCmpLT: + case Opcode::FCmpGT: + case Opcode::FCmpLE: + case Opcode::FCmpGE: + case Opcode::Neg: + case Opcode::Not: + case Opcode::FNeg: + case Opcode::FtoI: + case Opcode::IToF: + case Opcode::Zext: + return true; + default: + return false; + } +} + +ExprKey BuildExprKey(Instruction* inst) { + ExprKey key; + key.opcode = inst->GetOpcode(); + key.operands.reserve(inst->GetNumOperands()); + for (size_t i = 0; i < inst->GetNumOperands(); ++i) { + key.operands.push_back( + reinterpret_cast(inst->GetOperand(i))); + } + if (inst->GetNumOperands() == 2 && passutils::IsCommutativeOpcode(inst->GetOpcode()) && + key.operands[1] < key.operands[0]) { + std::swap(key.operands[0], key.operands[1]); + } + return key; +} + +bool RunCSEOnFunction(Function& function) { + if (function.IsExternal()) { + return false; + } + + bool changed = false; + for (const auto& block_ptr : function.GetBlocks()) { + std::unordered_map available_exprs; + std::vector to_remove; + for (const auto& inst_ptr : block_ptr->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (!IsSupportedCSEInstruction(inst)) { + continue; + } + + const auto key = BuildExprKey(inst); + auto it = available_exprs.find(key); + if (it == available_exprs.end()) { + available_exprs.emplace(key, inst); + continue; + } + + inst->ReplaceAllUsesWith(it->second); + to_remove.push_back(inst); + changed = true; + } + + for (auto* inst : to_remove) { + block_ptr->EraseInstruction(inst); + } + } + + return changed; +} + +} // namespace + +bool RunCSE(Module& module) { + bool changed = false; + for (const auto& function : module.GetFunctions()) { + if (function) { + changed |= RunCSEOnFunction(*function); + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/ConstFold.cpp b/src/ir/passes/ConstFold.cpp index 19f2d43..f003127 100644 --- a/src/ir/passes/ConstFold.cpp +++ b/src/ir/passes/ConstFold.cpp @@ -1,4 +1,469 @@ -// IR 常量折叠: -// - 折叠可判定的常量表达式 -// - 简化常量控制流分支(按实现范围裁剪) +#include "ir/PassManager.h" +#include "ir/IR.h" +#include "PassUtils.h" + +#include +#include +#include +#include + +namespace ir { +namespace { + +Value* GetInt32Const(Context& ctx, std::int32_t value) { + return ctx.GetConstInt(static_cast(value)); +} + +Value* GetBoolConst(Context& ctx, bool value) { return ctx.GetConstBool(value); } + +Value* GetFloatConst(float value) { + return new ConstantFloat(Type::GetFloatType(), value); +} + +bool TryGetInt32(Value* value, std::int32_t& out) { + if (auto* ci = dyncast(value)) { + out = static_cast(ci->GetValue()); + return true; + } + return false; +} + +bool TryGetBool(Value* value, bool& out) { + if (auto* cb = dyncast(value)) { + out = cb->GetValue(); + return true; + } + return false; +} + +bool TryGetFloat(Value* value, float& out) { + if (auto* cf = dyncast(value)) { + out = cf->GetValue(); + return true; + } + return false; +} + +bool IsZeroValue(Value* value) { + std::int32_t i32 = 0; + bool i1 = false; + float f32 = 0.0f; + return (TryGetInt32(value, i32) && i32 == 0) || (TryGetBool(value, i1) && !i1) || + (TryGetFloat(value, f32) && passutils::FloatBits(f32) == 0); +} + +bool IsOneValue(Value* value) { + std::int32_t i32 = 0; + bool i1 = false; + float f32 = 0.0f; + return (TryGetInt32(value, i32) && i32 == 1) || (TryGetBool(value, i1) && i1) || + (TryGetFloat(value, f32) && + passutils::FloatBits(f32) == passutils::FloatBits(1.0f)); +} + +bool IsAllOnesInt(Value* value) { + std::int32_t i32 = 0; + return TryGetInt32(value, i32) && i32 == -1; +} + +std::int32_t WrapInt32(std::uint32_t value) { + return static_cast(value); +} + +Value* FoldBinary(Context& ctx, BinaryInst* inst) { + const auto opcode = inst->GetOpcode(); + auto* lhs = inst->GetLhs(); + auto* rhs = inst->GetRhs(); + + std::int32_t lhs_i32 = 0; + std::int32_t rhs_i32 = 0; + bool lhs_i1 = false; + bool rhs_i1 = false; + float lhs_f32 = 0.0f; + float rhs_f32 = 0.0f; + + const bool has_lhs_i32 = TryGetInt32(lhs, lhs_i32); + const bool has_rhs_i32 = TryGetInt32(rhs, rhs_i32); + const bool has_lhs_i1 = TryGetBool(lhs, lhs_i1); + const bool has_rhs_i1 = TryGetBool(rhs, rhs_i1); + const bool has_lhs_f32 = TryGetFloat(lhs, lhs_f32); + const bool has_rhs_f32 = TryGetFloat(rhs, rhs_f32); + + if (has_lhs_i32 && has_rhs_i32) { + switch (opcode) { + case Opcode::Add: + return GetInt32Const( + ctx, WrapInt32(static_cast(lhs_i32) + + static_cast(rhs_i32))); + case Opcode::Sub: + return GetInt32Const( + ctx, WrapInt32(static_cast(lhs_i32) - + static_cast(rhs_i32))); + case Opcode::Mul: + return GetInt32Const( + ctx, WrapInt32(static_cast(lhs_i32) * + static_cast(rhs_i32))); + case Opcode::Div: + if (rhs_i32 == 0 || + (lhs_i32 == std::numeric_limits::min() && rhs_i32 == -1)) { + return nullptr; + } + return GetInt32Const(ctx, lhs_i32 / rhs_i32); + case Opcode::Rem: + if (rhs_i32 == 0 || + (lhs_i32 == std::numeric_limits::min() && rhs_i32 == -1)) { + return nullptr; + } + return GetInt32Const(ctx, lhs_i32 % rhs_i32); + case Opcode::And: + return GetInt32Const(ctx, lhs_i32 & rhs_i32); + case Opcode::Or: + return GetInt32Const(ctx, lhs_i32 | rhs_i32); + case Opcode::Xor: + return GetInt32Const(ctx, lhs_i32 ^ rhs_i32); + case Opcode::Shl: + if (rhs_i32 < 0 || rhs_i32 >= 32) { + return nullptr; + } + return GetInt32Const(ctx, WrapInt32(static_cast(lhs_i32) + << rhs_i32)); + case Opcode::AShr: + if (rhs_i32 < 0 || rhs_i32 >= 32) { + return nullptr; + } + return GetInt32Const(ctx, lhs_i32 >> rhs_i32); + case Opcode::LShr: + if (rhs_i32 < 0 || rhs_i32 >= 32) { + return nullptr; + } + return GetInt32Const( + ctx, + WrapInt32(static_cast(lhs_i32) >> rhs_i32)); + case Opcode::ICmpEQ: + return GetBoolConst(ctx, lhs_i32 == rhs_i32); + case Opcode::ICmpNE: + return GetBoolConst(ctx, lhs_i32 != rhs_i32); + case Opcode::ICmpLT: + return GetBoolConst(ctx, lhs_i32 < rhs_i32); + case Opcode::ICmpGT: + return GetBoolConst(ctx, lhs_i32 > rhs_i32); + case Opcode::ICmpLE: + return GetBoolConst(ctx, lhs_i32 <= rhs_i32); + case Opcode::ICmpGE: + return GetBoolConst(ctx, lhs_i32 >= rhs_i32); + default: + break; + } + } + + if (has_lhs_i1 && has_rhs_i1) { + switch (opcode) { + case Opcode::And: + return GetBoolConst(ctx, lhs_i1 && rhs_i1); + case Opcode::Or: + return GetBoolConst(ctx, lhs_i1 || rhs_i1); + case Opcode::Xor: + return GetBoolConst(ctx, lhs_i1 != rhs_i1); + case Opcode::ICmpEQ: + return GetBoolConst(ctx, lhs_i1 == rhs_i1); + case Opcode::ICmpNE: + return GetBoolConst(ctx, lhs_i1 != rhs_i1); + case Opcode::ICmpLT: + return GetBoolConst(ctx, static_cast(lhs_i1) < static_cast(rhs_i1)); + case Opcode::ICmpGT: + return GetBoolConst(ctx, static_cast(lhs_i1) > static_cast(rhs_i1)); + case Opcode::ICmpLE: + return GetBoolConst(ctx, static_cast(lhs_i1) <= static_cast(rhs_i1)); + case Opcode::ICmpGE: + return GetBoolConst(ctx, static_cast(lhs_i1) >= static_cast(rhs_i1)); + default: + break; + } + } + + if (has_lhs_f32 && has_rhs_f32) { + switch (opcode) { + case Opcode::FAdd: + return GetFloatConst(lhs_f32 + rhs_f32); + case Opcode::FSub: + return GetFloatConst(lhs_f32 - rhs_f32); + case Opcode::FMul: + return GetFloatConst(lhs_f32 * rhs_f32); + case Opcode::FDiv: + return GetFloatConst(lhs_f32 / rhs_f32); + case Opcode::FRem: + return GetFloatConst(std::fmod(lhs_f32, rhs_f32)); + case Opcode::FCmpEQ: + return GetBoolConst( + ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 == rhs_f32); + case Opcode::FCmpNE: + return GetBoolConst( + ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 != rhs_f32); + case Opcode::FCmpLT: + return GetBoolConst( + ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 < rhs_f32); + case Opcode::FCmpGT: + return GetBoolConst( + ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 > rhs_f32); + case Opcode::FCmpLE: + return GetBoolConst( + ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 <= rhs_f32); + case Opcode::FCmpGE: + return GetBoolConst( + ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 >= rhs_f32); + default: + break; + } + } + + switch (opcode) { + case Opcode::Add: + if (IsZeroValue(rhs)) { + return lhs; + } + if (IsZeroValue(lhs)) { + return rhs; + } + break; + case Opcode::Sub: + if (IsZeroValue(rhs)) { + return lhs; + } + break; + case Opcode::Mul: + if (IsOneValue(rhs)) { + return lhs; + } + if (IsOneValue(lhs)) { + return rhs; + } + if (IsZeroValue(lhs) || IsZeroValue(rhs)) { + return GetInt32Const(ctx, 0); + } + break; + case Opcode::Div: + if (IsOneValue(rhs)) { + return lhs; + } + if (IsZeroValue(lhs) && !IsZeroValue(rhs)) { + return GetInt32Const(ctx, 0); + } + break; + case Opcode::Rem: + if ((has_rhs_i32 && (rhs_i32 == 1 || rhs_i32 == -1)) || + (has_rhs_i1 && rhs_i1)) { + return GetInt32Const(ctx, 0); + } + if (IsZeroValue(lhs) && !IsZeroValue(rhs)) { + return GetInt32Const(ctx, 0); + } + break; + case Opcode::And: + if (IsZeroValue(lhs) || IsZeroValue(rhs)) { + return inst->GetType()->IsInt1() ? GetBoolConst(ctx, false) + : GetInt32Const(ctx, 0); + } + if (has_lhs_i1 && lhs_i1) { + return rhs; + } + if (has_rhs_i1 && rhs_i1) { + return lhs; + } + if (IsAllOnesInt(lhs)) { + return rhs; + } + if (IsAllOnesInt(rhs)) { + return lhs; + } + if (passutils::AreEquivalentValues(lhs, rhs)) { + return lhs; + } + break; + case Opcode::Or: + if (IsZeroValue(lhs)) { + return rhs; + } + if (IsZeroValue(rhs)) { + return lhs; + } + if (has_lhs_i1 && lhs_i1) { + return GetBoolConst(ctx, true); + } + if (has_rhs_i1 && rhs_i1) { + return GetBoolConst(ctx, true); + } + if (passutils::AreEquivalentValues(lhs, rhs)) { + return lhs; + } + break; + case Opcode::Xor: + if (IsZeroValue(lhs)) { + return rhs; + } + if (IsZeroValue(rhs)) { + return lhs; + } + if (passutils::AreEquivalentValues(lhs, rhs)) { + return inst->GetType()->IsInt1() ? GetBoolConst(ctx, false) + : GetInt32Const(ctx, 0); + } + break; + case Opcode::Shl: + case Opcode::AShr: + case Opcode::LShr: + if (IsZeroValue(rhs)) { + return lhs; + } + if (IsZeroValue(lhs)) { + return GetInt32Const(ctx, 0); + } + break; + case Opcode::FAdd: + if (IsZeroValue(lhs)) { + return rhs; + } + if (IsZeroValue(rhs)) { + return lhs; + } + break; + case Opcode::FSub: + if (IsZeroValue(rhs)) { + return lhs; + } + break; + case Opcode::FMul: + if (IsOneValue(lhs)) { + return rhs; + } + if (IsOneValue(rhs)) { + return lhs; + } + break; + case Opcode::FDiv: + if (IsOneValue(rhs)) { + return lhs; + } + break; + default: + break; + } + + return nullptr; +} + +Value* FoldUnary(Context& ctx, UnaryInst* inst) { + auto* operand = inst->GetOprd(); + std::int32_t i32 = 0; + bool i1 = false; + float f32 = 0.0f; + switch (inst->GetOpcode()) { + case Opcode::Neg: + if (TryGetInt32(operand, i32)) { + return GetInt32Const(ctx, WrapInt32(0u - static_cast(i32))); + } + break; + case Opcode::Not: + if (TryGetBool(operand, i1)) { + return GetBoolConst(ctx, !i1); + } + if (TryGetInt32(operand, i32)) { + return GetInt32Const(ctx, i32 ^ 1); + } + break; + case Opcode::FNeg: + if (TryGetFloat(operand, f32)) { + return GetFloatConst(-f32); + } + break; + case Opcode::FtoI: + if (TryGetFloat(operand, f32)) { + return GetInt32Const(ctx, static_cast(f32)); + } + break; + case Opcode::IToF: + if (TryGetInt32(operand, i32)) { + return GetFloatConst(static_cast(i32)); + } + if (TryGetBool(operand, i1)) { + return GetFloatConst(i1 ? 1.0f : 0.0f); + } + break; + default: + break; + } + return nullptr; +} + +Value* FoldZext(Context& ctx, ZextInst* inst) { + auto* value = inst->GetValue(); + bool i1 = false; + std::int32_t i32 = 0; + if (inst->GetType()->IsInt1()) { + if (TryGetBool(value, i1)) { + return GetBoolConst(ctx, i1); + } + if (TryGetInt32(value, i32)) { + return GetBoolConst(ctx, i32 != 0); + } + } + if (inst->GetType()->IsInt32()) { + if (TryGetBool(value, i1)) { + return GetInt32Const(ctx, i1 ? 1 : 0); + } + if (TryGetInt32(value, i32)) { + return GetInt32Const(ctx, i32); + } + } + return nullptr; +} + +bool FoldFunction(Function& function, Context& ctx) { + if (function.IsExternal()) { + return false; + } + + bool changed = false; + for (const auto& block_ptr : function.GetBlocks()) { + std::vector to_remove; + for (const auto& inst_ptr : block_ptr->GetInstructions()) { + auto* inst = inst_ptr.get(); + Value* replacement = nullptr; + if (auto* binary = dyncast(inst)) { + replacement = FoldBinary(ctx, binary); + } else if (auto* unary = dyncast(inst)) { + replacement = FoldUnary(ctx, unary); + } else if (auto* zext = dyncast(inst)) { + replacement = FoldZext(ctx, zext); + } + + if (!replacement || replacement == inst) { + continue; + } + inst->ReplaceAllUsesWith(replacement); + to_remove.push_back(inst); + changed = true; + } + + for (auto* inst : to_remove) { + block_ptr->EraseInstruction(inst); + } + } + + return changed; +} + +} // namespace + +bool RunConstFold(Module& module) { + bool changed = false; + auto& ctx = module.GetContext(); + for (const auto& function : module.GetFunctions()) { + if (function) { + changed |= FoldFunction(*function, ctx); + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/ConstProp.cpp b/src/ir/passes/ConstProp.cpp index 1768b71..15c9153 100644 --- a/src/ir/passes/ConstProp.cpp +++ b/src/ir/passes/ConstProp.cpp @@ -1,5 +1,550 @@ -// 常量传播(Constant Propagation): -// - 沿 use-def 关系传播已知常量 -// - 将可替换的 SSA 值改写为常量,暴露更多折叠机会 -// - 常与 ConstFold、DCE、CFGSimplify 迭代配合使用 +#include "ir/PassManager.h" +#include "ir/IR.h" +#include "PassUtils.h" + +#include +#include +#include +#include +#include + +namespace ir { +namespace { + +enum class LatticeKind { Unknown, Constant, Overdefined }; + +struct ConstantValue { + enum class Kind { Int32, Bool, Float }; + + Kind kind = Kind::Int32; + std::int32_t int32_value = 0; + bool bool_value = false; + float float_value = 0.0f; +}; + +struct LatticeValue { + LatticeKind kind = LatticeKind::Unknown; + ConstantValue constant; +}; + +bool EqualConstants(const ConstantValue& lhs, const ConstantValue& rhs) { + if (lhs.kind != rhs.kind) { + return false; + } + switch (lhs.kind) { + case ConstantValue::Kind::Int32: + return lhs.int32_value == rhs.int32_value; + case ConstantValue::Kind::Bool: + return lhs.bool_value == rhs.bool_value; + case ConstantValue::Kind::Float: + return passutils::FloatBits(lhs.float_value) == + passutils::FloatBits(rhs.float_value); + } + return false; +} + +Value* MaterializeConstant(Context& ctx, const ConstantValue& constant) { + switch (constant.kind) { + case ConstantValue::Kind::Int32: + return ctx.GetConstInt(static_cast(constant.int32_value)); + case ConstantValue::Kind::Bool: + return ctx.GetConstBool(constant.bool_value); + case ConstantValue::Kind::Float: + return new ConstantFloat(Type::GetFloatType(), constant.float_value); + } + return nullptr; +} + +bool TryGetConstantValue(Value* value, ConstantValue& out) { + if (auto* ci = dyncast(value)) { + out.kind = ConstantValue::Kind::Int32; + out.int32_value = static_cast(ci->GetValue()); + return true; + } + if (auto* cb = dyncast(value)) { + out.kind = ConstantValue::Kind::Bool; + out.bool_value = cb->GetValue(); + return true; + } + if (auto* cf = dyncast(value)) { + out.kind = ConstantValue::Kind::Float; + out.float_value = cf->GetValue(); + return true; + } + return false; +} + +LatticeValue ConstantLattice(const ConstantValue& constant) { + LatticeValue value; + value.kind = LatticeKind::Constant; + value.constant = constant; + return value; +} + +LatticeValue OverdefinedLattice() { + LatticeValue value; + value.kind = LatticeKind::Overdefined; + return value; +} + +LatticeValue GetValueState( + Value* value, const std::unordered_map& states) { + ConstantValue constant; + if (TryGetConstantValue(value, constant)) { + return ConstantLattice(constant); + } + auto it = states.find(value); + if (it != states.end()) { + return it->second; + } + return OverdefinedLattice(); +} + +LatticeValue Meet(LatticeValue lhs, const LatticeValue& rhs) { + if (lhs.kind == LatticeKind::Overdefined || rhs.kind == LatticeKind::Overdefined) { + return OverdefinedLattice(); + } + if (lhs.kind == LatticeKind::Unknown) { + return rhs; + } + if (rhs.kind == LatticeKind::Unknown) { + return lhs; + } + if (EqualConstants(lhs.constant, rhs.constant)) { + return lhs; + } + return OverdefinedLattice(); +} + +bool EvaluateUnary(Opcode opcode, const ConstantValue& operand, + ConstantValue& result) { + switch (opcode) { + case Opcode::Neg: + if (operand.kind != ConstantValue::Kind::Int32) { + return false; + } + result.kind = ConstantValue::Kind::Int32; + result.int32_value = static_cast( + 0u - static_cast(operand.int32_value)); + return true; + case Opcode::Not: + if (operand.kind == ConstantValue::Kind::Bool) { + result.kind = ConstantValue::Kind::Bool; + result.bool_value = !operand.bool_value; + return true; + } + if (operand.kind == ConstantValue::Kind::Int32) { + result.kind = ConstantValue::Kind::Int32; + result.int32_value = operand.int32_value ^ 1; + return true; + } + return false; + case Opcode::FNeg: + if (operand.kind != ConstantValue::Kind::Float) { + return false; + } + result.kind = ConstantValue::Kind::Float; + result.float_value = -operand.float_value; + return true; + case Opcode::FtoI: + if (operand.kind != ConstantValue::Kind::Float) { + return false; + } + result.kind = ConstantValue::Kind::Int32; + result.int32_value = static_cast(operand.float_value); + return true; + case Opcode::IToF: + if (operand.kind == ConstantValue::Kind::Int32) { + result.kind = ConstantValue::Kind::Float; + result.float_value = static_cast(operand.int32_value); + return true; + } + if (operand.kind == ConstantValue::Kind::Bool) { + result.kind = ConstantValue::Kind::Float; + result.float_value = operand.bool_value ? 1.0f : 0.0f; + return true; + } + return false; + default: + return false; + } +} + +bool EvaluateBinary(Opcode opcode, const ConstantValue& lhs, + const ConstantValue& rhs, ConstantValue& result) { + if (lhs.kind == ConstantValue::Kind::Int32 && + rhs.kind == ConstantValue::Kind::Int32) { + const auto left = lhs.int32_value; + const auto right = rhs.int32_value; + switch (opcode) { + case Opcode::Add: + result.kind = ConstantValue::Kind::Int32; + result.int32_value = static_cast( + static_cast(left) + static_cast(right)); + return true; + case Opcode::Sub: + result.kind = ConstantValue::Kind::Int32; + result.int32_value = static_cast( + static_cast(left) - static_cast(right)); + return true; + case Opcode::Mul: + result.kind = ConstantValue::Kind::Int32; + result.int32_value = static_cast( + static_cast(left) * static_cast(right)); + return true; + case Opcode::Div: + if (right == 0 || + (left == std::numeric_limits::min() && right == -1)) { + return false; + } + result.kind = ConstantValue::Kind::Int32; + result.int32_value = left / right; + return true; + case Opcode::Rem: + if (right == 0 || + (left == std::numeric_limits::min() && right == -1)) { + return false; + } + result.kind = ConstantValue::Kind::Int32; + result.int32_value = left % right; + return true; + case Opcode::And: + result.kind = ConstantValue::Kind::Int32; + result.int32_value = left & right; + return true; + case Opcode::Or: + result.kind = ConstantValue::Kind::Int32; + result.int32_value = left | right; + return true; + case Opcode::Xor: + result.kind = ConstantValue::Kind::Int32; + result.int32_value = left ^ right; + return true; + case Opcode::Shl: + if (right < 0 || right >= 32) { + return false; + } + result.kind = ConstantValue::Kind::Int32; + result.int32_value = + static_cast(static_cast(left) << right); + return true; + case Opcode::AShr: + if (right < 0 || right >= 32) { + return false; + } + result.kind = ConstantValue::Kind::Int32; + result.int32_value = left >> right; + return true; + case Opcode::LShr: + if (right < 0 || right >= 32) { + return false; + } + result.kind = ConstantValue::Kind::Int32; + result.int32_value = static_cast( + static_cast(left) >> right); + return true; + case Opcode::ICmpEQ: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = left == right; + return true; + case Opcode::ICmpNE: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = left != right; + return true; + case Opcode::ICmpLT: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = left < right; + return true; + case Opcode::ICmpGT: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = left > right; + return true; + case Opcode::ICmpLE: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = left <= right; + return true; + case Opcode::ICmpGE: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = left >= right; + return true; + default: + break; + } + } + + if (lhs.kind == ConstantValue::Kind::Bool && rhs.kind == ConstantValue::Kind::Bool) { + const auto left = lhs.bool_value; + const auto right = rhs.bool_value; + switch (opcode) { + case Opcode::And: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = left && right; + return true; + case Opcode::Or: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = left || right; + return true; + case Opcode::Xor: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = left != right; + return true; + case Opcode::ICmpEQ: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = left == right; + return true; + case Opcode::ICmpNE: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = left != right; + return true; + case Opcode::ICmpLT: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = static_cast(left) < static_cast(right); + return true; + case Opcode::ICmpGT: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = static_cast(left) > static_cast(right); + return true; + case Opcode::ICmpLE: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = static_cast(left) <= static_cast(right); + return true; + case Opcode::ICmpGE: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = static_cast(left) >= static_cast(right); + return true; + default: + break; + } + } + + if (lhs.kind == ConstantValue::Kind::Float && + rhs.kind == ConstantValue::Kind::Float) { + const auto left = lhs.float_value; + const auto right = rhs.float_value; + switch (opcode) { + case Opcode::FAdd: + result.kind = ConstantValue::Kind::Float; + result.float_value = left + right; + return true; + case Opcode::FSub: + result.kind = ConstantValue::Kind::Float; + result.float_value = left - right; + return true; + case Opcode::FMul: + result.kind = ConstantValue::Kind::Float; + result.float_value = left * right; + return true; + case Opcode::FDiv: + result.kind = ConstantValue::Kind::Float; + result.float_value = left / right; + return true; + case Opcode::FRem: + result.kind = ConstantValue::Kind::Float; + result.float_value = std::fmod(left, right); + return true; + case Opcode::FCmpEQ: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = !std::isnan(left) && !std::isnan(right) && left == right; + return true; + case Opcode::FCmpNE: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = !std::isnan(left) && !std::isnan(right) && left != right; + return true; + case Opcode::FCmpLT: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = !std::isnan(left) && !std::isnan(right) && left < right; + return true; + case Opcode::FCmpGT: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = !std::isnan(left) && !std::isnan(right) && left > right; + return true; + case Opcode::FCmpLE: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = !std::isnan(left) && !std::isnan(right) && left <= right; + return true; + case Opcode::FCmpGE: + result.kind = ConstantValue::Kind::Bool; + result.bool_value = !std::isnan(left) && !std::isnan(right) && left >= right; + return true; + default: + break; + } + } + + return false; +} + +LatticeValue EvaluateInstruction( + Instruction* inst, const std::unordered_map& states) { + if (!inst || inst->IsVoid()) { + return OverdefinedLattice(); + } + + if (auto* phi = dyncast(inst)) { + LatticeValue merged; + for (int i = 0; i < phi->GetNumIncomings(); ++i) { + merged = Meet(merged, GetValueState(phi->GetIncomingValue(i), states)); + if (merged.kind == LatticeKind::Overdefined) { + break; + } + } + return merged; + } + + if (auto* binary = dyncast(inst)) { + const auto lhs = GetValueState(binary->GetLhs(), states); + const auto rhs = GetValueState(binary->GetRhs(), states); + if (lhs.kind == LatticeKind::Overdefined || rhs.kind == LatticeKind::Overdefined) { + return OverdefinedLattice(); + } + if (lhs.kind != LatticeKind::Constant || rhs.kind != LatticeKind::Constant) { + return {}; + } + ConstantValue folded; + if (!EvaluateBinary(binary->GetOpcode(), lhs.constant, rhs.constant, folded)) { + return OverdefinedLattice(); + } + return ConstantLattice(folded); + } + + if (auto* unary = dyncast(inst)) { + const auto operand = GetValueState(unary->GetOprd(), states); + if (operand.kind == LatticeKind::Overdefined) { + return OverdefinedLattice(); + } + if (operand.kind != LatticeKind::Constant) { + return {}; + } + ConstantValue folded; + if (!EvaluateUnary(unary->GetOpcode(), operand.constant, folded)) { + return OverdefinedLattice(); + } + return ConstantLattice(folded); + } + + if (auto* zext = dyncast(inst)) { + const auto operand = GetValueState(zext->GetValue(), states); + if (operand.kind == LatticeKind::Overdefined) { + return OverdefinedLattice(); + } + if (operand.kind != LatticeKind::Constant) { + return {}; + } + ConstantValue folded; + if (zext->GetType()->IsInt1()) { + folded.kind = ConstantValue::Kind::Bool; + if (operand.constant.kind == ConstantValue::Kind::Bool) { + folded.bool_value = operand.constant.bool_value; + return ConstantLattice(folded); + } + if (operand.constant.kind == ConstantValue::Kind::Int32) { + folded.bool_value = operand.constant.int32_value != 0; + return ConstantLattice(folded); + } + return OverdefinedLattice(); + } + if (zext->GetType()->IsInt32()) { + folded.kind = ConstantValue::Kind::Int32; + if (operand.constant.kind == ConstantValue::Kind::Bool) { + folded.int32_value = operand.constant.bool_value ? 1 : 0; + return ConstantLattice(folded); + } + if (operand.constant.kind == ConstantValue::Kind::Int32) { + folded.int32_value = operand.constant.int32_value; + return ConstantLattice(folded); + } + } + return OverdefinedLattice(); + } + + return OverdefinedLattice(); +} + +bool RewriteFunction(Function& function, Context& ctx) { + if (function.IsExternal()) { + return false; + } + + std::unordered_map states; + for (const auto& block_ptr : function.GetBlocks()) { + for (const auto& inst_ptr : block_ptr->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (!inst->IsVoid()) { + states[inst] = {}; + } + } + } + + bool changed = true; + while (changed) { + changed = false; + for (const auto& block_ptr : function.GetBlocks()) { + for (const auto& inst_ptr : block_ptr->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (inst->IsVoid()) { + continue; + } + const auto evaluated = EvaluateInstruction(inst, states); + if (evaluated.kind != states[inst].kind || + (evaluated.kind == LatticeKind::Constant && + !EqualConstants(evaluated.constant, states[inst].constant))) { + states[inst] = evaluated; + changed = true; + } + } + } + } + + bool rewritten = false; + for (const auto& block_ptr : function.GetBlocks()) { + std::vector to_remove; + for (const auto& inst_ptr : block_ptr->GetInstructions()) { + auto* inst = inst_ptr.get(); + for (size_t i = 0; i < inst->GetNumOperands(); ++i) { + auto* operand = inst->GetOperand(i); + if (isa(operand) || isa(operand) || operand->IsConstant()) { + continue; + } + const auto state = GetValueState(operand, states); + if (state.kind != LatticeKind::Constant) { + continue; + } + inst->SetOperand(i, MaterializeConstant(ctx, state.constant)); + rewritten = true; + } + + if (inst->IsVoid()) { + continue; + } + const auto state = states[inst]; + if (state.kind != LatticeKind::Constant) { + continue; + } + inst->ReplaceAllUsesWith(MaterializeConstant(ctx, state.constant)); + to_remove.push_back(inst); + rewritten = true; + } + + for (auto* inst : to_remove) { + block_ptr->EraseInstruction(inst); + } + } + + return rewritten; +} + +} // namespace + +bool RunConstProp(Module& module) { + bool changed = false; + auto& ctx = module.GetContext(); + for (const auto& function : module.GetFunctions()) { + if (function) { + changed |= RewriteFunction(*function, ctx); + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/DCE.cpp b/src/ir/passes/DCE.cpp index 5a0db91..84b3535 100644 --- a/src/ir/passes/DCE.cpp +++ b/src/ir/passes/DCE.cpp @@ -1,4 +1,55 @@ -// 死代码删除(DCE): -// - 删除无用指令与无用基本块 -// - 通常与 CFG 简化配合使用 +#include "ir/PassManager.h" +#include "ir/IR.h" +#include "PassUtils.h" + +#include + +namespace ir { +namespace { + +bool RunDCEOnFunction(Function& function) { + if (function.IsExternal()) { + return false; + } + + bool changed = false; + bool local_changed = true; + while (local_changed) { + local_changed = false; + for (const auto& block_ptr : function.GetBlocks()) { + std::vector to_remove; + for (const auto& inst_ptr : block_ptr->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (!passutils::IsTriviallyDead(inst)) { + continue; + } + to_remove.push_back(inst); + } + if (to_remove.empty()) { + continue; + } + for (auto* inst : to_remove) { + block_ptr->EraseInstruction(inst); + } + local_changed = true; + changed = true; + } + } + + return changed; +} + +} // namespace + +bool RunDCE(Module& module) { + bool changed = false; + for (const auto& function : module.GetFunctions()) { + if (function) { + changed |= RunDCEOnFunction(*function); + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/GVN.cpp b/src/ir/passes/GVN.cpp new file mode 100644 index 0000000..c5568a7 --- /dev/null +++ b/src/ir/passes/GVN.cpp @@ -0,0 +1,196 @@ +#include "ir/PassManager.h" + +#include "ir/Analysis.h" +#include "ir/IR.h" +#include "MemoryUtils.h" +#include "PassUtils.h" + +#include +#include +#include +#include + +namespace ir { +namespace { + +struct ExprKey { + Opcode opcode = Opcode::Add; + std::uintptr_t result_type = 0; + std::uintptr_t aux_type = 0; + std::vector operands; + + bool operator==(const ExprKey& rhs) const { + return opcode == rhs.opcode && result_type == rhs.result_type && + aux_type == rhs.aux_type && operands == rhs.operands; + } +}; + +struct ExprKeyHash { + std::size_t operator()(const ExprKey& key) const { + std::size_t h = static_cast(key.opcode); + h ^= std::hash{}(key.result_type) + 0x9e3779b9 + (h << 6) + + (h >> 2); + h ^= std::hash{}(key.aux_type) + 0x9e3779b9 + (h << 6) + + (h >> 2); + for (auto operand : key.operands) { + h ^= std::hash{}(operand) + 0x9e3779b9 + (h << 6) + (h >> 2); + } + return h; + } +}; + +struct ScopedExpr { + ExprKey key; + Value* previous = nullptr; + bool had_previous = false; +}; + +bool IsSupportedGVNInstruction(Instruction* inst) { + if (!inst || inst->IsVoid()) { + return false; + } + switch (inst->GetOpcode()) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::Div: + case Opcode::Rem: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: + case Opcode::FRem: + case Opcode::And: + case Opcode::Or: + case Opcode::Xor: + case Opcode::Shl: + case Opcode::AShr: + case Opcode::LShr: + case Opcode::ICmpEQ: + case Opcode::ICmpNE: + case Opcode::ICmpLT: + case Opcode::ICmpGT: + case Opcode::ICmpLE: + case Opcode::ICmpGE: + case Opcode::FCmpEQ: + case Opcode::FCmpNE: + case Opcode::FCmpLT: + case Opcode::FCmpGT: + case Opcode::FCmpLE: + case Opcode::FCmpGE: + case Opcode::Neg: + case Opcode::Not: + case Opcode::FNeg: + case Opcode::FtoI: + case Opcode::IToF: + case Opcode::GetElementPtr: + case Opcode::Zext: + return true; + case Opcode::Call: + return memutils::IsPureCall(dyncast(inst)); + default: + return false; + } +} + +ExprKey BuildExprKey(Instruction* inst) { + ExprKey key; + key.opcode = inst->GetOpcode(); + key.result_type = + reinterpret_cast(inst->GetType().get()); + if (auto* gep = dyncast(inst)) { + key.aux_type = reinterpret_cast(gep->GetSourceType().get()); + } + key.operands.reserve(inst->GetNumOperands()); + for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) { + key.operands.push_back( + reinterpret_cast(inst->GetOperand(i))); + } + if (inst->GetNumOperands() == 2 && + passutils::IsCommutativeOpcode(inst->GetOpcode()) && + key.operands[1] < key.operands[0]) { + std::swap(key.operands[0], key.operands[1]); + } + return key; +} + +bool RunGVNInDomSubtree( + BasicBlock* block, const DominatorTree& dom_tree, + std::unordered_map& available) { + if (!block) { + return false; + } + + bool changed = false; + std::vector scope; + std::vector to_remove; + + for (const auto& inst_ptr : block->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (!IsSupportedGVNInstruction(inst)) { + continue; + } + + const auto key = BuildExprKey(inst); + auto it = available.find(key); + if (it != available.end()) { + inst->ReplaceAllUsesWith(it->second); + to_remove.push_back(inst); + changed = true; + continue; + } + + ScopedExpr scoped{key, nullptr, false}; + auto existing = available.find(key); + if (existing != available.end()) { + scoped.previous = existing->second; + scoped.had_previous = true; + existing->second = inst; + } else { + available.emplace(key, inst); + } + scope.push_back(std::move(scoped)); + } + + for (auto* inst : to_remove) { + block->EraseInstruction(inst); + } + + for (auto* child : dom_tree.GetChildren(block)) { + changed |= RunGVNInDomSubtree(child, dom_tree, available); + } + + for (auto it = scope.rbegin(); it != scope.rend(); ++it) { + if (it->had_previous) { + available[it->key] = it->previous; + } else { + available.erase(it->key); + } + } + + return changed; +} + +bool RunGVNOnFunction(Function& function) { + if (function.IsExternal() || function.GetEntryBlock() == nullptr) { + return false; + } + + DominatorTree dom_tree(function); + std::unordered_map available; + return RunGVNInDomSubtree(function.GetEntryBlock(), dom_tree, available); +} + +} // namespace + +bool RunGVN(Module& module) { + bool changed = false; + for (const auto& function : module.GetFunctions()) { + if (function) { + changed |= RunGVNOnFunction(*function); + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/Inline.cpp b/src/ir/passes/Inline.cpp new file mode 100644 index 0000000..adcbef0 --- /dev/null +++ b/src/ir/passes/Inline.cpp @@ -0,0 +1,403 @@ +#include "ir/PassManager.h" + +#include "ir/IR.h" +#include "LoopPassUtils.h" + +#include +#include +#include +#include + +namespace ir { +namespace { + +struct InlineCandidateInfo { + bool valid = false; + int cost = 0; + bool has_nested_call = false; +}; + +bool IsInlineableInstruction(const Instruction* inst) { + if (!inst) { + return false; + } + + switch (inst->GetOpcode()) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::Div: + case Opcode::Rem: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: + case Opcode::FRem: + case Opcode::And: + case Opcode::Or: + case Opcode::Xor: + case Opcode::Shl: + case Opcode::AShr: + case Opcode::LShr: + case Opcode::ICmpEQ: + case Opcode::ICmpNE: + case Opcode::ICmpLT: + case Opcode::ICmpGT: + case Opcode::ICmpLE: + case Opcode::ICmpGE: + case Opcode::FCmpEQ: + case Opcode::FCmpNE: + case Opcode::FCmpLT: + case Opcode::FCmpGT: + case Opcode::FCmpLE: + case Opcode::FCmpGE: + case Opcode::Neg: + case Opcode::Not: + case Opcode::FNeg: + case Opcode::FtoI: + case Opcode::IToF: + case Opcode::Load: + case Opcode::Store: + case Opcode::GetElementPtr: + case Opcode::Zext: + case Opcode::Memset: + case Opcode::Call: + case Opcode::Return: + return true; + default: + return false; + } +} + +int EstimateInstructionCost(const Instruction* inst) { + if (!inst) { + return 0; + } + switch (inst->GetOpcode()) { + case Opcode::Return: + return 0; + case Opcode::Load: + case Opcode::Store: + case Opcode::Memset: + return 3; + case Opcode::Call: + return 8; + case Opcode::GetElementPtr: + return 2; + default: + return 1; + } +} + +InlineCandidateInfo AnalyzeInlineCandidate(const Function& function) { + InlineCandidateInfo info; + if (function.IsExternal() || function.IsRecursive() || function.GetBlocks().size() != 1) { + return info; + } + + const auto& block = function.GetBlocks().front(); + if (!block || block->GetInstructions().empty()) { + return info; + } + + bool saw_return = false; + for (std::size_t i = 0; i < block->GetInstructions().size(); ++i) { + auto* inst = block->GetInstructions()[i].get(); + if (!IsInlineableInstruction(inst)) { + return {}; + } + if (dyncast(inst)) { + if (i + 1 != block->GetInstructions().size()) { + return {}; + } + saw_return = true; + continue; + } + if (dyncast(inst)) { + info.has_nested_call = true; + } + info.cost += EstimateInstructionCost(inst); + } + + if (!saw_return) { + return {}; + } + + info.valid = true; + return info; +} + +std::unordered_map CountDirectCalls(Module& module) { + std::unordered_map counts; + for (const auto& function_ptr : module.GetFunctions()) { + if (!function_ptr) { + continue; + } + for (const auto& block_ptr : function_ptr->GetBlocks()) { + for (const auto& inst_ptr : block_ptr->GetInstructions()) { + if (auto* call = dyncast(inst_ptr.get())) { + if (auto* callee = call->GetCallee()) { + ++counts[callee]; + } + } + } + } + } + return counts; +} + +bool ShouldInlineCallSite(const Function& caller, const CallInst& call, + const InlineCandidateInfo& callee_info, int call_count) { + auto* callee = call.GetCallee(); + if (!callee || callee == &caller || !callee_info.valid) { + return false; + } + + int budget = callee->CanDiscardUnusedCall() ? 40 : 24; + if (call_count <= 1) { + budget += 12; + } + if (callee_info.has_nested_call) { + budget -= 8; + } + if (callee->MayWriteMemory()) { + budget -= 4; + } + return callee_info.cost <= budget; +} + +Instruction* CloneInstructionAt(Function& function, Instruction* inst, BasicBlock* dest, + std::size_t insert_index, + std::unordered_map& remap) { + if (!inst || !dest) { + return nullptr; + } + + const auto name = inst->IsVoid() ? std::string() + : looputils::NextSyntheticName(function, "inline."); + auto remap_operand = [&](Value* value) { return looputils::RemapValue(remap, value); }; + auto remember = [&](Instruction* clone) { + if (clone && !inst->IsVoid()) { + remap[inst] = clone; + } + return clone; + }; + + switch (inst->GetOpcode()) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::Div: + case Opcode::Rem: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: + case Opcode::FRem: + case Opcode::And: + case Opcode::Or: + case Opcode::Xor: + case Opcode::Shl: + case Opcode::AShr: + case Opcode::LShr: + case Opcode::ICmpEQ: + case Opcode::ICmpNE: + case Opcode::ICmpLT: + case Opcode::ICmpGT: + case Opcode::ICmpLE: + case Opcode::ICmpGE: + case Opcode::FCmpEQ: + case Opcode::FCmpNE: + case Opcode::FCmpLT: + case Opcode::FCmpGT: + case Opcode::FCmpLE: + case Opcode::FCmpGE: { + auto* bin = static_cast(inst); + return remember(dest->Insert(insert_index, inst->GetOpcode(), inst->GetType(), + remap_operand(bin->GetLhs()), + remap_operand(bin->GetRhs()), nullptr, name)); + } + case Opcode::Neg: + case Opcode::Not: + case Opcode::FNeg: + case Opcode::FtoI: + case Opcode::IToF: { + auto* un = static_cast(inst); + return remember(dest->Insert(insert_index, inst->GetOpcode(), inst->GetType(), + remap_operand(un->GetOprd()), nullptr, name)); + } + case Opcode::Load: { + auto* load = static_cast(inst); + return remember(dest->Insert(insert_index, inst->GetType(), + remap_operand(load->GetPtr()), nullptr, name)); + } + case Opcode::Store: { + auto* store = static_cast(inst); + return dest->Insert(insert_index, remap_operand(store->GetValue()), + remap_operand(store->GetPtr()), nullptr); + } + case Opcode::Memset: { + auto* memset = static_cast(inst); + return dest->Insert(insert_index, remap_operand(memset->GetDest()), + remap_operand(memset->GetValue()), + remap_operand(memset->GetLength()), + remap_operand(memset->GetIsVolatile()), nullptr); + } + case Opcode::GetElementPtr: { + auto* gep = static_cast(inst); + std::vector indices; + indices.reserve(gep->GetNumIndices()); + for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) { + indices.push_back(remap_operand(gep->GetIndex(i))); + } + return remember(dest->Insert( + insert_index, gep->GetSourceType(), remap_operand(gep->GetPointer()), indices, nullptr, + name)); + } + case Opcode::Zext: { + auto* zext = static_cast(inst); + return remember(dest->Insert(insert_index, remap_operand(zext->GetValue()), + inst->GetType(), nullptr, name)); + } + case Opcode::Call: { + auto* call = static_cast(inst); + std::vector args; + const auto original_args = call->GetArguments(); + args.reserve(original_args.size()); + for (auto* arg : original_args) { + args.push_back(remap_operand(arg)); + } + return remember( + dest->Insert(insert_index, call->GetCallee(), args, nullptr, name)); + } + case Opcode::Return: + case Opcode::Alloca: + case Opcode::Phi: + case Opcode::Br: + case Opcode::CondBr: + case Opcode::Unreachable: + break; + } + return nullptr; +} + +bool InlineCallSite(Function& caller, CallInst* call) { + if (!call) { + return false; + } + auto* callee = call->GetCallee(); + if (!callee || callee->GetBlocks().size() != 1) { + return false; + } + + const auto& callee_args = callee->GetArguments(); + const auto call_args = call->GetArguments(); + if (callee_args.size() != call_args.size()) { + return false; + } + + auto* block = call->GetParent(); + if (!block) { + return false; + } + + auto& instructions = block->GetInstructions(); + auto call_it = std::find_if(instructions.begin(), instructions.end(), + [&](const std::unique_ptr& current) { + return current.get() == call; + }); + if (call_it == instructions.end()) { + return false; + } + + std::size_t insert_index = static_cast(call_it - instructions.begin()); + std::unordered_map remap; + for (std::size_t i = 0; i < call_args.size(); ++i) { + remap[callee_args[i].get()] = call_args[i]; + } + + Value* return_value = nullptr; + for (const auto& inst_ptr : callee->GetBlocks().front()->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (auto* ret = dyncast(inst)) { + if (ret->HasReturnValue()) { + return_value = looputils::RemapValue(remap, ret->GetReturnValue()); + } + break; + } + if (!CloneInstructionAt(caller, inst, block, insert_index, remap)) { + return false; + } + ++insert_index; + } + + if (!call->GetType()->IsVoid()) { + if (!return_value) { + return false; + } + call->ReplaceAllUsesWith(return_value); + } + + block->EraseInstruction(call); + return true; +} + +bool RunFunctionInliningOnFunction( + Function& function, + const std::unordered_map& callee_info, + const std::unordered_map& call_counts) { + if (function.IsExternal()) { + return false; + } + + bool changed = false; + for (auto& block_ptr : function.GetBlocks()) { + std::vector calls; + for (const auto& inst_ptr : block_ptr->GetInstructions()) { + if (auto* call = dyncast(inst_ptr.get())) { + calls.push_back(call); + } + } + + for (auto* call : calls) { + auto* callee = call->GetCallee(); + if (!callee) { + continue; + } + auto info_it = callee_info.find(callee); + if (info_it == callee_info.end()) { + continue; + } + const int call_count = + call_counts.count(callee) != 0 ? call_counts.at(callee) : 0; + if (!ShouldInlineCallSite(function, *call, info_it->second, call_count)) { + continue; + } + changed |= InlineCallSite(function, call); + } + } + + return changed; +} + +} // namespace + +bool RunFunctionInlining(Module& module) { + std::unordered_map callee_info; + for (const auto& function_ptr : module.GetFunctions()) { + if (function_ptr) { + callee_info.emplace(function_ptr.get(), AnalyzeInlineCandidate(*function_ptr)); + } + } + + const auto call_counts = CountDirectCalls(module); + bool changed = false; + for (const auto& function_ptr : module.GetFunctions()) { + if (function_ptr) { + changed |= RunFunctionInliningOnFunction(*function_ptr, callee_info, call_counts); + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/LICM.cpp b/src/ir/passes/LICM.cpp new file mode 100644 index 0000000..d4fd53e --- /dev/null +++ b/src/ir/passes/LICM.cpp @@ -0,0 +1,236 @@ +#include "ir/PassManager.h" + +#include "ir/Analysis.h" +#include "ir/IR.h" +#include "LoopMemoryUtils.h" +#include "LoopPassUtils.h" +#include "MemoryUtils.h" + +#include +#include +#include +#include +#include + +namespace ir { +namespace { + +struct HoistedLoadKey { + memutils::AddressKey address; + std::uintptr_t type_id = 0; + + bool operator==(const HoistedLoadKey& rhs) const { + return type_id == rhs.type_id && address == rhs.address; + } +}; + +struct HoistedLoadKeyHash { + std::size_t operator()(const HoistedLoadKey& key) const { + std::size_t h = memutils::AddressKeyHash{}(key.address); + h ^= std::hash{}(key.type_id) + 0x9e3779b9 + (h << 6) + (h >> 2); + return h; + } +}; + +bool IsHoistableInstruction(const Instruction* inst) { + if (!inst || inst->IsTerminator() || inst->IsVoid()) { + return false; + } + + switch (inst->GetOpcode()) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: + case Opcode::And: + case Opcode::Or: + case Opcode::Xor: + case Opcode::Shl: + case Opcode::AShr: + case Opcode::LShr: + case Opcode::ICmpEQ: + case Opcode::ICmpNE: + case Opcode::ICmpLT: + case Opcode::ICmpGT: + case Opcode::ICmpLE: + case Opcode::ICmpGE: + case Opcode::FCmpEQ: + case Opcode::FCmpNE: + case Opcode::FCmpLT: + case Opcode::FCmpGT: + case Opcode::FCmpLE: + case Opcode::FCmpGE: + case Opcode::Neg: + case Opcode::Not: + case Opcode::FNeg: + case Opcode::FtoI: + case Opcode::IToF: + case Opcode::GetElementPtr: + case Opcode::Zext: + case Opcode::Load: + return true; + default: + return false; + } +} + +bool IsLoopInvariantInstruction( + const Loop& loop, Instruction* inst, + const std::unordered_set& invariant_insts, + PhiInst* iv, int iv_stride, + const std::vector& accesses, + const memutils::EscapeSummary& escapes) { + if (!IsHoistableInstruction(inst)) { + return false; + } + + for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) { + auto* operand = inst->GetOperand(i); + auto* operand_inst = dyncast(operand); + if (!operand_inst) { + continue; + } + if (!loop.Contains(operand_inst->GetParent())) { + continue; + } + if (invariant_insts.find(operand_inst) == invariant_insts.end()) { + return false; + } + } + + if (auto* load = dyncast(inst)) { + return loopmem::IsSafeInvariantLoadToHoist(loop, load, iv, iv_stride, accesses, &escapes); + } + return true; +} + +bool HoistLoopInvariants(Function& function, const Loop& loop, + BasicBlock* preheader) { + if (!preheader) { + return false; + } + + loopmem::SimpleInductionVar induction_var; + PhiInst* iv = nullptr; + int iv_stride = 1; + for (const auto& inst_ptr : loop.header->GetInstructions()) { + auto* phi = dyncast(inst_ptr.get()); + if (!phi) { + break; + } + if (loopmem::MatchSimpleInductionVariable(loop, preheader, phi, induction_var)) { + iv = induction_var.phi; + iv_stride = induction_var.stride; + break; + } + } + const auto escapes = memutils::AnalyzeEscapes(function); + const auto accesses = loopmem::CollectMemoryAccesses(loop, iv, &escapes); + + std::unordered_set invariant_insts; + std::vector hoist_list; + + bool progress = true; + while (progress) { + progress = false; + for (const auto& block_ptr : function.GetBlocks()) { + auto* block = block_ptr.get(); + if (!loop.Contains(block) || block == preheader) { + continue; + } + for (const auto& inst_ptr : block->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (invariant_insts.find(inst) != invariant_insts.end()) { + continue; + } + if (!IsLoopInvariantInstruction(loop, inst, invariant_insts, iv, iv_stride, + accesses, escapes)) { + continue; + } + invariant_insts.insert(inst); + hoist_list.push_back(inst); + progress = true; + } + } + } + + bool changed = false; + std::unordered_map hoisted_loads; + for (auto* inst : hoist_list) { + if (auto* load = dyncast(inst)) { + auto ptr = loopmem::AnalyzePointer(load->GetPtr(), iv, loop, + load->GetType()->GetSize(), &escapes); + if (ptr.exact_key_valid) { + HoistedLoadKey key{ptr.exact_key, + reinterpret_cast(load->GetType().get())}; + auto it = hoisted_loads.find(key); + if (it != hoisted_loads.end()) { + load->ReplaceAllUsesWith(it->second); + load->GetParent()->EraseInstruction(load); + changed = true; + continue; + } + auto* moved = dyncast( + looputils::MoveInstructionBeforeTerminator(load, preheader)); + if (moved) { + hoisted_loads.emplace(std::move(key), moved); + changed = true; + } + continue; + } + } + if (looputils::MoveInstructionBeforeTerminator(inst, preheader)) { + changed = true; + } + } + return changed; +} + +bool RunLICMOnFunction(Function& function) { + if (function.IsExternal() || function.GetEntryBlock() == nullptr) { + return false; + } + + bool changed = false; + while (true) { + DominatorTree dom_tree(function); + LoopInfo loop_info(function, dom_tree); + bool local_changed = false; + + for (auto* loop : loop_info.GetLoopsInPostOrder()) { + auto* old_preheader = loop->preheader; + auto* preheader = looputils::EnsurePreheader(function, *loop); + bool loop_changed = preheader != old_preheader; + loop_changed |= HoistLoopInvariants(function, *loop, preheader); + if (!loop_changed) { + continue; + } + changed = true; + local_changed = true; + break; + } + + if (!local_changed) { + break; + } + } + + return changed; +} + +} // namespace + +bool RunLICM(Module& module) { + bool changed = false; + for (const auto& function : module.GetFunctions()) { + if (function) { + changed |= RunLICMOnFunction(*function); + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/LoadStoreElim.cpp b/src/ir/passes/LoadStoreElim.cpp new file mode 100644 index 0000000..6caa64f --- /dev/null +++ b/src/ir/passes/LoadStoreElim.cpp @@ -0,0 +1,319 @@ +#include "ir/PassManager.h" + +#include "ir/IR.h" +#include "MemoryUtils.h" +#include "PassUtils.h" + +#include +#include +#include + +namespace ir { +namespace { + +struct AvailableValue { + Value* value = nullptr; + + bool operator==(const AvailableValue& rhs) const { + return passutils::AreEquivalentValues(value, rhs.value) || value == rhs.value; + } +}; + +using MemoryState = + std::unordered_map; + +bool SameAvailableValue(const AvailableValue& lhs, const AvailableValue& rhs) { + return lhs == rhs; +} + +bool SameMemoryState(const MemoryState& lhs, const MemoryState& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (const auto& [key, value] : lhs) { + auto it = rhs.find(key); + if (it == rhs.end() || !SameAvailableValue(value, it->second)) { + return false; + } + } + return true; +} + +MemoryState MeetMemoryStates(const std::vector& predecessors) { + if (predecessors.empty()) { + return {}; + } + + MemoryState in = *predecessors.front(); + for (auto it = in.begin(); it != in.end();) { + bool keep = true; + for (std::size_t i = 1; i < predecessors.size(); ++i) { + auto pred_it = predecessors[i]->find(it->first); + if (pred_it == predecessors[i]->end() || + !SameAvailableValue(it->second, pred_it->second)) { + keep = false; + break; + } + } + if (!keep) { + it = in.erase(it); + continue; + } + ++it; + } + return in; +} + +void InvalidateAliasStates(MemoryState& state, + const memutils::AddressKey& key) { + for (auto it = state.begin(); it != state.end();) { + if (memutils::MayAliasConservatively(it->first, key)) { + it = state.erase(it); + continue; + } + ++it; + } +} + +void InvalidateStatesForCall(MemoryState& state, Function* callee) { + for (auto it = state.begin(); it != state.end();) { + if (memutils::CallMayWriteRoot(callee, it->first.kind)) { + it = state.erase(it); + continue; + } + ++it; + } +} + +void SimulateInstruction(const memutils::EscapeSummary& escapes, Instruction* inst, + MemoryState& state) { + if (!inst) { + return; + } + + if (auto* load = dyncast(inst)) { + memutils::AddressKey key; + if (!memutils::BuildExactAddressKey(load->GetPtr(), &escapes, key)) { + state.clear(); + } + return; + } + + if (auto* store = dyncast(inst)) { + memutils::AddressKey key; + if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) { + state.clear(); + return; + } + InvalidateAliasStates(state, key); + state[key] = {store->GetValue()}; + return; + } + + if (auto* call = dyncast(inst)) { + InvalidateStatesForCall(state, call->GetCallee()); + return; + } + + if (auto* memset = dyncast(inst)) { + memutils::AddressKey key; + if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key)) { + state.clear(); + return; + } + InvalidateAliasStates(state, key); + return; + } +} + +MemoryState SimulateBlock(const memutils::EscapeSummary& escapes, BasicBlock* block, + const MemoryState& in_state) { + MemoryState state = in_state; + for (const auto& inst_ptr : block->GetInstructions()) { + SimulateInstruction(escapes, inst_ptr.get(), state); + } + return state; +} + +bool MarkLoadObserved( + const memutils::AddressKey& key, + std::unordered_map& + pending_stores) { + bool changed = false; + for (auto it = pending_stores.begin(); it != pending_stores.end();) { + if (memutils::MayAliasConservatively(it->first, key)) { + it = pending_stores.erase(it); + changed = true; + continue; + } + ++it; + } + return changed; +} + +void InvalidatePendingForCall( + std::unordered_map& + pending_stores, + Function* callee) { + for (auto it = pending_stores.begin(); it != pending_stores.end();) { + if (memutils::CallMayReadRoot(callee, it->first.kind) || + memutils::CallMayWriteRoot(callee, it->first.kind)) { + it = pending_stores.erase(it); + continue; + } + ++it; + } +} + +bool OptimizeBlock( + const memutils::EscapeSummary& escapes, BasicBlock* block, + const MemoryState& in_state) { + bool changed = false; + MemoryState state = in_state; + std::unordered_map + pending_stores; + std::vector to_remove; + + for (const auto& inst_ptr : block->GetInstructions()) { + auto* inst = inst_ptr.get(); + + if (auto* load = dyncast(inst)) { + memutils::AddressKey key; + if (!memutils::BuildExactAddressKey(load->GetPtr(), &escapes, key)) { + state.clear(); + pending_stores.clear(); + continue; + } + + MarkLoadObserved(key, pending_stores); + auto it = state.find(key); + if (it != state.end() && it->second.value != load) { + load->ReplaceAllUsesWith(it->second.value); + to_remove.push_back(load); + changed = true; + continue; + } + // Keep block-local load reuse, but do not expose load results to cross-block + // dataflow because the defining load itself may be removed later. + state[key] = {load}; + continue; + } + + if (auto* store = dyncast(inst)) { + memutils::AddressKey key; + if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) { + state.clear(); + pending_stores.clear(); + continue; + } + + for (auto it = pending_stores.begin(); it != pending_stores.end();) { + if (!memutils::MayAliasConservatively(it->first, key)) { + ++it; + continue; + } + if (it->first == key) { + to_remove.push_back(it->second); + changed = true; + } + it = pending_stores.erase(it); + } + + pending_stores.emplace(key, store); + InvalidateAliasStates(state, key); + state[key] = {store->GetValue()}; + continue; + } + + if (auto* call = dyncast(inst)) { + InvalidateStatesForCall(state, call->GetCallee()); + InvalidatePendingForCall(pending_stores, call->GetCallee()); + continue; + } + + if (auto* memset = dyncast(inst)) { + memutils::AddressKey key; + if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key)) { + state.clear(); + pending_stores.clear(); + continue; + } + InvalidateAliasStates(state, key); + MarkLoadObserved(key, pending_stores); + continue; + } + } + + for (auto* inst : to_remove) { + if (inst->GetParent() == block) { + block->EraseInstruction(inst); + } + } + return changed; +} + +bool RunLoadStoreElimOnFunction(Function& function) { + if (function.IsExternal() || function.GetEntryBlock() == nullptr) { + return false; + } + + const auto escapes = memutils::AnalyzeEscapes(function); + const auto reachable_blocks = passutils::CollectReachableBlocks(function); + if (reachable_blocks.empty()) { + return false; + } + + std::unordered_map in_states; + std::unordered_map out_states; + + bool dataflow_changed = true; + while (dataflow_changed) { + dataflow_changed = false; + for (auto* block : reachable_blocks) { + MemoryState in_state; + if (block != function.GetEntryBlock()) { + std::vector predecessors; + for (auto* pred : block->GetPredecessors()) { + auto it = out_states.find(pred); + if (it != out_states.end()) { + predecessors.push_back(&it->second); + } + } + in_state = MeetMemoryStates(predecessors); + } + + auto out_state = SimulateBlock(escapes, block, in_state); + auto in_it = in_states.find(block); + if (in_it == in_states.end() || !SameMemoryState(in_it->second, in_state)) { + in_states[block] = in_state; + dataflow_changed = true; + } + auto out_it = out_states.find(block); + if (out_it == out_states.end() || !SameMemoryState(out_it->second, out_state)) { + out_states[block] = std::move(out_state); + dataflow_changed = true; + } + } + } + + bool changed = false; + for (auto* block : reachable_blocks) { + changed |= OptimizeBlock(escapes, block, in_states[block]); + } + return changed; +} + +} // namespace + +bool RunLoadStoreElim(Module& module) { + bool changed = false; + for (const auto& function : module.GetFunctions()) { + if (function) { + changed |= RunLoadStoreElimOnFunction(*function); + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/LoopFission.cpp b/src/ir/passes/LoopFission.cpp new file mode 100644 index 0000000..366de8b --- /dev/null +++ b/src/ir/passes/LoopFission.cpp @@ -0,0 +1,326 @@ +#include "ir/PassManager.h" + +#include "ir/Analysis.h" +#include "ir/IR.h" +#include "LoopMemoryUtils.h" +#include "LoopPassUtils.h" + +#include +#include +#include + +namespace ir { +namespace { + +struct FissionLoopInfo { + Loop* loop = nullptr; + BasicBlock* preheader = nullptr; + BasicBlock* header = nullptr; + BasicBlock* body = nullptr; + BasicBlock* exit = nullptr; + CondBrInst* branch = nullptr; + BinaryInst* compare = nullptr; + Opcode compare_opcode = Opcode::ICmpLT; + Value* bound = nullptr; + loopmem::SimpleInductionVar induction_var; + PhiInst* iv = nullptr; + BinaryInst* step_inst = nullptr; +}; + +bool HasSyntheticLoopTag(const std::string& name) { + return name.find("unroll.") != std::string::npos || + name.find("fission.") != std::string::npos; +} + +bool IsAlreadyTransformedLoop(const Loop& loop, BasicBlock* body) { + if (!loop.preheader || !loop.header || !body) { + return true; + } + return HasSyntheticLoopTag(loop.preheader->GetName()) || + HasSyntheticLoopTag(loop.header->GetName()) || + HasSyntheticLoopTag(body->GetName()); +} + +Opcode SwapCompareOpcode(Opcode opcode) { + switch (opcode) { + case Opcode::ICmpLT: + return Opcode::ICmpGT; + case Opcode::ICmpLE: + return Opcode::ICmpGE; + case Opcode::ICmpGT: + return Opcode::ICmpLT; + case Opcode::ICmpGE: + return Opcode::ICmpLE; + default: + return opcode; + } +} + +bool MatchFissionLoop(Loop& loop, FissionLoopInfo& info) { + if (!loop.preheader || !loop.header || !loop.IsInnermost()) { + return false; + } + + BasicBlock* body = nullptr; + BasicBlock* exit = nullptr; + if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) { + return false; + } + if (IsAlreadyTransformedLoop(loop, body)) { + return false; + } + + std::vector phis; + loopmem::SimpleInductionVar induction_var; + bool found_iv = false; + for (const auto& inst_ptr : loop.header->GetInstructions()) { + auto* phi = dyncast(inst_ptr.get()); + if (!phi) { + break; + } + phis.push_back(phi); + if (!found_iv && + loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) { + found_iv = true; + } + } + if (!found_iv || phis.size() != 1) { + return false; + } + + auto* branch = dyncast(looputils::GetTerminator(loop.header)); + auto* compare = branch ? dyncast(branch->GetCondition()) : nullptr; + if (!branch || branch->GetThenBlock() != body || !compare) { + return false; + } + + Opcode compare_opcode = compare->GetOpcode(); + Value* bound = nullptr; + if (compare->GetLhs() == induction_var.phi && + looputils::IsLoopInvariantValue(loop, compare->GetRhs())) { + bound = compare->GetRhs(); + } else if (compare->GetRhs() == induction_var.phi && + looputils::IsLoopInvariantValue(loop, compare->GetLhs())) { + bound = compare->GetLhs(); + compare_opcode = SwapCompareOpcode(compare_opcode); + } else { + return false; + } + + auto* step_inst = dyncast(induction_var.latch_value); + if (!step_inst || step_inst->GetParent() != body) { + return false; + } + + for (const auto& inst_ptr : body->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (inst->IsTerminator() || inst == step_inst) { + continue; + } + if (!looputils::IsCloneableInstruction(inst) || dyncast(inst) || + dyncast(inst) || dyncast(inst)) { + return false; + } + } + + info.loop = &loop; + info.preheader = loop.preheader; + info.header = loop.header; + info.body = body; + info.exit = exit; + info.branch = branch; + info.compare = compare; + info.compare_opcode = compare_opcode; + info.bound = bound; + info.induction_var = induction_var; + info.iv = induction_var.phi; + info.step_inst = step_inst; + return true; +} + +bool ContainsInterestingPayload(const std::vector& group) { + bool has_memory = false; + for (auto* inst : group) { + if (dyncast(inst) || dyncast(inst)) { + has_memory = true; + } + } + return has_memory; +} + +Value* RemapExitValue(Value* value, PhiInst* old_iv, PhiInst* new_iv) { + if (value == old_iv) { + return new_iv; + } + return value; +} + +bool BuildSecondLoop(Function& function, const FissionLoopInfo& info, + const std::vector& second_group) { + auto* second_header = + function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.header")); + auto* second_body = + function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.body")); + + const int preheader_index = looputils::GetPhiIncomingIndex(info.iv, info.preheader); + if (preheader_index < 0) { + return false; + } + auto* second_iv = second_header->Append( + info.iv->GetType(), nullptr, + looputils::NextSyntheticName(function, "fission.iv.")); + second_iv->AddIncoming(info.iv->GetIncomingValue(preheader_index), info.header); + + auto* second_cmp = second_header->Append( + info.compare_opcode, Type::GetBoolType(), second_iv, info.bound, nullptr, + looputils::NextSyntheticName(function, "fission.cmp.")); + second_header->Append(second_cmp, second_body, info.exit, nullptr); + second_header->AddPredecessor(info.header); + second_header->AddSuccessor(second_body); + second_header->AddSuccessor(info.exit); + + std::unordered_map remap; + remap[info.iv] = second_iv; + std::unordered_set selected(second_group.begin(), second_group.end()); + selected.insert(info.step_inst); + for (const auto& inst_ptr : info.body->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (inst->IsTerminator() || selected.find(inst) == selected.end()) { + continue; + } + looputils::CloneInstruction(function, inst, second_body, remap, "fission."); + } + auto* cloned_step_value = looputils::RemapValue(remap, info.step_inst); + if (!cloned_step_value) { + return false; + } + second_iv->AddIncoming(cloned_step_value, second_body); + second_body->Append(second_header, nullptr); + second_body->AddPredecessor(second_header); + second_body->AddSuccessor(second_header); + second_header->AddPredecessor(second_body); + + if (!looputils::RedirectSuccessorEdge(info.header, info.exit, second_header)) { + return false; + } + info.exit->RemovePredecessor(info.header); + info.exit->AddPredecessor(second_header); + + for (const auto& inst_ptr : info.exit->GetInstructions()) { + auto* phi = dyncast(inst_ptr.get()); + if (!phi) { + break; + } + const int incoming = looputils::GetPhiIncomingIndex(phi, info.header); + if (incoming < 0) { + continue; + } + phi->SetOperand(static_cast(2 * incoming), + RemapExitValue(phi->GetIncomingValue(incoming), info.iv, second_iv)); + phi->SetOperand(static_cast(2 * incoming + 1), second_header); + } + + return true; +} + +bool RunLoopFissionOnFunction(Function& function) { + if (function.IsExternal() || !function.GetEntryBlock()) { + return false; + } + + bool changed = false; + while (true) { + DominatorTree dom_tree(function); + LoopInfo loop_info(function, dom_tree); + bool local_changed = false; + + for (auto* loop : loop_info.GetLoopsInPostOrder()) { + FissionLoopInfo info; + if (!MatchFissionLoop(*loop, info)) { + continue; + } + + const auto accesses = loopmem::CollectMemoryAccesses(*loop, info.iv); + std::vector payload; + for (const auto& inst_ptr : info.body->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (inst->IsTerminator() || inst == info.step_inst) { + continue; + } + payload.push_back(inst); + } + if (payload.size() < 2) { + continue; + } + + int chosen_cut = -1; + std::vector first_group; + std::vector second_group; + for (std::size_t cut = 1; cut < payload.size(); ++cut) { + std::vector first(payload.begin(), payload.begin() + static_cast(cut)); + std::vector second(payload.begin() + static_cast(cut), + payload.end()); + if (!ContainsInterestingPayload(first) || !ContainsInterestingPayload(second)) { + continue; + } + + std::unordered_set first_set(first.begin(), first.end()); + std::unordered_set second_set(second.begin(), second.end()); + if (loopmem::HasScalarDependenceAcrossCut(first, second_set) || + loopmem::HasMemoryDependenceAcrossCut(accesses, first_set, second_set, + info.induction_var.stride)) { + continue; + } + + chosen_cut = static_cast(cut); + first_group = std::move(first); + second_group = std::move(second); + break; + } + + if (chosen_cut < 0) { + continue; + } + + std::unordered_set keep(first_group.begin(), first_group.end()); + keep.insert(info.step_inst); + std::vector to_remove; + for (const auto& inst_ptr : info.body->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (inst->IsTerminator() || keep.find(inst) != keep.end()) { + continue; + } + to_remove.push_back(inst); + } + if (!BuildSecondLoop(function, info, second_group)) { + continue; + } + for (auto* inst : to_remove) { + info.body->EraseInstruction(inst); + } + + changed = true; + local_changed = true; + break; + } + + if (!local_changed) { + break; + } + } + return changed; +} + +} // namespace + +bool RunLoopFission(Module& module) { + bool changed = false; + for (const auto& function : module.GetFunctions()) { + if (function) { + changed |= RunLoopFissionOnFunction(*function); + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/LoopMemoryUtils.h b/src/ir/passes/LoopMemoryUtils.h new file mode 100644 index 0000000..c3c5cf2 --- /dev/null +++ b/src/ir/passes/LoopMemoryUtils.h @@ -0,0 +1,506 @@ +#pragma once + +#include "LoopPassUtils.h" +#include "MemoryUtils.h" + +#include +#include +#include +#include +#include + +namespace ir::loopmem { + +struct SimpleInductionVar { + PhiInst* phi = nullptr; + Value* start = nullptr; + Value* latch_value = nullptr; + BasicBlock* latch = nullptr; + int stride = 0; +}; + +inline bool MatchSimpleInductionVariable(const Loop& loop, BasicBlock* preheader, + PhiInst* phi, SimpleInductionVar& info) { + if (!phi || !preheader || phi->GetParent() != loop.header || + !phi->GetType()->IsInt32() || phi->GetNumIncomings() != 2 || + loop.latches.size() != 1) { + return false; + } + + auto* latch = loop.latches.front(); + const int preheader_index = looputils::GetPhiIncomingIndex(phi, preheader); + const int latch_index = looputils::GetPhiIncomingIndex(phi, latch); + if (preheader_index < 0 || latch_index < 0) { + return false; + } + + auto* step_inst = dyncast(phi->GetIncomingValue(latch_index)); + if (!step_inst || step_inst->GetParent() != latch) { + return false; + } + + int stride = 0; + if (step_inst->GetOpcode() == Opcode::Add) { + if (step_inst->GetLhs() == phi) { + auto* delta = dyncast(step_inst->GetRhs()); + if (!delta) { + return false; + } + stride = delta->GetValue(); + } else if (step_inst->GetRhs() == phi) { + auto* delta = dyncast(step_inst->GetLhs()); + if (!delta) { + return false; + } + stride = delta->GetValue(); + } else { + return false; + } + } else if (step_inst->GetOpcode() == Opcode::Sub) { + if (step_inst->GetLhs() != phi) { + return false; + } + auto* delta = dyncast(step_inst->GetRhs()); + if (!delta) { + return false; + } + stride = -delta->GetValue(); + } else { + return false; + } + + if (stride == 0) { + return false; + } + + info.phi = phi; + info.start = phi->GetIncomingValue(preheader_index); + info.latch_value = phi->GetIncomingValue(latch_index); + info.latch = latch; + info.stride = stride; + return true; +} + +inline bool GetCanonicalLoopBlocks(const Loop& loop, BasicBlock*& body, + BasicBlock*& exit) { + body = nullptr; + exit = nullptr; + if (!loop.header || loop.latches.size() != 1 || loop.block_list.size() != 2) { + return false; + } + + auto* condbr = dyncast(looputils::GetTerminator(loop.header)); + if (!condbr) { + return false; + } + + auto* then_block = condbr->GetThenBlock(); + auto* else_block = condbr->GetElseBlock(); + const bool then_in_loop = loop.Contains(then_block); + const bool else_in_loop = loop.Contains(else_block); + if (then_in_loop == else_in_loop) { + return false; + } + + body = then_in_loop ? then_block : else_block; + exit = then_in_loop ? else_block : then_block; + if (!body || !exit || body != loop.latches.front() || + body->GetSuccessors().size() != 1 || body->GetSuccessors().front() != loop.header) { + return false; + } + return true; +} + +struct AffineExpr { + bool valid = false; + Value* var = nullptr; + std::int64_t coeff = 0; + std::int64_t constant = 0; +}; + +inline AffineExpr MakeConst(std::int64_t value) { + return {true, nullptr, 0, value}; +} + +inline AffineExpr Scale(const AffineExpr& expr, std::int64_t factor) { + if (!expr.valid) { + return {}; + } + return {true, expr.var, expr.coeff * factor, expr.constant * factor}; +} + +inline AffineExpr Combine(const AffineExpr& lhs, const AffineExpr& rhs, int sign) { + if (!lhs.valid || !rhs.valid) { + return {}; + } + if (lhs.var != nullptr && rhs.var != nullptr && lhs.var != rhs.var) { + return {}; + } + AffineExpr out; + out.valid = true; + out.var = lhs.var ? lhs.var : rhs.var; + out.coeff = lhs.coeff + sign * rhs.coeff; + out.constant = lhs.constant + sign * rhs.constant; + return out; +} + +inline AffineExpr AnalyzeAffine(Value* value, PhiInst* iv, const Loop& loop) { + if (!value) { + return {}; + } + if (auto* ci = dyncast(value)) { + return MakeConst(ci->GetValue()); + } + if (value == iv) { + return {true, iv, 1, 0}; + } + if (looputils::IsLoopInvariantValue(loop, value)) { + return {}; + } + + if (auto* zext = dyncast(value)) { + return AnalyzeAffine(zext->GetValue(), iv, loop); + } + auto* inst = dyncast(value); + if (!inst) { + return {}; + } + + switch (inst->GetOpcode()) { + case Opcode::Add: + return Combine(AnalyzeAffine(static_cast(inst)->GetLhs(), iv, loop), + AnalyzeAffine(static_cast(inst)->GetRhs(), iv, loop), +1); + case Opcode::Sub: + return Combine(AnalyzeAffine(static_cast(inst)->GetLhs(), iv, loop), + AnalyzeAffine(static_cast(inst)->GetRhs(), iv, loop), -1); + case Opcode::Mul: { + auto* bin = static_cast(inst); + auto lhs = AnalyzeAffine(bin->GetLhs(), iv, loop); + auto rhs = AnalyzeAffine(bin->GetRhs(), iv, loop); + if (lhs.valid && lhs.var == nullptr && rhs.valid) { + return Scale(rhs, lhs.constant); + } + if (rhs.valid && rhs.var == nullptr && lhs.valid) { + return Scale(lhs, rhs.constant); + } + return {}; + } + case Opcode::Neg: + return Scale(AnalyzeAffine(static_cast(inst)->GetOprd(), iv, loop), -1); + default: + return {}; + } +} + +struct PointerInfo { + Value* base = nullptr; + AffineExpr byte_offset; + bool invariant_address = false; + bool distinct_root = false; + bool argument_root = false; + bool readonly_root = false; + bool exact_key_valid = false; + memutils::PointerRootKind root_kind = memutils::PointerRootKind::Unknown; + memutils::AddressKey exact_key; + int access_size = 0; +}; + +inline Value* StripPointerBase(Value* pointer) { + auto* value = pointer; + while (auto* gep = dyncast(value)) { + value = gep->GetPointer(); + } + return value; +} + +inline std::shared_ptr AdvanceGEPType(std::shared_ptr current) { + if (current && current->IsArray()) { + return current->GetElementType(); + } + return current; +} + +inline PointerInfo AnalyzePointer(Value* pointer, PhiInst* iv, const Loop& loop, + int access_size, + const memutils::EscapeSummary* escapes = nullptr) { + PointerInfo info; + info.access_size = access_size; + info.base = StripPointerBase(pointer); + info.root_kind = memutils::ClassifyRoot(info.base, escapes); + info.argument_root = info.root_kind == memutils::PointerRootKind::Param; + info.readonly_root = info.root_kind == memutils::PointerRootKind::ReadonlyGlobal; + info.distinct_root = info.root_kind == memutils::PointerRootKind::Local || + info.root_kind == memutils::PointerRootKind::Global || + info.root_kind == memutils::PointerRootKind::ReadonlyGlobal; + info.exact_key_valid = + escapes != nullptr && memutils::BuildExactAddressKey(pointer, escapes, info.exact_key); + + info.invariant_address = looputils::IsLoopInvariantValue(loop, pointer); + if (!dyncast(pointer)) { + info.byte_offset = MakeConst(0); + return info; + } + + auto* gep = static_cast(pointer); + std::shared_ptr current = gep->GetSourceType(); + AffineExpr total = MakeConst(0); + bool all_indices_loop_invariant = looputils::IsLoopInvariantValue(loop, gep->GetPointer()); + for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) { + auto* index = gep->GetIndex(i); + all_indices_loop_invariant &= looputils::IsLoopInvariantValue(loop, index); + const std::int64_t stride = current ? current->GetSize() : 0; + auto term = AnalyzeAffine(index, iv, loop); + if (!term.valid) { + total = {}; + } else if (total.valid) { + total = Combine(total, Scale(term, stride), +1); + } + current = AdvanceGEPType(current); + } + info.invariant_address = all_indices_loop_invariant; + info.byte_offset = total; + return info; +} + +struct MemoryAccessInfo { + Instruction* inst = nullptr; + Value* pointer = nullptr; + PointerInfo ptr; + bool is_read = false; + bool is_write = false; +}; + +inline std::vector CollectMemoryAccesses(const Loop& loop, + PhiInst* iv, + const memutils::EscapeSummary* escapes = + nullptr) { + std::vector accesses; + for (auto* block : loop.block_list) { + for (const auto& inst_ptr : block->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (auto* load = dyncast(inst)) { + accesses.push_back( + {inst, load->GetPtr(), + AnalyzePointer(load->GetPtr(), iv, loop, load->GetType()->GetSize(), escapes), + true, + false}); + } else if (auto* store = dyncast(inst)) { + accesses.push_back({inst, store->GetPtr(), + AnalyzePointer(store->GetPtr(), iv, loop, + store->GetValue()->GetType()->GetSize(), escapes), + false, true}); + } else if (auto* memset = dyncast(inst)) { + accesses.push_back( + {inst, memset->GetDest(), + AnalyzePointer(memset->GetDest(), iv, loop, 1, escapes), false, true}); + } + } + } + return accesses; +} + +inline bool SameAffineAddress(const PointerInfo& lhs, const PointerInfo& rhs) { + return lhs.base == rhs.base && lhs.byte_offset.valid && rhs.byte_offset.valid && + lhs.byte_offset.var == rhs.byte_offset.var && + lhs.byte_offset.coeff == rhs.byte_offset.coeff && + lhs.byte_offset.constant == rhs.byte_offset.constant; +} + +inline bool MayAliasSameIteration(const PointerInfo& lhs, const PointerInfo& rhs) { + if (lhs.exact_key_valid && rhs.exact_key_valid) { + return memutils::MayAliasConservatively(lhs.exact_key, rhs.exact_key); + } + if (!lhs.base || !rhs.base) { + return true; + } + if (lhs.base != rhs.base) { + if (lhs.distinct_root && rhs.distinct_root && !lhs.argument_root && !rhs.argument_root) { + return false; + } + return true; + } + if (!lhs.byte_offset.valid || !rhs.byte_offset.valid) { + return true; + } + if (lhs.byte_offset.var != rhs.byte_offset.var) { + return true; + } + if (lhs.byte_offset.coeff != rhs.byte_offset.coeff) { + return true; + } + const auto diff = std::llabs(lhs.byte_offset.constant - rhs.byte_offset.constant); + const auto overlap = std::min(lhs.access_size, rhs.access_size); + return diff < overlap; +} + +inline bool HasCrossIterationDependence(const PointerInfo& lhs, const PointerInfo& rhs, + int iv_stride) { + if (lhs.exact_key_valid && rhs.exact_key_valid && + !memutils::MayAliasConservatively(lhs.exact_key, rhs.exact_key)) { + return false; + } + if (!lhs.base || !rhs.base) { + return true; + } + if (lhs.base != rhs.base) { + if (lhs.distinct_root && rhs.distinct_root && !lhs.argument_root && !rhs.argument_root) { + return false; + } + return true; + } + if (!lhs.byte_offset.valid || !rhs.byte_offset.valid) { + return true; + } + if (lhs.byte_offset.var != rhs.byte_offset.var) { + return true; + } + + const auto lhs_step = lhs.byte_offset.coeff * iv_stride; + const auto rhs_step = rhs.byte_offset.coeff * iv_stride; + if (lhs_step == 0 && rhs_step == 0) { + return MayAliasSameIteration(lhs, rhs); + } + if (lhs_step == rhs_step && lhs_step != 0) { + const auto diff = rhs.byte_offset.constant - lhs.byte_offset.constant; + return diff != 0 && diff % std::llabs(lhs_step) == 0; + } + return true; +} + +inline bool CallMayWritePointer(Function* callee, const PointerInfo& ptr) { + if (ptr.readonly_root) { + return false; + } + return memutils::CallMayWriteRoot(callee, ptr.root_kind); +} + +inline bool IsSafeInvariantLoadToHoist(const Loop& loop, LoadInst* load, PhiInst* iv, + int iv_stride, + const std::vector& accesses, + const memutils::EscapeSummary* escapes = nullptr) { + if (!load) { + return false; + } + auto ptr = AnalyzePointer(load->GetPtr(), iv, loop, load->GetType()->GetSize(), escapes); + if (!ptr.invariant_address) { + return false; + } + if (ptr.readonly_root) { + return true; + } + + for (auto* block : loop.block_list) { + for (const auto& inst_ptr : block->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (inst == load) { + continue; + } + if (auto* call = dyncast(inst)) { + if (CallMayWritePointer(call->GetCallee(), ptr)) { + return false; + } + } + } + } + + for (const auto& access : accesses) { + if (access.inst == load || !access.is_write) { + continue; + } + if (MayAliasSameIteration(ptr, access.ptr)) { + return false; + } + if (HasCrossIterationDependence(ptr, access.ptr, iv_stride)) { + return false; + } + } + return true; +} + +inline bool HasScalarDependenceAcrossCut(const std::vector& first_group, + const std::unordered_set& second_set) { + for (auto* inst : first_group) { + if (!inst || inst->IsVoid()) { + continue; + } + for (const auto& use : inst->GetUses()) { + auto* user = dyncast(use.GetUser()); + if (user && second_set.find(user) != second_set.end()) { + return true; + } + } + } + return false; +} + +inline bool HasMemoryDependenceAcrossCut(const std::vector& accesses, + const std::unordered_set& first_set, + const std::unordered_set& second_set, + int iv_stride) { + for (const auto& lhs : accesses) { + if (first_set.find(lhs.inst) == first_set.end()) { + continue; + } + for (const auto& rhs : accesses) { + if (second_set.find(rhs.inst) == second_set.end()) { + continue; + } + if (!lhs.is_write && !rhs.is_write) { + continue; + } + if (MayAliasSameIteration(lhs.ptr, rhs.ptr) || + HasCrossIterationDependence(lhs.ptr, rhs.ptr, iv_stride)) { + return true; + } + } + } + return false; +} + +inline bool IsLoopParallelizable(const Loop& loop, PhiInst* iv, int iv_stride, + const std::vector& accesses) { + for (const auto& inst_ptr : loop.header->GetInstructions()) { + auto* phi = dyncast(inst_ptr.get()); + if (!phi) { + break; + } + if (phi != iv) { + return false; + } + } + + for (auto* block : loop.block_list) { + for (const auto& inst_ptr : block->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (auto* call = dyncast(inst)) { + auto* callee = call->GetCallee(); + if (callee == nullptr || callee->HasObservableSideEffects() || callee->IsRecursive()) { + return false; + } + for (const auto& access : accesses) { + if (CallMayWritePointer(callee, access.ptr)) { + return false; + } + } + continue; + } + if (dyncast(inst)) { + return false; + } + } + } + + for (std::size_t i = 0; i < accesses.size(); ++i) { + for (std::size_t j = i + 1; j < accesses.size(); ++j) { + if (!accesses[i].is_write && !accesses[j].is_write) { + continue; + } + if (HasCrossIterationDependence(accesses[i].ptr, accesses[j].ptr, iv_stride)) { + return false; + } + } + } + return true; +} + +} // namespace ir::loopmem diff --git a/src/ir/passes/LoopPassUtils.h b/src/ir/passes/LoopPassUtils.h new file mode 100644 index 0000000..ed17bf1 --- /dev/null +++ b/src/ir/passes/LoopPassUtils.h @@ -0,0 +1,440 @@ +#pragma once + +#include "ir/Analysis.h" +#include "ir/IR.h" + +#include +#include +#include +#include +#include +#include + +namespace ir::looputils { + +inline Instruction* GetTerminator(BasicBlock* block) { + if (!block || block->GetInstructions().empty()) { + return nullptr; + } + auto* inst = block->GetInstructions().back().get(); + return inst && inst->IsTerminator() ? inst : nullptr; +} + +inline std::size_t GetTerminatorIndex(BasicBlock* block) { + if (!block) { + return 0; + } + const auto size = block->GetInstructions().size(); + if (!block->HasTerminator()) { + return size; + } + return size == 0 ? 0 : size - 1; +} + +inline std::size_t GetFirstNonPhiIndex(BasicBlock* block) { + if (!block) { + return 0; + } + std::size_t index = 0; + for (const auto& inst_ptr : block->GetInstructions()) { + if (!dyncast(inst_ptr.get())) { + break; + } + ++index; + } + return index; +} + +inline std::string NextSyntheticName(Function& function, const std::string& prefix) { + static std::unordered_map counters; + const int id = ++counters[&function]; + return "%" + prefix + std::to_string(id); +} + +inline std::string NextSyntheticBlockName(Function& function, + const std::string& prefix) { + static std::unordered_map counters; + const int id = ++counters[&function]; + return prefix + "." + std::to_string(id); +} + +inline ConstantInt* ConstInt(int value) { + return new ConstantInt(Type::GetInt32Type(), value); +} + +inline int GetPhiIncomingIndex(PhiInst* phi, BasicBlock* block) { + if (!phi || !block) { + return -1; + } + for (int i = 0; i < phi->GetNumIncomings(); ++i) { + if (phi->GetIncomingBlock(i) == block) { + return i; + } + } + return -1; +} + +inline bool ReplacePhiIncoming(PhiInst* phi, BasicBlock* old_block, + Value* new_value, BasicBlock* new_block) { + if (!phi || !old_block || !new_value || !new_block) { + return false; + } + const int index = GetPhiIncomingIndex(phi, old_block); + if (index < 0) { + return false; + } + phi->SetOperand(static_cast(2 * index), new_value); + phi->SetOperand(static_cast(2 * index + 1), new_block); + return true; +} + +inline bool RedirectSuccessorEdge(BasicBlock* pred, BasicBlock* old_succ, + BasicBlock* new_succ) { + auto* terminator = GetTerminator(pred); + if (!terminator || !old_succ || !new_succ) { + return false; + } + + if (auto* br = dyncast(terminator)) { + if (br->GetDest() != old_succ) { + return false; + } + br->SetOperand(0, new_succ); + } else if (auto* condbr = dyncast(terminator)) { + bool changed = false; + if (condbr->GetThenBlock() == old_succ) { + condbr->SetOperand(1, new_succ); + changed = true; + } + if (condbr->GetElseBlock() == old_succ) { + condbr->SetOperand(2, new_succ); + changed = true; + } + if (!changed) { + return false; + } + } else { + return false; + } + + pred->RemoveSuccessor(old_succ); + pred->AddSuccessor(new_succ); + return true; +} + +inline Instruction* MoveInstructionBeforeTerminator(Instruction* inst, + BasicBlock* dest) { + if (!inst || !dest) { + return nullptr; + } + auto* src = inst->GetParent(); + if (!src || src == dest) { + return inst; + } + + auto& src_insts = src->GetInstructions(); + auto src_it = std::find_if(src_insts.begin(), src_insts.end(), + [&](const std::unique_ptr& current) { + return current.get() == inst; + }); + if (src_it == src_insts.end()) { + return nullptr; + } + + auto moved = std::move(*src_it); + src_insts.erase(src_it); + moved->SetParent(dest); + + auto& dest_insts = dest->GetInstructions(); + auto insert_it = dest_insts.begin() + + static_cast(GetTerminatorIndex(dest)); + auto* ptr = moved.get(); + dest_insts.insert(insert_it, std::move(moved)); + return ptr; +} + +inline bool IsLoopInvariantValue(const Loop& loop, Value* value) { + auto* inst = dyncast(value); + return inst == nullptr || !loop.Contains(inst->GetParent()); +} + +inline Value* RemapValue(const std::unordered_map& remap, + Value* value) { + auto it = remap.find(value); + return it == remap.end() ? value : it->second; +} + +inline bool IsCloneableInstruction(const Instruction* inst) { + if (!inst || inst->IsTerminator() || inst->GetOpcode() == Opcode::Phi) { + return false; + } + switch (inst->GetOpcode()) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::Div: + case Opcode::Rem: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: + case Opcode::FRem: + case Opcode::And: + case Opcode::Or: + case Opcode::Xor: + case Opcode::Shl: + case Opcode::AShr: + case Opcode::LShr: + case Opcode::ICmpEQ: + case Opcode::ICmpNE: + case Opcode::ICmpLT: + case Opcode::ICmpGT: + case Opcode::ICmpLE: + case Opcode::ICmpGE: + case Opcode::FCmpEQ: + case Opcode::FCmpNE: + case Opcode::FCmpLT: + case Opcode::FCmpGT: + case Opcode::FCmpLE: + case Opcode::FCmpGE: + case Opcode::Neg: + case Opcode::Not: + case Opcode::FNeg: + case Opcode::FtoI: + case Opcode::IToF: + case Opcode::Alloca: + case Opcode::Load: + case Opcode::Store: + case Opcode::Memset: + case Opcode::GetElementPtr: + case Opcode::Zext: + case Opcode::Call: + return true; + default: + return false; + } +} + +inline Instruction* CloneInstruction(Function& function, Instruction* inst, + BasicBlock* dest, + std::unordered_map& remap, + const std::string& prefix) { + if (!inst || !dest || !IsCloneableInstruction(inst)) { + return nullptr; + } + + const auto insert_index = GetTerminatorIndex(dest); + const auto name = inst->IsVoid() ? std::string() + : NextSyntheticName(function, prefix); + + auto remap_operand = [&](Value* value) { return RemapValue(remap, value); }; + auto remember = [&](Instruction* clone) { + if (clone && !inst->IsVoid()) { + remap[inst] = clone; + } + return clone; + }; + + switch (inst->GetOpcode()) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::Div: + case Opcode::Rem: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: + case Opcode::FRem: + case Opcode::And: + case Opcode::Or: + case Opcode::Xor: + case Opcode::Shl: + case Opcode::AShr: + case Opcode::LShr: + case Opcode::ICmpEQ: + case Opcode::ICmpNE: + case Opcode::ICmpLT: + case Opcode::ICmpGT: + case Opcode::ICmpLE: + case Opcode::ICmpGE: + case Opcode::FCmpEQ: + case Opcode::FCmpNE: + case Opcode::FCmpLT: + case Opcode::FCmpGT: + case Opcode::FCmpLE: + case Opcode::FCmpGE: { + auto* bin = static_cast(inst); + return remember(dest->Insert( + insert_index, inst->GetOpcode(), inst->GetType(), + remap_operand(bin->GetLhs()), remap_operand(bin->GetRhs()), nullptr, + name)); + } + case Opcode::Neg: + case Opcode::Not: + case Opcode::FNeg: + case Opcode::FtoI: + case Opcode::IToF: { + auto* un = static_cast(inst); + return remember(dest->Insert(insert_index, inst->GetOpcode(), + inst->GetType(), + remap_operand(un->GetOprd()), + nullptr, name)); + } + case Opcode::Alloca: { + auto* alloca = static_cast(inst); + return remember(dest->Insert(insert_index, + alloca->GetAllocatedType(), + nullptr, name)); + } + case Opcode::Load: { + auto* load = static_cast(inst); + return remember(dest->Insert(insert_index, inst->GetType(), + remap_operand(load->GetPtr()), + nullptr, name)); + } + case Opcode::Store: { + auto* store = static_cast(inst); + return dest->Insert(insert_index, + remap_operand(store->GetValue()), + remap_operand(store->GetPtr()), nullptr); + } + case Opcode::Memset: { + auto* memset = static_cast(inst); + return dest->Insert(insert_index, + remap_operand(memset->GetDest()), + remap_operand(memset->GetValue()), + remap_operand(memset->GetLength()), + remap_operand(memset->GetIsVolatile()), + nullptr); + } + case Opcode::GetElementPtr: { + auto* gep = static_cast(inst); + std::vector indices; + indices.reserve(gep->GetNumIndices()); + for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) { + indices.push_back(remap_operand(gep->GetIndex(i))); + } + return remember(dest->Insert( + insert_index, gep->GetSourceType(), remap_operand(gep->GetPointer()), + indices, nullptr, name)); + } + case Opcode::Zext: { + auto* zext = static_cast(inst); + return remember(dest->Insert(insert_index, + remap_operand(zext->GetValue()), + inst->GetType(), nullptr, name)); + } + case Opcode::Call: { + auto* call = static_cast(inst); + std::vector args; + const auto original_args = call->GetArguments(); + args.reserve(original_args.size()); + for (auto* arg : original_args) { + args.push_back(remap_operand(arg)); + } + return remember(dest->Insert(insert_index, call->GetCallee(), + args, nullptr, name)); + } + case Opcode::Phi: + case Opcode::Br: + case Opcode::CondBr: + case Opcode::Return: + case Opcode::Unreachable: + break; + } + return nullptr; +} + +inline BasicBlock* EnsurePreheader(Function& function, Loop& loop) { + if (loop.preheader) { + return loop.preheader; + } + + auto* header = loop.header; + if (!header) { + return nullptr; + } + + std::vector outside_preds; + for (auto* pred : header->GetPredecessors()) { + if (!loop.Contains(pred)) { + outside_preds.push_back(pred); + } + } + if (outside_preds.empty()) { + return nullptr; + } + + if (outside_preds.size() == 1 && + outside_preds.front()->GetSuccessors().size() == 1) { + loop.preheader = outside_preds.front(); + return loop.preheader; + } + + auto* preheader = function.CreateBlock( + NextSyntheticBlockName(function, header->GetName() + ".preheader")); + + std::size_t phi_insert_index = 0; + for (const auto& inst_ptr : header->GetInstructions()) { + auto* phi = dyncast(inst_ptr.get()); + if (!phi) { + break; + } + + std::vector outside_incomings; + for (int i = 0; i < phi->GetNumIncomings(); ++i) { + if (!loop.Contains(phi->GetIncomingBlock(i))) { + outside_incomings.push_back(i); + } + } + if (outside_incomings.empty()) { + continue; + } + + Value* merged_value = nullptr; + if (outside_incomings.size() == 1) { + merged_value = phi->GetIncomingValue(outside_incomings.front()); + } else { + auto new_phi = std::make_unique( + phi->GetType(), nullptr, + NextSyntheticName(function, "preheader.phi.")); + auto* new_phi_ptr = new_phi.get(); + new_phi_ptr->SetParent(preheader); + auto& preheader_insts = preheader->GetInstructions(); + preheader_insts.insert(preheader_insts.begin() + + static_cast(phi_insert_index), + std::move(new_phi)); + ++phi_insert_index; + + for (int incoming_index : outside_incomings) { + new_phi_ptr->AddIncoming(phi->GetIncomingValue(incoming_index), + phi->GetIncomingBlock(incoming_index)); + } + merged_value = new_phi_ptr; + } + + for (auto it = outside_incomings.rbegin(); it != outside_incomings.rend(); + ++it) { + phi->RemoveOperand(static_cast(2 * *it + 1)); + phi->RemoveOperand(static_cast(2 * *it)); + } + phi->AddIncoming(merged_value, preheader); + } + + preheader->Append(header, nullptr); + preheader->AddSuccessor(header); + header->AddPredecessor(preheader); + + for (auto* pred : outside_preds) { + if (RedirectSuccessorEdge(pred, header, preheader)) { + preheader->AddPredecessor(pred); + header->RemovePredecessor(pred); + } + } + + loop.preheader = preheader; + return preheader; +} + +} // namespace ir::looputils diff --git a/src/ir/passes/LoopStrengthReduction.cpp b/src/ir/passes/LoopStrengthReduction.cpp new file mode 100644 index 0000000..c9a4899 --- /dev/null +++ b/src/ir/passes/LoopStrengthReduction.cpp @@ -0,0 +1,295 @@ +#include "ir/PassManager.h" + +#include "ir/Analysis.h" +#include "ir/IR.h" +#include "LoopPassUtils.h" + +#include +#include +#include + +namespace ir { +namespace { + +struct InductionVarInfo { + PhiInst* phi = nullptr; + Value* start = nullptr; + BasicBlock* latch = nullptr; + int stride = 0; +}; + +Value* BuildMulValue(Function& function, BasicBlock* block, Value* lhs, Value* rhs, + const std::string& prefix) { + if (auto* lhs_const = dyncast(lhs)) { + if (lhs_const->GetValue() == 0) { + return looputils::ConstInt(0); + } + if (lhs_const->GetValue() == 1) { + return rhs; + } + } + if (auto* rhs_const = dyncast(rhs)) { + if (rhs_const->GetValue() == 0) { + return looputils::ConstInt(0); + } + if (rhs_const->GetValue() == 1) { + return lhs; + } + } + if (auto* lhs_const = dyncast(lhs)) { + if (auto* rhs_const = dyncast(rhs)) { + return looputils::ConstInt(lhs_const->GetValue() * rhs_const->GetValue()); + } + } + return block->Insert(looputils::GetTerminatorIndex(block), Opcode::Mul, + Type::GetInt32Type(), lhs, rhs, nullptr, + looputils::NextSyntheticName(function, prefix)); +} + +Value* BuildScaledValue(Function& function, BasicBlock* block, Value* base, + int factor, const std::string& prefix) { + if (factor == 0) { + return looputils::ConstInt(0); + } + if (factor == 1) { + return base; + } + if (auto* base_const = dyncast(base)) { + return looputils::ConstInt(base_const->GetValue() * factor); + } + if (factor == -1) { + return block->Insert(looputils::GetTerminatorIndex(block), Opcode::Neg, + base->GetType(), base, nullptr, + looputils::NextSyntheticName(function, prefix)); + } + return BuildMulValue(function, block, base, looputils::ConstInt(factor), prefix); +} + +bool MatchSimpleInductionVariable(const Loop& loop, BasicBlock* preheader, + PhiInst* phi, InductionVarInfo& info) { + if (!phi || !phi->GetType() || !phi->GetType()->IsInt32() || + phi->GetParent() != loop.header || phi->GetNumIncomings() != 2 || + loop.latches.size() != 1) { + return false; + } + + auto* latch = loop.latches.front(); + int preheader_index = -1; + int latch_index = -1; + for (int i = 0; i < phi->GetNumIncomings(); ++i) { + if (phi->GetIncomingBlock(i) == preheader) { + preheader_index = i; + } else if (phi->GetIncomingBlock(i) == latch) { + latch_index = i; + } + } + if (preheader_index < 0 || latch_index < 0) { + return false; + } + + auto* step_inst = dyncast(phi->GetIncomingValue(latch_index)); + if (!step_inst || step_inst->GetParent() != latch) { + return false; + } + + int stride = 0; + if (step_inst->GetOpcode() == Opcode::Add) { + if (step_inst->GetLhs() == phi) { + auto* delta = dyncast(step_inst->GetRhs()); + if (!delta) { + return false; + } + stride = delta->GetValue(); + } else if (step_inst->GetRhs() == phi) { + auto* delta = dyncast(step_inst->GetLhs()); + if (!delta) { + return false; + } + stride = delta->GetValue(); + } else { + return false; + } + } else if (step_inst->GetOpcode() == Opcode::Sub) { + if (step_inst->GetLhs() != phi) { + return false; + } + auto* delta = dyncast(step_inst->GetRhs()); + if (!delta) { + return false; + } + stride = -delta->GetValue(); + } else { + return false; + } + + if (stride == 0) { + return false; + } + + info.phi = phi; + info.start = phi->GetIncomingValue(preheader_index); + info.latch = latch; + info.stride = stride; + return true; +} + +bool IsMulCandidate(const Loop& loop, Instruction* inst, PhiInst* phi, Value*& factor) { + auto* mul = dyncast(inst); + if (!mul || mul->GetOpcode() != Opcode::Mul || !mul->GetType()->IsInt32()) { + return false; + } + + if (mul->GetLhs() == phi && looputils::IsLoopInvariantValue(loop, mul->GetRhs())) { + factor = mul->GetRhs(); + return true; + } + if (mul->GetRhs() == phi && looputils::IsLoopInvariantValue(loop, mul->GetLhs())) { + factor = mul->GetLhs(); + return true; + } + return false; +} + +Value* CreateReducedPhi(Function& function, BasicBlock* header, BasicBlock* preheader, + const InductionVarInfo& iv, Value* factor) { + auto* reduced_phi = header->Insert( + looputils::GetFirstNonPhiIndex(header), Type::GetInt32Type(), nullptr, + looputils::NextSyntheticName(function, "lsr.phi.")); + + Value* init = BuildMulValue(function, preheader, iv.start, factor, "lsr.init."); + reduced_phi->AddIncoming(init, preheader); + + Value* step = BuildScaledValue(function, preheader, factor, std::abs(iv.stride), + "lsr.step."); + Instruction* next = nullptr; + if (iv.stride > 0) { + next = iv.latch->Insert( + looputils::GetTerminatorIndex(iv.latch), Opcode::Add, Type::GetInt32Type(), + reduced_phi, step, nullptr, + looputils::NextSyntheticName(function, "lsr.next.")); + } else { + next = iv.latch->Insert( + looputils::GetTerminatorIndex(iv.latch), Opcode::Sub, Type::GetInt32Type(), + reduced_phi, step, nullptr, + looputils::NextSyntheticName(function, "lsr.next.")); + } + reduced_phi->AddIncoming(next, iv.latch); + return reduced_phi; +} + +bool ReduceLoopMultiplications(Function& function, const Loop& loop, + BasicBlock* preheader) { + if (!preheader || loop.latches.size() != 1) { + return false; + } + + std::vector induction_vars; + for (const auto& inst_ptr : loop.header->GetInstructions()) { + auto* phi = dyncast(inst_ptr.get()); + if (!phi) { + break; + } + InductionVarInfo info; + if (MatchSimpleInductionVariable(loop, preheader, phi, info)) { + induction_vars.push_back(info); + } + } + if (induction_vars.empty()) { + return false; + } + + bool changed = false; + std::vector to_remove; + for (const auto& iv : induction_vars) { + std::vector> candidates; + for (auto* block : loop.block_list) { + for (const auto& inst_ptr : block->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (inst == iv.phi) { + continue; + } + + Value* factor = nullptr; + if (IsMulCandidate(loop, inst, iv.phi, factor)) { + candidates.push_back({inst, factor}); + } + } + } + + if (candidates.empty()) { + continue; + } + + std::unordered_map reduced_cache; + for (const auto& candidate : candidates) { + auto* inst = candidate.first; + auto* factor = candidate.second; + + auto cache_it = reduced_cache.find(factor); + Value* replacement = nullptr; + if (cache_it != reduced_cache.end()) { + replacement = cache_it->second; + } else { + replacement = CreateReducedPhi(function, loop.header, preheader, iv, factor); + reduced_cache.emplace(factor, replacement); + } + + inst->ReplaceAllUsesWith(replacement); + to_remove.push_back(inst); + changed = true; + } + } + + for (auto* inst : to_remove) { + if (inst && inst->GetParent()) { + inst->GetParent()->EraseInstruction(inst); + } + } + return changed; +} + +bool RunLoopStrengthReductionOnFunction(Function& function) { + if (function.IsExternal() || function.GetEntryBlock() == nullptr) { + return false; + } + + bool changed = false; + while (true) { + DominatorTree dom_tree(function); + LoopInfo loop_info(function, dom_tree); + bool local_changed = false; + + for (auto* loop : loop_info.GetLoopsInPostOrder()) { + auto* old_preheader = loop->preheader; + auto* preheader = looputils::EnsurePreheader(function, *loop); + bool loop_changed = preheader != old_preheader; + loop_changed |= ReduceLoopMultiplications(function, *loop, preheader); + if (!loop_changed) { + continue; + } + changed = true; + local_changed = true; + break; + } + + if (!local_changed) { + break; + } + } + + return changed; +} + +} // namespace + +bool RunLoopStrengthReduction(Module& module) { + bool changed = false; + for (const auto& function : module.GetFunctions()) { + if (function) { + changed |= RunLoopStrengthReductionOnFunction(*function); + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/LoopUnroll.cpp b/src/ir/passes/LoopUnroll.cpp new file mode 100644 index 0000000..5336651 --- /dev/null +++ b/src/ir/passes/LoopUnroll.cpp @@ -0,0 +1,400 @@ +#include "ir/PassManager.h" + +#include "ir/Analysis.h" +#include "ir/IR.h" +#include "LoopMemoryUtils.h" +#include "LoopPassUtils.h" + +#include +#include + +namespace ir { +namespace { + +struct CountedLoopInfo { + Loop* loop = nullptr; + BasicBlock* preheader = nullptr; + BasicBlock* header = nullptr; + BasicBlock* body = nullptr; + BasicBlock* exit = nullptr; + CondBrInst* branch = nullptr; + BinaryInst* compare = nullptr; + Opcode compare_opcode = Opcode::ICmpLT; + Value* bound = nullptr; + loopmem::SimpleInductionVar induction_var; + std::vector phis; +}; + +bool HasSyntheticLoopTag(const std::string& name) { + return name.find("unroll.") != std::string::npos; +} + +bool IsAlreadyTransformedLoop(const Loop& loop, BasicBlock* body) { + if (!loop.preheader || !loop.header || !body) { + return true; + } + if (HasSyntheticLoopTag(loop.preheader->GetName()) || + HasSyntheticLoopTag(loop.header->GetName()) || + HasSyntheticLoopTag(body->GetName())) { + return true; + } + for (const auto& inst_ptr : loop.header->GetInstructions()) { + auto* phi = dyncast(inst_ptr.get()); + if (!phi) { + break; + } + const int incoming = looputils::GetPhiIncomingIndex(phi, loop.preheader); + if (incoming < 0) { + continue; + } + auto* incoming_phi = dyncast(phi->GetIncomingValue(incoming)); + if (incoming_phi && incoming_phi->GetParent() && + HasSyntheticLoopTag(incoming_phi->GetParent()->GetName())) { + return true; + } + } + return false; +} + +bool IsSupportedCompareOpcode(Opcode opcode) { + switch (opcode) { + case Opcode::ICmpLT: + case Opcode::ICmpLE: + case Opcode::ICmpGT: + case Opcode::ICmpGE: + return true; + default: + return false; + } +} + +Opcode SwapCompareOpcode(Opcode opcode) { + switch (opcode) { + case Opcode::ICmpLT: + return Opcode::ICmpGT; + case Opcode::ICmpLE: + return Opcode::ICmpGE; + case Opcode::ICmpGT: + return Opcode::ICmpLT; + case Opcode::ICmpGE: + return Opcode::ICmpLE; + default: + return opcode; + } +} + +int CountPayloadInstructions(BasicBlock* block) { + int count = 0; + if (!block) { + return 0; + } + for (const auto& inst_ptr : block->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (inst->IsTerminator()) { + break; + } + ++count; + } + return count; +} + +int ChooseUnrollFactor(BasicBlock* body) { + const int inst_count = CountPayloadInstructions(body); + int mem_ops = 0; + for (const auto& inst_ptr : body->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (inst->IsTerminator()) { + break; + } + if (dyncast(inst) || dyncast(inst)) { + ++mem_ops; + } + } + if (inst_count >= 2 && inst_count <= 6 && mem_ops <= 2) { + return 4; + } + if (inst_count >= 2 && inst_count <= 18) { + return 2; + } + return 1; +} + +bool HasUnsafeLoopCarriedMemoryDependence( + const std::vector& accesses, int iv_stride) { + for (std::size_t i = 0; i < accesses.size(); ++i) { + if (accesses[i].is_write && + loopmem::HasCrossIterationDependence(accesses[i].ptr, accesses[i].ptr, + iv_stride)) { + return true; + } + for (std::size_t j = i + 1; j < accesses.size(); ++j) { + if (!accesses[i].is_write && !accesses[j].is_write) { + continue; + } + if (loopmem::HasCrossIterationDependence(accesses[i].ptr, accesses[j].ptr, + iv_stride)) { + return true; + } + } + } + return false; +} + +bool MatchCountedLoop(Loop& loop, CountedLoopInfo& info) { + if (!loop.preheader || !loop.header || !loop.IsInnermost()) { + return false; + } + + BasicBlock* body = nullptr; + BasicBlock* exit = nullptr; + if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) { + return false; + } + if (IsAlreadyTransformedLoop(loop, body)) { + return false; + } + + auto* branch = dyncast(looputils::GetTerminator(loop.header)); + if (!branch || branch->GetThenBlock() != body) { + return false; + } + + auto* compare = dyncast(branch->GetCondition()); + if (!compare || !compare->GetType()->IsBool() || + !IsSupportedCompareOpcode(compare->GetOpcode())) { + return false; + } + + bool found_iv = false; + loopmem::SimpleInductionVar induction_var; + std::vector phis; + for (const auto& inst_ptr : loop.header->GetInstructions()) { + auto* phi = dyncast(inst_ptr.get()); + if (!phi) { + break; + } + phis.push_back(phi); + if (!found_iv && + loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) { + found_iv = true; + } + } + if (!found_iv) { + return false; + } + + Opcode compare_opcode = compare->GetOpcode(); + Value* bound = nullptr; + if (compare->GetLhs() == induction_var.phi && + looputils::IsLoopInvariantValue(loop, compare->GetRhs())) { + bound = compare->GetRhs(); + } else if (compare->GetRhs() == induction_var.phi && + looputils::IsLoopInvariantValue(loop, compare->GetLhs())) { + bound = compare->GetLhs(); + compare_opcode = SwapCompareOpcode(compare_opcode); + } else { + return false; + } + + if (!bound) { + return false; + } + if ((induction_var.stride > 0 && + !(compare_opcode == Opcode::ICmpLT || compare_opcode == Opcode::ICmpLE)) || + (induction_var.stride < 0 && + !(compare_opcode == Opcode::ICmpGT || compare_opcode == Opcode::ICmpGE))) { + return false; + } + + const auto accesses = loopmem::CollectMemoryAccesses(loop, induction_var.phi); + if (HasUnsafeLoopCarriedMemoryDependence(accesses, induction_var.stride)) { + return false; + } + + for (const auto& inst_ptr : loop.header->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (dyncast(inst) || inst == compare || inst->IsTerminator()) { + continue; + } + return false; + } + + for (const auto& inst_ptr : body->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (inst->IsTerminator()) { + continue; + } + if (!looputils::IsCloneableInstruction(inst) || dyncast(inst) || + dyncast(inst) || dyncast(inst)) { + return false; + } + } + + info.loop = &loop; + info.preheader = loop.preheader; + info.header = loop.header; + info.body = body; + info.exit = exit; + info.branch = branch; + info.compare = compare; + info.compare_opcode = compare_opcode; + info.bound = bound; + info.induction_var = induction_var; + info.phis = std::move(phis); + return true; +} + +Value* BuildAdjustedBound(Function& function, BasicBlock* preheader, Value* bound, + int stride, int factor) { + const int delta = std::abs(stride) * (factor - 1); + if (delta == 0) { + return bound; + } + if (auto* ci = dyncast(bound)) { + return looputils::ConstInt(stride > 0 ? ci->GetValue() - delta : ci->GetValue() + delta); + } + return preheader->Insert( + looputils::GetTerminatorIndex(preheader), + stride > 0 ? Opcode::Sub : Opcode::Add, Type::GetInt32Type(), bound, + looputils::ConstInt(delta), nullptr, + looputils::NextSyntheticName(function, "unroll.bound.")); +} + +bool RunLoopUnrollOnFunction(Function& function) { + if (function.IsExternal() || !function.GetEntryBlock()) { + return false; + } + + bool changed = false; + while (true) { + DominatorTree dom_tree(function); + LoopInfo loop_info(function, dom_tree); + bool local_changed = false; + + for (auto* loop : loop_info.GetLoopsInPostOrder()) { + CountedLoopInfo info; + if (!MatchCountedLoop(*loop, info)) { + continue; + } + + const int factor = ChooseUnrollFactor(info.body); + if (factor <= 1) { + continue; + } + + auto* unrolled_header = + function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.header")); + auto* unrolled_body = + function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.body")); + auto* unrolled_exit = + function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.exit")); + + std::unordered_map remap; + std::unordered_map unrolled_phis; + std::unordered_map exit_phis; + std::unordered_map current_phi_values; + std::unordered_map latch_values; + + for (auto* phi : info.phis) { + auto* cloned_phi = unrolled_header->Append( + phi->GetType(), nullptr, + looputils::NextSyntheticName(function, "unroll.phi.")); + const int preheader_index = looputils::GetPhiIncomingIndex(phi, info.preheader); + const int latch_index = looputils::GetPhiIncomingIndex(phi, info.body); + if (preheader_index < 0 || latch_index < 0) { + continue; + } + cloned_phi->AddIncoming(phi->GetIncomingValue(preheader_index), info.preheader); + remap[phi] = cloned_phi; + unrolled_phis.emplace(phi, cloned_phi); + current_phi_values.emplace(phi, cloned_phi); + latch_values.emplace(phi, phi->GetIncomingValue(latch_index)); + } + + auto* adjusted_bound = BuildAdjustedBound(function, info.preheader, info.bound, + info.induction_var.stride, factor); + auto* unrolled_cond = unrolled_header->Append( + info.compare_opcode, Type::GetBoolType(), unrolled_phis[info.induction_var.phi], + adjusted_bound, nullptr, + looputils::NextSyntheticName(function, "unroll.cmp.")); + unrolled_header->Append(unrolled_cond, unrolled_body, unrolled_exit, nullptr); + unrolled_header->AddPredecessor(info.preheader); + unrolled_header->AddSuccessor(unrolled_body); + unrolled_header->AddSuccessor(unrolled_exit); + + for (int lane = 0; lane < factor; ++lane) { + for (const auto& inst_ptr : info.body->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (inst->IsTerminator()) { + continue; + } + looputils::CloneInstruction(function, inst, unrolled_body, remap, + "unroll." + std::to_string(lane) + "."); + } + + std::unordered_map next_phi_values; + for (const auto& entry : latch_values) { + next_phi_values.emplace(entry.first, + looputils::RemapValue(remap, entry.second)); + } + for (const auto& entry : next_phi_values) { + remap[entry.first] = entry.second; + current_phi_values[entry.first] = entry.second; + } + } + + for (const auto& entry : unrolled_phis) { + entry.second->AddIncoming(current_phi_values[entry.first], unrolled_body); + } + unrolled_body->Append(unrolled_header, nullptr); + unrolled_body->AddPredecessor(unrolled_header); + unrolled_body->AddSuccessor(unrolled_header); + unrolled_header->AddPredecessor(unrolled_body); + + for (const auto& entry : unrolled_phis) { + auto* exit_phi = unrolled_exit->Append( + entry.first->GetType(), nullptr, + looputils::NextSyntheticName(function, "unroll.exit.")); + exit_phi->AddIncoming(entry.second, unrolled_header); + exit_phis.emplace(entry.first, exit_phi); + } + unrolled_exit->Append(info.header, nullptr); + unrolled_exit->AddPredecessor(unrolled_header); + unrolled_exit->AddSuccessor(info.header); + + if (!looputils::RedirectSuccessorEdge(info.preheader, info.header, unrolled_header)) { + continue; + } + info.header->RemovePredecessor(info.preheader); + info.header->AddPredecessor(unrolled_exit); + + for (auto* phi : info.phis) { + looputils::ReplacePhiIncoming(phi, info.preheader, exit_phis[phi], unrolled_exit); + } + + changed = true; + local_changed = true; + break; + } + + if (!local_changed) { + break; + } + } + + return changed; +} + +} // namespace + +bool RunLoopUnroll(Module& module) { + bool changed = false; + for (const auto& function : module.GetFunctions()) { + if (function) { + changed |= RunLoopUnrollOnFunction(*function); + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/MemoryUtils.h b/src/ir/passes/MemoryUtils.h new file mode 100644 index 0000000..bcd8198 --- /dev/null +++ b/src/ir/passes/MemoryUtils.h @@ -0,0 +1,260 @@ +#pragma once + +#include "ir/IR.h" +#include "PassUtils.h" + +#include +#include +#include + +namespace ir::memutils { + +enum class PointerRootKind { + Local, + Global, + ReadonlyGlobal, + Param, + Unknown, +}; + +struct AddressComponent { + bool is_constant = false; + std::int64_t constant = 0; + Value* value = nullptr; + + bool operator==(const AddressComponent& rhs) const { + return is_constant == rhs.is_constant && constant == rhs.constant && + value == rhs.value; + } +}; + +struct AddressKey { + PointerRootKind kind = PointerRootKind::Unknown; + Value* root = nullptr; + std::vector components; + + bool operator==(const AddressKey& rhs) const { + return kind == rhs.kind && root == rhs.root && components == rhs.components; + } +}; + +struct AddressKeyHash { + std::size_t operator()(const AddressKey& key) const { + std::size_t h = static_cast(key.kind); + h ^= std::hash{}(key.root) + 0x9e3779b9 + (h << 6) + (h >> 2); + for (const auto& component : key.components) { + h ^= std::hash{}(component.is_constant) + 0x9e3779b9 + (h << 6) + (h >> 2); + if (component.is_constant) { + h ^= std::hash{}(component.constant) + 0x9e3779b9 + (h << 6) + + (h >> 2); + } else { + h ^= std::hash{}(component.value) + 0x9e3779b9 + (h << 6) + (h >> 2); + } + } + return h; + } +}; + +struct EscapeSummary { + std::unordered_set escaped_locals; + + bool IsEscaped(Value* value) const { + return value != nullptr && escaped_locals.find(value) != escaped_locals.end(); + } +}; + +inline bool IsNoEscapePointerUse(Value* current, Instruction* user) { + if (!current || !user) { + return false; + } + if (auto* load = dyncast(user)) { + return load->GetPtr() == current; + } + if (auto* store = dyncast(user)) { + return store->GetPtr() == current; + } + if (auto* memset = dyncast(user)) { + return memset->GetDest() == current; + } + return false; +} + +inline bool PointerValueEscapes(Value* current, Value* root, + std::unordered_set& visiting) { + if (!current || !root || !visiting.insert(current).second) { + return false; + } + + for (const auto& use : current->GetUses()) { + auto* user = dyncast(use.GetUser()); + if (!user) { + return true; + } + if (auto* gep = dyncast(user)) { + if (gep->GetPointer() == current && + PointerValueEscapes(gep, root, visiting)) { + return true; + } + continue; + } + if (IsNoEscapePointerUse(current, user)) { + continue; + } + return true; + } + + return false; +} + +inline EscapeSummary AnalyzeEscapes(Function& function) { + EscapeSummary summary; + for (const auto& block_ptr : function.GetBlocks()) { + for (const auto& inst_ptr : block_ptr->GetInstructions()) { + auto* alloca = dyncast(inst_ptr.get()); + if (!alloca) { + continue; + } + std::unordered_set visiting; + if (PointerValueEscapes(alloca, alloca, visiting)) { + summary.escaped_locals.insert(alloca); + } + } + } + return summary; +} + +inline PointerRootKind ClassifyRoot(Value* root, const EscapeSummary* summary) { + if (root == nullptr) { + return PointerRootKind::Unknown; + } + if (auto* global = dyncast(root)) { + return global->IsConstant() ? PointerRootKind::ReadonlyGlobal + : PointerRootKind::Global; + } + if (isa(root)) { + return PointerRootKind::Param; + } + if (isa(root)) { + if (summary != nullptr && summary->IsEscaped(root)) { + return PointerRootKind::Unknown; + } + return PointerRootKind::Local; + } + return PointerRootKind::Unknown; +} + +inline Value* StripPointerRoot(Value* pointer) { + auto* current = pointer; + while (auto* gep = dyncast(current)) { + current = gep->GetPointer(); + } + return current; +} + +inline AddressComponent MakeAddressComponent(Value* value) { + if (auto* ci = dyncast(value)) { + return {true, ci->GetValue(), nullptr}; + } + if (auto* cb = dyncast(value)) { + return {true, cb->GetValue() ? 1 : 0, nullptr}; + } + return {false, 0, value}; +} + +inline bool BuildExactAddressKey(Value* pointer, const EscapeSummary* summary, + AddressKey& key) { + if (!pointer) { + return false; + } + if (auto* gep = dyncast(pointer)) { + if (!BuildExactAddressKey(gep->GetPointer(), summary, key)) { + return false; + } + for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) { + key.components.push_back(MakeAddressComponent(gep->GetIndex(i))); + } + return true; + } + + key.kind = ClassifyRoot(pointer, summary); + key.root = pointer; + key.components.clear(); + return true; +} + +inline bool HasOnlyConstantComponents(const AddressKey& key) { + for (const auto& component : key.components) { + if (!component.is_constant) { + return false; + } + } + return true; +} + +inline bool MayAliasConservatively(const AddressKey& lhs, const AddressKey& rhs) { + if (lhs.kind == PointerRootKind::Unknown || rhs.kind == PointerRootKind::Unknown) { + return true; + } + if (lhs.kind != rhs.kind || lhs.root != rhs.root) { + return false; + } + if (lhs.components == rhs.components) { + return true; + } + if (HasOnlyConstantComponents(lhs) && HasOnlyConstantComponents(rhs)) { + return false; + } + return true; +} + +inline bool CallMayReadRoot(Function* callee, PointerRootKind kind) { + if (!callee) { + return true; + } + if (callee->HasUnknownEffects()) { + return true; + } + switch (kind) { + case PointerRootKind::ReadonlyGlobal: + return callee->ReadsGlobalMemory(); + case PointerRootKind::Global: + return callee->ReadsGlobalMemory() || callee->WritesGlobalMemory(); + case PointerRootKind::Param: + return callee->ReadsParamMemory() || callee->WritesParamMemory(); + case PointerRootKind::Local: + return false; + case PointerRootKind::Unknown: + return callee->MayReadMemory(); + } + return true; +} + +inline bool CallMayWriteRoot(Function* callee, PointerRootKind kind) { + if (!callee) { + return true; + } + if (callee->HasUnknownEffects()) { + return true; + } + switch (kind) { + case PointerRootKind::ReadonlyGlobal: + return false; + case PointerRootKind::Global: + return callee->WritesGlobalMemory(); + case PointerRootKind::Param: + return callee->WritesParamMemory(); + case PointerRootKind::Local: + return false; + case PointerRootKind::Unknown: + return callee->MayWriteMemory(); + } + return true; +} + +inline bool IsPureCall(const CallInst* call) { + auto* callee = call == nullptr ? nullptr : call->GetCallee(); + return callee != nullptr && callee->CanDiscardUnusedCall() && + !callee->MayReadMemory(); +} + +} // namespace ir::memutils diff --git a/src/ir/passes/PassManager.cpp b/src/ir/passes/PassManager.cpp index c1f9ccc..2620d5e 100644 --- a/src/ir/passes/PassManager.cpp +++ b/src/ir/passes/PassManager.cpp @@ -11,7 +11,35 @@ void RunIRPassPipeline(Module& module) { if (disable_mem2reg != nullptr && disable_mem2reg[0] != '\0' && disable_mem2reg[0] != '0') { return; } + RunMem2Reg(module); + + constexpr int kMaxIterations = 8; + for (int iteration = 0; iteration < kMaxIterations; ++iteration) { + bool changed = false; + changed |= RunFunctionInlining(module); + changed |= RunConstProp(module); + changed |= RunConstFold(module); + changed |= RunGVN(module); + changed |= RunLoadStoreElim(module); + changed |= RunCSE(module); + changed |= RunDCE(module); + changed |= RunCFGSimplify(module); + changed |= RunLICM(module); + changed |= RunLoopStrengthReduction(module); + changed |= RunLoopFission(module); + changed |= RunLoopUnroll(module); + changed |= RunConstProp(module); + changed |= RunConstFold(module); + changed |= RunGVN(module); + changed |= RunLoadStoreElim(module); + changed |= RunCSE(module); + changed |= RunDCE(module); + changed |= RunCFGSimplify(module); + if (!changed) { + break; + } + } } -} // namespace ir \ No newline at end of file +} // namespace ir diff --git a/src/ir/passes/PassUtils.h b/src/ir/passes/PassUtils.h new file mode 100644 index 0000000..9d42dc6 --- /dev/null +++ b/src/ir/passes/PassUtils.h @@ -0,0 +1,234 @@ +#pragma once + +#include "ir/IR.h" + +#include +#include +#include +#include +#include +#include + +namespace ir::passutils { + +inline std::uint32_t FloatBits(float value) { + std::uint32_t bits = 0; + std::memcpy(&bits, &value, sizeof(bits)); + return bits; +} + +inline bool AreEquivalentValues(Value* lhs, Value* rhs) { + if (lhs == rhs) { + return true; + } + auto* lhs_i32 = dyncast(lhs); + auto* rhs_i32 = dyncast(rhs); + if (lhs_i32 && rhs_i32) { + return lhs_i32->GetValue() == rhs_i32->GetValue(); + } + auto* lhs_i1 = dyncast(lhs); + auto* rhs_i1 = dyncast(rhs); + if (lhs_i1 && rhs_i1) { + return lhs_i1->GetValue() == rhs_i1->GetValue(); + } + auto* lhs_f32 = dyncast(lhs); + auto* rhs_f32 = dyncast(rhs); + if (lhs_f32 && rhs_f32) { + return FloatBits(lhs_f32->GetValue()) == FloatBits(rhs_f32->GetValue()); + } + return false; +} + +inline std::vector CollectReachableBlocks(Function& function) { + std::vector order; + auto* entry = function.GetEntryBlock(); + if (!entry) { + return order; + } + + std::unordered_set visited; + std::vector stack{entry}; + while (!stack.empty()) { + auto* block = stack.back(); + stack.pop_back(); + if (!block || !visited.insert(block).second) { + continue; + } + order.push_back(block); + const auto& succs = block->GetSuccessors(); + for (auto it = succs.rbegin(); it != succs.rend(); ++it) { + if (*it != nullptr) { + stack.push_back(*it); + } + } + } + + return order; +} + +inline bool IsSideEffectingInstruction(const Instruction* inst) { + if (!inst) { + return false; + } + switch (inst->GetOpcode()) { + case Opcode::Store: + case Opcode::Memset: + case Opcode::Br: + case Opcode::CondBr: + case Opcode::Return: + case Opcode::Unreachable: + return true; + case Opcode::Call: { + auto* call = dyncast(inst); + auto* callee = call == nullptr ? nullptr : call->GetCallee(); + return callee == nullptr || !callee->CanDiscardUnusedCall(); + } + default: + return false; + } +} + +inline bool IsTriviallyDead(Instruction* inst) { + return inst != nullptr && !IsSideEffectingInstruction(inst) && + inst->GetUses().empty(); +} + +inline void RemoveIncomingForBlock(PhiInst* phi, BasicBlock* block) { + if (!phi || !block) { + return; + } + for (int i = phi->GetNumIncomings() - 1; i >= 0; --i) { + if (phi->GetIncomingBlock(i) != block) { + continue; + } + phi->RemoveOperand(static_cast(2 * i + 1)); + phi->RemoveOperand(static_cast(2 * i)); + } +} + +inline void RemoveIncomingFromSuccessor(BasicBlock* succ, BasicBlock* pred) { + if (!succ || !pred) { + return; + } + for (const auto& inst_ptr : succ->GetInstructions()) { + auto* phi = dyncast(inst_ptr.get()); + if (!phi) { + break; + } + RemoveIncomingForBlock(phi, pred); + } +} + +inline void ReplaceTerminatorWithBr(BasicBlock* block, BasicBlock* dest) { + auto& instructions = block->GetInstructions(); + if (instructions.empty() || !instructions.back()->IsTerminator()) { + return; + } + instructions.back()->ClearAllOperands(); + auto branch = std::make_unique(dest, nullptr); + branch->SetParent(block); + instructions.back() = std::move(branch); +} + +inline bool SimplifyPhiInst(PhiInst* phi) { + if (!phi) { + return false; + } + + Value* unique_value = nullptr; + for (int i = 0; i < phi->GetNumIncomings(); ++i) { + auto* incoming = phi->GetIncomingValue(i); + if (incoming == phi) { + continue; + } + if (unique_value == nullptr) { + unique_value = incoming; + continue; + } + if (!AreEquivalentValues(unique_value, incoming)) { + return false; + } + } + + if (unique_value == nullptr) { + return false; + } + + auto* parent = phi->GetParent(); + phi->ReplaceAllUsesWith(unique_value); + parent->EraseInstruction(phi); + return true; +} + +inline void EraseBlock(Function& function, BasicBlock* block) { + if (!block) { + return; + } + auto& blocks = function.GetBlocks(); + blocks.erase(std::remove_if(blocks.begin(), blocks.end(), + [&](const std::unique_ptr& current) { + return current.get() == block; + }), + blocks.end()); +} + +inline bool RemoveUnreachableBlocks(Function& function) { + auto reachable = CollectReachableBlocks(function); + std::unordered_set reachable_set(reachable.begin(), reachable.end()); + std::vector dead_blocks; + for (const auto& block_ptr : function.GetBlocks()) { + auto* block = block_ptr.get(); + if (reachable_set.find(block) == reachable_set.end()) { + dead_blocks.push_back(block); + } + } + + if (dead_blocks.empty()) { + return false; + } + + for (auto* block : dead_blocks) { + auto preds = block->GetPredecessors(); + auto succs = block->GetSuccessors(); + for (auto* succ : succs) { + RemoveIncomingFromSuccessor(succ, block); + succ->RemovePredecessor(block); + } + for (auto* pred : preds) { + pred->RemoveSuccessor(block); + } + } + + for (auto* block : dead_blocks) { + for (const auto& inst_ptr : block->GetInstructions()) { + inst_ptr->ClearAllOperands(); + } + } + + for (auto* block : dead_blocks) { + EraseBlock(function, block); + } + + return true; +} + +inline bool IsCommutativeOpcode(Opcode opcode) { + switch (opcode) { + case Opcode::Add: + case Opcode::Mul: + case Opcode::And: + case Opcode::Or: + case Opcode::Xor: + case Opcode::FAdd: + case Opcode::FMul: + case Opcode::ICmpEQ: + case Opcode::ICmpNE: + case Opcode::FCmpEQ: + case Opcode::FCmpNE: + return true; + default: + return false; + } +} + +} // namespace ir::passutils diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index 2a7e060..a3bde5b 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -1,6 +1,8 @@ #include "irgen/IRGen.h" #include +#include +#include #include #include @@ -213,6 +215,16 @@ void IRGenImpl::EmitGlobalVarDef(SysYParser::VarDefContext& ctx, SemanticType ty symbol->dims = ParseArrayDims(ctx.constExp()); if (symbol->is_array) { + // Leave uninitialized globals as zeroinitializer instead of materializing + // an explicit all-zero constant array, which can explode memory usage. + if (ctx.initVal() == nullptr) { + global->SetInitializer(nullptr); + return; + } + if (IsExplicitZeroInitVal(ctx.initVal(), type)) { + global->SetInitializer(nullptr); + return; + } auto flat = FlattenInitVal(ctx.initVal(), type, symbol->dims); std::vector elements; elements.reserve(flat.size()); @@ -250,7 +262,14 @@ void IRGenImpl::EmitGlobalConstDef(SysYParser::ConstDefContext& ctx, global->SetConstant(true); if (symbol->is_array) { + if (IsExplicitZeroConstInitVal(ctx.constInitVal(), type)) { + symbol->const_array.clear(); + symbol->const_array_all_zero = true; + global->SetInitializer(nullptr); + return; + } symbol->const_array = FlattenConstInitVal(ctx.constInitVal(), type, symbol->dims); + symbol->const_array_all_zero = false; std::vector elements; elements.reserve(symbol->const_array.size()); for (const auto& value : symbol->const_array) { @@ -359,7 +378,15 @@ void IRGenImpl::EmitLocalConstDef(SysYParser::ConstDefContext& ctx, return; } + if (IsExplicitZeroConstInitVal(ctx.constInitVal(), type)) { + symbol->const_array.clear(); + symbol->const_array_all_zero = true; + ZeroInitializeLocalArray(symbol->ir_value, type, symbol->dims); + return; + } + symbol->const_array = FlattenConstInitVal(ctx.constInitVal(), type, symbol->dims); + symbol->const_array_all_zero = false; ZeroInitializeLocalArray(symbol->ir_value, type, symbol->dims); for (size_t i = 0; i < symbol->const_array.size(); ++i) { if (symbol->const_array[i].type == SemanticType::Int && symbol->const_array[i].int_value == 0) { @@ -537,6 +564,54 @@ size_t IRGenImpl::FlattenIndices(const std::vector& dims, return offset; } +bool IRGenImpl::IsZeroConstant(const ConstantValue& value) const { + switch (value.type) { + case SemanticType::Int: + return value.int_value == 0; + case SemanticType::Float: { + std::uint32_t bits = 0; + std::memcpy(&bits, &value.float_value, sizeof(bits)); + return bits == 0; + } + case SemanticType::Void: + return false; + } + return false; +} + +bool IRGenImpl::IsExplicitZeroConstInitVal(SysYParser::ConstInitValContext* ctx, + SemanticType base_type) { + if (ctx == nullptr) { + return true; + } + if (ctx->constExp() != nullptr) { + return IsZeroConstant( + ConvertConst(EvalConstAddExp(*ctx->constExp()->addExp()), base_type)); + } + for (auto* child : ctx->constInitVal()) { + if (!IsExplicitZeroConstInitVal(child, base_type)) { + return false; + } + } + return true; +} + +bool IRGenImpl::IsExplicitZeroInitVal(SysYParser::InitValContext* ctx, + SemanticType base_type) { + if (ctx == nullptr) { + return true; + } + if (ctx->exp() != nullptr) { + return IsZeroConstant(ConvertConst(EvalConstExp(*ctx->exp()), base_type)); + } + for (auto* child : ctx->initVal()) { + if (!IsExplicitZeroInitVal(child, base_type)) { + return false; + } + } + return true; +} + ConstantValue IRGenImpl::ZeroConst(SemanticType type) const { ConstantValue value; value.type = type; diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index 5b1e213..6f94b2e 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -280,6 +280,9 @@ ConstantValue IRGenImpl::EvalConstLVal(SysYParser::LValContext& ctx) { } } const auto offset = FlattenIndices(symbol->dims, indices); + if (symbol->const_array_all_zero) { + return ZeroConst(symbol->type); + } if (offset >= symbol->const_array.size()) { ThrowError(&ctx, "???????????: " + name); } diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index 30c52b5..5fc70fc 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -25,6 +25,17 @@ std::string IRGenImpl::NextBlockName(const std::string& prefix) { return module_.GetContext().NextBlockName(prefix); } +void IRGenImpl::ApplyFunctionSema(const std::string& name, ir::Function& function) { + const auto* info = sema_.LookupFunction(name); + if (info == nullptr) { + return; + } + function.SetEffectInfo(info->reads_global_memory, info->writes_global_memory, + info->reads_param_memory, info->writes_param_memory, + info->has_io, info->has_unknown_effects, + info->is_recursive); +} + std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { if (ctx == nullptr) { ThrowError(ctx, "??????"); @@ -96,6 +107,7 @@ void IRGenImpl::RegisterBuiltinFunctions() { auto* function = module_.CreateFunction( builtin.name, GetIRScalarType(builtin.return_type), ir_param_types, ir_param_names, true); + ApplyFunctionSema(builtin.name, *function); SymbolEntry entry; entry.kind = SymbolKind::Function; @@ -171,6 +183,7 @@ void IRGenImpl::PredeclareFunction(SysYParser::FuncDefContext& ctx) { auto* function = module_.CreateFunction( name, GetIRScalarType(function_type.return_type), BuildFunctionIRParamTypes(function_type), BuildFunctionIRParamNames(ctx), false); + ApplyFunctionSema(name, *function); SymbolEntry entry; entry.kind = SymbolKind::Function; diff --git a/src/main.cpp b/src/main.cpp index ad6b901..95e835a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -60,14 +60,16 @@ int main(int argc, char** argv) { need_blank_line = true; } - if (opts.emit_asm) { - auto machine_module = mir::LowerToMIR(*asm_module); - mir::RunRegAlloc(*machine_module); - mir::RunFrameLowering(*machine_module); - if (need_blank_line) { - std::cout << "\n"; - } - mir::PrintAsm(*machine_module, std::cout); + if (opts.emit_asm) { + auto machine_module = mir::LowerToMIR(*asm_module); + mir::RunMIRPreRegAllocPassPipeline(*machine_module); + mir::RunRegAlloc(*machine_module); + mir::RunMIRPostRegAllocPassPipeline(*machine_module); + mir::RunFrameLowering(*machine_module); + if (need_blank_line) { + std::cout << "\n"; + } + mir::PrintAsm(*machine_module, std::cout); } #else if (opts.emit_ir || opts.emit_asm) { diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index b1ff8e8..013b668 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -605,8 +605,8 @@ bool TryEmitDirectMemoryAccess(const MachineFunction& function, const AddressExp void EmitAddressExpr(const MachineFunction& function, const AddressExpr& address, const char* dst_reg, std::ostream& os) { - std::unordered_map preserved_index_regs; - int next_preserve_scratch = 15; + int preserved_index_vreg = -1; + std::string preserved_index_reg; for (const auto& term : address.scaled_vregs) { const int index_vreg = term.first; const auto& alloc = function.GetAllocation(index_vreg); @@ -617,16 +617,15 @@ void EmitAddressExpr(const MachineFunction& function, const AddressExpr& address if (std::string(GetPhysRegName(alloc.phys, ValueType::Ptr)) != dst_reg) { continue; } - auto it = preserved_index_regs.find(index_vreg); - if (it != preserved_index_regs.end()) { - continue; + if (preserved_index_vreg >= 0 && preserved_index_vreg != index_vreg) { + throw std::runtime_error( + FormatError("mir", "multiple address indices conflict with lea destination")); } - const auto scratch = PhysReg{RegClass::GPR, next_preserve_scratch--}; - const char* scratch_name = - GetPhysRegName(scratch, function.GetVRegInfo(index_vreg).type); - EmitCopy(os, scratch_name, + preserved_index_vreg = index_vreg; + preserved_index_reg = + GetPhysRegName({RegClass::GPR, 12}, function.GetVRegInfo(index_vreg).type); + EmitCopy(os, preserved_index_reg.c_str(), GetPhysRegName(alloc.phys, function.GetVRegInfo(index_vreg).type), false); - preserved_index_regs.emplace(index_vreg, scratch_name); } switch (address.base_kind) { @@ -656,9 +655,8 @@ void EmitAddressExpr(const MachineFunction& function, const AddressExpr& address for (const auto& term : address.scaled_vregs) { std::string index_reg; - auto preserved_it = preserved_index_regs.find(term.first); - if (preserved_it != preserved_index_regs.end()) { - index_reg = preserved_it->second; + if (term.first == preserved_index_vreg) { + index_reg = preserved_index_reg; } else { index_reg = MaterializeGprUse(function, MachineOperand::VReg(term.first), ValueType::I32, 10, os); diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index fce5590..ef9b7fd 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -3,20 +3,54 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include "ir/IR.h" #include "utils/Log.h" -namespace mir { -namespace { - -enum class LoweredKind { Invalid, VReg, StackObject, Global }; - -struct LoweredValue { +namespace mir { +namespace { + +enum class LoweredKind { Invalid, VReg, StackObject, Global }; + +std::vector CollectLoweringOrder(ir::Function& function) { + std::vector order; + auto* entry = function.GetEntryBlock(); + if (!entry) { + return order; + } + + std::unordered_set visited; + std::vector stack{entry}; + while (!stack.empty()) { + auto* block = stack.back(); + stack.pop_back(); + if (!block || !visited.insert(block).second) { + continue; + } + order.push_back(block); + + const auto& succs = block->GetSuccessors(); + for (auto it = succs.rbegin(); it != succs.rend(); ++it) { + if (*it != nullptr) { + stack.push_back(*it); + } + } + } + + for (const auto& block : function.GetBlocks()) { + if (block && visited.insert(block.get()).second) { + order.push_back(block.get()); + } + } + return order; +} + +struct LoweredValue { LoweredKind kind = LoweredKind::Invalid; ValueType type = ValueType::Void; int index = -1; @@ -905,13 +939,14 @@ class Lowerer { param_types.push_back(LowerType(type)); } - auto machine_function = std::make_unique( - function.GetName(), LowerType(function.GetReturnType()), std::move(param_types)); - current_ir_function_ = &function; - current_function_ = machine_function.get(); - - for (const auto& block : function.GetBlocks()) { - blocks_[block.get()] = current_function_->CreateBlock(block->GetName()); + auto machine_function = std::make_unique( + function.GetName(), LowerType(function.GetReturnType()), std::move(param_types)); + current_ir_function_ = &function; + current_function_ = machine_function.get(); + const auto ordered_blocks = CollectLoweringOrder(function); + + for (const auto& block : function.GetBlocks()) { + blocks_[block.get()] = current_function_->CreateBlock(block->GetName()); } if (!function.GetBlocks().empty()) { @@ -924,14 +959,14 @@ class Lowerer { values_[argument.get()] = lowered; } } - - PreparePhiResults(function); - - for (const auto& block : function.GetBlocks()) { - current_block_ = blocks_.at(block.get()); - for (const auto& inst : block->GetInstructions()) { - LowerInstruction(*inst); - } + + PreparePhiResults(function); + + for (auto* block : ordered_blocks) { + current_block_ = blocks_.at(block); + for (const auto& inst : block->GetInstructions()) { + LowerInstruction(*inst); + } } EmitPhiCopies(function); diff --git a/src/mir/RegAlloc.cpp b/src/mir/RegAlloc.cpp index 10c4244..3c9996a 100644 --- a/src/mir/RegAlloc.cpp +++ b/src/mir/RegAlloc.cpp @@ -36,26 +36,37 @@ bool BelongsToClass(ValueType type, RegClass reg_class) { return IsFPR(type) ? reg_class == RegClass::FPR : reg_class == RegClass::GPR; } -bool IsCalleeSaved(PhysReg reg) { - if (reg.reg_class == RegClass::GPR) { - return reg.index >= 19 && reg.index <= 28; - } - return reg.index >= 8 && reg.index <= 15; -} - -std::vector GetAllocatableRegs(RegClass reg_class) { - std::vector regs; - if (reg_class == RegClass::FPR) { - for (int i = 8; i <= 15; ++i) { - regs.push_back({RegClass::FPR, i}); - } - return regs; - } - for (int i = 19; i <= 28; ++i) { - regs.push_back({RegClass::GPR, i}); - } - return regs; -} +bool IsCalleeSaved(PhysReg reg) { + if (reg.reg_class == RegClass::GPR) { + return reg.index >= 19 && reg.index <= 28; + } + return reg.index >= 8 && reg.index <= 15; +} + +bool IsCallerSaved(PhysReg reg) { + return !IsCalleeSaved(reg); +} + +std::vector GetAllocatableRegs(RegClass reg_class) { + std::vector regs; + if (reg_class == RegClass::FPR) { + for (int i = 19; i <= 31; ++i) { + regs.push_back({RegClass::FPR, i}); + } + for (int i = 8; i <= 15; ++i) { + regs.push_back({RegClass::FPR, i}); + } + return regs; + } + regs.push_back({RegClass::GPR, 8}); + for (int i = 13; i <= 15; ++i) { + regs.push_back({RegClass::GPR, i}); + } + for (int i = 19; i <= 28; ++i) { + regs.push_back({RegClass::GPR, i}); + } + return regs; +} int CreateSpillSlot(MachineFunction& function, int vreg) { const auto type = function.GetVRegInfo(vreg).type; @@ -171,12 +182,13 @@ class GeorgeColoringAllocator { reg_class_(reg_class), regs_(GetAllocatableRegs(reg_class)), k_(static_cast(regs_.size())), - block_infos_(block_infos), - num_vregs_(static_cast(function.GetVRegs().size())), - in_class_(static_cast(num_vregs_), 0), - adjacency_(static_cast(num_vregs_)), - degree_(static_cast(num_vregs_), 0), - spill_cost_(static_cast(num_vregs_), 0.0), + block_infos_(block_infos), + num_vregs_(static_cast(function.GetVRegs().size())), + in_class_(static_cast(num_vregs_), 0), + live_across_call_(static_cast(num_vregs_), 0), + adjacency_(static_cast(num_vregs_)), + degree_(static_cast(num_vregs_), 0), + spill_cost_(static_cast(num_vregs_), 0.0), move_list_(static_cast(num_vregs_)), alias_(static_cast(num_vregs_), -1), color_index_(static_cast(num_vregs_), -1), @@ -235,15 +247,25 @@ class GeorgeColoringAllocator { } } - const auto& instructions = block->GetInstructions(); - for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) { - const auto& inst = *it; + const auto& instructions = block->GetInstructions(); + for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) { + const auto& inst = *it; auto defs = FilterClass(inst.GetDefs()); auto uses = FilterClass(inst.GetUses()); + if (inst.GetOpcode() == MachineInstr::Opcode::Call || + inst.GetOpcode() == MachineInstr::Opcode::Memset) { + for (int vreg = 0; vreg < num_vregs_; ++vreg) { + if (live[static_cast(vreg)] && + in_class_[static_cast(vreg)]) { + live_across_call_[static_cast(vreg)] = 1; + } + } + } + for (int def : defs) { spill_cost_[static_cast(def)] += block_weight; - } + } for (int use : uses) { spill_cost_[static_cast(use)] += block_weight; } @@ -389,12 +411,19 @@ class GeorgeColoringAllocator { select_stack_.pop_back(); in_select_stack_[static_cast(node)] = 0; - std::vector ok_colors(static_cast(regs_.size()), 1); - for (int neighbor : adjacency_[static_cast(node)]) { - const int alias = GetAlias(neighbor); - if (!is_colored_[static_cast(alias)]) { - continue; - } + std::vector ok_colors(static_cast(regs_.size()), 1); + if (live_across_call_[static_cast(node)]) { + for (size_t i = 0; i < regs_.size(); ++i) { + if (IsCallerSaved(regs_[i])) { + ok_colors[i] = 0; + } + } + } + for (int neighbor : adjacency_[static_cast(node)]) { + const int alias = GetAlias(neighbor); + if (!is_colored_[static_cast(alias)]) { + continue; + } const int color = color_index_[static_cast(alias)]; if (color >= 0 && color < static_cast(regs_.size())) { ok_colors[static_cast(color)] = 0; @@ -506,12 +535,14 @@ class GeorgeColoringAllocator { void Combine(int keep, int remove) { simplify_worklist_[static_cast(remove)] = 0; freeze_worklist_[static_cast(remove)] = 0; - spill_worklist_[static_cast(remove)] = 0; - is_coalesced_[static_cast(remove)] = 1; - alias_[static_cast(remove)] = keep; - - auto& keep_moves = move_list_[static_cast(keep)]; - const auto& remove_moves = move_list_[static_cast(remove)]; + spill_worklist_[static_cast(remove)] = 0; + is_coalesced_[static_cast(remove)] = 1; + alias_[static_cast(remove)] = keep; + live_across_call_[static_cast(keep)] |= + live_across_call_[static_cast(remove)]; + + auto& keep_moves = move_list_[static_cast(keep)]; + const auto& remove_moves = move_list_[static_cast(remove)]; keep_moves.insert(keep_moves.end(), remove_moves.begin(), remove_moves.end()); EnableMoves({remove}); @@ -679,11 +710,12 @@ class GeorgeColoringAllocator { std::vector regs_; int k_ = 0; const std::vector& block_infos_; - int num_vregs_ = 0; - - std::vector in_class_; - std::vector> adjacency_; - std::vector degree_; + int num_vregs_ = 0; + + std::vector in_class_; + std::vector live_across_call_; + std::vector> adjacency_; + std::vector degree_; std::vector spill_cost_; std::vector> move_list_; std::vector moves_; diff --git a/src/mir/passes/PassManager.cpp b/src/mir/passes/PassManager.cpp index c510460..b5976c4 100644 --- a/src/mir/passes/PassManager.cpp +++ b/src/mir/passes/PassManager.cpp @@ -1,4 +1,25 @@ -// MIR Pass 管理: -// - 组织后端 pass 的运行顺序(PreRA/PostRA/PEI 等阶段) -// - 统一运行 pass 与调试输出(按需要扩展) +#include "mir/MIR.h" +namespace mir { + +void RunMIRPreRegAllocPassPipeline(MachineModule& module) { + RunAddressHoisting(module); + + constexpr int kMaxIterations = 4; + for (int iteration = 0; iteration < kMaxIterations; ++iteration) { + if (!RunPeephole(module)) { + break; + } + } +} + +void RunMIRPostRegAllocPassPipeline(MachineModule& module) { + constexpr int kMaxIterations = 2; + for (int iteration = 0; iteration < kMaxIterations; ++iteration) { + if (!RunPeephole(module)) { + break; + } + } +} + +} // namespace mir diff --git a/src/mir/passes/Peephole.cpp b/src/mir/passes/Peephole.cpp index c6d9ab7..f46d6a4 100644 --- a/src/mir/passes/Peephole.cpp +++ b/src/mir/passes/Peephole.cpp @@ -1,4 +1,904 @@ -// 窥孔优化(Peephole): -// - 删除冗余 move、合并常见指令模式 -// - 提升最终汇编质量(按实现范围裁剪) +#include "mir/MIR.h" +#include "ir/IR.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace mir { +namespace { + +using AliasMap = std::unordered_map; + +struct CFGInfo { + std::vector> predecessors; + std::vector> successors; +}; + +struct AddressKey { + AddrBaseKind base_kind = AddrBaseKind::None; + int base_index = -1; + std::string symbol; + std::int64_t const_offset = 0; + std::vector> scaled_vregs; + + bool operator==(const AddressKey& rhs) const { + return base_kind == rhs.base_kind && base_index == rhs.base_index && + symbol == rhs.symbol && const_offset == rhs.const_offset && + scaled_vregs == rhs.scaled_vregs; + } +}; + +struct AddressKeyHash { + std::size_t operator()(const AddressKey& key) const { + std::size_t h = static_cast(key.base_kind); + h ^= std::hash{}(key.base_index) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(key.symbol) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(key.const_offset) + 0x9e3779b9 + (h << 6) + (h >> 2); + for (const auto& term : key.scaled_vregs) { + h ^= std::hash{}(term.first) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(term.second) + 0x9e3779b9 + (h << 6) + (h >> 2); + } + return h; + } +}; + +struct MemoryState { + MachineOperand value; + ValueType type = ValueType::Void; + int pending_store_index = -1; +}; + +using MemoryMap = std::unordered_map; + +bool IsImm(const MachineOperand& operand, std::int64_t value) { + return operand.GetKind() == OperandKind::Imm && operand.GetImm() == value; +} + +bool SameExactOperand(const MachineOperand& lhs, const MachineOperand& rhs) { + if (lhs.GetKind() != rhs.GetKind()) { + return false; + } + switch (lhs.GetKind()) { + case OperandKind::Invalid: + return true; + case OperandKind::VReg: + return lhs.GetVReg() == rhs.GetVReg(); + case OperandKind::Imm: + return lhs.GetImm() == rhs.GetImm(); + case OperandKind::Block: + case OperandKind::Symbol: + return lhs.GetText() == rhs.GetText(); + } + return false; +} + +bool SameResolvedLocation(const MachineFunction& function, int lhs_vreg, int rhs_vreg) { + if (lhs_vreg == rhs_vreg) { + return true; + } + + const auto& lhs = function.GetAllocation(lhs_vreg); + const auto& rhs = function.GetAllocation(rhs_vreg); + if (lhs.kind == Allocation::Kind::Unassigned || rhs.kind == Allocation::Kind::Unassigned || + lhs.kind != rhs.kind) { + return false; + } + if (lhs.kind == Allocation::Kind::PhysReg) { + return lhs.phys == rhs.phys; + } + if (lhs.kind == Allocation::Kind::Spill) { + return lhs.stack_object == rhs.stack_object; + } + return false; +} + +bool SameResolvedOperand(const MachineFunction& function, const MachineOperand& lhs, + const MachineOperand& rhs) { + if (SameExactOperand(lhs, rhs)) { + return true; + } + if (lhs.GetKind() == OperandKind::VReg && rhs.GetKind() == OperandKind::VReg) { + return SameResolvedLocation(function, lhs.GetVReg(), rhs.GetVReg()); + } + return false; +} + +MachineOperand ResolveAlias(const AliasMap& aliases, const MachineOperand& operand) { + if (operand.GetKind() != OperandKind::VReg) { + return operand; + } + + int current = operand.GetVReg(); + std::unordered_set visited; + visited.insert(current); + + while (true) { + auto it = aliases.find(current); + if (it == aliases.end()) { + return MachineOperand::VReg(current); + } + + if (it->second.GetKind() != OperandKind::VReg) { + return it->second; + } + + const int next = it->second.GetVReg(); + if (!visited.insert(next).second) { + return MachineOperand::VReg(current); + } + current = next; + } +} + +bool RewriteOperand(MachineOperand& operand, const AliasMap& aliases) { + const auto rewritten = ResolveAlias(aliases, operand); + if (SameExactOperand(rewritten, operand)) { + return false; + } + + operand = rewritten; + return true; +} + +bool RewriteAddress(AddressExpr& address, const AliasMap& aliases) { + bool changed = false; + + if (address.base_kind == AddrBaseKind::VReg && address.base_index >= 0) { + const auto rewritten = ResolveAlias(aliases, MachineOperand::VReg(address.base_index)); + if (rewritten.GetKind() == OperandKind::VReg && + rewritten.GetVReg() != address.base_index) { + address.base_index = rewritten.GetVReg(); + changed = true; + } + } + + std::vector> rewritten_scaled; + rewritten_scaled.reserve(address.scaled_vregs.size()); + for (const auto& term : address.scaled_vregs) { + const auto rewritten = ResolveAlias(aliases, MachineOperand::VReg(term.first)); + if (rewritten.GetKind() == OperandKind::Imm) { + address.const_offset += rewritten.GetImm() * term.second; + changed = true; + continue; + } + + if (rewritten.GetKind() == OperandKind::VReg && rewritten.GetVReg() != term.first) { + rewritten_scaled.push_back({rewritten.GetVReg(), term.second}); + changed = true; + continue; + } + + rewritten_scaled.push_back(term); + } + + if (rewritten_scaled.size() != address.scaled_vregs.size()) { + changed = true; + } + address.scaled_vregs = std::move(rewritten_scaled); + return changed; +} + +bool RewriteUses(MachineInstr& inst, const AliasMap& aliases) { + bool changed = false; + auto& operands = inst.GetOperands(); + + switch (inst.GetOpcode()) { + case MachineInstr::Opcode::Copy: + case MachineInstr::Opcode::ZExt: + case MachineInstr::Opcode::ItoF: + case MachineInstr::Opcode::FtoI: + case MachineInstr::Opcode::FNeg: + if (operands.size() >= 2) { + changed |= RewriteOperand(operands[1], aliases); + } + break; + case MachineInstr::Opcode::Store: + if (!operands.empty()) { + changed |= RewriteOperand(operands[0], aliases); + } + break; + case MachineInstr::Opcode::Add: + case MachineInstr::Opcode::Sub: + case MachineInstr::Opcode::Mul: + case MachineInstr::Opcode::Div: + case MachineInstr::Opcode::Rem: + case MachineInstr::Opcode::And: + case MachineInstr::Opcode::Or: + case MachineInstr::Opcode::Xor: + case MachineInstr::Opcode::Shl: + case MachineInstr::Opcode::AShr: + case MachineInstr::Opcode::LShr: + case MachineInstr::Opcode::FAdd: + case MachineInstr::Opcode::FSub: + case MachineInstr::Opcode::FMul: + case MachineInstr::Opcode::FDiv: + case MachineInstr::Opcode::ICmp: + case MachineInstr::Opcode::FCmp: + if (operands.size() >= 2) { + changed |= RewriteOperand(operands[1], aliases); + } + if (operands.size() >= 3) { + changed |= RewriteOperand(operands[2], aliases); + } + break; + case MachineInstr::Opcode::CondBr: + if (!operands.empty()) { + changed |= RewriteOperand(operands[0], aliases); + } + break; + case MachineInstr::Opcode::Call: { + const size_t arg_begin = inst.GetCallReturnType() == ValueType::Void ? 0 : 1; + for (size_t i = arg_begin; i < operands.size(); ++i) { + changed |= RewriteOperand(operands[i], aliases); + } + break; + } + case MachineInstr::Opcode::Ret: + if (!operands.empty()) { + changed |= RewriteOperand(operands[0], aliases); + } + break; + case MachineInstr::Opcode::Memset: + if (!operands.empty()) { + changed |= RewriteOperand(operands[0], aliases); + } + if (operands.size() >= 2) { + changed |= RewriteOperand(operands[1], aliases); + } + break; + case MachineInstr::Opcode::Arg: + case MachineInstr::Opcode::Load: + case MachineInstr::Opcode::Lea: + case MachineInstr::Opcode::Br: + case MachineInstr::Opcode::Unreachable: + break; + } + + if (inst.HasAddress()) { + changed |= RewriteAddress(inst.GetAddress(), aliases); + } + return changed; +} + +MachineInstr MakeCopyLike(const MachineInstr& inst, MachineOperand source) { + return MachineInstr(MachineInstr::Opcode::Copy, + {inst.GetOperands()[0], std::move(source)}); +} + +bool SimplifyCopy(const MachineFunction& function, MachineInstr& inst) { + if (inst.GetOpcode() != MachineInstr::Opcode::Copy) { + return false; + } + const auto& operands = inst.GetOperands(); + if (operands.size() < 2 || operands[0].GetKind() != OperandKind::VReg) { + return false; + } + return SameResolvedOperand(function, operands[0], operands[1]); +} + +bool SimplifyZExt(MachineInstr& inst) { + const auto& operands = inst.GetOperands(); + if (operands.size() < 2 || operands[1].GetKind() != OperandKind::Imm) { + return false; + } + + inst = MakeCopyLike(inst, MachineOperand::Imm(operands[1].GetImm() != 0 ? 1 : 0)); + return true; +} + +bool SimplifyIntegerBinary(MachineInstr& inst) { + const auto opcode = inst.GetOpcode(); + const auto& operands = inst.GetOperands(); + if (operands.size() < 3) { + return false; + } + + const auto& lhs = operands[1]; + const auto& rhs = operands[2]; + switch (opcode) { + case MachineInstr::Opcode::Add: + if (IsImm(rhs, 0)) { + inst = MakeCopyLike(inst, lhs); + return true; + } + if (IsImm(lhs, 0)) { + inst = MakeCopyLike(inst, rhs); + return true; + } + return false; + case MachineInstr::Opcode::Sub: + if (IsImm(rhs, 0)) { + inst = MakeCopyLike(inst, lhs); + return true; + } + return false; + case MachineInstr::Opcode::Mul: + if (IsImm(rhs, 1)) { + inst = MakeCopyLike(inst, lhs); + return true; + } + if (IsImm(lhs, 1)) { + inst = MakeCopyLike(inst, rhs); + return true; + } + if (IsImm(rhs, 0) || IsImm(lhs, 0)) { + inst = MakeCopyLike(inst, MachineOperand::Imm(0)); + return true; + } + return false; + case MachineInstr::Opcode::Div: + if (IsImm(rhs, 1)) { + inst = MakeCopyLike(inst, lhs); + return true; + } + return false; + case MachineInstr::Opcode::And: + if (IsImm(rhs, -1)) { + inst = MakeCopyLike(inst, lhs); + return true; + } + if (IsImm(lhs, -1)) { + inst = MakeCopyLike(inst, rhs); + return true; + } + if (IsImm(rhs, 0) || IsImm(lhs, 0)) { + inst = MakeCopyLike(inst, MachineOperand::Imm(0)); + return true; + } + return false; + case MachineInstr::Opcode::Or: + case MachineInstr::Opcode::Xor: + if (IsImm(rhs, 0)) { + inst = MakeCopyLike(inst, lhs); + return true; + } + if (IsImm(lhs, 0)) { + inst = MakeCopyLike(inst, rhs); + return true; + } + return false; + case MachineInstr::Opcode::Shl: + case MachineInstr::Opcode::AShr: + case MachineInstr::Opcode::LShr: + if (IsImm(rhs, 0)) { + inst = MakeCopyLike(inst, lhs); + return true; + } + return false; + default: + return false; + } +} + +bool SimplifyCondBr(MachineInstr& inst) { + auto& operands = inst.GetOperands(); + if (operands.size() < 3) { + return false; + } + + if (operands[1].GetKind() == OperandKind::Block && + operands[2].GetKind() == OperandKind::Block && + operands[1].GetText() == operands[2].GetText()) { + inst = MachineInstr(MachineInstr::Opcode::Br, {operands[1]}); + return true; + } + + if (operands[0].GetKind() != OperandKind::Imm) { + return false; + } + + inst = MachineInstr(MachineInstr::Opcode::Br, + {operands[0].GetImm() != 0 ? operands[1] : operands[2]}); + return true; +} + +bool SimplifyInstruction(MachineInstr& inst) { + switch (inst.GetOpcode()) { + case MachineInstr::Opcode::ZExt: + return SimplifyZExt(inst); + case MachineInstr::Opcode::Add: + case MachineInstr::Opcode::Sub: + case MachineInstr::Opcode::Mul: + case MachineInstr::Opcode::Div: + case MachineInstr::Opcode::And: + case MachineInstr::Opcode::Or: + case MachineInstr::Opcode::Xor: + case MachineInstr::Opcode::Shl: + case MachineInstr::Opcode::AShr: + case MachineInstr::Opcode::LShr: + return SimplifyIntegerBinary(inst); + case MachineInstr::Opcode::CondBr: + return SimplifyCondBr(inst); + default: + return false; + } +} + +bool TrackAlias(const MachineInstr& inst, AliasMap& aliases) { + if (inst.GetOpcode() != MachineInstr::Opcode::Copy) { + return false; + } + + const auto& operands = inst.GetOperands(); + if (operands.size() < 2 || operands[0].GetKind() != OperandKind::VReg) { + return false; + } + + aliases[operands[0].GetVReg()] = operands[1]; + return true; +} + +AddressKey MakeAddressKey(const AddressExpr& address) { + return {address.base_kind, address.base_index, address.symbol, address.const_offset, + address.scaled_vregs}; +} + +bool HasTrackedAddress(const MachineInstr& inst) { + return inst.HasAddress() && inst.GetAddress().base_kind != AddrBaseKind::None; +} + +const ir::Function* LookupSourceCallee(const MachineModule& module, + const MachineInstr& inst) { + if (inst.GetOpcode() != MachineInstr::Opcode::Call || inst.GetCallee().empty()) { + return nullptr; + } + return module.GetSourceModule().GetFunction(inst.GetCallee()); +} + +bool CallMayReadMemory(const MachineModule& module, const MachineInstr& inst) { + auto* callee = LookupSourceCallee(module, inst); + return callee == nullptr || callee->MayReadMemory(); +} + +bool CallMayWriteMemory(const MachineModule& module, const MachineInstr& inst) { + auto* callee = LookupSourceCallee(module, inst); + return callee == nullptr || callee->MayWriteMemory(); +} + +bool SameMemoryStateValue(const MemoryState& lhs, const MemoryState& rhs) { + return lhs.type == rhs.type && SameExactOperand(lhs.value, rhs.value); +} + +bool SameMemoryMap(const MemoryMap& lhs, const MemoryMap& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (const auto& [key, value] : lhs) { + auto it = rhs.find(key); + if (it == rhs.end() || !SameMemoryStateValue(value, it->second)) { + return false; + } + } + return true; +} + +MemoryMap MeetMemoryStates(const std::vector& predecessors) { + if (predecessors.empty()) { + return {}; + } + + MemoryMap in = *predecessors.front(); + for (auto it = in.begin(); it != in.end();) { + bool keep = true; + for (std::size_t i = 1; i < predecessors.size(); ++i) { + auto pred_it = predecessors[i]->find(it->first); + if (pred_it == predecessors[i]->end() || + !SameMemoryStateValue(it->second, pred_it->second)) { + keep = false; + break; + } + } + if (!keep) { + it = in.erase(it); + continue; + } + ++it; + } + return in; +} + +CFGInfo BuildCFG(const MachineFunction& function) { + CFGInfo cfg; + const auto& blocks = function.GetBlocks(); + cfg.predecessors.resize(blocks.size()); + cfg.successors.resize(blocks.size()); + + std::unordered_map name_to_index; + for (std::size_t i = 0; i < blocks.size(); ++i) { + name_to_index.emplace(blocks[i]->GetName(), static_cast(i)); + } + + auto add_edge = [&](int pred, const std::string& succ_name) { + auto it = name_to_index.find(succ_name); + if (it == name_to_index.end()) { + return; + } + cfg.successors[static_cast(pred)].push_back(it->second); + cfg.predecessors[static_cast(it->second)].push_back(pred); + }; + + for (std::size_t i = 0; i < blocks.size(); ++i) { + const auto& instructions = blocks[i]->GetInstructions(); + if (instructions.empty()) { + continue; + } + const auto& terminator = instructions.back(); + if (terminator.GetOpcode() == MachineInstr::Opcode::Br && + !terminator.GetOperands().empty()) { + add_edge(static_cast(i), terminator.GetOperands()[0].GetText()); + } else if (terminator.GetOpcode() == MachineInstr::Opcode::CondBr && + terminator.GetOperands().size() >= 3) { + add_edge(static_cast(i), terminator.GetOperands()[1].GetText()); + add_edge(static_cast(i), terminator.GetOperands()[2].GetText()); + } + + auto& succs = cfg.successors[i]; + std::sort(succs.begin(), succs.end()); + succs.erase(std::unique(succs.begin(), succs.end()), succs.end()); + } + + for (auto& preds : cfg.predecessors) { + std::sort(preds.begin(), preds.end()); + preds.erase(std::unique(preds.begin(), preds.end()), preds.end()); + } + return cfg; +} + +bool SameBaseObject(const AddressKey& lhs, const AddressKey& rhs) { + if (lhs.base_kind != rhs.base_kind) { + return false; + } + + switch (lhs.base_kind) { + case AddrBaseKind::FrameObject: + case AddrBaseKind::VReg: + return lhs.base_index == rhs.base_index; + case AddrBaseKind::Global: + return lhs.symbol == rhs.symbol; + case AddrBaseKind::None: + return false; + } + return false; +} + +void InvalidateMemoryState(std::unordered_map& states, + const AddressKey* store_key) { + if (store_key == nullptr) { + states.clear(); + return; + } + + if (store_key->base_kind == AddrBaseKind::VReg) { + states.clear(); + return; + } + + for (auto it = states.begin(); it != states.end();) { + if (it->first.base_kind == AddrBaseKind::VReg || SameBaseObject(it->first, *store_key)) { + it = states.erase(it); + continue; + } + ++it; + } +} + +void ObservePendingStores(MemoryMap& states) { + for (auto& [_, state] : states) { + state.pending_store_index = -1; + } +} + +bool TryOptimizeMemoryInstruction( + const MachineModule& module, const MachineFunction& function, + MachineInstr& inst, + MemoryMap& states, + std::vector& removed, + std::size_t current_index, + bool* remove_current) { + *remove_current = false; + if (inst.GetOpcode() == MachineInstr::Opcode::Call) { + if (CallMayWriteMemory(module, inst)) { + InvalidateMemoryState(states, nullptr); + } + return false; + } + if (inst.GetOpcode() == MachineInstr::Opcode::Memset) { + InvalidateMemoryState(states, nullptr); + return false; + } + + if (!HasTrackedAddress(inst)) { + return false; + } + + const AddressKey key = MakeAddressKey(inst.GetAddress()); + if (inst.GetOpcode() == MachineInstr::Opcode::Load) { + ValueType load_type = ValueType::Void; + if (!inst.GetOperands().empty() && inst.GetOperands()[0].GetKind() == OperandKind::VReg) { + load_type = function.GetVRegInfo(inst.GetOperands()[0].GetVReg()).type; + } + auto it = states.find(key); + if (it != states.end() && it->second.type == load_type) { + inst = MakeCopyLike(inst, it->second.value); + it->second.pending_store_index = -1; + return true; + } + + auto dest = inst.GetOperands()[0]; + states[key] = {dest, load_type, -1}; + return false; + } + + if (inst.GetOpcode() != MachineInstr::Opcode::Store) { + return false; + } + + const auto value = inst.GetOperands()[0]; + auto existing = states.find(key); + if (existing != states.end() && existing->second.type == inst.GetValueType() && + SameExactOperand(existing->second.value, value)) { + *remove_current = true; + return true; + } + if (existing != states.end() && existing->second.pending_store_index >= 0) { + removed[static_cast(existing->second.pending_store_index)] = true; + } + + InvalidateMemoryState(states, &key); + states[key] = {value, inst.GetValueType(), static_cast(current_index)}; + return false; +} + +void ApplyMemoryDataflowInstruction(const MachineModule& module, const MachineInstr& inst, + MemoryMap& states) { + if (inst.GetOpcode() == MachineInstr::Opcode::Call) { + if (CallMayWriteMemory(module, inst)) { + InvalidateMemoryState(states, nullptr); + } + return; + } + if (inst.GetOpcode() == MachineInstr::Opcode::Memset) { + InvalidateMemoryState(states, nullptr); + return; + } + + if (!HasTrackedAddress(inst)) { + return; + } + + const AddressKey key = MakeAddressKey(inst.GetAddress()); + if (inst.GetOpcode() == MachineInstr::Opcode::Store) { + InvalidateMemoryState(states, &key); + states[key] = {inst.GetOperands()[0], inst.GetValueType(), -1}; + return; + } +} + +MemoryMap SimulateBlockMemory(const MachineModule& module, const MachineBasicBlock& block, + const MemoryMap& in_state) { + MemoryMap state = in_state; + for (const auto& inst : block.GetInstructions()) { + ApplyMemoryDataflowInstruction(module, inst, state); + } + return state; +} + +bool RunPeepholeOnBlock(const MachineModule& module, const MachineFunction& function, + MachineBasicBlock& block, const MemoryMap& in_state) { + bool changed = false; + AliasMap aliases; + MemoryMap memory_states = in_state; + std::vector rewritten; + std::vector removed; + rewritten.reserve(block.GetInstructions().size()); + removed.reserve(block.GetInstructions().size()); + + for (const auto& original : block.GetInstructions()) { + MachineInstr inst = original; + changed |= RewriteUses(inst, aliases); + changed |= SimplifyInstruction(inst); + + if (SimplifyCopy(function, inst)) { + changed = true; + continue; + } + + rewritten.push_back(std::move(inst)); + removed.push_back(false); + + MachineInstr& current = rewritten.back(); + bool remove_current = false; + changed |= TryOptimizeMemoryInstruction(module, function, current, memory_states, removed, + rewritten.size() - 1, &remove_current); + if (remove_current) { + removed.back() = true; + changed = true; + continue; + } + changed |= SimplifyInstruction(current); + + if (SimplifyCopy(function, current)) { + removed.back() = true; + changed = true; + continue; + } + + if (current.GetOpcode() == MachineInstr::Opcode::Call) { + if (CallMayReadMemory(module, current) || CallMayWriteMemory(module, current)) { + ObservePendingStores(memory_states); + } + } else if (current.GetOpcode() == MachineInstr::Opcode::Memset) { + ObservePendingStores(memory_states); + } + + TrackAlias(current, aliases); + } + + std::vector compacted; + compacted.reserve(rewritten.size()); + for (std::size_t i = 0; i < rewritten.size(); ++i) { + if (!removed[i]) { + compacted.push_back(std::move(rewritten[i])); + } else { + changed = true; + } + } + + if (compacted.size() != block.GetInstructions().size()) { + changed = true; + } + if (changed) { + block.GetInstructions() = std::move(compacted); + } + return changed; +} + +bool IsSideEffectFree(const MachineInstr& inst) { + switch (inst.GetOpcode()) { + case MachineInstr::Opcode::Arg: + case MachineInstr::Opcode::Copy: + case MachineInstr::Opcode::Load: + case MachineInstr::Opcode::Lea: + case MachineInstr::Opcode::Add: + case MachineInstr::Opcode::Sub: + case MachineInstr::Opcode::Mul: + case MachineInstr::Opcode::Div: + case MachineInstr::Opcode::Rem: + case MachineInstr::Opcode::And: + case MachineInstr::Opcode::Or: + case MachineInstr::Opcode::Xor: + case MachineInstr::Opcode::Shl: + case MachineInstr::Opcode::AShr: + case MachineInstr::Opcode::LShr: + case MachineInstr::Opcode::FAdd: + case MachineInstr::Opcode::FSub: + case MachineInstr::Opcode::FMul: + case MachineInstr::Opcode::FDiv: + case MachineInstr::Opcode::FNeg: + case MachineInstr::Opcode::ICmp: + case MachineInstr::Opcode::FCmp: + case MachineInstr::Opcode::ZExt: + case MachineInstr::Opcode::ItoF: + case MachineInstr::Opcode::FtoI: + return true; + case MachineInstr::Opcode::Store: + case MachineInstr::Opcode::Br: + case MachineInstr::Opcode::CondBr: + case MachineInstr::Opcode::Call: + case MachineInstr::Opcode::Ret: + case MachineInstr::Opcode::Memset: + case MachineInstr::Opcode::Unreachable: + return false; + } + return false; +} + +bool RunDeadInstrElimination(MachineFunction& function) { + bool changed = false; + + while (true) { + std::unordered_map use_counts; + for (const auto& block : function.GetBlocks()) { + for (const auto& inst : block->GetInstructions()) { + for (int use : inst.GetUses()) { + ++use_counts[use]; + } + } + } + + bool local_changed = false; + for (auto& block : function.GetBlocks()) { + std::vector rewritten; + rewritten.reserve(block->GetInstructions().size()); + for (auto& inst : block->GetInstructions()) { + const auto defs = inst.GetDefs(); + const bool has_live_def = + defs.empty() || use_counts.find(defs.front()) != use_counts.end(); + if (has_live_def || !IsSideEffectFree(inst)) { + rewritten.push_back(inst); + continue; + } + local_changed = true; + } + if (local_changed) { + block->GetInstructions() = std::move(rewritten); + } + } + + if (!local_changed) { + break; + } + changed = true; + } + + return changed; +} + +bool HasAssignedAllocations(const MachineFunction& function) { + for (const auto& vreg : function.GetVRegs()) { + if (function.GetAllocation(vreg.id).kind != Allocation::Kind::Unassigned) { + return true; + } + } + return false; +} + +} // namespace + +bool RunPeephole(MachineModule& module) { + bool changed = false; + for (auto& function : module.GetFunctions()) { + if (!function) { + continue; + } + + bool function_changed = false; + const auto cfg = BuildCFG(*function); + std::vector in_states(function->GetBlocks().size()); + std::vector out_states(function->GetBlocks().size()); + + bool dataflow_changed = true; + while (dataflow_changed) { + dataflow_changed = false; + for (std::size_t i = 0; i < function->GetBlocks().size(); ++i) { + MemoryMap in_state; + if (i != 0) { + std::vector predecessors; + for (int pred : cfg.predecessors[i]) { + predecessors.push_back(&out_states[static_cast(pred)]); + } + in_state = MeetMemoryStates(predecessors); + } + + auto out_state = + SimulateBlockMemory(module, *function->GetBlocks()[i], in_state); + if (!SameMemoryMap(in_states[i], in_state)) { + in_states[i] = std::move(in_state); + dataflow_changed = true; + } + if (!SameMemoryMap(out_states[i], out_state)) { + out_states[i] = std::move(out_state); + dataflow_changed = true; + } + } + } + + for (std::size_t i = 0; i < function->GetBlocks().size(); ++i) { + function_changed |= + RunPeepholeOnBlock(module, *function, *function->GetBlocks()[i], in_states[i]); + } + if (!HasAssignedAllocations(*function)) { + function_changed |= RunDeadInstrElimination(*function); + } + changed |= function_changed; + } + return changed; +} + +} // namespace mir diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index 772406b..feeccca 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -1,6 +1,685 @@ #include "sem/Sema.h" +#include +#include +#include +#include +#include +#include + +namespace { + +enum class MemoryRoot { + None, + Local, + Global, + Param, + Unknown, +}; + +struct SymbolInfo { + SemanticType type = SemanticType::Int; + bool is_array = false; + bool is_param_array = false; + MemoryRoot root = MemoryRoot::Local; + std::vector dims; +}; + +struct ExprInfo { + MemoryRoot root = MemoryRoot::None; + bool is_array = false; +}; + +struct CallSiteInfo { + std::string callee; + std::vector arg_roots; +}; + +struct DirectFunctionAnalysis { + FunctionSemanticInfo info; + std::vector calls; +}; + +class ScopedSymbols { + public: + void EnterScope() { scopes_.emplace_back(); } + + void ExitScope() { + if (!scopes_.empty()) { + scopes_.pop_back(); + } + } + + bool Insert(const std::string& name, const SymbolInfo& info) { + if (scopes_.empty()) { + EnterScope(); + } + return scopes_.back().emplace(name, info).second; + } + + const SymbolInfo* Lookup(const std::string& name) const { + for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { + auto found = it->find(name); + if (found != it->end()) { + return &found->second; + } + } + return nullptr; + } + + private: + std::vector> scopes_; +}; + +std::string ExpectIdent(antlr4::tree::TerminalNode* ident) { + return ident == nullptr ? std::string{} : ident->getText(); +} + +SemanticType ParseBType(SysYParser::BTypeContext* ctx) { + if (ctx != nullptr && ctx->FLOAT() != nullptr) { + return SemanticType::Float; + } + return SemanticType::Int; +} + +SemanticType ParseFuncType(SysYParser::FuncTypeContext* ctx) { + if (ctx == nullptr || ctx->VOID() != nullptr) { + return SemanticType::Void; + } + if (ctx->FLOAT() != nullptr) { + return SemanticType::Float; + } + return SemanticType::Int; +} + +std::vector MakeShape(std::size_t rank) { + return std::vector(rank, -1); +} + +void RegisterBuiltinFunctions(SemanticContext& context) { + struct BuiltinSpec { + const char* name; + SemanticType return_type; + std::vector param_is_array; + bool reads_global_memory = false; + bool writes_global_memory = false; + bool reads_param_memory = false; + bool writes_param_memory = false; + bool has_io = false; + }; + + const std::vector builtins = { + {"getint", SemanticType::Int, {}, false, false, false, false, true}, + {"getch", SemanticType::Int, {}, false, false, false, false, true}, + {"getfloat", SemanticType::Float, {}, false, false, false, false, true}, + {"getarray", SemanticType::Int, {true}, false, false, false, true, true}, + {"getfarray", SemanticType::Int, {true}, false, false, false, true, true}, + {"putint", SemanticType::Void, {false}, false, false, false, false, true}, + {"putch", SemanticType::Void, {false}, false, false, false, false, true}, + {"putfloat", SemanticType::Void, {false}, false, false, false, false, true}, + {"putarray", SemanticType::Void, {false, true}, false, false, true, false, true}, + {"putfarray", SemanticType::Void, {false, true}, false, false, true, false, true}, + {"starttime", SemanticType::Void, {}, false, false, false, false, true}, + {"stoptime", SemanticType::Void, {}, false, false, false, false, true}, + }; + + for (const auto& builtin : builtins) { + auto& info = context.UpsertFunction(builtin.name); + info.return_type = builtin.return_type; + info.param_is_array = builtin.param_is_array; + info.is_builtin = true; + info.is_defined = false; + info.reads_global_memory = builtin.reads_global_memory; + info.writes_global_memory = builtin.writes_global_memory; + info.reads_param_memory = builtin.reads_param_memory; + info.writes_param_memory = builtin.writes_param_memory; + info.has_io = builtin.has_io; + info.has_unknown_effects = false; + info.is_recursive = false; + info.direct_callees.clear(); + } +} + +void CollectGlobalDecl(SemanticContext& context, SysYParser::DeclContext& ctx) { + if (auto* const_decl = ctx.constDecl()) { + const auto type = ParseBType(const_decl->bType()); + for (auto* def : const_decl->constDef()) { + auto& info = context.UpsertGlobal(ExpectIdent(def->Ident())); + info.type = type; + info.is_const = true; + info.is_array = !def->constExp().empty(); + info.dims = MakeShape(def->constExp().size()); + } + return; + } + + if (auto* var_decl = ctx.varDecl()) { + const auto type = ParseBType(var_decl->bType()); + for (auto* def : var_decl->varDef()) { + auto& info = context.UpsertGlobal(ExpectIdent(def->Ident())); + info.type = type; + info.is_const = false; + info.is_array = !def->constExp().empty(); + info.dims = MakeShape(def->constExp().size()); + } + } +} + +void CollectFunctionSignature(SemanticContext& context, SysYParser::FuncDefContext& ctx) { + auto& info = context.UpsertFunction(ExpectIdent(ctx.Ident())); + info.return_type = ParseFuncType(ctx.funcType()); + info.param_is_array.clear(); + info.is_builtin = false; + info.is_defined = true; + info.reads_global_memory = false; + info.writes_global_memory = false; + info.reads_param_memory = false; + info.writes_param_memory = false; + info.has_io = false; + info.has_unknown_effects = false; + info.is_recursive = false; + info.direct_callees.clear(); + + if (auto* params = ctx.funcFParams()) { + for (auto* param : params->funcFParam()) { + info.param_is_array.push_back(!param->LBRACK().empty()); + } + } +} + +bool SameInfo(const FunctionSemanticInfo& lhs, const FunctionSemanticInfo& rhs) { + return lhs.return_type == rhs.return_type && + lhs.param_is_array == rhs.param_is_array && + lhs.is_builtin == rhs.is_builtin && lhs.is_defined == rhs.is_defined && + lhs.reads_global_memory == rhs.reads_global_memory && + lhs.writes_global_memory == rhs.writes_global_memory && + lhs.reads_param_memory == rhs.reads_param_memory && + lhs.writes_param_memory == rhs.writes_param_memory && + lhs.has_io == rhs.has_io && + lhs.has_unknown_effects == rhs.has_unknown_effects && + lhs.is_recursive == rhs.is_recursive && + lhs.direct_callees == rhs.direct_callees; +} + +class FunctionAnalyzer { + public: + FunctionAnalyzer(const SemanticContext& context, DirectFunctionAnalysis& analysis) + : context_(context), analysis_(analysis) {} + + void Analyze(SysYParser::FuncDefContext& ctx) { + symbols_.EnterScope(); + + if (auto* params = ctx.funcFParams()) { + for (auto* param : params->funcFParam()) { + SymbolInfo info; + info.type = ParseBType(param->bType()); + info.is_array = !param->LBRACK().empty(); + info.is_param_array = info.is_array; + info.root = info.is_array ? MemoryRoot::Param : MemoryRoot::Local; + info.dims = MakeShape(param->exp().size()); + symbols_.Insert(ExpectIdent(param->Ident()), info); + } + } + + AnalyzeBlock(*ctx.block(), false); + symbols_.ExitScope(); + } + + private: + struct LValueShape { + MemoryRoot root = MemoryRoot::None; + bool is_array = false; + }; + + const SymbolInfo* LookupSymbol(const std::string& name) const { + if (const auto* local = symbols_.Lookup(name)) { + return local; + } + if (const auto* global = context_.LookupGlobal(name)) { + static thread_local SymbolInfo scratch; + scratch.type = global->type; + scratch.is_array = global->is_array; + scratch.is_param_array = false; + scratch.root = MemoryRoot::Global; + scratch.dims = global->dims; + return &scratch; + } + return nullptr; + } + + void AnalyzeBlock(SysYParser::BlockContext& ctx, bool create_scope) { + if (create_scope) { + symbols_.EnterScope(); + } + for (auto* item : ctx.blockItem()) { + AnalyzeBlockItem(*item); + } + if (create_scope) { + symbols_.ExitScope(); + } + } + + void AnalyzeBlockItem(SysYParser::BlockItemContext& ctx) { + if (auto* decl = ctx.decl()) { + AnalyzeDecl(*decl); + } else if (auto* stmt = ctx.stmt()) { + AnalyzeStmt(*stmt); + } + } + + void AnalyzeDecl(SysYParser::DeclContext& ctx) { + if (auto* const_decl = ctx.constDecl()) { + AnalyzeConstDecl(*const_decl); + } else if (auto* var_decl = ctx.varDecl()) { + AnalyzeVarDecl(*var_decl); + } + } + + void AnalyzeVarDecl(SysYParser::VarDeclContext& ctx) { + const auto type = ParseBType(ctx.bType()); + for (auto* def : ctx.varDef()) { + SymbolInfo info; + info.type = type; + info.is_array = !def->constExp().empty(); + info.root = MemoryRoot::Local; + info.dims = MakeShape(def->constExp().size()); + symbols_.Insert(ExpectIdent(def->Ident()), info); + if (def->initVal() != nullptr) { + AnalyzeInitVal(def->initVal()); + } + } + } + + void AnalyzeConstDecl(SysYParser::ConstDeclContext& ctx) { + const auto type = ParseBType(ctx.bType()); + for (auto* def : ctx.constDef()) { + SymbolInfo info; + info.type = type; + info.is_array = !def->constExp().empty(); + info.root = MemoryRoot::Local; + info.dims = MakeShape(def->constExp().size()); + symbols_.Insert(ExpectIdent(def->Ident()), info); + AnalyzeConstInitVal(def->constInitVal()); + } + } + + void AnalyzeInitVal(SysYParser::InitValContext* ctx) { + if (ctx == nullptr) { + return; + } + if (ctx->exp() != nullptr) { + AnalyzeExp(*ctx->exp()); + return; + } + for (auto* child : ctx->initVal()) { + AnalyzeInitVal(child); + } + } + + void AnalyzeConstInitVal(SysYParser::ConstInitValContext* ctx) { + if (ctx == nullptr) { + return; + } + if (ctx->constExp() != nullptr) { + AnalyzeAddExp(*ctx->constExp()->addExp()); + return; + } + for (auto* child : ctx->constInitVal()) { + AnalyzeConstInitVal(child); + } + } + + void AnalyzeStmt(SysYParser::StmtContext& ctx) { + if (ctx.lVal() != nullptr && ctx.ASSIGN() != nullptr) { + AnalyzeLValWrite(*ctx.lVal()); + AnalyzeExp(*ctx.exp()); + return; + } + if (ctx.block() != nullptr) { + AnalyzeBlock(*ctx.block(), true); + return; + } + if (ctx.IF() != nullptr) { + if (ctx.cond() != nullptr) { + AnalyzeLOrExp(*ctx.cond()->lOrExp()); + } + if (!ctx.stmt().empty()) { + AnalyzeStmt(*ctx.stmt()[0]); + } + if (ctx.stmt().size() > 1 && ctx.stmt()[1] != nullptr) { + AnalyzeStmt(*ctx.stmt()[1]); + } + return; + } + if (ctx.WHILE() != nullptr) { + if (ctx.cond() != nullptr) { + AnalyzeLOrExp(*ctx.cond()->lOrExp()); + } + if (!ctx.stmt().empty() && ctx.stmt()[0] != nullptr) { + AnalyzeStmt(*ctx.stmt()[0]); + } + return; + } + if (ctx.RETURN() != nullptr) { + if (ctx.exp() != nullptr) { + AnalyzeExp(*ctx.exp()); + } + return; + } + if (ctx.exp() != nullptr) { + AnalyzeExp(*ctx.exp()); + } + } + + ExprInfo AnalyzeExp(SysYParser::ExpContext& ctx) { + return AnalyzeAddExp(*ctx.addExp()); + } + + ExprInfo AnalyzeAddExp(SysYParser::AddExpContext& ctx) { + if (ctx.addExp() != nullptr) { + AnalyzeAddExp(*ctx.addExp()); + AnalyzeMulExp(*ctx.mulExp()); + return {}; + } + return AnalyzeMulExp(*ctx.mulExp()); + } + + ExprInfo AnalyzeMulExp(SysYParser::MulExpContext& ctx) { + if (ctx.mulExp() != nullptr) { + AnalyzeMulExp(*ctx.mulExp()); + AnalyzeUnaryExp(*ctx.unaryExp()); + return {}; + } + return AnalyzeUnaryExp(*ctx.unaryExp()); + } + + ExprInfo AnalyzeUnaryExp(SysYParser::UnaryExpContext& ctx) { + if (ctx.primaryExp() != nullptr) { + return AnalyzePrimaryExp(*ctx.primaryExp()); + } + + if (ctx.Ident() != nullptr) { + const auto name = ExpectIdent(ctx.Ident()); + CallSiteInfo call; + call.callee = name; + + const auto* callee = context_.LookupFunction(name); + const auto args = ctx.funcRParams() == nullptr ? std::vector{} + : ctx.funcRParams()->exp(); + call.arg_roots.resize(args.size(), MemoryRoot::None); + for (std::size_t i = 0; i < args.size(); ++i) { + auto arg_info = AnalyzeExp(*args[i]); + if (callee != nullptr && i < callee->param_is_array.size() && callee->param_is_array[i]) { + call.arg_roots[i] = arg_info.is_array ? arg_info.root : MemoryRoot::Unknown; + } + } + if (callee == nullptr) { + analysis_.info.has_unknown_effects = true; + } + analysis_.calls.push_back(std::move(call)); + return {}; + } + + if (ctx.unaryExp() != nullptr) { + AnalyzeUnaryExp(*ctx.unaryExp()); + } + return {}; + } + + ExprInfo AnalyzePrimaryExp(SysYParser::PrimaryExpContext& ctx) { + if (ctx.exp() != nullptr) { + return AnalyzeExp(*ctx.exp()); + } + if (ctx.lVal() != nullptr) { + return AnalyzeLValRead(*ctx.lVal()); + } + return {}; + } + + ExprInfo AnalyzeRelExp(SysYParser::RelExpContext& ctx) { + if (ctx.relExp() != nullptr) { + AnalyzeRelExp(*ctx.relExp()); + AnalyzeAddExp(*ctx.addExp()); + return {}; + } + return AnalyzeAddExp(*ctx.addExp()); + } + + ExprInfo AnalyzeEqExp(SysYParser::EqExpContext& ctx) { + if (ctx.eqExp() != nullptr) { + AnalyzeEqExp(*ctx.eqExp()); + AnalyzeRelExp(*ctx.relExp()); + return {}; + } + return AnalyzeRelExp(*ctx.relExp()); + } + + ExprInfo AnalyzeLAndExp(SysYParser::LAndExpContext& ctx) { + if (ctx.lAndExp() != nullptr) { + AnalyzeLAndExp(*ctx.lAndExp()); + AnalyzeEqExp(*ctx.eqExp()); + return {}; + } + return AnalyzeEqExp(*ctx.eqExp()); + } + + ExprInfo AnalyzeLOrExp(SysYParser::LOrExpContext& ctx) { + if (ctx.lOrExp() != nullptr) { + AnalyzeLOrExp(*ctx.lOrExp()); + AnalyzeLAndExp(*ctx.lAndExp()); + return {}; + } + return AnalyzeLAndExp(*ctx.lAndExp()); + } + + LValueShape DescribeLVal(SysYParser::LValContext& ctx) { + for (auto* index : ctx.exp()) { + AnalyzeExp(*index); + } + + const auto* symbol = LookupSymbol(ExpectIdent(ctx.Ident())); + if (symbol == nullptr) { + return {}; + } + if (!symbol->is_array) { + return {symbol->root, false}; + } + + const auto index_count = ctx.exp().size(); + bool still_array = false; + if (symbol->is_param_array) { + if (index_count == 0) { + still_array = true; + } else if (index_count <= symbol->dims.size()) { + still_array = true; + } + } else { + still_array = index_count < symbol->dims.size(); + } + + return {symbol->root, still_array}; + } + + ExprInfo AnalyzeLValRead(SysYParser::LValContext& ctx) { + const auto shape = DescribeLVal(ctx); + if (!shape.is_array) { + if (shape.root == MemoryRoot::Global) { + analysis_.info.reads_global_memory = true; + } else if (shape.root == MemoryRoot::Param) { + analysis_.info.reads_param_memory = true; + } + } + return {shape.root, shape.is_array}; + } + + void AnalyzeLValWrite(SysYParser::LValContext& ctx) { + const auto shape = DescribeLVal(ctx); + if (shape.root == MemoryRoot::Global) { + analysis_.info.writes_global_memory = true; + } else if (shape.root == MemoryRoot::Param) { + analysis_.info.writes_param_memory = true; + } + } + + const SemanticContext& context_; + DirectFunctionAnalysis& analysis_; + ScopedSymbols symbols_; +}; + +void PropagateCallEffects(SemanticContext& context, + const std::unordered_map& analyses) { + bool changed = true; + while (changed) { + changed = false; + for (const auto& [name, analysis] : analyses) { + auto next = analysis.info; + std::unordered_set callees_seen; + for (const auto& call : analysis.calls) { + if (!call.callee.empty()) { + callees_seen.insert(call.callee); + } + const auto* callee = context.LookupFunction(call.callee); + if (callee == nullptr) { + next.has_unknown_effects = true; + continue; + } + + next.has_io = next.has_io || callee->has_io; + next.has_unknown_effects = next.has_unknown_effects || callee->has_unknown_effects; + next.reads_global_memory = + next.reads_global_memory || callee->reads_global_memory; + next.writes_global_memory = + next.writes_global_memory || callee->writes_global_memory; + + const auto arg_count = std::min(call.arg_roots.size(), callee->param_is_array.size()); + for (std::size_t i = 0; i < arg_count; ++i) { + if (!callee->param_is_array[i]) { + continue; + } + switch (call.arg_roots[i]) { + case MemoryRoot::Global: + next.reads_global_memory = + next.reads_global_memory || callee->reads_param_memory; + next.writes_global_memory = + next.writes_global_memory || callee->writes_param_memory; + break; + case MemoryRoot::Param: + next.reads_param_memory = + next.reads_param_memory || callee->reads_param_memory; + next.writes_param_memory = + next.writes_param_memory || callee->writes_param_memory; + break; + case MemoryRoot::Unknown: + if (callee->reads_param_memory || callee->writes_param_memory) { + next.has_unknown_effects = true; + } + break; + case MemoryRoot::None: + case MemoryRoot::Local: + break; + } + } + } + + next.direct_callees.assign(callees_seen.begin(), callees_seen.end()); + std::sort(next.direct_callees.begin(), next.direct_callees.end()); + + auto* current = context.LookupFunction(name); + if (current == nullptr || !SameInfo(next, *current)) { + context.UpsertFunction(name) = std::move(next); + changed = true; + } + } + } +} + +bool ReachesSelf(const SemanticContext& context, const std::string& root, + const std::string& current, + std::unordered_set& visiting) { + const auto* info = context.LookupFunction(current); + if (info == nullptr) { + return false; + } + if (!visiting.insert(current).second) { + return false; + } + + for (const auto& callee : info->direct_callees) { + if (callee == root) { + return true; + } + if (ReachesSelf(context, root, callee, visiting)) { + return true; + } + } + return false; +} + +void MarkRecursiveFunctions(SemanticContext& context) { + std::vector function_names; + function_names.reserve(context.GetFunctions().size()); + for (const auto& [name, info] : context.GetFunctions()) { + if (!info.is_builtin) { + function_names.push_back(name); + } + } + + for (const auto& name : function_names) { + std::unordered_set visiting; + if (ReachesSelf(context, name, name, visiting)) { + context.UpsertFunction(name).is_recursive = true; + } + } +} + +} // namespace + SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) { - (void)comp_unit; - return SemanticContext{}; -} + SemanticContext context; + RegisterBuiltinFunctions(context); + + for (auto* child : comp_unit.children) { + if (auto* decl = dynamic_cast(child)) { + CollectGlobalDecl(context, *decl); + } else if (auto* func = dynamic_cast(child)) { + CollectFunctionSignature(context, *func); + } + } + + std::unordered_map analyses; + for (auto* child : comp_unit.children) { + auto* func = dynamic_cast(child); + if (func == nullptr) { + continue; + } + + const auto name = ExpectIdent(func->Ident()); + auto* existing = context.LookupFunction(name); + if (existing == nullptr) { + continue; + } + + DirectFunctionAnalysis analysis; + analysis.info = *existing; + analysis.info.reads_global_memory = false; + analysis.info.writes_global_memory = false; + analysis.info.reads_param_memory = false; + analysis.info.writes_param_memory = false; + analysis.info.has_io = false; + analysis.info.has_unknown_effects = false; + analysis.info.is_recursive = false; + analysis.info.direct_callees.clear(); + + FunctionAnalyzer analyzer(context, analysis); + analyzer.Analyze(*func); + analyses.emplace(name, std::move(analysis)); + } + + for (const auto& [name, analysis] : analyses) { + context.UpsertFunction(name) = analysis.info; + } + + PropagateCallEffects(context, analyses); + MarkRecursiveFunctions(context); + return context; +}