// IR -> MIR 指令选择(Lab5): // - 为每个 IR 值分配虚拟寄存器(GPR / FPR 两类) // - alloca -> 栈对象;gep/global -> 地址计算 // - 完整覆盖算术、比较、分支、调用、访存、类型转换、浮点 #include "mir/MIR.h" #include #include #include #include #include #include #include "ir/IR.h" #include "utils/Log.h" namespace mir { namespace { int TypeSize(const ir::Type& t) { switch (t.GetKind()) { case ir::Type::Kind::Int1: case ir::Type::Kind::Int32: case ir::Type::Kind::Float: return 4; case ir::Type::Kind::Pointer: return 8; case ir::Type::Kind::Array: return static_cast(t.GetArraySize()) * TypeSize(*t.GetElementType()); default: return 8; } } RegClass ClassOf(const ir::Type& t) { return t.IsFloat() ? RegClass::FPR : RegClass::GPR; } int BytesOf(const ir::Type& t) { return t.IsPointer() ? 8 : 4; } bool IsPow2(long long v) { return v > 0 && (v & (v - 1)) == 0; } int Log2(long long v) { int n = 0; while ((1LL << n) < v) ++n; return n; } class Lowerer { public: Lowerer(const ir::Module& m, MachineModule& out) : ir_(m), out_(out) {} void Run(); private: const ir::Module& ir_; MachineModule& out_; MachineFunction* mf_ = nullptr; MachineBasicBlock* mbb_ = nullptr; std::unordered_map vmap_; std::unordered_map bmap_; std::unordered_map allocas_; int label_id_ = 0; void Emit(Opcode op, std::vector ops, int defs, Cond c = Cond::AL) { mbb_->Add(MachineInstr(op, std::move(ops), defs, c)); } Operand NewG(int bytes = 4) { return mf_->NewVRegOp(RegClass::GPR, bytes); } Operand NewF() { return mf_->NewVRegOp(RegClass::FPR, 4); } void LowerFunction(const ir::Function& f); void LowerBlock(const ir::BasicBlock& bb); void LowerInst(const ir::Instruction& inst); void LowerInstMem(const ir::Instruction& inst); Operand GetReg(const ir::Value* v); // 取值(必要时物化常量/地址) Operand MaterializeInt(int v); Operand AddressOf(const ir::Value* ptr); long long GepConst(const ir::GepInst& gep, Operand* out_base); void LowerGlobals(); Cond ICmpCond(ir::ICmpPredicate p); Cond FCmpCond(ir::FCmpPredicate p); }; void Lowerer::Run() { LowerGlobals(); for (const auto& f : ir_.GetFunctions()) { if (f->IsDeclaration()) continue; LowerFunction(*f); } } void Lowerer::LowerGlobals() { for (const auto& g : ir_.GetGlobals()) { MachineGlobal mg; mg.name = g->GetName(); const ir::Type& vt = *g->GetValueType(); mg.size = TypeSize(vt); mg.align = vt.IsArray() ? 16 : 4; mg.is_const = g->IsConst(); ir::ConstantValue* init = g->GetInitializer(); mg.zero_init = true; int nwords = (mg.size + 3) / 4; mg.words.assign(nwords, 0u); // 收集初始化位(递归展开数组)。 std::vector flat; std::function walk = [&](ir::ConstantValue* c) { if (!c) return; if (auto* ci = dynamic_cast(c)) { flat.push_back(static_cast(ci->GetValue())); } else if (auto* cf = dynamic_cast(c)) { float v = cf->GetValue(); unsigned bits; std::memcpy(&bits, &v, 4); flat.push_back(bits); } else if (auto* ca = dynamic_cast(c)) { for (auto* e : ca->GetElements()) walk(e); } }; walk(init); for (size_t i = 0; i < flat.size() && i < mg.words.size(); ++i) { mg.words[i] = flat[i]; if (flat[i] != 0) mg.zero_init = false; } out_.Globals().push_back(std::move(mg)); } } void Lowerer::LowerFunction(const ir::Function& f) { out_.Functions().push_back(std::make_unique(f.GetName())); mf_ = out_.Functions().back().get(); vmap_.clear(); bmap_.clear(); allocas_.clear(); for (const auto& bb : f.GetBlocks()) { bmap_[bb.get()] = mf_->CreateBlock(bb->GetName()); } // 记录后继,便于活跃性分析。 for (const auto& bb : f.GetBlocks()) { MachineBasicBlock* mb = bmap_[bb.get()]; for (auto* s : bb->GetSuccessors()) mb->Succs().push_back(bmap_[s]); } mbb_ = bmap_[f.GetEntry()]; // 形参:整型走 x0.., 浮点走 s0..,超过 8 个的从栈读取(测试未用,简化)。 int ig = 0, fg = 0; std::vector arg_copies; for (size_t i = 0; i < f.GetNumArgs(); ++i) { ir::Argument* a = const_cast(f).GetArg(i); const ir::Type& at = *a->GetType(); if (at.IsFloat()) { Operand dst = NewF(); arg_copies.push_back(MachineInstr( Opcode::FMov, {dst, Operand::PReg(fg++, RegClass::FPR, 4)}, 1)); vmap_[a] = dst; } else { int bytes = at.IsPointer() ? 8 : 4; Operand dst = NewG(bytes); arg_copies.push_back(MachineInstr( Opcode::Mov, {dst, Operand::PReg(ig++, RegClass::GPR, bytes)}, 1)); vmap_[a] = dst; } } mf_->SetArgCounts(ig, fg); for (auto& mi : arg_copies) mbb_->Add(std::move(mi)); for (const auto& bb : f.GetBlocks()) { mbb_ = bmap_[bb.get()]; LowerBlock(*bb); } } void Lowerer::LowerBlock(const ir::BasicBlock& bb) { for (const auto& inst : bb.GetInstructions()) { LowerInst(*inst); } } Operand Lowerer::MaterializeInt(int v) { Operand d = NewG(4); Emit(Opcode::MovImm, {d, Operand::Imm(v)}, 1); return d; } Operand Lowerer::GetReg(const ir::Value* v) { auto it = vmap_.find(v); if (it != vmap_.end()) return it->second; if (auto* ci = dynamic_cast(v)) { return MaterializeInt(ci->GetValue()); } if (auto* cf = dynamic_cast(v)) { float f = cf->GetValue(); unsigned bits; std::memcpy(&bits, &f, 4); Operand d = NewF(); Emit(Opcode::FMovImm, {d, Operand::Imm(bits)}, 1); return d; } // 兜底:未知值视为 0。 return MaterializeInt(0); } // 计算 gep 的常量字节偏移;返回偏移并把基址写入 *out_base。 // 若存在变量下标,直接生成地址计算指令并把最终地址放入 *out_base、返回 0。 long long Lowerer::GepConst(const ir::GepInst& gep, Operand* out_base) { Operand base = AddressOf(gep.GetBasePtr()); // 推断逐层元素类型:基址指针指向的类型。 std::shared_ptr cur = gep.GetBasePtr()->GetType()->GetElementType(); long long const_off = 0; Operand addr = base; bool addr_dirty = false; const auto& idxs = gep.GetIndices(); for (size_t i = 0; i < idxs.size(); ++i) { int elem_size = cur ? TypeSize(*cur) : 4; ir::Value* iv = idxs[i]; if (auto* ci = dynamic_cast(iv)) { const_off += static_cast(ci->GetValue()) * elem_size; } else { // addr += index * elem_size Operand idx = GetReg(iv); Operand idx64 = NewG(8); Emit(Opcode::Sxtw, {idx64, idx}, 1); Operand scaled = NewG(8); if (IsPow2(elem_size)) { Emit(Opcode::LslImm, {scaled, idx64, Operand::Imm(Log2(elem_size))}, 1); } else { Operand sz = NewG(8); Emit(Opcode::MovImm, {sz, Operand::Imm(elem_size)}, 1); Emit(Opcode::Mul, {scaled, idx64, sz}, 1); } Operand na = NewG(8); Emit(Opcode::Add, {na, addr, scaled}, 1); addr = na; addr_dirty = true; } if (cur && cur->IsArray()) cur = cur->GetElementType(); } if (addr_dirty) { *out_base = addr; return const_off; } *out_base = base; return const_off; } // 返回某个指针型 IR 值对应的“地址”寄存器。 Operand Lowerer::AddressOf(const ir::Value* ptr) { auto it = vmap_.find(ptr); if (it != vmap_.end()) return it->second; if (auto* a = dynamic_cast(ptr)) { int idx; auto ai = allocas_.find(a); if (ai == allocas_.end()) { idx = mf_->CreateStackObject(TypeSize(*a->GetAllocatedType()), a->GetAllocatedType()->IsArray() ? 16 : 4); allocas_[a] = idx; } else { idx = ai->second; } Operand d = NewG(8); Emit(Opcode::AddrFrame, {d, Operand::Frame(idx)}, 1); vmap_[ptr] = d; return d; } // 全局变量 Operand d = NewG(8); Emit(Opcode::AddrGlobal, {d, Operand::Global(ptr->GetName())}, 1); return d; } Cond Lowerer::ICmpCond(ir::ICmpPredicate p) { switch (p) { case ir::ICmpPredicate::Eq: return Cond::EQ; case ir::ICmpPredicate::Ne: return Cond::NE; case ir::ICmpPredicate::Slt: return Cond::LT; case ir::ICmpPredicate::Sle: return Cond::LE; case ir::ICmpPredicate::Sgt: return Cond::GT; case ir::ICmpPredicate::Sge: return Cond::GE; } return Cond::EQ; } Cond Lowerer::FCmpCond(ir::FCmpPredicate p) { switch (p) { case ir::FCmpPredicate::Oeq: return Cond::EQ; case ir::FCmpPredicate::One: return Cond::NE; case ir::FCmpPredicate::Olt: return Cond::MI; case ir::FCmpPredicate::Ole: return Cond::LS; case ir::FCmpPredicate::Ogt: return Cond::GT; case ir::FCmpPredicate::Oge: return Cond::GE; } return Cond::EQ; } void Lowerer::LowerInst(const ir::Instruction& inst) { using ir::Opcode; switch (inst.GetOpcode()) { case Opcode::Add: case Opcode::Sub: case Opcode::Mul: case Opcode::SDiv: case Opcode::SRem: { auto& b = static_cast(inst); Operand l = GetReg(b.GetLhs()); Operand r = GetReg(b.GetRhs()); Operand d = NewG(4); mir::Opcode mop = mir::Opcode::Add; if (inst.GetOpcode() == Opcode::Add) mop = mir::Opcode::Add; else if (inst.GetOpcode() == Opcode::Sub) mop = mir::Opcode::Sub; else if (inst.GetOpcode() == Opcode::Mul) mop = mir::Opcode::Mul; else mop = mir::Opcode::SDiv; if (inst.GetOpcode() == Opcode::SRem) { Operand q = NewG(4); Emit(mir::Opcode::SDiv, {q, l, r}, 1); Emit(mir::Opcode::MSub, {d, q, r, l}, 1); // d = l - q*r } else { Emit(mop, {d, l, r}, 1); } vmap_[&inst] = d; break; } case Opcode::FAdd: case Opcode::FSub: case Opcode::FMul: case Opcode::FDiv: { auto& b = static_cast(inst); Operand l = GetReg(b.GetLhs()); Operand r = GetReg(b.GetRhs()); Operand d = NewF(); mir::Opcode mop = mir::Opcode::FAdd; if (inst.GetOpcode() == Opcode::FSub) mop = mir::Opcode::FSub; else if (inst.GetOpcode() == Opcode::FMul) mop = mir::Opcode::FMul; else if (inst.GetOpcode() == Opcode::FDiv) mop = mir::Opcode::FDiv; Emit(mop, {d, l, r}, 1); vmap_[&inst] = d; break; } case Opcode::SIToFP: { auto& c = static_cast(inst); Operand s = GetReg(c.GetValue()); Operand d = NewF(); Emit(mir::Opcode::SCvtF, {d, s}, 1); vmap_[&inst] = d; break; } case Opcode::FPToSI: { auto& c = static_cast(inst); Operand s = GetReg(c.GetValue()); Operand d = NewG(4); Emit(mir::Opcode::FCvtZS, {d, s}, 1); vmap_[&inst] = d; break; } case Opcode::ZExt: { auto& c = static_cast(inst); Operand s = GetReg(c.GetValue()); Operand d = NewG(4); Emit(mir::Opcode::Mov, {d, s}, 1); // i1->i32:cset 已产出 0/1 vmap_[&inst] = d; break; } case Opcode::ICmp: { auto& c = static_cast(inst); Operand l = GetReg(c.GetLhs()); Operand r = GetReg(c.GetRhs()); Emit(mir::Opcode::Cmp, {l, r}, 0); Operand d = NewG(4); Emit(mir::Opcode::CSet, {d}, 1, ICmpCond(c.GetPredicate())); vmap_[&inst] = d; break; } case Opcode::FCmp: { auto& c = static_cast(inst); Operand l = GetReg(c.GetLhs()); Operand r = GetReg(c.GetRhs()); Emit(mir::Opcode::FCmp, {l, r}, 0); Operand d = NewG(4); Emit(mir::Opcode::CSet, {d}, 1, FCmpCond(c.GetPredicate())); vmap_[&inst] = d; break; } default: LowerInstMem(inst); break; } } void Lowerer::LowerInstMem(const ir::Instruction& inst) { using ir::Opcode; switch (inst.GetOpcode()) { case Opcode::Alloca: { auto& a = static_cast(inst); int idx = mf_->CreateStackObject(TypeSize(*a.GetAllocatedType()), a.GetAllocatedType()->IsArray() ? 16 : 4); allocas_[&a] = idx; Operand d = NewG(8); Emit(mir::Opcode::AddrFrame, {d, Operand::Frame(idx)}, 1); vmap_[&inst] = d; break; } case Opcode::Load: { auto& ld = static_cast(inst); Operand base; long long off = 0; if (auto* gep = dynamic_cast(ld.GetPtr())) { off = GepConst(*gep, &base); } else { base = AddressOf(ld.GetPtr()); } bool is_f = inst.GetType()->IsFloat(); Operand d = is_f ? NewF() : NewG(BytesOf(*inst.GetType())); Emit(mir::Opcode::Ldr, {d, base, Operand::Imm(off)}, 1); vmap_[&inst] = d; break; } case Opcode::Store: { auto& st = static_cast(inst); Operand val = GetReg(st.GetValue()); Operand base; long long off = 0; if (auto* gep = dynamic_cast(st.GetPtr())) { off = GepConst(*gep, &base); } else { base = AddressOf(st.GetPtr()); } Emit(mir::Opcode::Str, {val, base, Operand::Imm(off)}, 0); break; } case Opcode::Gep: { auto& gep = static_cast(inst); Operand base; long long off = GepConst(gep, &base); Operand d = NewG(8); if (off == 0) { Emit(mir::Opcode::Mov, {d, base}, 1); } else { Emit(mir::Opcode::AddImm, {d, base, Operand::Imm(off)}, 1); } vmap_[&inst] = d; break; } case Opcode::Call: { auto& call = static_cast(inst); // 先把所有实参算入虚拟寄存器,再连续搬入物理参数寄存器, // 避免计算后续实参时分配器复用 x0..x7 破坏已就绪的参数。 std::vector vals; for (auto* arg : call.GetArgs()) vals.push_back(GetReg(arg)); int ig = 0, fg = 0; std::vector arg_uses; for (size_t i = 0; i < call.GetArgs().size(); ++i) { ir::Value* arg = call.GetArgs()[i]; if (arg->GetType()->IsFloat()) { Operand p = Operand::PReg(fg++, RegClass::FPR, 4); Emit(mir::Opcode::FMov, {p, vals[i]}, 1); arg_uses.push_back(p); } else { int bytes = arg->GetType()->IsPointer() ? 8 : 4; Operand p = Operand::PReg(ig++, RegClass::GPR, bytes); Emit(mir::Opcode::Mov, {p, vals[i]}, 1); arg_uses.push_back(p); } } std::vector ops; ops.push_back(Operand::Global(call.GetCallee()->GetName())); for (auto& u : arg_uses) ops.push_back(u); Emit(mir::Opcode::Bl, ops, 0); if (!inst.GetType()->IsVoid()) { if (inst.GetType()->IsFloat()) { Operand d = NewF(); Emit(mir::Opcode::FMov, {d, Operand::PReg(0, RegClass::FPR, 4)}, 1); vmap_[&inst] = d; } else { Operand d = NewG(BytesOf(*inst.GetType())); Emit(mir::Opcode::Mov, {d, Operand::PReg(0, RegClass::GPR, d.GetBytes())}, 1); vmap_[&inst] = d; } } break; } case Opcode::Br: { auto& br = static_cast(inst); Emit(mir::Opcode::B, {Operand::Label(br.GetDest()->GetName())}, 0); break; } case Opcode::CondBr: { auto& cbr = static_cast(inst); Operand c = GetReg(cbr.GetCond()); Emit(mir::Opcode::CmpImm, {c, Operand::Imm(0)}, 0); Emit(mir::Opcode::BCond, {Operand::Label(cbr.GetTrueDest()->GetName())}, 0, Cond::NE); Emit(mir::Opcode::B, {Operand::Label(cbr.GetFalseDest()->GetName())}, 0); break; } case Opcode::Ret: { auto& ret = static_cast(inst); if (ret.HasReturnValue()) { Operand v = GetReg(ret.GetValue()); if (ret.GetValue()->GetType()->IsFloat()) { Emit(mir::Opcode::FMov, {Operand::PReg(0, RegClass::FPR, 4), v}, 1); } else { Emit(mir::Opcode::Mov, {Operand::PReg(0, RegClass::GPR, v.GetBytes()), v}, 1); } } Emit(mir::Opcode::Ret, {}, 0); break; } default: break; } } } // namespace std::unique_ptr LowerToMIR(const ir::Module& module) { auto out = std::make_unique(); Lowerer lo(module, *out); lo.Run(); return out; } } // namespace mir