You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

526 lines
17 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// IR -> MIR 指令选择Lab5
// - 为每个 IR 值分配虚拟寄存器GPR / FPR 两类)
// - alloca -> 栈对象gep/global -> 地址计算
// - 完整覆盖算术、比较、分支、调用、访存、类型转换、浮点
#include "mir/MIR.h"
#include <cstring>
#include <functional>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>
#include "ir/IR.h"
#include "utils/Log.h"
namespace mir {
namespace {
int TypeSize(const ir::Type& t) {
switch (t.GetKind()) {
case ir::Type::Kind::Int1:
case ir::Type::Kind::Int32:
case ir::Type::Kind::Float:
return 4;
case ir::Type::Kind::Pointer:
return 8;
case ir::Type::Kind::Array:
return static_cast<int>(t.GetArraySize()) *
TypeSize(*t.GetElementType());
default:
return 8;
}
}
RegClass ClassOf(const ir::Type& t) {
return t.IsFloat() ? RegClass::FPR : RegClass::GPR;
}
int BytesOf(const ir::Type& t) { return t.IsPointer() ? 8 : 4; }
bool IsPow2(long long v) { return v > 0 && (v & (v - 1)) == 0; }
int Log2(long long v) {
int n = 0;
while ((1LL << n) < v) ++n;
return n;
}
class Lowerer {
public:
Lowerer(const ir::Module& m, MachineModule& out) : ir_(m), out_(out) {}
void Run();
private:
const ir::Module& ir_;
MachineModule& out_;
MachineFunction* mf_ = nullptr;
MachineBasicBlock* mbb_ = nullptr;
std::unordered_map<const ir::Value*, Operand> vmap_;
std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> bmap_;
std::unordered_map<const ir::AllocaInst*, int> allocas_;
int label_id_ = 0;
void Emit(Opcode op, std::vector<Operand> ops, int defs, Cond c = Cond::AL) {
mbb_->Add(MachineInstr(op, std::move(ops), defs, c));
}
Operand NewG(int bytes = 4) { return mf_->NewVRegOp(RegClass::GPR, bytes); }
Operand NewF() { return mf_->NewVRegOp(RegClass::FPR, 4); }
void LowerFunction(const ir::Function& f);
void LowerBlock(const ir::BasicBlock& bb);
void LowerInst(const ir::Instruction& inst);
void LowerInstMem(const ir::Instruction& inst);
Operand GetReg(const ir::Value* v); // 取值(必要时物化常量/地址)
Operand MaterializeInt(int v);
Operand AddressOf(const ir::Value* ptr);
long long GepConst(const ir::GepInst& gep, Operand* out_base);
void LowerGlobals();
Cond ICmpCond(ir::ICmpPredicate p);
Cond FCmpCond(ir::FCmpPredicate p);
};
void Lowerer::Run() {
LowerGlobals();
for (const auto& f : ir_.GetFunctions()) {
if (f->IsDeclaration()) continue;
LowerFunction(*f);
}
}
void Lowerer::LowerGlobals() {
for (const auto& g : ir_.GetGlobals()) {
MachineGlobal mg;
mg.name = g->GetName();
const ir::Type& vt = *g->GetValueType();
mg.size = TypeSize(vt);
mg.align = vt.IsArray() ? 16 : 4;
mg.is_const = g->IsConst();
ir::ConstantValue* init = g->GetInitializer();
mg.zero_init = true;
int nwords = (mg.size + 3) / 4;
mg.words.assign(nwords, 0u);
// 收集初始化位(递归展开数组)。
std::vector<unsigned> flat;
std::function<void(ir::ConstantValue*)> walk = [&](ir::ConstantValue* c) {
if (!c) return;
if (auto* ci = dynamic_cast<ir::ConstantInt*>(c)) {
flat.push_back(static_cast<unsigned>(ci->GetValue()));
} else if (auto* cf = dynamic_cast<ir::ConstantFloat*>(c)) {
float v = cf->GetValue();
unsigned bits;
std::memcpy(&bits, &v, 4);
flat.push_back(bits);
} else if (auto* ca = dynamic_cast<ir::ConstantArray*>(c)) {
for (auto* e : ca->GetElements()) walk(e);
}
};
walk(init);
for (size_t i = 0; i < flat.size() && i < mg.words.size(); ++i) {
mg.words[i] = flat[i];
if (flat[i] != 0) mg.zero_init = false;
}
out_.Globals().push_back(std::move(mg));
}
}
void Lowerer::LowerFunction(const ir::Function& f) {
out_.Functions().push_back(std::make_unique<MachineFunction>(f.GetName()));
mf_ = out_.Functions().back().get();
vmap_.clear();
bmap_.clear();
allocas_.clear();
for (const auto& bb : f.GetBlocks()) {
bmap_[bb.get()] = mf_->CreateBlock(bb->GetName());
}
// 记录后继,便于活跃性分析。
for (const auto& bb : f.GetBlocks()) {
MachineBasicBlock* mb = bmap_[bb.get()];
for (auto* s : bb->GetSuccessors()) mb->Succs().push_back(bmap_[s]);
}
mbb_ = bmap_[f.GetEntry()];
// 形参:整型走 x0.., 浮点走 s0..,超过 8 个的从栈读取(测试未用,简化)。
int ig = 0, fg = 0;
std::vector<MachineInstr> arg_copies;
for (size_t i = 0; i < f.GetNumArgs(); ++i) {
ir::Argument* a = const_cast<ir::Function&>(f).GetArg(i);
const ir::Type& at = *a->GetType();
if (at.IsFloat()) {
Operand dst = NewF();
arg_copies.push_back(MachineInstr(
Opcode::FMov,
{dst, Operand::PReg(fg++, RegClass::FPR, 4)}, 1));
vmap_[a] = dst;
} else {
int bytes = at.IsPointer() ? 8 : 4;
Operand dst = NewG(bytes);
arg_copies.push_back(MachineInstr(
Opcode::Mov, {dst, Operand::PReg(ig++, RegClass::GPR, bytes)}, 1));
vmap_[a] = dst;
}
}
mf_->SetArgCounts(ig, fg);
for (auto& mi : arg_copies) mbb_->Add(std::move(mi));
for (const auto& bb : f.GetBlocks()) {
mbb_ = bmap_[bb.get()];
LowerBlock(*bb);
}
}
void Lowerer::LowerBlock(const ir::BasicBlock& bb) {
for (const auto& inst : bb.GetInstructions()) {
LowerInst(*inst);
}
}
Operand Lowerer::MaterializeInt(int v) {
Operand d = NewG(4);
Emit(Opcode::MovImm, {d, Operand::Imm(v)}, 1);
return d;
}
Operand Lowerer::GetReg(const ir::Value* v) {
auto it = vmap_.find(v);
if (it != vmap_.end()) return it->second;
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(v)) {
return MaterializeInt(ci->GetValue());
}
if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(v)) {
float f = cf->GetValue();
unsigned bits;
std::memcpy(&bits, &f, 4);
Operand d = NewF();
Emit(Opcode::FMovImm, {d, Operand::Imm(bits)}, 1);
return d;
}
// 兜底:未知值视为 0。
return MaterializeInt(0);
}
// 计算 gep 的常量字节偏移;返回偏移并把基址写入 *out_base。
// 若存在变量下标,直接生成地址计算指令并把最终地址放入 *out_base、返回 0。
long long Lowerer::GepConst(const ir::GepInst& gep, Operand* out_base) {
Operand base = AddressOf(gep.GetBasePtr());
// 推断逐层元素类型:基址指针指向的类型。
std::shared_ptr<ir::Type> cur = gep.GetBasePtr()->GetType()->GetElementType();
long long const_off = 0;
Operand addr = base;
bool addr_dirty = false;
const auto& idxs = gep.GetIndices();
for (size_t i = 0; i < idxs.size(); ++i) {
int elem_size = cur ? TypeSize(*cur) : 4;
ir::Value* iv = idxs[i];
if (auto* ci = dynamic_cast<ir::ConstantInt*>(iv)) {
const_off += static_cast<long long>(ci->GetValue()) * elem_size;
} else {
// addr += index * elem_size
Operand idx = GetReg(iv);
Operand idx64 = NewG(8);
Emit(Opcode::Sxtw, {idx64, idx}, 1);
Operand scaled = NewG(8);
if (IsPow2(elem_size)) {
Emit(Opcode::LslImm, {scaled, idx64, Operand::Imm(Log2(elem_size))}, 1);
} else {
Operand sz = NewG(8);
Emit(Opcode::MovImm, {sz, Operand::Imm(elem_size)}, 1);
Emit(Opcode::Mul, {scaled, idx64, sz}, 1);
}
Operand na = NewG(8);
Emit(Opcode::Add, {na, addr, scaled}, 1);
addr = na;
addr_dirty = true;
}
if (cur && cur->IsArray()) cur = cur->GetElementType();
}
if (addr_dirty) {
*out_base = addr;
return const_off;
}
*out_base = base;
return const_off;
}
// 返回某个指针型 IR 值对应的“地址”寄存器。
Operand Lowerer::AddressOf(const ir::Value* ptr) {
auto it = vmap_.find(ptr);
if (it != vmap_.end()) return it->second;
if (auto* a = dynamic_cast<const ir::AllocaInst*>(ptr)) {
int idx;
auto ai = allocas_.find(a);
if (ai == allocas_.end()) {
idx = mf_->CreateStackObject(TypeSize(*a->GetAllocatedType()),
a->GetAllocatedType()->IsArray() ? 16 : 4);
allocas_[a] = idx;
} else {
idx = ai->second;
}
Operand d = NewG(8);
Emit(Opcode::AddrFrame, {d, Operand::Frame(idx)}, 1);
vmap_[ptr] = d;
return d;
}
// 全局变量
Operand d = NewG(8);
Emit(Opcode::AddrGlobal, {d, Operand::Global(ptr->GetName())}, 1);
return d;
}
Cond Lowerer::ICmpCond(ir::ICmpPredicate p) {
switch (p) {
case ir::ICmpPredicate::Eq: return Cond::EQ;
case ir::ICmpPredicate::Ne: return Cond::NE;
case ir::ICmpPredicate::Slt: return Cond::LT;
case ir::ICmpPredicate::Sle: return Cond::LE;
case ir::ICmpPredicate::Sgt: return Cond::GT;
case ir::ICmpPredicate::Sge: return Cond::GE;
}
return Cond::EQ;
}
Cond Lowerer::FCmpCond(ir::FCmpPredicate p) {
switch (p) {
case ir::FCmpPredicate::Oeq: return Cond::EQ;
case ir::FCmpPredicate::One: return Cond::NE;
case ir::FCmpPredicate::Olt: return Cond::MI;
case ir::FCmpPredicate::Ole: return Cond::LS;
case ir::FCmpPredicate::Ogt: return Cond::GT;
case ir::FCmpPredicate::Oge: return Cond::GE;
}
return Cond::EQ;
}
void Lowerer::LowerInst(const ir::Instruction& inst) {
using ir::Opcode;
switch (inst.GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::SDiv:
case Opcode::SRem: {
auto& b = static_cast<const ir::BinaryInst&>(inst);
Operand l = GetReg(b.GetLhs());
Operand r = GetReg(b.GetRhs());
Operand d = NewG(4);
mir::Opcode mop = mir::Opcode::Add;
if (inst.GetOpcode() == Opcode::Add) mop = mir::Opcode::Add;
else if (inst.GetOpcode() == Opcode::Sub) mop = mir::Opcode::Sub;
else if (inst.GetOpcode() == Opcode::Mul) mop = mir::Opcode::Mul;
else mop = mir::Opcode::SDiv;
if (inst.GetOpcode() == Opcode::SRem) {
Operand q = NewG(4);
Emit(mir::Opcode::SDiv, {q, l, r}, 1);
Emit(mir::Opcode::MSub, {d, q, r, l}, 1); // d = l - q*r
} else {
Emit(mop, {d, l, r}, 1);
}
vmap_[&inst] = d;
break;
}
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv: {
auto& b = static_cast<const ir::BinaryInst&>(inst);
Operand l = GetReg(b.GetLhs());
Operand r = GetReg(b.GetRhs());
Operand d = NewF();
mir::Opcode mop = mir::Opcode::FAdd;
if (inst.GetOpcode() == Opcode::FSub) mop = mir::Opcode::FSub;
else if (inst.GetOpcode() == Opcode::FMul) mop = mir::Opcode::FMul;
else if (inst.GetOpcode() == Opcode::FDiv) mop = mir::Opcode::FDiv;
Emit(mop, {d, l, r}, 1);
vmap_[&inst] = d;
break;
}
case Opcode::SIToFP: {
auto& c = static_cast<const ir::CastInst&>(inst);
Operand s = GetReg(c.GetValue());
Operand d = NewF();
Emit(mir::Opcode::SCvtF, {d, s}, 1);
vmap_[&inst] = d;
break;
}
case Opcode::FPToSI: {
auto& c = static_cast<const ir::CastInst&>(inst);
Operand s = GetReg(c.GetValue());
Operand d = NewG(4);
Emit(mir::Opcode::FCvtZS, {d, s}, 1);
vmap_[&inst] = d;
break;
}
case Opcode::ZExt: {
auto& c = static_cast<const ir::CastInst&>(inst);
Operand s = GetReg(c.GetValue());
Operand d = NewG(4);
Emit(mir::Opcode::Mov, {d, s}, 1); // i1->i32cset 已产出 0/1
vmap_[&inst] = d;
break;
}
case Opcode::ICmp: {
auto& c = static_cast<const ir::ICmpInst&>(inst);
Operand l = GetReg(c.GetLhs());
Operand r = GetReg(c.GetRhs());
Emit(mir::Opcode::Cmp, {l, r}, 0);
Operand d = NewG(4);
Emit(mir::Opcode::CSet, {d}, 1, ICmpCond(c.GetPredicate()));
vmap_[&inst] = d;
break;
}
case Opcode::FCmp: {
auto& c = static_cast<const ir::FCmpInst&>(inst);
Operand l = GetReg(c.GetLhs());
Operand r = GetReg(c.GetRhs());
Emit(mir::Opcode::FCmp, {l, r}, 0);
Operand d = NewG(4);
Emit(mir::Opcode::CSet, {d}, 1, FCmpCond(c.GetPredicate()));
vmap_[&inst] = d;
break;
}
default:
LowerInstMem(inst);
break;
}
}
void Lowerer::LowerInstMem(const ir::Instruction& inst) {
using ir::Opcode;
switch (inst.GetOpcode()) {
case Opcode::Alloca: {
auto& a = static_cast<const ir::AllocaInst&>(inst);
int idx = mf_->CreateStackObject(TypeSize(*a.GetAllocatedType()),
a.GetAllocatedType()->IsArray() ? 16 : 4);
allocas_[&a] = idx;
Operand d = NewG(8);
Emit(mir::Opcode::AddrFrame, {d, Operand::Frame(idx)}, 1);
vmap_[&inst] = d;
break;
}
case Opcode::Load: {
auto& ld = static_cast<const ir::LoadInst&>(inst);
Operand base;
long long off = 0;
if (auto* gep = dynamic_cast<const ir::GepInst*>(ld.GetPtr())) {
off = GepConst(*gep, &base);
} else {
base = AddressOf(ld.GetPtr());
}
bool is_f = inst.GetType()->IsFloat();
Operand d = is_f ? NewF() : NewG(BytesOf(*inst.GetType()));
Emit(mir::Opcode::Ldr, {d, base, Operand::Imm(off)}, 1);
vmap_[&inst] = d;
break;
}
case Opcode::Store: {
auto& st = static_cast<const ir::StoreInst&>(inst);
Operand val = GetReg(st.GetValue());
Operand base;
long long off = 0;
if (auto* gep = dynamic_cast<const ir::GepInst*>(st.GetPtr())) {
off = GepConst(*gep, &base);
} else {
base = AddressOf(st.GetPtr());
}
Emit(mir::Opcode::Str, {val, base, Operand::Imm(off)}, 0);
break;
}
case Opcode::Gep: {
auto& gep = static_cast<const ir::GepInst&>(inst);
Operand base;
long long off = GepConst(gep, &base);
Operand d = NewG(8);
if (off == 0) {
Emit(mir::Opcode::Mov, {d, base}, 1);
} else {
Emit(mir::Opcode::AddImm, {d, base, Operand::Imm(off)}, 1);
}
vmap_[&inst] = d;
break;
}
case Opcode::Call: {
auto& call = static_cast<const ir::CallInst&>(inst);
// 先把所有实参算入虚拟寄存器,再连续搬入物理参数寄存器,
// 避免计算后续实参时分配器复用 x0..x7 破坏已就绪的参数。
std::vector<Operand> vals;
for (auto* arg : call.GetArgs()) vals.push_back(GetReg(arg));
int ig = 0, fg = 0;
std::vector<Operand> arg_uses;
for (size_t i = 0; i < call.GetArgs().size(); ++i) {
ir::Value* arg = call.GetArgs()[i];
if (arg->GetType()->IsFloat()) {
Operand p = Operand::PReg(fg++, RegClass::FPR, 4);
Emit(mir::Opcode::FMov, {p, vals[i]}, 1);
arg_uses.push_back(p);
} else {
int bytes = arg->GetType()->IsPointer() ? 8 : 4;
Operand p = Operand::PReg(ig++, RegClass::GPR, bytes);
Emit(mir::Opcode::Mov, {p, vals[i]}, 1);
arg_uses.push_back(p);
}
}
std::vector<Operand> ops;
ops.push_back(Operand::Global(call.GetCallee()->GetName()));
for (auto& u : arg_uses) ops.push_back(u);
Emit(mir::Opcode::Bl, ops, 0);
if (!inst.GetType()->IsVoid()) {
if (inst.GetType()->IsFloat()) {
Operand d = NewF();
Emit(mir::Opcode::FMov, {d, Operand::PReg(0, RegClass::FPR, 4)}, 1);
vmap_[&inst] = d;
} else {
Operand d = NewG(BytesOf(*inst.GetType()));
Emit(mir::Opcode::Mov,
{d, Operand::PReg(0, RegClass::GPR, d.GetBytes())}, 1);
vmap_[&inst] = d;
}
}
break;
}
case Opcode::Br: {
auto& br = static_cast<const ir::BranchInst&>(inst);
Emit(mir::Opcode::B, {Operand::Label(br.GetDest()->GetName())}, 0);
break;
}
case Opcode::CondBr: {
auto& cbr = static_cast<const ir::CondBrInst&>(inst);
Operand c = GetReg(cbr.GetCond());
Emit(mir::Opcode::CmpImm, {c, Operand::Imm(0)}, 0);
Emit(mir::Opcode::BCond, {Operand::Label(cbr.GetTrueDest()->GetName())}, 0,
Cond::NE);
Emit(mir::Opcode::B, {Operand::Label(cbr.GetFalseDest()->GetName())}, 0);
break;
}
case Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
if (ret.HasReturnValue()) {
Operand v = GetReg(ret.GetValue());
if (ret.GetValue()->GetType()->IsFloat()) {
Emit(mir::Opcode::FMov, {Operand::PReg(0, RegClass::FPR, 4), v}, 1);
} else {
Emit(mir::Opcode::Mov,
{Operand::PReg(0, RegClass::GPR, v.GetBytes()), v}, 1);
}
}
Emit(mir::Opcode::Ret, {}, 0);
break;
}
default:
break;
}
}
} // namespace
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) {
auto out = std::make_unique<MachineModule>();
Lowerer lo(module, *out);
lo.Run();
return out;
}
} // namespace mir