You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/mir/Lowering.cpp

1922 lines
68 KiB

#include "mir/MIR.h"
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <unordered_map>
#include "ir/IR.h"
#include "utils/Log.h"
namespace mir
{
namespace
{
using ValueVRegMap = std::unordered_map<const ir::Value *, int>;
using LocalScalarMap = std::unordered_map<const ir::Value *, int>;
using LocalArrayMap = std::unordered_map<const ir::Value *, int>;
using BlockMap = std::unordered_map<const ir::BasicBlock *, MachineBasicBlock *>;
static bool TryGetConstantInt(const ir::Value *value, int &out);
static int GetTypeSize(const std::shared_ptr<ir::Type> &type)
{
if (!type)
return 4;
if (type->IsPtrInt32() || type->IsPtrFloat32())
return 8;
return 4;
}
static int AlignTo(int value, int align)
{
return ((value + align - 1) / align) * align;
}
static bool IsPointerValue(const ir::Value *value)
{
if (!value)
return false;
auto type = value->GetType();
return type && (type->IsPtrInt32() || type->IsPtrFloat32());
}
static bool IsPointerType(const std::shared_ptr<ir::Type> &type)
{
return type && (type->IsPtrInt32() || type->IsPtrFloat32());
}
static bool IsFloatType(const std::shared_ptr<ir::Type> &type)
{
return type && type->IsFloat32();
}
static bool IsFloatValue(const ir::Value *value)
{
return value && IsFloatType(value->GetType());
}
static bool IsIntegerCompareOpcode(ir::Opcode opcode)
{
switch (opcode)
{
case ir::Opcode::Eq:
case ir::Opcode::Ne:
case ir::Opcode::Lt:
case ir::Opcode::Le:
case ir::Opcode::Gt:
case ir::Opcode::Ge:
return true;
default:
return false;
}
}
static CondCode GetCondCodeForCompareOpcode(ir::Opcode opcode)
{
switch (opcode)
{
case ir::Opcode::Eq:
return CondCode::EQ;
case ir::Opcode::Ne:
return CondCode::NE;
case ir::Opcode::Lt:
return CondCode::LT;
case ir::Opcode::Le:
return CondCode::LE;
case ir::Opcode::Gt:
return CondCode::GT;
case ir::Opcode::Ge:
return CondCode::GE;
default:
throw std::runtime_error(FormatError("mir", "不支持的比较 opcode"));
}
}
static CondCode NegateCondCode(CondCode cond)
{
switch (cond)
{
case CondCode::EQ:
return CondCode::NE;
case CondCode::NE:
return CondCode::EQ;
case CondCode::LT:
return CondCode::GE;
case CondCode::LE:
return CondCode::GT;
case CondCode::GT:
return CondCode::LE;
case CondCode::GE:
return CondCode::LT;
default:
return CondCode::NE;
}
}
static PhysReg GetArgWReg(size_t index)
{
static const PhysReg regs[] = {
PhysReg::W0, PhysReg::W1, PhysReg::W2, PhysReg::W3,
PhysReg::W4, PhysReg::W5, PhysReg::W6, PhysReg::W7};
return index < 8 ? regs[index] : PhysReg::W0;
}
static PhysReg GetArgXReg(size_t index)
{
static const PhysReg regs[] = {
PhysReg::X0, PhysReg::X1, PhysReg::X2, PhysReg::X3,
PhysReg::X4, PhysReg::X5, PhysReg::X6, PhysReg::X7};
return index < 8 ? regs[index] : PhysReg::X0;
}
static PhysReg GetArgSReg(size_t index)
{
static const PhysReg regs[] = {
PhysReg::S0, PhysReg::S1, PhysReg::S2, PhysReg::S3,
PhysReg::S4, PhysReg::S5, PhysReg::S6, PhysReg::S7};
return index < 8 ? regs[index] : PhysReg::S0;
}
static bool TryGetConstantInt(const ir::Value *value, int &out)
{
if (auto *constant = dynamic_cast<const ir::ConstantInt *>(value))
{
out = constant->GetValue();
return true;
}
return false;
}
static int FloatToBits(float value)
{
int bits = 0;
std::memcpy(&bits, &value, sizeof(bits));
return bits;
}
static bool TryGetConstantFloatBits(const ir::Value *value, int &out)
{
if (auto *constant = dynamic_cast<const ir::ConstantFloat *>(value))
{
out = FloatToBits(static_cast<float>(constant->GetValue()));
return true;
}
return false;
}
static const ir::GlobalVariable *AsGlobalScalarObject(const ir::Value *value)
{
auto *global = dynamic_cast<const ir::GlobalVariable *>(value);
if (!global)
return nullptr;
if (global->IsArray())
return nullptr;
if (!IsPointerValue(global))
return nullptr;
return global;
}
static const ir::GlobalVariable *AsGlobalArrayObject(const ir::Value *value)
{
auto *global = dynamic_cast<const ir::GlobalVariable *>(value);
if (!global)
return nullptr;
if (!global->IsArray())
return nullptr;
if (!IsPointerValue(global))
return nullptr;
return global;
}
static bool IsZeroIntConstant(const ir::Value *value)
{
int imm = 0;
return TryGetConstantInt(value, imm) && imm == 0;
}
[[maybe_unused]] static bool IsSolelyConsumedByCondBr(const ir::Instruction &inst)
{
const auto &uses = inst.GetUses();
if (uses.size() != 1)
return false;
auto *user = uses.front().GetUser();
return dynamic_cast<const ir::CondBranchInst *>(user) != nullptr;
}
static bool IsSolelyConsumedByCanonicalBoolUse(const ir::Instruction &inst)
{
const auto &uses = inst.GetUses();
if (uses.size() != 1)
return false;
auto *user_inst = dynamic_cast<const ir::Instruction *>(uses.front().GetUser());
if (!user_inst)
return false;
if (dynamic_cast<const ir::CondBranchInst *>(user_inst))
return true;
if (auto *cast = dynamic_cast<const ir::CastInst *>(user_inst))
return cast->GetOpcode() == ir::Opcode::ZExt;
auto *bin = dynamic_cast<const ir::BinaryInst *>(user_inst);
if (!bin)
return false;
if (bin->GetOpcode() != ir::Opcode::Eq && bin->GetOpcode() != ir::Opcode::Ne)
return false;
return (bin->GetLhs() == &inst && IsZeroIntConstant(bin->GetRhs())) ||
(bin->GetRhs() == &inst && IsZeroIntConstant(bin->GetLhs()));
}
static bool TryResolveDirectScalarSlot(const ir::Value *ptr,
const LocalScalarMap &scalar_slots,
int &out_slot)
{
auto it = scalar_slots.find(ptr);
if (it != scalar_slots.end())
{
out_slot = it->second;
return true;
}
auto *gep = dynamic_cast<const ir::GetElementPtrInst *>(ptr);
if (!gep)
return false;
int idx = 0;
if (!TryGetConstantInt(gep->GetIndex(), idx))
return false;
if (idx != 0)
return false;
auto base_it = scalar_slots.find(gep->GetBasePtr());
if (base_it != scalar_slots.end())
{
out_slot = base_it->second;
return true;
}
return false;
}
static int EmitIntValue(const ir::Value *value, MachineFunction &function,
ValueVRegMap &value_vregs, const LocalScalarMap &scalar_slots,
const LocalArrayMap &array_slots, MachineBasicBlock &block);
static int EmitFloatValue(const ir::Value *value, MachineFunction &function,
ValueVRegMap &value_vregs, MachineBasicBlock &block);
static int EmitPtrValue(const ir::Value *value, MachineFunction &function,
ValueVRegMap &value_vregs, const LocalScalarMap &scalar_slots,
const LocalArrayMap &array_slots, MachineBasicBlock &block);
static int EmitIntValue(const ir::Value *value, MachineFunction &function,
ValueVRegMap &value_vregs, const LocalScalarMap &scalar_slots,
const LocalArrayMap &array_slots, MachineBasicBlock &block)
{
if (!value)
{
int vreg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::MovImm, {Operand::VReg(vreg, VRegClass::Int), Operand::Imm(0)});
return vreg;
}
auto it = value_vregs.find(value);
if (it != value_vregs.end())
{
if (function.GetVRegClass(it->second) == VRegClass::Float)
{
int dst = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::FCvtzs,
{Operand::VReg(dst, VRegClass::Int), Operand::VReg(it->second, VRegClass::Float)});
return dst;
}
return it->second;
}
int imm = 0;
if (TryGetConstantInt(value, imm))
{
int vreg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::MovImm, {Operand::VReg(vreg, VRegClass::Int), Operand::Imm(imm)});
value_vregs[value] = vreg;
return vreg;
}
if (TryGetConstantFloatBits(value, imm))
{
int vreg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::MovImm, {Operand::VReg(vreg, VRegClass::Int), Operand::Imm(imm)});
value_vregs[value] = vreg;
return vreg;
}
if (auto *global = AsGlobalScalarObject(value))
{
int vreg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::LoadGlobal,
{Operand::VReg(vreg, VRegClass::Int), Operand::Symbol(global->GetName())});
value_vregs[value] = vreg;
return vreg;
}
if (auto *cast = dynamic_cast<const ir::CastInst *>(value))
{
if (cast->GetOpcode() == ir::Opcode::ZExt)
{
int src = EmitIntValue(cast->GetOperandValue(), function, value_vregs,
scalar_slots, array_slots, block);
value_vregs[value] = src;
return src;
}
if (cast->GetOpcode() == ir::Opcode::FPToSI)
{
int src = EmitFloatValue(cast->GetOperandValue(), function, value_vregs, block);
int dst = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::FCvtzs,
{Operand::VReg(dst, VRegClass::Int), Operand::VReg(src, VRegClass::Float)});
value_vregs[value] = dst;
return dst;
}
}
if (auto *bin = dynamic_cast<const ir::BinaryInst *>(value))
{
if (IsIntegerCompareOpcode(bin->GetOpcode()))
{
int lhs = EmitIntValue(bin->GetLhs(), function, value_vregs,
scalar_slots, array_slots, block);
int rhs = EmitIntValue(bin->GetRhs(), function, value_vregs,
scalar_slots, array_slots, block);
block.Append(Opcode::CmpRR,
{Operand::VReg(lhs, VRegClass::Int), Operand::VReg(rhs, VRegClass::Int)});
int dst = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::CSet,
{Operand::VReg(dst, VRegClass::Int),
Operand::Imm(static_cast<int>(GetCondCodeForCompareOpcode(bin->GetOpcode())))});
value_vregs[value] = dst;
return dst;
}
if (IsFloatType(bin->GetType()))
{
return EmitFloatValue(value, function, value_vregs, block);
}
Opcode opcode = Opcode::AddRR;
switch (bin->GetOpcode())
{
case ir::Opcode::Add:
opcode = Opcode::AddRR;
break;
case ir::Opcode::Sub:
opcode = Opcode::SubRR;
break;
case ir::Opcode::Mul:
opcode = Opcode::MulRR;
break;
case ir::Opcode::Div:
opcode = Opcode::DivRR;
break;
case ir::Opcode::Mod:
opcode = Opcode::ModRR;
break;
default:
break;
}
int lhs = EmitIntValue(bin->GetLhs(), function, value_vregs,
scalar_slots, array_slots, block);
int rhs = EmitIntValue(bin->GetRhs(), function, value_vregs,
scalar_slots, array_slots, block);
int dst = function.CreateVReg(VRegClass::Int);
if (opcode == Opcode::MulRR)
{
auto *rhs_const = dynamic_cast<const ir::ConstantInt *>(bin->GetRhs());
if (rhs_const)
{
int val = rhs_const->GetValue();
if (val > 0 && (val & (val - 1)) == 0)
{
int shift = 0;
while (val > 1)
{
val >>= 1;
++shift;
}
block.Append(Opcode::ShlRR,
{Operand::VReg(dst, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int),
Operand::Imm(shift)});
value_vregs[value] = dst;
return dst;
}
}
}
if (opcode == Opcode::DivRR)
{
auto *rhs_const = dynamic_cast<const ir::ConstantInt *>(bin->GetRhs());
if (rhs_const)
{
int val = rhs_const->GetValue();
if (val == 1)
{
value_vregs[value] = lhs;
return lhs;
}
if (val == -1)
{
block.Append(Opcode::NegRR,
{Operand::VReg(dst, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int)});
value_vregs[value] = dst;
return dst;
}
if (val > 0 && (val & (val - 1)) == 0)
{
int shift = 0;
int tmp = val;
while (tmp > 1)
{
tmp >>= 1;
++shift;
}
int bias = (1 << shift) - 1;
int biased = function.CreateVReg(VRegClass::Int);
if (bias <= 4095)
{
block.Append(Opcode::AddRR,
{Operand::VReg(biased, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int),
Operand::Imm(bias)});
}
else
{
int bias_reg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::MovImm,
{Operand::VReg(bias_reg, VRegClass::Int),
Operand::Imm(bias)});
block.Append(Opcode::AddRR,
{Operand::VReg(biased, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int),
Operand::VReg(bias_reg, VRegClass::Int)});
}
block.Append(Opcode::CmpImm,
{Operand::VReg(lhs, VRegClass::Int),
Operand::Imm(0)});
int selected = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::Csel,
{Operand::VReg(selected, VRegClass::Int),
Operand::VReg(biased, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int),
Operand::Imm(static_cast<int>(CondCode::LT))});
block.Append(Opcode::AsrRR,
{Operand::VReg(dst, VRegClass::Int),
Operand::VReg(selected, VRegClass::Int),
Operand::Imm(shift)});
value_vregs[value] = dst;
return dst;
}
if (val < 0 && (-val & (-val - 1)) == 0 && val != -1)
{
int abs_val = -val;
int shift = 0;
int tmp = abs_val;
while (tmp > 1)
{
tmp >>= 1;
++shift;
}
int bias = (1 << shift) - 1;
int biased = function.CreateVReg(VRegClass::Int);
if (bias <= 4095)
{
block.Append(Opcode::AddRR,
{Operand::VReg(biased, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int),
Operand::Imm(bias)});
}
else
{
int bias_reg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::MovImm,
{Operand::VReg(bias_reg, VRegClass::Int),
Operand::Imm(bias)});
block.Append(Opcode::AddRR,
{Operand::VReg(biased, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int),
Operand::VReg(bias_reg, VRegClass::Int)});
}
block.Append(Opcode::CmpImm,
{Operand::VReg(lhs, VRegClass::Int),
Operand::Imm(0)});
int selected = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::Csel,
{Operand::VReg(selected, VRegClass::Int),
Operand::VReg(biased, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int),
Operand::Imm(static_cast<int>(CondCode::LT))});
int pos_q = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::AsrRR,
{Operand::VReg(pos_q, VRegClass::Int),
Operand::VReg(selected, VRegClass::Int),
Operand::Imm(shift)});
block.Append(Opcode::NegRR,
{Operand::VReg(dst, VRegClass::Int),
Operand::VReg(pos_q, VRegClass::Int)});
value_vregs[value] = dst;
return dst;
}
}
}
if (opcode == Opcode::ModRR)
{
auto *rhs_const = dynamic_cast<const ir::ConstantInt *>(bin->GetRhs());
if (rhs_const)
{
int val = rhs_const->GetValue();
if (val > 0 && (val & (val - 1)) == 0)
{
int bias = val - 1;
int biased = function.CreateVReg(VRegClass::Int);
if (bias <= 4095)
{
block.Append(Opcode::AddRR,
{Operand::VReg(biased, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int),
Operand::Imm(bias)});
}
else
{
int bias_reg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::MovImm,
{Operand::VReg(bias_reg, VRegClass::Int),
Operand::Imm(bias)});
block.Append(Opcode::AddRR,
{Operand::VReg(biased, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int),
Operand::VReg(bias_reg, VRegClass::Int)});
}
int shift = 0;
int tmp = val;
while (tmp > 1)
{
tmp >>= 1;
++shift;
}
block.Append(Opcode::CmpImm,
{Operand::VReg(lhs, VRegClass::Int),
Operand::Imm(0)});
int selected = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::Csel,
{Operand::VReg(selected, VRegClass::Int),
Operand::VReg(biased, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int),
Operand::Imm(static_cast<int>(CondCode::LT))});
int q_dst = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::AsrRR,
{Operand::VReg(q_dst, VRegClass::Int),
Operand::VReg(selected, VRegClass::Int),
Operand::Imm(shift)});
int d_reg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::MovImm,
{Operand::VReg(d_reg, VRegClass::Int),
Operand::Imm(val)});
block.Append(Opcode::Msub,
{Operand::VReg(dst, VRegClass::Int),
Operand::VReg(q_dst, VRegClass::Int),
Operand::VReg(d_reg, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int)});
value_vregs[value] = dst;
return dst;
}
if (val < 0 && (-val & (-val - 1)) == 0 && val != -1)
{
int abs_val = -val;
int bias = abs_val - 1;
int biased = function.CreateVReg(VRegClass::Int);
if (bias <= 4095)
{
block.Append(Opcode::AddRR,
{Operand::VReg(biased, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int),
Operand::Imm(bias)});
}
else
{
int bias_reg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::MovImm,
{Operand::VReg(bias_reg, VRegClass::Int),
Operand::Imm(bias)});
block.Append(Opcode::AddRR,
{Operand::VReg(biased, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int),
Operand::VReg(bias_reg, VRegClass::Int)});
}
int shift = 0;
int tmp = abs_val;
while (tmp > 1)
{
tmp >>= 1;
++shift;
}
block.Append(Opcode::CmpImm,
{Operand::VReg(lhs, VRegClass::Int),
Operand::Imm(0)});
int selected = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::Csel,
{Operand::VReg(selected, VRegClass::Int),
Operand::VReg(biased, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int),
Operand::Imm(static_cast<int>(CondCode::LT))});
int asr_result = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::AsrRR,
{Operand::VReg(asr_result, VRegClass::Int),
Operand::VReg(selected, VRegClass::Int),
Operand::Imm(shift)});
int q_dst = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::NegRR,
{Operand::VReg(q_dst, VRegClass::Int),
Operand::VReg(asr_result, VRegClass::Int)});
int d_reg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::MovImm,
{Operand::VReg(d_reg, VRegClass::Int),
Operand::Imm(val)});
block.Append(Opcode::Msub,
{Operand::VReg(dst, VRegClass::Int),
Operand::VReg(q_dst, VRegClass::Int),
Operand::VReg(d_reg, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int)});
value_vregs[value] = dst;
return dst;
}
}
}
block.Append(opcode,
{Operand::VReg(dst, VRegClass::Int),
Operand::VReg(lhs, VRegClass::Int),
Operand::VReg(rhs, VRegClass::Int)});
value_vregs[value] = dst;
return dst;
}
if (auto *phi = dynamic_cast<const ir::PhiInst *>(value))
{
auto phi_it = value_vregs.find(value);
if (phi_it != value_vregs.end())
return phi_it->second;
}
if (auto *load = dynamic_cast<const ir::LoadInst *>(value))
{
int scalar_slot = -1;
if (TryResolveDirectScalarSlot(load->GetPtr(), scalar_slots, scalar_slot))
{
int vreg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::LoadStack,
{Operand::VReg(vreg, VRegClass::Int), Operand::FrameIndex(scalar_slot)});
value_vregs[value] = vreg;
return vreg;
}
if (auto *global = AsGlobalScalarObject(load->GetPtr()))
{
int vreg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::LoadGlobal,
{Operand::VReg(vreg, VRegClass::Int), Operand::Symbol(global->GetName())});
value_vregs[value] = vreg;
return vreg;
}
int addr = EmitPtrValue(load->GetPtr(), function, value_vregs,
scalar_slots, array_slots, block);
int vreg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::LoadMem,
{Operand::VReg(vreg, VRegClass::Int), Operand::VReg(addr, VRegClass::Ptr)});
value_vregs[value] = vreg;
return vreg;
}
int vreg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::MovImm, {Operand::VReg(vreg, VRegClass::Int), Operand::Imm(0)});
value_vregs[value] = vreg;
return vreg;
}
static int EmitFloatValue(const ir::Value *value, MachineFunction &function,
ValueVRegMap &value_vregs, MachineBasicBlock &block)
{
if (!value)
{
int vreg = function.CreateVReg(VRegClass::Float);
int wvreg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::MovImm, {Operand::VReg(wvreg, VRegClass::Int), Operand::Imm(0)});
block.Append(Opcode::FMovWS,
{Operand::VReg(vreg, VRegClass::Float), Operand::VReg(wvreg, VRegClass::Int)});
return vreg;
}
auto it = value_vregs.find(value);
if (it != value_vregs.end())
{
if (function.GetVRegClass(it->second) != VRegClass::Float)
{
int dst = function.CreateVReg(VRegClass::Float);
block.Append(Opcode::Scvtf,
{Operand::VReg(dst, VRegClass::Float), Operand::VReg(it->second, VRegClass::Int)});
return dst;
}
return it->second;
}
int bits = 0;
if (TryGetConstantFloatBits(value, bits))
{
int wvreg = function.CreateVReg(VRegClass::Int);
int vreg = function.CreateVReg(VRegClass::Float);
block.Append(Opcode::MovImm, {Operand::VReg(wvreg, VRegClass::Int), Operand::Imm(bits)});
block.Append(Opcode::FMovWS,
{Operand::VReg(vreg, VRegClass::Float), Operand::VReg(wvreg, VRegClass::Int)});
value_vregs[value] = vreg;
return vreg;
}
if (auto *global = AsGlobalScalarObject(value))
{
int vreg = function.CreateVReg(VRegClass::Float);
block.Append(Opcode::LoadGlobal,
{Operand::VReg(vreg, VRegClass::Float), Operand::Symbol(global->GetName())});
value_vregs[value] = vreg;
return vreg;
}
if (auto *cast = dynamic_cast<const ir::CastInst *>(value))
{
if (cast->GetOpcode() == ir::Opcode::SIToFP)
{
int src = EmitIntValue(cast->GetOperandValue(), function, value_vregs,
LocalScalarMap(), LocalArrayMap(), block);
int dst = function.CreateVReg(VRegClass::Float);
block.Append(Opcode::Scvtf,
{Operand::VReg(dst, VRegClass::Float), Operand::VReg(src, VRegClass::Int)});
value_vregs[value] = dst;
return dst;
}
}
if (auto *bin = dynamic_cast<const ir::BinaryInst *>(value))
{
if (IsFloatType(bin->GetType()))
{
Opcode opcode = Opcode::FAddRR;
switch (bin->GetOpcode())
{
case ir::Opcode::Add:
opcode = Opcode::FAddRR;
break;
case ir::Opcode::Sub:
opcode = Opcode::FSubRR;
break;
case ir::Opcode::Mul:
opcode = Opcode::FMulRR;
break;
case ir::Opcode::Div:
opcode = Opcode::FDivRR;
break;
default:
break;
}
int lhs = EmitFloatValue(bin->GetLhs(), function, value_vregs, block);
int rhs = EmitFloatValue(bin->GetRhs(), function, value_vregs, block);
int dst = function.CreateVReg(VRegClass::Float);
block.Append(opcode,
{Operand::VReg(dst, VRegClass::Float),
Operand::VReg(lhs, VRegClass::Float),
Operand::VReg(rhs, VRegClass::Float)});
value_vregs[value] = dst;
return dst;
}
}
if (auto *phi = dynamic_cast<const ir::PhiInst *>(value))
{
auto phi_it = value_vregs.find(value);
if (phi_it != value_vregs.end())
return phi_it->second;
}
if (auto *load = dynamic_cast<const ir::LoadInst *>(value))
{
int scalar_slot = -1;
LocalScalarMap dummy_scalar;
if (TryResolveDirectScalarSlot(load->GetPtr(), dummy_scalar, scalar_slot))
{
int vreg = function.CreateVReg(VRegClass::Float);
block.Append(Opcode::LoadStack,
{Operand::VReg(vreg, VRegClass::Float), Operand::FrameIndex(scalar_slot)});
value_vregs[value] = vreg;
return vreg;
}
}
int vreg = function.CreateVReg(VRegClass::Float);
int wvreg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::MovImm, {Operand::VReg(wvreg, VRegClass::Int), Operand::Imm(0)});
block.Append(Opcode::FMovWS,
{Operand::VReg(vreg, VRegClass::Float), Operand::VReg(wvreg, VRegClass::Int)});
value_vregs[value] = vreg;
return vreg;
}
static int EmitPtrValue(const ir::Value *value, MachineFunction &function,
ValueVRegMap &value_vregs, const LocalScalarMap &scalar_slots,
const LocalArrayMap &array_slots, MachineBasicBlock &block)
{
auto it = value_vregs.find(value);
if (it != value_vregs.end())
return it->second;
if (auto *global_scalar = AsGlobalScalarObject(value))
{
int vreg = function.CreateVReg(VRegClass::Ptr);
block.Append(Opcode::LoadGlobalAddr,
{Operand::VReg(vreg, VRegClass::Ptr), Operand::Symbol(global_scalar->GetName())});
value_vregs[value] = vreg;
return vreg;
}
if (auto *global_array = AsGlobalArrayObject(value))
{
int vreg = function.CreateVReg(VRegClass::Ptr);
block.Append(Opcode::LoadGlobalAddr,
{Operand::VReg(vreg, VRegClass::Ptr), Operand::Symbol(global_array->GetName())});
value_vregs[value] = vreg;
return vreg;
}
auto scalar_it = scalar_slots.find(value);
if (scalar_it != scalar_slots.end())
{
int vreg = function.CreateVReg(VRegClass::Ptr);
block.Append(Opcode::LoadStackAddr,
{Operand::VReg(vreg, VRegClass::Ptr), Operand::FrameIndex(scalar_it->second)});
value_vregs[value] = vreg;
return vreg;
}
auto array_it = array_slots.find(value);
if (array_it != array_slots.end())
{
int vreg = function.CreateVReg(VRegClass::Ptr);
block.Append(Opcode::LoadStackAddr,
{Operand::VReg(vreg, VRegClass::Ptr), Operand::FrameIndex(array_it->second)});
value_vregs[value] = vreg;
return vreg;
}
auto *gep = dynamic_cast<const ir::GetElementPtrInst *>(value);
if (gep)
{
int base = EmitPtrValue(gep->GetBasePtr(), function, value_vregs,
scalar_slots, array_slots, block);
int idx_imm = 0;
if (TryGetConstantInt(gep->GetIndex(), idx_imm))
{
const int byte_offset = static_cast<int>(static_cast<unsigned int>(idx_imm) * 4u);
if (byte_offset == 0)
{
value_vregs[value] = base;
return base;
}
int dst = function.CreateVReg(VRegClass::Ptr);
int offset_vreg = function.CreateVReg(VRegClass::Ptr);
int abs_off = byte_offset > 0 ? byte_offset : -byte_offset;
block.Append(Opcode::MovImm, {Operand::VReg(offset_vreg, VRegClass::Ptr), Operand::Imm(abs_off)});
if (byte_offset > 0)
{
block.Append(Opcode::AddRR,
{Operand::VReg(dst, VRegClass::Ptr),
Operand::VReg(base, VRegClass::Ptr),
Operand::VReg(offset_vreg, VRegClass::Ptr)});
}
else
{
block.Append(Opcode::SubRR,
{Operand::VReg(dst, VRegClass::Ptr),
Operand::VReg(base, VRegClass::Ptr),
Operand::VReg(offset_vreg, VRegClass::Ptr)});
}
value_vregs[value] = dst;
return dst;
}
int idx = EmitIntValue(gep->GetIndex(), function, value_vregs,
scalar_slots, array_slots, block);
int sext = function.CreateVReg(VRegClass::Ptr);
block.Append(Opcode::Sxtw,
{Operand::VReg(sext, VRegClass::Ptr), Operand::VReg(idx, VRegClass::Int)});
int shifted = function.CreateVReg(VRegClass::Ptr);
block.Append(Opcode::ShlRR,
{Operand::VReg(shifted, VRegClass::Ptr),
Operand::VReg(sext, VRegClass::Ptr),
Operand::Imm(2)});
int dst = function.CreateVReg(VRegClass::Ptr);
block.Append(Opcode::AddRR,
{Operand::VReg(dst, VRegClass::Ptr),
Operand::VReg(base, VRegClass::Ptr),
Operand::VReg(shifted, VRegClass::Ptr)});
value_vregs[value] = dst;
return dst;
}
if (IsPointerValue(value))
{
auto vreg_it = value_vregs.find(value);
if (vreg_it != value_vregs.end())
return vreg_it->second;
}
int vreg = function.CreateVReg(VRegClass::Ptr);
block.Append(Opcode::MovImm, {Operand::VReg(vreg, VRegClass::Ptr), Operand::Imm(0)});
value_vregs[value] = vreg;
return vreg;
}
static void EmitCompareToFlags(const ir::BinaryInst &bin,
MachineFunction &function,
ValueVRegMap &value_vregs,
const LocalScalarMap &scalar_slots,
const LocalArrayMap &array_slots,
MachineBasicBlock &block)
{
if (IsFloatValue(bin.GetLhs()) || IsFloatValue(bin.GetRhs()))
{
int lhs = EmitFloatValue(bin.GetLhs(), function, value_vregs, block);
int rhs = EmitFloatValue(bin.GetRhs(), function, value_vregs, block);
block.Append(Opcode::FCmpRR,
{Operand::VReg(lhs, VRegClass::Float), Operand::VReg(rhs, VRegClass::Float)});
return;
}
int lhs = EmitIntValue(bin.GetLhs(), function, value_vregs,
scalar_slots, array_slots, block);
int rhs = EmitIntValue(bin.GetRhs(), function, value_vregs,
scalar_slots, array_slots, block);
block.Append(Opcode::CmpRR,
{Operand::VReg(lhs, VRegClass::Int), Operand::VReg(rhs, VRegClass::Int)});
}
static bool TryEmitCondValueToFlags(const ir::Value *value,
MachineFunction &function,
ValueVRegMap &value_vregs,
const LocalScalarMap &scalar_slots,
const LocalArrayMap &array_slots,
MachineBasicBlock &block,
CondCode &true_cond, int depth = 0)
{
if (!value || depth > 8)
return false;
if (auto *cast = dynamic_cast<const ir::CastInst *>(value))
{
if (cast->GetOpcode() == ir::Opcode::ZExt)
{
return TryEmitCondValueToFlags(cast->GetOperandValue(),
function, value_vregs, scalar_slots, array_slots,
block, true_cond, depth + 1);
}
}
if (auto *bin = dynamic_cast<const ir::BinaryInst *>(value))
{
if (IsIntegerCompareOpcode(bin->GetOpcode()))
{
if (bin->GetOpcode() == ir::Opcode::Eq || bin->GetOpcode() == ir::Opcode::Ne)
{
const ir::Value *inner = nullptr;
if (IsZeroIntConstant(bin->GetLhs()))
inner = bin->GetRhs();
else if (IsZeroIntConstant(bin->GetRhs()))
inner = bin->GetLhs();
if (inner)
{
CondCode inner_cond = CondCode::NE;
if (TryEmitCondValueToFlags(inner, function, value_vregs,
scalar_slots, array_slots,
block, inner_cond, depth + 1))
{
true_cond = (bin->GetOpcode() == ir::Opcode::Eq)
? NegateCondCode(inner_cond)
: inner_cond;
return true;
}
}
}
EmitCompareToFlags(*bin, function, value_vregs, scalar_slots, array_slots, block);
true_cond = GetCondCodeForCompareOpcode(bin->GetOpcode());
return true;
}
}
if (IsFloatValue(value))
{
int vreg = EmitFloatValue(value, function, value_vregs, block);
int zero_w = function.CreateVReg(VRegClass::Int);
int zero_s = function.CreateVReg(VRegClass::Float);
block.Append(Opcode::MovImm, {Operand::VReg(zero_w, VRegClass::Int), Operand::Imm(0)});
block.Append(Opcode::FMovWS,
{Operand::VReg(zero_s, VRegClass::Float), Operand::VReg(zero_w, VRegClass::Int)});
block.Append(Opcode::FCmpRR,
{Operand::VReg(vreg, VRegClass::Float), Operand::VReg(zero_s, VRegClass::Float)});
true_cond = CondCode::NE;
return true;
}
int vreg = EmitIntValue(value, function, value_vregs, scalar_slots, array_slots, block);
int zero = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::MovImm, {Operand::VReg(zero, VRegClass::Int), Operand::Imm(0)});
block.Append(Opcode::CmpRR,
{Operand::VReg(vreg, VRegClass::Int), Operand::VReg(zero, VRegClass::Int)});
true_cond = CondCode::NE;
return true;
}
static void EmitStackPointerAdjust(MachineBasicBlock &block, Opcode opcode, int amount)
{
if (amount <= 0)
return;
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X14), Operand::Imm(amount)});
block.Append(opcode, {Operand::Reg(PhysReg::SP), Operand::Reg(PhysReg::SP), Operand::Reg(PhysReg::X14)});
}
static int ComputeStackArgumentBytes(const ir::CallInst &call)
{
int total = 0;
size_t gp_idx = 0;
size_t fp_idx = 0;
for (size_t i = 0; i < call.GetNumArgs(); ++i)
{
auto *arg = call.GetArg(i);
auto type = arg ? arg->GetType() : nullptr;
if (IsFloatType(type))
{
if (fp_idx < 8)
++fp_idx;
else
total += 8;
}
else
{
if (gp_idx < 8)
++gp_idx;
else
total += 8;
}
}
return total;
}
static void LowerFunctionParams(const ir::Function &function,
MachineFunction &machine_func,
ValueVRegMap &value_vregs)
{
if (!machine_func.GetEntryPtr())
return;
auto &entry = machine_func.GetEntry();
const auto &params = function.GetParams();
size_t gp_idx = 0;
size_t fp_idx = 0;
int callee_stack_offset = 0;
for (const auto &param : params)
{
if (!param)
continue;
if (IsFloatType(param->GetType()))
{
if (fp_idx < 8)
{
int vreg = machine_func.CreateVReg(VRegClass::Float);
entry.Append(Opcode::MovReg,
{Operand::VReg(vreg, VRegClass::Float),
Operand::Reg(GetArgSReg(fp_idx))});
value_vregs[param.get()] = vreg;
++fp_idx;
}
else
{
const int arg_slot = machine_func.CreateCalleeStackArgFrameIndex(GetTypeSize(param->GetType()));
machine_func.GetFrameSlot(arg_slot).offset = callee_stack_offset;
callee_stack_offset += 8;
int vreg = machine_func.CreateVReg(VRegClass::Float);
entry.Append(Opcode::LoadStack,
{Operand::VReg(vreg, VRegClass::Float), Operand::FrameIndex(arg_slot)});
value_vregs[param.get()] = vreg;
}
}
else if (IsPointerType(param->GetType()))
{
if (gp_idx < 8)
{
int vreg = machine_func.CreateVReg(VRegClass::Ptr);
entry.Append(Opcode::MovReg,
{Operand::VReg(vreg, VRegClass::Ptr),
Operand::Reg(GetArgXReg(gp_idx))});
value_vregs[param.get()] = vreg;
++gp_idx;
}
else
{
const int arg_slot = machine_func.CreateCalleeStackArgFrameIndex(GetTypeSize(param->GetType()));
machine_func.GetFrameSlot(arg_slot).offset = callee_stack_offset;
callee_stack_offset += 8;
int vreg = machine_func.CreateVReg(VRegClass::Ptr);
entry.Append(Opcode::LoadStack,
{Operand::VReg(vreg, VRegClass::Ptr), Operand::FrameIndex(arg_slot)});
value_vregs[param.get()] = vreg;
}
}
else
{
if (gp_idx < 8)
{
int vreg = machine_func.CreateVReg(VRegClass::Int);
entry.Append(Opcode::MovReg,
{Operand::VReg(vreg, VRegClass::Int),
Operand::Reg(GetArgWReg(gp_idx))});
value_vregs[param.get()] = vreg;
++gp_idx;
}
else
{
const int arg_slot = machine_func.CreateCalleeStackArgFrameIndex(GetTypeSize(param->GetType()));
machine_func.GetFrameSlot(arg_slot).offset = callee_stack_offset;
callee_stack_offset += 8;
int vreg = machine_func.CreateVReg(VRegClass::Int);
entry.Append(Opcode::LoadStack,
{Operand::VReg(vreg, VRegClass::Int), Operand::FrameIndex(arg_slot)});
value_vregs[param.get()] = vreg;
}
}
}
}
static void EmitPhiValueStores(const ir::BasicBlock *src_bb,
const ir::BasicBlock *dst_bb,
MachineFunction &function,
const ValueVRegMap &value_vregs,
const LocalScalarMap &scalar_slots,
const LocalArrayMap &array_slots,
MachineBasicBlock &mir_block)
{
if (!src_bb || !dst_bb)
return;
for (const auto &inst_ptr : dst_bb->GetInstructions())
{
auto *phi = dynamic_cast<const ir::PhiInst *>(inst_ptr.get());
if (!phi)
break;
auto phi_it = value_vregs.find(phi);
if (phi_it == value_vregs.end())
continue;
int phi_vreg = phi_it->second;
const ir::Value *incoming_value = nullptr;
size_t num_ops = phi->GetNumOperands();
for (size_t i = 0; i + 1 < num_ops; i += 2)
{
auto *val = phi->GetOperand(i);
auto *bb_ptr = dynamic_cast<ir::BasicBlock *>(phi->GetOperand(i + 1));
if (bb_ptr && bb_ptr == src_bb)
{
incoming_value = val;
break;
}
}
if (!incoming_value)
continue;
VRegClass phi_class = function.GetVRegClass(phi_vreg);
if (phi_class == VRegClass::Float)
{
int src = EmitFloatValue(incoming_value, function,
const_cast<ValueVRegMap &>(value_vregs), mir_block);
if (phi_vreg == src) {
// self-referencing PHI, skip
} else {
mir_block.Append(Opcode::MovReg,
{Operand::VReg(phi_vreg, VRegClass::Float),
Operand::VReg(src, VRegClass::Float)});
}
}
else if (phi_class == VRegClass::Ptr)
{
int src = EmitPtrValue(incoming_value, function,
const_cast<ValueVRegMap &>(value_vregs),
scalar_slots, array_slots, mir_block);
if (phi_vreg == src) {
// self-referencing PHI, skip
} else {
mir_block.Append(Opcode::MovReg,
{Operand::VReg(phi_vreg, VRegClass::Ptr),
Operand::VReg(src, VRegClass::Ptr)});
}
}
else
{
int src = EmitIntValue(incoming_value, function,
const_cast<ValueVRegMap &>(value_vregs),
scalar_slots, array_slots, mir_block);
if (phi_vreg == src) {
// self-referencing PHI, skip
} else {
mir_block.Append(Opcode::MovReg,
{Operand::VReg(phi_vreg, VRegClass::Int),
Operand::VReg(src, VRegClass::Int)});
}
}
}
}
static void LowerInstruction(const ir::Instruction &inst,
MachineFunction &function,
ValueVRegMap &value_vregs,
LocalScalarMap &scalar_slots,
LocalArrayMap &array_slots,
const BlockMap &block_map,
MachineBasicBlock &block)
{
switch (inst.GetOpcode())
{
case ir::Opcode::Alloca:
{
auto &alloca = static_cast<const ir::AllocaInst &>(inst);
const int elem_size = GetTypeSize(alloca.GetElementType());
if (alloca.IsArrayAlloca())
{
int count = 0;
if (TryGetConstantInt(alloca.GetCount(), count) && count > 0)
array_slots[&inst] = function.CreateFrameIndex(elem_size * count);
else
array_slots[&inst] = function.CreateFrameIndex(elem_size);
}
else
{
const int slot = function.CreateFrameIndex(elem_size);
scalar_slots[&inst] = slot;
}
return;
}
case ir::Opcode::Load:
{
auto &load = static_cast<const ir::LoadInst &>(inst);
const bool is_ptr = IsPointerType(load.GetType());
const bool is_float = IsFloatType(load.GetType());
int scalar_slot = -1;
if (TryResolveDirectScalarSlot(load.GetPtr(), scalar_slots, scalar_slot))
{
if (is_ptr)
{
int vreg = function.CreateVReg(VRegClass::Ptr);
block.Append(Opcode::LoadStack,
{Operand::VReg(vreg, VRegClass::Ptr), Operand::FrameIndex(scalar_slot)});
value_vregs[&load] = vreg;
}
else if (is_float)
{
int vreg = function.CreateVReg(VRegClass::Float);
block.Append(Opcode::LoadStack,
{Operand::VReg(vreg, VRegClass::Float), Operand::FrameIndex(scalar_slot)});
value_vregs[&load] = vreg;
}
else
{
int vreg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::LoadStack,
{Operand::VReg(vreg, VRegClass::Int), Operand::FrameIndex(scalar_slot)});
value_vregs[&load] = vreg;
}
return;
}
if (!is_ptr)
{
if (auto *global = AsGlobalScalarObject(load.GetPtr()))
{
if (is_float)
{
int vreg = function.CreateVReg(VRegClass::Float);
block.Append(Opcode::LoadGlobal,
{Operand::VReg(vreg, VRegClass::Float), Operand::Symbol(global->GetName())});
value_vregs[&load] = vreg;
}
else
{
int vreg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::LoadGlobal,
{Operand::VReg(vreg, VRegClass::Int), Operand::Symbol(global->GetName())});
value_vregs[&load] = vreg;
}
return;
}
}
int addr = EmitPtrValue(load.GetPtr(), function, value_vregs,
scalar_slots, array_slots, block);
if (is_ptr)
{
int vreg = function.CreateVReg(VRegClass::Ptr);
block.Append(Opcode::LoadMem,
{Operand::VReg(vreg, VRegClass::Ptr), Operand::VReg(addr, VRegClass::Ptr)});
value_vregs[&load] = vreg;
}
else if (is_float)
{
int vreg = function.CreateVReg(VRegClass::Float);
block.Append(Opcode::LoadMem,
{Operand::VReg(vreg, VRegClass::Float), Operand::VReg(addr, VRegClass::Ptr)});
value_vregs[&load] = vreg;
}
else
{
int vreg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::LoadMem,
{Operand::VReg(vreg, VRegClass::Int), Operand::VReg(addr, VRegClass::Ptr)});
value_vregs[&load] = vreg;
}
return;
}
case ir::Opcode::Store:
{
auto &store = static_cast<const ir::StoreInst &>(inst);
const bool value_is_ptr = IsPointerType(store.GetValue()->GetType());
const bool value_is_float = IsFloatType(store.GetValue()->GetType());
int scalar_slot = -1;
if (TryResolveDirectScalarSlot(store.GetPtr(), scalar_slots, scalar_slot))
{
if (value_is_ptr)
{
int vreg = EmitPtrValue(store.GetValue(), function, value_vregs,
scalar_slots, array_slots, block);
block.Append(Opcode::StoreStack,
{Operand::VReg(vreg, VRegClass::Ptr), Operand::FrameIndex(scalar_slot)});
}
else if (value_is_float)
{
int vreg = EmitFloatValue(store.GetValue(), function, value_vregs, block);
block.Append(Opcode::StoreStack,
{Operand::VReg(vreg, VRegClass::Float), Operand::FrameIndex(scalar_slot)});
}
else
{
int vreg = EmitIntValue(store.GetValue(), function, value_vregs,
scalar_slots, array_slots, block);
block.Append(Opcode::StoreStack,
{Operand::VReg(vreg, VRegClass::Int), Operand::FrameIndex(scalar_slot)});
}
return;
}
if (!value_is_ptr)
{
if (auto *global = AsGlobalScalarObject(store.GetPtr()))
{
if (value_is_float)
{
int vreg = EmitFloatValue(store.GetValue(), function, value_vregs, block);
block.Append(Opcode::StoreGlobal,
{Operand::VReg(vreg, VRegClass::Float), Operand::Symbol(global->GetName())});
}
else
{
int vreg = EmitIntValue(store.GetValue(), function, value_vregs,
scalar_slots, array_slots, block);
block.Append(Opcode::StoreGlobal,
{Operand::VReg(vreg, VRegClass::Int), Operand::Symbol(global->GetName())});
}
return;
}
}
int addr = EmitPtrValue(store.GetPtr(), function, value_vregs,
scalar_slots, array_slots, block);
if (value_is_ptr)
{
int val = EmitPtrValue(store.GetValue(), function, value_vregs,
scalar_slots, array_slots, block);
block.Append(Opcode::StoreMem,
{Operand::VReg(val, VRegClass::Ptr), Operand::VReg(addr, VRegClass::Ptr)});
}
else if (value_is_float)
{
int val = EmitFloatValue(store.GetValue(), function, value_vregs, block);
block.Append(Opcode::StoreMem,
{Operand::VReg(val, VRegClass::Float), Operand::VReg(addr, VRegClass::Ptr)});
}
else
{
int val = EmitIntValue(store.GetValue(), function, value_vregs,
scalar_slots, array_slots, block);
block.Append(Opcode::StoreMem,
{Operand::VReg(val, VRegClass::Int), Operand::VReg(addr, VRegClass::Ptr)});
}
return;
}
case ir::Opcode::GEP:
return;
case ir::Opcode::Add:
case ir::Opcode::Sub:
case ir::Opcode::Mul:
case ir::Opcode::Div:
case ir::Opcode::Mod:
{
auto &bin = static_cast<const ir::BinaryInst &>(inst);
if (IsFloatType(bin.GetType()))
{
EmitFloatValue(&bin, function, value_vregs, block);
return;
}
EmitIntValue(&bin, function, value_vregs, scalar_slots, array_slots, block);
return;
}
case ir::Opcode::Eq:
case ir::Opcode::Ne:
case ir::Opcode::Lt:
case ir::Opcode::Le:
case ir::Opcode::Gt:
case ir::Opcode::Ge:
{
auto &bin = static_cast<const ir::BinaryInst &>(inst);
if (IsSolelyConsumedByCanonicalBoolUse(bin))
return;
EmitIntValue(&bin, function, value_vregs, scalar_slots, array_slots, block);
return;
}
case ir::Opcode::SIToFP:
case ir::Opcode::FPToSI:
case ir::Opcode::ZExt:
{
auto &cast = static_cast<const ir::CastInst &>(inst);
if (inst.GetOpcode() == ir::Opcode::ZExt)
{
if (IsSolelyConsumedByCanonicalBoolUse(cast))
return;
}
if (inst.GetOpcode() == ir::Opcode::SIToFP)
{
EmitFloatValue(&inst, function, value_vregs, block);
return;
}
if (inst.GetOpcode() == ir::Opcode::FPToSI)
{
EmitIntValue(&inst, function, value_vregs, scalar_slots, array_slots, block);
return;
}
if (inst.GetOpcode() == ir::Opcode::ZExt)
{
EmitIntValue(&inst, function, value_vregs, scalar_slots, array_slots, block);
return;
}
return;
}
case ir::Opcode::Phi:
return;
case ir::Opcode::Br:
{
auto &br = static_cast<const ir::BranchInst &>(inst);
auto it = block_map.find(br.GetTarget());
if (it != block_map.end() && it->second)
{
EmitPhiValueStores(inst.GetParent(), br.GetTarget(), function,
value_vregs, scalar_slots, array_slots, block);
block.Append(Opcode::Br, {Operand::Label(it->second->GetLabelId())});
}
return;
}
case ir::Opcode::CondBr:
{
auto &br = static_cast<const ir::CondBranchInst &>(inst);
CondCode true_cond = CondCode::NE;
TryEmitCondValueToFlags(br.GetCond(), function, value_vregs,
scalar_slots, array_slots, block, true_cond);
auto true_it = block_map.find(br.GetTrueTarget());
if (true_it != block_map.end() && true_it->second)
{
EmitPhiValueStores(inst.GetParent(), br.GetTrueTarget(), function,
value_vregs, scalar_slots, array_slots, block);
block.Append(Opcode::CondBr,
{Operand::Imm(static_cast<int>(true_cond)),
Operand::Label(true_it->second->GetLabelId())});
}
auto false_it = block_map.find(br.GetFalseTarget());
if (false_it != block_map.end() && false_it->second)
{
EmitPhiValueStores(inst.GetParent(), br.GetFalseTarget(), function,
value_vregs, scalar_slots, array_slots, block);
block.Append(Opcode::Br,
{Operand::Label(false_it->second->GetLabelId())});
}
return;
}
case ir::Opcode::Call:
{
auto &call = static_cast<const ir::CallInst &>(inst);
auto *callee = call.GetCallee();
if (!callee)
{
if (!call.GetType()->IsVoid())
{
if (IsPointerType(call.GetType()))
{
int vreg = function.CreateVReg(VRegClass::Ptr);
block.Append(Opcode::MovImm, {Operand::VReg(vreg, VRegClass::Ptr), Operand::Imm(0)});
value_vregs[&call] = vreg;
}
else if (IsFloatType(call.GetType()))
{
int vreg = function.CreateVReg(VRegClass::Float);
int wvreg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::MovImm, {Operand::VReg(wvreg, VRegClass::Int), Operand::Imm(0)});
block.Append(Opcode::FMovWS,
{Operand::VReg(vreg, VRegClass::Float), Operand::VReg(wvreg, VRegClass::Int)});
value_vregs[&call] = vreg;
}
else
{
int vreg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::MovImm, {Operand::VReg(vreg, VRegClass::Int), Operand::Imm(0)});
value_vregs[&call] = vreg;
}
}
return;
}
std::vector<size_t> stack_arg_indices;
size_t gp_idx = 0;
size_t fp_idx = 0;
for (size_t i = 0; i < call.GetNumArgs(); ++i)
{
auto *arg = call.GetArg(i);
if (!arg)
continue;
if (IsFloatType(arg->GetType()))
{
if (fp_idx < 8)
{
int vreg = EmitFloatValue(arg, function, value_vregs, block);
block.Append(Opcode::MovReg,
{Operand::Reg(GetArgSReg(fp_idx)),
Operand::VReg(vreg, VRegClass::Float)});
++fp_idx;
}
else
{
stack_arg_indices.push_back(i);
}
}
else if (IsPointerValue(arg))
{
if (gp_idx < 8)
{
int vreg = EmitPtrValue(arg, function, value_vregs,
scalar_slots, array_slots, block);
block.Append(Opcode::MovReg,
{Operand::Reg(GetArgXReg(gp_idx)),
Operand::VReg(vreg, VRegClass::Ptr)});
++gp_idx;
}
else
{
stack_arg_indices.push_back(i);
}
}
else
{
if (gp_idx < 8)
{
int vreg = EmitIntValue(arg, function, value_vregs,
scalar_slots, array_slots, block);
block.Append(Opcode::MovReg,
{Operand::Reg(GetArgWReg(gp_idx)),
Operand::VReg(vreg, VRegClass::Int)});
++gp_idx;
}
else
{
stack_arg_indices.push_back(i);
}
}
}
const int raw_stack_arg_bytes = ComputeStackArgumentBytes(call);
const int aligned_stack_arg_bytes = AlignTo(raw_stack_arg_bytes, 16);
if (aligned_stack_arg_bytes > 0)
{
EmitStackPointerAdjust(block, Opcode::SubRR, aligned_stack_arg_bytes);
int offset = 0;
for (size_t idx : stack_arg_indices)
{
auto *arg = call.GetArg(idx);
const int arg_size = GetTypeSize(arg ? arg->GetType() : nullptr);
if (!arg)
{
offset += 8;
continue;
}
const int slot = function.CreateStackArgFrameIndex(arg_size);
function.GetFrameSlot(slot).offset = offset;
if (IsFloatType(arg->GetType()))
{
int vreg = EmitFloatValue(arg, function, value_vregs, block);
block.Append(Opcode::StoreStack,
{Operand::VReg(vreg, VRegClass::Float), Operand::FrameIndex(slot)});
}
else if (IsPointerValue(arg))
{
int vreg = EmitPtrValue(arg, function, value_vregs,
scalar_slots, array_slots, block);
block.Append(Opcode::StoreStack,
{Operand::VReg(vreg, VRegClass::Ptr), Operand::FrameIndex(slot)});
}
else
{
int vreg = EmitIntValue(arg, function, value_vregs,
scalar_slots, array_slots, block);
block.Append(Opcode::StoreStack,
{Operand::VReg(vreg, VRegClass::Int), Operand::FrameIndex(slot)});
}
offset += 8;
}
}
block.Append(Opcode::Call, {Operand::Symbol(callee->GetName())});
if (aligned_stack_arg_bytes > 0)
{
EmitStackPointerAdjust(block, Opcode::AddRR, aligned_stack_arg_bytes);
}
if (!call.GetType()->IsVoid())
{
if (IsPointerType(call.GetType()))
{
int vreg = function.CreateVReg(VRegClass::Ptr);
block.Append(Opcode::MovReg,
{Operand::VReg(vreg, VRegClass::Ptr), Operand::Reg(PhysReg::X0)});
value_vregs[&call] = vreg;
}
else if (IsFloatType(call.GetType()))
{
int vreg = function.CreateVReg(VRegClass::Float);
block.Append(Opcode::MovReg,
{Operand::VReg(vreg, VRegClass::Float), Operand::Reg(PhysReg::S0)});
value_vregs[&call] = vreg;
}
else
{
int vreg = function.CreateVReg(VRegClass::Int);
block.Append(Opcode::MovReg,
{Operand::VReg(vreg, VRegClass::Int), Operand::Reg(PhysReg::W0)});
value_vregs[&call] = vreg;
}
}
return;
}
case ir::Opcode::Ret:
{
auto &ret = static_cast<const ir::ReturnInst &>(inst);
if (ret.HasValue())
{
if (ret.GetValue()->GetType()->IsPtrInt32() ||
ret.GetValue()->GetType()->IsPtrFloat32())
{
int vreg = EmitPtrValue(ret.GetValue(), function, value_vregs,
scalar_slots, array_slots, block);
block.Append(Opcode::MovReg,
{Operand::Reg(PhysReg::X0), Operand::VReg(vreg, VRegClass::Ptr)});
}
else if (IsFloatValue(ret.GetValue()))
{
int vreg = EmitFloatValue(ret.GetValue(), function, value_vregs, block);
block.Append(Opcode::MovReg,
{Operand::Reg(PhysReg::S0), Operand::VReg(vreg, VRegClass::Float)});
}
else
{
int vreg = EmitIntValue(ret.GetValue(), function, value_vregs,
scalar_slots, array_slots, block);
block.Append(Opcode::MovReg,
{Operand::Reg(PhysReg::W0), Operand::VReg(vreg, VRegClass::Int)});
}
}
block.Append(Opcode::Ret);
return;
}
default:
return;
}
}
static void LowerOneFunction(const ir::Function &ir_function,
MachineFunction &machine_func)
{
ValueVRegMap value_vregs;
LocalScalarMap scalar_slots;
LocalArrayMap array_slots;
BlockMap block_map;
const auto *entry = ir_function.GetEntry();
if (!entry)
{
throw std::runtime_error(
FormatError("mir", "IR 函数缺少入口基本块: " + ir_function.GetName()));
}
block_map.emplace(entry, &machine_func.GetEntry());
for (const auto &bb : ir_function.GetBlocks())
{
if (!bb || bb.get() == entry)
continue;
block_map.emplace(bb.get(), &machine_func.CreateBlock(bb->GetName()));
}
LowerFunctionParams(ir_function, machine_func, value_vregs);
for (const auto &bb : ir_function.GetBlocks())
{
if (!bb)
continue;
for (const auto &inst : bb->GetInstructions())
{
auto *phi = dynamic_cast<const ir::PhiInst *>(inst.get());
if (!phi)
break;
VRegClass vc = VRegClass::Int;
if (IsFloatType(phi->GetType()))
vc = VRegClass::Float;
else if (IsPointerType(phi->GetType()))
vc = VRegClass::Ptr;
int phi_vreg = machine_func.CreateVReg(vc);
value_vregs[phi] = phi_vreg;
}
}
for (const auto &bb : ir_function.GetBlocks())
{
if (!bb)
continue;
auto it = block_map.find(bb.get());
if (it == block_map.end() || !it->second)
continue;
auto &mir_block = *it->second;
std::vector<const ir::Value *> to_remove;
for (auto &pair : value_vregs)
{
if (dynamic_cast<const ir::ConstantInt *>(pair.first) ||
dynamic_cast<const ir::ConstantFloat *>(pair.first) ||
AsGlobalScalarObject(pair.first) ||
AsGlobalArrayObject(pair.first) ||
dynamic_cast<const ir::AllocaInst *>(pair.first) ||
dynamic_cast<const ir::GetElementPtrInst *>(pair.first))
{
to_remove.push_back(pair.first);
}
}
for (auto *v : to_remove)
value_vregs.erase(v);
for (const auto &inst : bb->GetInstructions())
{
LowerInstruction(*inst, machine_func,
value_vregs, scalar_slots, array_slots,
block_map, mir_block);
}
}
}
static void LowerGlobals(const ir::Module &module,
MachineModule &machine_module)
{
for (const auto &global : module.GetGlobals())
{
if (!global)
continue;
if (!IsPointerValue(global.get()))
continue;
if (global->IsArray())
{
if (global->IsPtrFloat32())
{
std::vector<int> init_bits;
if (global->HasInitValues())
{
const auto &init_values = global->GetInitFloatValues();
init_bits.reserve(init_values.size());
for (double v : init_values)
{
init_bits.push_back(FloatToBits(static_cast<float>(v)));
}
}
machine_module.AddGlobalArrayI32(global->GetName(),
global->GetArraySize(),
std::move(init_bits));
}
else if (global->HasInitValues())
{
machine_module.AddGlobalArrayI32(global->GetName(),
global->GetArraySize(),
global->GetInitValues());
}
else
{
machine_module.AddGlobalArrayI32(global->GetName(),
global->GetArraySize());
}
continue;
}
if (global->IsPtrFloat32())
{
machine_module.AddGlobalI32(global->GetName(),
FloatToBits(static_cast<float>(global->GetInitFloatValue())));
}
else
{
machine_module.AddGlobalI32(global->GetName(), global->GetInitValue());
}
}
}
} // namespace
std::unique_ptr<MachineModule> LowerModuleToMIR(const ir::Module &module)
{
DefaultContext();
auto machine_module = std::make_unique<MachineModule>();
LowerGlobals(module, *machine_module);
for (const auto &func : module.GetFunctions())
{
if (!func || func->IsExternal())
continue;
auto &machine_func = machine_module->CreateFunction(func->GetName());
LowerOneFunction(*func, machine_func);
}
return machine_module;
}
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module &module)
{
auto machine_module = LowerModuleToMIR(module);
if (!machine_module)
{
throw std::runtime_error(FormatError("mir", "LowerModuleToMIR 失败"));
}
auto &functions = machine_module->GetFunctions();
for (auto &func : functions)
{
if (func && func->GetName() == "main")
{
return std::move(func);
}
}
throw std::runtime_error(FormatError("mir", "未找到 main 函数对应的 MIR"));
}
} // namespace mir