diff --git a/src/mir/InstLiveness.cpp b/src/mir/InstLiveness.cpp new file mode 100644 index 00000000..0e7f789c --- /dev/null +++ b/src/mir/InstLiveness.cpp @@ -0,0 +1,426 @@ +#include "mir/MIR.h" + +#include +#include +#include +#include +#include + +#include "utils/Log.h" + +namespace mir +{ + namespace + { + + // ---- Phase 1 helpers ------------------------------------------------- + + /// Return true if opcode has a VReg def (always operands[0]). + static bool HasVRegDef(Opcode opcode) + { + switch (opcode) + { + case Opcode::MovImm: + case Opcode::LoadStack: + case Opcode::LoadGlobal: + case Opcode::LoadGlobalAddr: + case Opcode::LoadStackAddr: + case Opcode::LoadMem: + case Opcode::AddRR: + case Opcode::SubRR: + case Opcode::AddImm: + case Opcode::SubImm: + case Opcode::MulRR: + case Opcode::DivRR: + case Opcode::ModRR: + case Opcode::AndRR: + case Opcode::OrRR: + case Opcode::XorRR: + case Opcode::ShlRR: + case Opcode::ShrRR: + case Opcode::AsrRR: + case Opcode::Asr64RR: + case Opcode::Uxtw: + case Opcode::Sxtw: + case Opcode::CSet: + case Opcode::Csel: + case Opcode::Smull: + case Opcode::Msub: + case Opcode::NegRR: + case Opcode::FAddRR: + case Opcode::FSubRR: + case Opcode::FMulRR: + case Opcode::FDivRR: + case Opcode::Scvtf: + case Opcode::FCvtzs: + case Opcode::FMovWS: + case Opcode::MovReg: + return true; + default: + return false; + } + } + + /// Extract def VReg (operands[0] if VReg) and use VRegs from one instruction. + static void ExtractDefUse(const MachineInstr &inst, int &def_vreg, + std::vector &use_vregs) + { + def_vreg = -1; + use_vregs.clear(); + + const auto &ops = inst.GetOperands(); + const auto opcode = inst.GetOpcode(); + + if (HasVRegDef(opcode) && !ops.empty() && + ops[0].GetKind() == Operand::Kind::VReg) + { + def_vreg = ops[0].GetVRegId(); + } + + // All other VReg operands are uses + for (size_t i = 0; i < ops.size(); ++i) + { + // For def-producing instructions, operands[0] is the def (already handled) + if (HasVRegDef(opcode) && i == 0) + continue; + if (ops[i].GetKind() == Operand::Kind::VReg) + use_vregs.push_back(ops[i].GetVRegId()); + } + } + + } // anonymous namespace +} // namespace mir + +namespace mir +{ + + // ---- Block-level dataflow structures -------------------------------- + + struct BlockLiveInfo + { + std::unordered_set def; + std::unordered_set use; + std::unordered_set live_in; + std::unordered_set live_out; + std::vector successors; // block indices + std::vector predecessors; // block indices + }; + + std::vector ComputeInstLiveness(MachineFunction &func) + { + auto &blocks = func.GetBlocks(); + const int num_blocks = static_cast(blocks.size()); + + // ================================================================ + // Phase 1: Block-level backward liveness (fixpoint iteration) + // ================================================================ + + // 1a. Build label → block-index mapping + std::unordered_map label_to_idx; + for (int i = 0; i < num_blocks; ++i) + { + if (!blocks[i]) + continue; + label_to_idx[blocks[i]->GetLabelId()] = i; + } + + // 1b. Compute per-block def/use + successors + std::vector blk_info(num_blocks); + + for (int i = 0; i < num_blocks; ++i) + { + if (!blocks[i]) + continue; + auto &info = blk_info[i]; + auto &insts = blocks[i]->GetInstructions(); + + for (const auto &inst : insts) + { + int def_vreg; + std::vector use_vregs; + ExtractDefUse(inst, def_vreg, use_vregs); + + // All uses are added first, then def is added. This avoids + // counting "def first, then use" in the same block incorrectly. + for (int u : use_vregs) + { + if (info.def.count(u) == 0) + info.use.insert(u); + } + if (def_vreg >= 0) + { + // A vreg used before being defined in this block stays in use set + if (info.use.count(def_vreg) == 0) + info.def.insert(def_vreg); + } + } + + // ---- Determine successors ---- + bool has_br = false; + bool has_condbr = false; + int br_target_label = -1; + int condbr_target_label = -1; + bool has_ret = false; + + for (const auto &inst : insts) + { + const auto opcode = inst.GetOpcode(); + const auto &ops = inst.GetOperands(); + + if (opcode == Opcode::Br && !ops.empty() && + ops[0].GetKind() == Operand::Kind::Label) + { + has_br = true; + br_target_label = ops[0].GetLabel(); + } + else if (opcode == Opcode::CondBr && ops.size() >= 2 && + ops[1].GetKind() == Operand::Kind::Label) + { + has_condbr = true; + condbr_target_label = ops[1].GetLabel(); + } + else if (opcode == Opcode::Ret) + { + has_ret = true; + } + } + + auto add_succ = [&](int label) + { + auto it = label_to_idx.find(label); + if (it != label_to_idx.end()) + info.successors.push_back(it->second); + }; + + if (has_ret) + { + // No successors — function exit + } + else if (has_br) + { + // Unconditional branch: target covers the only outgoing path. + add_succ(br_target_label); + // If there's also a CondBr, its target is taken when condition is + // true — the Br covers the false path. + if (has_condbr) + add_succ(condbr_target_label); + } + else if (has_condbr) + { + // Conditional branch without Br: true path = target, false path = + // falls through to next block in insertion order. + add_succ(condbr_target_label); + if (i + 1 < num_blocks) + info.successors.push_back(i + 1); + } + else + { + // Ordinary block — falls through to next block. + if (i + 1 < num_blocks) + info.successors.push_back(i + 1); + } + } + + // 1c. Build predecessor lists + for (int i = 0; i < num_blocks; ++i) + { + for (int s : blk_info[i].successors) + { + if (s >= 0 && s < num_blocks) + blk_info[s].predecessors.push_back(i); + } + } + + // 1d. Worklist fixpoint + // Initialise live_in with use sets + for (int i = 0; i < num_blocks; ++i) + { + blk_info[i].live_in = blk_info[i].use; + } + + std::queue worklist; + std::vector in_queue(num_blocks, false); + for (int i = 0; i < num_blocks; ++i) + { + if (blocks[i]) + { + worklist.push(i); + in_queue[i] = true; + } + } + + while (!worklist.empty()) + { + int b = worklist.front(); + worklist.pop(); + in_queue[b] = false; + + // Compute new live_out = union of successors' live_in + std::unordered_set new_live_out; + for (int s : blk_info[b].successors) + { + if (s < 0 || s >= num_blocks) + continue; + for (int v : blk_info[s].live_in) + new_live_out.insert(v); + } + + // Compute new live_in = use ∪ (live_out - def) + std::unordered_set new_live_in = blk_info[b].use; + for (int v : new_live_out) + { + if (blk_info[b].def.count(v) == 0) + new_live_in.insert(v); + } + + if (new_live_in != blk_info[b].live_in) + { + blk_info[b].live_out = std::move(new_live_out); + blk_info[b].live_in = std::move(new_live_in); + + // Enqueue all predecessors (their live_out depends on us) + for (int p : blk_info[b].predecessors) + { + if (!in_queue[p]) + { + in_queue[p] = true; + worklist.push(p); + } + } + } + } + + // ================================================================ + // Phase 2: Instruction-level interval computation (reverse scan) + // ================================================================ + std::unordered_map vreg_start; + std::unordered_map vreg_end; + + // Assign global instruction positions + int global_pos = 0; + // Map global_pos → (block_idx, local_instr_idx) for the reverse scan + struct PosInfo + { + int block_idx; + int instr_count; // number of instructions in this block + }; + std::vector pos_to_block; + + for (int i = 0; i < num_blocks; ++i) + { + if (!blocks[i]) + continue; + int count = static_cast(blocks[i]->GetInstructions().size()); + for (int j = 0; j < count; ++j) + pos_to_block.push_back({i, count}); + global_pos += count; + } + + const int total_instrs = global_pos; + + // Reverse scan: process blocks in reverse order + for (int bi = num_blocks - 1; bi >= 0; --bi) + { + if (!blocks[bi]) + continue; + auto &insts = blocks[bi]->GetInstructions(); + const int num_instrs = static_cast(insts.size()); + if (num_instrs == 0) + continue; + + // Compute the starting global position of the first instruction in + // this block + int block_start_pos = 0; + for (int pi = 0; pi < bi; ++pi) + { + if (blocks[pi]) + block_start_pos += static_cast(blocks[pi]->GetInstructions().size()); + } + + // Start with live_out of this block + std::unordered_set live = blk_info[bi].live_out; + + // Process instructions from last to first. + // Correct backward order: uses first (add to live), then record + // (interval extends to this position), then defs (remove from live). + // This ensures that a vreg used at this position IS recorded as + // live here, even if it was not previously in the live set. + for (int j = num_instrs - 1; j >= 0; --j) + { + int pos = block_start_pos + j; + + const auto &inst = insts[j]; + int def_vreg; + std::vector use_vregs; + ExtractDefUse(inst, def_vreg, use_vregs); + + // Uses: going backward, uses make the vreg live before this + // instruction. + for (int u : use_vregs) + live.insert(u); + + // Record: all vregs currently live extend their interval to this + // position. + for (int v : live) + { + auto sit = vreg_start.find(v); + if (sit == vreg_start.end() || pos < sit->second) + vreg_start[v] = pos; + auto eit = vreg_end.find(v); + if (eit == vreg_end.end() || pos > eit->second) + vreg_end[v] = pos; + } + + // Def: going backward, the def is the beginning of the live range + // — remove from live so that earlier positions don't see it + // (unless a later use re-adds it for the prior value). + if (def_vreg >= 0) + live.erase(def_vreg); + } + + // After processing all instructions, live should equal live_in. + // Any vreg still in live for the entry block (block 0) is live-in + // at function entry → set start = 0. + if (bi == 0) + { + for (int v : live) + { + vreg_start[v] = 0; + // Also ensure end is at least 0 + auto eit = vreg_end.find(v); + if (eit == vreg_end.end() || 0 > eit->second) + vreg_end[v] = 0; + } + } + } + + // ================================================================ + // Phase 3: Build LiveInterval objects + // ================================================================ + const int num_vregs = func.GetNumVRegs(); + std::vector intervals; + + for (int v = 0; v < num_vregs; ++v) + { + auto sit = vreg_start.find(v); + auto eit = vreg_end.find(v); + + int start = (sit != vreg_start.end()) ? sit->second : 0; + int end = (eit != vreg_end.end()) ? eit->second : 0; + + // Filter out unused vregs + if (start > end) + continue; + + LiveInterval li; + li.vreg = v; + li.start = start; + li.end = end; + li.vreg_class = func.GetVRegClass(v); + intervals.push_back(li); + } + + return intervals; + } + +} // namespace mir