diff --git a/src/main.cpp b/src/main.cpp index 643a987e..6fb62ae4 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -53,6 +53,7 @@ int main(int argc, char** argv) { auto machine_module = mir::LowerModuleToMIR(*module); mir::RunRegAlloc(*machine_module); mir::RunFrameLowering(*machine_module); + mir::RunBlockLayout(*machine_module); mir::RunPeephole(*machine_module); std::ostringstream asm_ss; diff --git a/src/mir/passes/BlockLayoutOpt.cpp b/src/mir/passes/BlockLayoutOpt.cpp new file mode 100644 index 00000000..2c827eef --- /dev/null +++ b/src/mir/passes/BlockLayoutOpt.cpp @@ -0,0 +1,341 @@ +#include "mir/MIR.h" +#include "mir/analysis/CFGAnalysis.h" + +#include +#include +#include + +#include "utils/Log.h" + +namespace mir +{ +namespace +{ + +// ============================================================ +// Pettis-Hansen 基本块重排序 (PLDI 1990) +// 目标:使热路径 fallthrough,最大化 Peephole FixFallThrough 效果 +// ============================================================ + +struct Chain +{ + std::vector blocks; + double total_freq = 0.0; +}; + +// 检测循环回边:用 DFS 找后向边 (back edge) +struct LoopInfo +{ + std::map dfs_num; + std::map in_stack; + int counter = 0; + + bool IsBackEdge(MachineBasicBlock *src, MachineBasicBlock *dst) const + { + auto sit = dfs_num.find(src); + auto dit = dfs_num.find(dst); + if (sit == dfs_num.end() || dit == dfs_num.end()) + return false; + // 回边:dst 的 DFS 编号 ≤ src,且 dst 在栈上 + return dit->second <= sit->second; + } +}; + +static void DFS(MachineBasicBlock *block, + const CFGAnalysisResult &cfg, + LoopInfo &li) +{ + li.dfs_num[block] = li.counter++; + li.in_stack[block] = true; + + auto it = cfg.successors.find(const_cast(block)); + if (it != cfg.successors.end()) + { + for (auto *succ : it->second) + { + if (li.dfs_num.find(succ) == li.dfs_num.end()) + { + DFS(succ, cfg, li); + } + } + } + li.in_stack[block] = false; +} + +static LoopInfo DetectLoops(MachineFunction &function, + const CFGAnalysisResult &cfg) +{ + LoopInfo li; + auto *entry = function.GetEntryPtr(); + if (entry) + DFS(entry, cfg, li); + return li; +} + +// Pettis-Hansen 核心:链合并 +static std::vector PettisHansenChains( + MachineFunction &function, + const CFGAnalysisResult &cfg, + const LoopInfo &li) +{ + auto &blocks = function.GetBlocks(); + + // 1. 初始化:每个基本块一个链 + std::vector chains; + std::map block_to_chain; + for (auto &block : blocks) + { + if (!block) + continue; + Chain c; + c.blocks.push_back(block.get()); + auto fit = cfg.block_freq.find(block.get()); + c.total_freq = (fit != cfg.block_freq.end()) ? fit->second : 0.0; + block_to_chain[block.get()] = chains.size(); + chains.push_back(std::move(c)); + } + + // 2. 收集所有边,分配权重 + struct WeightedEdge + { + MachineBasicBlock *src; + MachineBasicBlock *dst; + double weight; + }; + std::vector edges; + + for (auto &kv : cfg.successors) + { + auto *src = kv.first; + const auto &succs = kv.second; + for (size_t i = 0; i < succs.size(); ++i) + { + auto *dst = succs[i]; + double w = 0.0; + + // 回边权重最高(循环热点路径) + if (li.IsBackEdge(src, dst)) + { + w = 100.0; + } + // 第一个后继(if-then 的 then 分支)权重中等 + else if (i == 0) + { + w = 10.0; + } + // 其余后继(else 分支、break 等)权重较低 + else + { + w = 1.0; + } + + // 乘以基本块频率 + auto fit = cfg.block_freq.find(src); + if (fit != cfg.block_freq.end()) + w *= (1.0 + fit->second); + + edges.push_back({src, dst, w}); + } + } + + // 3. 按权重降序排序 + std::sort(edges.begin(), edges.end(), + [](const WeightedEdge &a, const WeightedEdge &b) + { return a.weight > b.weight; }); + + // 4. 贪心合并链 + for (const auto &edge : edges) + { + auto sit = block_to_chain.find(edge.src); + auto dit = block_to_chain.find(edge.dst); + if (sit == block_to_chain.end() || dit == block_to_chain.end()) + continue; + + size_t src_idx = sit->second; + size_t dst_idx = dit->second; + if (src_idx == dst_idx) + continue; // 已在同一链中 + + Chain &src_chain = chains[src_idx]; + Chain &dst_chain = chains[dst_idx]; + + // Pettis-Hansen 约束:src 必须是链尾,dst 必须是链头 + if (src_chain.blocks.back() != edge.src) + continue; + if (dst_chain.blocks.front() != edge.dst) + continue; + + // 合并:src_chain + dst_chain + src_chain.total_freq += dst_chain.total_freq; + for (auto *b : dst_chain.blocks) + { + src_chain.blocks.push_back(b); + block_to_chain[b] = src_idx; + } + dst_chain.blocks.clear(); + } + + // 5. 过滤空链,按总频率降序排列 + std::vector result; + for (auto &c : chains) + { + if (!c.blocks.empty()) + result.push_back(std::move(c)); + } + std::sort(result.begin(), result.end(), + [](const Chain &a, const Chain &b) + { return a.total_freq > b.total_freq; }); + + return result; +} + +static void ApplyBlockOrder(MachineFunction &function, + const std::vector &chains) +{ + // 收集所有块的新顺序,确保入口块始终在第一位 + std::vector new_order; + auto *entry = function.GetEntryPtr(); + + // 先放入口块 + if (entry) + new_order.push_back(entry); + + for (const auto &chain : chains) + { + for (auto *block : chain.blocks) + { + if (block && block != entry) + new_order.push_back(block); + } + } + + auto &blocks = function.GetBlocks(); + if (new_order.size() != blocks.size()) + return; // 安全网 + + // 重新排列 unique_ptr 向量 + std::map> ptr_map; + for (auto &bp : blocks) + { + if (bp) + ptr_map[bp.get()] = std::move(bp); + } + + for (size_t i = 0; i < new_order.size(); ++i) + { + auto it = ptr_map.find(new_order[i]); + if (it != ptr_map.end()) + blocks[i] = std::move(it->second); + } +} + +static void FixFallThroughAfterLayout(MachineFunction &function) +{ + // 块顺序改变后,重新运行 fallthrough 消除 + for (auto &block : function.GetBlocks()) + { + if (!block) + continue; + auto &insts = block->GetInstructions(); + if (insts.empty()) + continue; + + // 找到下一个块的 label + const auto &all_blocks = function.GetBlocks(); + int next_label = -1; + for (size_t bi = 0; bi < all_blocks.size(); ++bi) + { + if (all_blocks[bi].get() == block.get() && bi + 1 < all_blocks.size()) + { + next_label = all_blocks[bi + 1]->GetLabelId(); + break; + } + } + if (next_label < 0) + continue; + + // CondBr + Br 模式 + if (insts.size() >= 2) + { + auto br_it = insts.end() - 1; + auto cond_it = insts.end() - 2; + if (br_it->GetOpcode() == Opcode::Br && + cond_it->GetOpcode() == Opcode::CondBr && + br_it->GetOperands().size() >= 1 && + br_it->GetOperands()[0].GetKind() == Operand::Kind::Label && + cond_it->GetOperands().size() >= 2 && + cond_it->GetOperands()[1].GetKind() == Operand::Kind::Label) + { + int cond_target = cond_it->GetOperands()[1].GetLabel(); + int br_target = br_it->GetOperands()[0].GetLabel(); + + if (cond_target == next_label) + { + // CondBr 目标已是 fallthrough → 反转条件 + CondCode old_cc = + static_cast(cond_it->GetOperands()[0].GetImm()); + CondCode new_cc = old_cc; + switch (old_cc) + { + case CondCode::EQ: new_cc = CondCode::NE; break; + case CondCode::NE: new_cc = CondCode::EQ; break; + case CondCode::LT: new_cc = CondCode::GE; break; + case CondCode::LE: new_cc = CondCode::GT; break; + case CondCode::GT: new_cc = CondCode::LE; break; + case CondCode::GE: new_cc = CondCode::LT; break; + } + const_cast(cond_it->GetOperands()[0]) = + Operand::Imm(static_cast(new_cc)); + const_cast(cond_it->GetOperands()[1]) = + Operand::Label(br_target); + insts.pop_back(); // 删除 Br + } + else if (br_target == next_label) + { + insts.pop_back(); // Br 目标已是 fallthrough + } + } + } + + // 单独 Br 模式 + if (!insts.empty()) + { + auto &last = insts.back(); + if (last.GetOpcode() == Opcode::Br && + last.GetOperands().size() >= 1 && + last.GetOperands()[0].GetKind() == Operand::Kind::Label && + last.GetOperands()[0].GetLabel() == next_label) + { + insts.pop_back(); + } + } + } +} + +} // namespace + +void RunBlockLayout(MachineFunction &function) +{ + if (function.GetBlocks().size() <= 2) + return; + + auto cfg = AnalyzeCFG(function); + auto li = DetectLoops(function, cfg); + + auto chains = PettisHansenChains(function, cfg, li); + ApplyBlockOrder(function, chains); + + // 块重排后消除新产生的 fallthrough(Peephole 会处理) + // FixFallThroughAfterLayout(function); +} + +void RunBlockLayout(MachineModule &module) +{ + for (auto &function : module.GetFunctions()) + { + if (function) + RunBlockLayout(*function); + } +} + +} // namespace mir diff --git a/src/mir/passes/CMakeLists.txt b/src/mir/passes/CMakeLists.txt index 3b97b4f7..fdf46a93 100644 --- a/src/mir/passes/CMakeLists.txt +++ b/src/mir/passes/CMakeLists.txt @@ -1,6 +1,7 @@ add_library(mir_passes STATIC PassManager.cpp Peephole.cpp + BlockLayoutOpt.cpp ) target_link_libraries(mir_passes PUBLIC