You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1155 lines
41 KiB

#include "mir/MIR.h"
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <ostream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>
#include "ir/IR.h"
#include "utils/Log.h"
namespace mir {
namespace {
int AlignTo(int value, int align) {
if (align <= 1) {
return value;
}
return ((value + align - 1) / align) * align;
}
bool IsPowerOfTwo(std::int64_t value) {
return value > 0 && (value & (value - 1)) == 0;
}
int Log2(std::int64_t value) {
int shift = 0;
while (value > 1) {
value >>= 1;
++shift;
}
return shift;
}
const char* GetDRegName(int index) {
static const char* kNames[] = {
"d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15",
"d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23",
"d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31"};
if (index < 0 || index >= 32) {
throw std::runtime_error("float register index out of range");
}
return kNames[index];
}
std::string BlockLabel(const MachineFunction& function,
const MachineBasicBlock& block) {
return ".L." + function.GetName() + "." + block.GetName();
}
std::string BlockLabel(const MachineFunction& function, const std::string& block_name) {
return ".L." + function.GetName() + "." + block_name;
}
int ToAsmAlign(int align) {
int value = 0;
int current = 1;
while (current < align) {
current <<= 1;
++value;
}
return value;
}
std::uint32_t FloatBits(float value) {
std::uint32_t bits = 0;
std::memcpy(&bits, &value, sizeof(bits));
return bits;
}
ValueType LowerAsmType(const std::shared_ptr<ir::Type>& type) {
if (!type || type->IsVoid()) {
return ValueType::Void;
}
if (type->IsInt1()) {
return ValueType::I1;
}
if (type->IsInt32()) {
return ValueType::I32;
}
if (type->IsFloat()) {
return ValueType::F32;
}
if (type->IsPointer()) {
return ValueType::Ptr;
}
throw std::runtime_error(FormatError("mir", "unsupported IR type in asm printer"));
}
int GetIRTypeAlign(const std::shared_ptr<ir::Type>& type) {
if (!type) {
return 1;
}
if (type->IsArray()) {
return GetIRTypeAlign(type->GetElementType());
}
return GetValueAlign(LowerAsmType(type));
}
const ir::Type& GetScalarElementType(const ir::Type& type) {
const ir::Type* current = &type;
while (current->IsArray()) {
current = current->GetElementType().get();
}
return *current;
}
bool IsZeroScalarConstant(const ir::Value* value) {
if (value == nullptr) {
return true;
}
if (auto* ci = ir::dyncast<ir::ConstantInt>(value)) {
return ci->GetValue() == 0;
}
if (auto* cb = ir::dyncast<ir::ConstantI1>(value)) {
return !cb->GetValue();
}
if (auto* cf = ir::dyncast<ir::ConstantFloat>(value)) {
return FloatBits(cf->GetValue()) == 0;
}
return false;
}
std::size_t CountScalarElements(const ir::Type& type) {
if (!type.IsArray()) {
return 1;
}
return type.GetNumElements() * CountScalarElements(*type.GetElementType());
}
void FlattenGlobalScalars(const ir::Type& type, ir::Value* init,
std::vector<ir::Value*>& out) {
if (!type.IsArray()) {
out.push_back(init);
return;
}
auto* array_value = ir::dyncast<ir::ConstantArrayValue>(init);
if (array_value == nullptr) {
out.insert(out.end(), CountScalarElements(type), nullptr);
return;
}
const auto& elements = array_value->GetElements();
for (std::size_t i = 0; i < CountScalarElements(type); ++i) {
out.push_back(i < elements.size() ? elements[i] : nullptr);
}
}
void EmitGlobalScalar(std::ostream& os, const ir::Type& type, ir::Value* value) {
if (type.IsFloat()) {
float number = 0.0f;
if (auto* cf = ir::dyncast<ir::ConstantFloat>(value)) {
number = cf->GetValue();
} else if (auto* ci = ir::dyncast<ir::ConstantInt>(value)) {
number = static_cast<float>(ci->GetValue());
}
os << " .word " << FloatBits(number) << "\n";
return;
}
int number = 0;
if (auto* ci = ir::dyncast<ir::ConstantInt>(value)) {
number = ci->GetValue();
} else if (auto* cb = ir::dyncast<ir::ConstantI1>(value)) {
number = cb->GetValue() ? 1 : 0;
}
os << " .word " << number << "\n";
}
void EmitGlobal(const ir::GlobalValue& global, std::ostream& os) {
const auto object_type = global.GetObjectType();
const bool zero_init = !global.HasInitializer() || IsZeroScalarConstant(global.GetInitializer());
if (object_type->IsArray()) {
std::vector<ir::Value*> flat;
FlattenGlobalScalars(*object_type, global.GetInitializer(), flat);
const bool all_zero = std::all_of(flat.begin(), flat.end(), [](ir::Value* value) {
return IsZeroScalarConstant(value);
});
if (all_zero) {
os << ".bss\n";
} else if (global.IsConstant()) {
os << ".section .rodata\n";
} else {
os << ".data\n";
}
os << " .align " << ToAsmAlign(GetIRTypeAlign(object_type)) << "\n";
os << " .global " << global.GetName() << "\n";
os << global.GetName() << ":\n";
if (all_zero) {
os << " .zero " << object_type->GetSize() << "\n";
return;
}
std::size_t index = 0;
while (index < flat.size()) {
if (IsZeroScalarConstant(flat[index])) {
std::size_t begin = index;
while (index < flat.size() && IsZeroScalarConstant(flat[index])) {
++index;
}
os << " .zero " << static_cast<int>((index - begin) * 4) << "\n";
} else {
EmitGlobalScalar(os, GetScalarElementType(*object_type), flat[index]);
++index;
}
}
return;
}
if (zero_init) {
os << ".bss\n";
} else if (global.IsConstant()) {
os << ".section .rodata\n";
} else {
os << ".data\n";
}
os << " .align " << ToAsmAlign(GetIRTypeAlign(object_type)) << "\n";
os << " .global " << global.GetName() << "\n";
os << global.GetName() << ":\n";
if (zero_init) {
os << " .zero " << object_type->GetSize() << "\n";
} else {
EmitGlobalScalar(os, *object_type, global.GetInitializer());
}
}
int FindStackObject(const MachineFunction& function, const std::string& name) {
for (const auto& object : function.GetStackObjects()) {
if (object.name == name) {
return object.index;
}
}
return -1;
}
bool Is32BitRegName(const char* reg) {
return reg != nullptr && reg[0] == 'w';
}
bool IsAddSubImm12(std::int64_t value) {
return value >= 0 && value <= 4095;
}
bool IsAddSubImm12Shifted(std::int64_t value) {
return value >= 0 && value <= (4095ll << 12) && (value & 0xfffll) == 0;
}
bool IsAddSubImm(std::int64_t value) {
return IsAddSubImm12(value) || IsAddSubImm12Shifted(value);
}
void EmitAddSubImm(std::ostream& os, const char* opcode, const char* dst,
const char* src, std::int64_t value) {
if (!IsAddSubImm(value)) {
throw std::runtime_error(FormatError("mir", "invalid add/sub immediate"));
}
os << " " << opcode << " " << dst << ", " << src << ", #";
if (IsAddSubImm12(value)) {
os << value << "\n";
return;
}
os << (value >> 12) << ", lsl #12\n";
}
void EmitAdjustRegByImm(std::ostream& os, const char* dst, const char* src,
std::int64_t value) {
if (value == 0) {
if (std::string(dst) != src) {
os << " mov " << dst << ", " << src << "\n";
}
return;
}
const char* opcode = value >= 0 ? "add" : "sub";
std::uint64_t remaining = value >= 0 ? static_cast<std::uint64_t>(value)
: static_cast<std::uint64_t>(-value);
bool first = true;
auto emit_chunk = [&](std::uint64_t amount, bool shifted) {
const char* current_src = first ? src : dst;
os << " " << opcode << " " << dst << ", " << current_src << ", #" << amount;
if (shifted) {
os << ", lsl #12";
}
os << "\n";
first = false;
};
while (remaining >= 4096) {
const std::uint64_t units = std::min<std::uint64_t>(remaining >> 12, 4095);
emit_chunk(units, true);
remaining -= units << 12;
}
if (remaining > 0) {
emit_chunk(remaining, false);
}
}
void EmitMoveImm(std::ostream& os, const char* reg, std::int64_t value) {
if (reg == nullptr || reg[0] == '\0') {
throw std::runtime_error(FormatError("mir", "invalid register for immediate materialization"));
}
const bool is32 = Is32BitRegName(reg);
if (value == 0) {
os << " mov " << reg << ", #0\n";
return;
}
if (is32) {
const std::uint32_t bits = static_cast<std::uint32_t>(value);
bool emitted = false;
for (int shift = 0; shift <= 16; shift += 16) {
const std::uint32_t chunk = (bits >> shift) & 0xffffu;
if (chunk == 0 && emitted) {
continue;
}
if (!emitted) {
os << " movz " << reg << ", #" << chunk;
if (shift != 0) {
os << ", lsl #" << shift;
}
os << "\n";
emitted = true;
} else if (chunk != 0) {
os << " movk " << reg << ", #" << chunk;
if (shift != 0) {
os << ", lsl #" << shift;
}
os << "\n";
}
}
return;
}
const std::uint64_t bits = static_cast<std::uint64_t>(value);
bool emitted = false;
for (int shift = 0; shift <= 48; shift += 16) {
const std::uint64_t chunk = (bits >> shift) & 0xffffull;
if (chunk == 0 && emitted) {
continue;
}
if (!emitted) {
os << " movz " << reg << ", #" << chunk;
if (shift != 0) {
os << ", lsl #" << shift;
}
os << "\n";
emitted = true;
} else if (chunk != 0) {
os << " movk " << reg << ", #" << chunk;
if (shift != 0) {
os << ", lsl #" << shift;
}
os << "\n";
}
}
}
void EmitCopy(std::ostream& os, const char* dst, const char* src, bool is_float) {
if (std::string(dst) == src) {
return;
}
os << " " << (is_float ? "fmov" : "mov") << " " << dst << ", " << src << "\n";
}
void EmitFrameAddress(const MachineFunction& function, int object_index,
const char* addr_reg, std::ostream& os) {
const auto& object = function.GetStackObject(object_index);
EmitAdjustRegByImm(os, addr_reg, "x29", object.offset);
}
void EmitIncomingStackAddress(int stack_offset, const char* addr_reg, std::ostream& os) {
EmitAdjustRegByImm(os, addr_reg, "x29", 16 + stack_offset);
}
void EmitLoadFromAddr(ValueType type, const char* dst, const char* addr_reg,
std::ostream& os) {
switch (type) {
case ValueType::I1:
case ValueType::I32:
os << " ldr " << dst << ", [" << addr_reg << "]\n";
break;
case ValueType::F32:
os << " ldr " << dst << ", [" << addr_reg << "]\n";
break;
case ValueType::Ptr:
os << " ldr " << dst << ", [" << addr_reg << "]\n";
break;
case ValueType::Void:
break;
}
}
void EmitStoreToAddr(ValueType type, const char* src, const char* addr_reg,
std::ostream& os) {
switch (type) {
case ValueType::I1:
case ValueType::I32:
os << " str " << src << ", [" << addr_reg << "]\n";
break;
case ValueType::F32:
os << " str " << src << ", [" << addr_reg << "]\n";
break;
case ValueType::Ptr:
os << " str " << src << ", [" << addr_reg << "]\n";
break;
case ValueType::Void:
break;
}
}
void EmitLoadSpill(const MachineFunction& function, int object_index, ValueType type,
const char* dst, std::ostream& os) {
EmitFrameAddress(function, object_index, "x17", os);
EmitLoadFromAddr(type, dst, "x17", os);
}
void EmitStoreSpill(const MachineFunction& function, int object_index, ValueType type,
const char* src, std::ostream& os) {
EmitFrameAddress(function, object_index, "x17", os);
EmitStoreToAddr(type, src, "x17", os);
}
struct DefReg {
std::string reg_name;
bool spilled = false;
int spill_object = -1;
};
DefReg PrepareGprDef(const MachineFunction& function, int vreg, int scratch_index) {
const auto& alloc = function.GetAllocation(vreg);
const auto type = function.GetVRegInfo(vreg).type;
if (alloc.kind == Allocation::Kind::PhysReg) {
return {GetPhysRegName(alloc.phys, type), false, -1};
}
return {GetPhysRegName({RegClass::GPR, scratch_index}, type), true, alloc.stack_object};
}
DefReg PrepareFprDef(const MachineFunction& function, int vreg, int scratch_index) {
const auto& alloc = function.GetAllocation(vreg);
if (alloc.kind == Allocation::Kind::PhysReg) {
return {GetPhysRegName(alloc.phys, ValueType::F32), false, -1};
}
return {GetPhysRegName({RegClass::FPR, scratch_index}, ValueType::F32), true,
alloc.stack_object};
}
void FinalizeDef(const MachineFunction& function, int vreg, const DefReg& def,
std::ostream& os) {
if (!def.spilled) {
return;
}
EmitStoreSpill(function, def.spill_object, function.GetVRegInfo(vreg).type,
def.reg_name.c_str(), os);
}
std::string MaterializeGprUse(const MachineFunction& function,
const MachineOperand& operand, ValueType type,
int scratch_index, std::ostream& os) {
const char* scratch = GetPhysRegName({RegClass::GPR, scratch_index}, type);
if (operand.GetKind() == OperandKind::Imm) {
EmitMoveImm(os, scratch, operand.GetImm());
return scratch;
}
if (operand.GetKind() != OperandKind::VReg) {
throw std::runtime_error(FormatError("mir", "expected gpr operand"));
}
const int vreg = operand.GetVReg();
const auto& alloc = function.GetAllocation(vreg);
const auto vtype = function.GetVRegInfo(vreg).type;
if (alloc.kind == Allocation::Kind::PhysReg) {
return GetPhysRegName(alloc.phys, vtype);
}
EmitLoadSpill(function, alloc.stack_object, vtype, scratch, os);
return scratch;
}
std::string MaterializeFprUse(const MachineFunction& function,
const MachineOperand& operand, int scratch_fpr,
int scratch_gpr, std::ostream& os) {
const char* scratch = GetPhysRegName({RegClass::FPR, scratch_fpr}, ValueType::F32);
if (operand.GetKind() == OperandKind::Imm) {
EmitMoveImm(os, GetPhysRegName({RegClass::GPR, scratch_gpr}, ValueType::I32),
operand.GetImm());
os << " fmov " << scratch << ", "
<< GetPhysRegName({RegClass::GPR, scratch_gpr}, ValueType::I32) << "\n";
return scratch;
}
if (operand.GetKind() != OperandKind::VReg) {
throw std::runtime_error(FormatError("mir", "expected fpr operand"));
}
const int vreg = operand.GetVReg();
const auto& alloc = function.GetAllocation(vreg);
if (alloc.kind == Allocation::Kind::PhysReg) {
return GetPhysRegName(alloc.phys, ValueType::F32);
}
EmitLoadSpill(function, alloc.stack_object, ValueType::F32, scratch, os);
return scratch;
}
void EmitAddressExpr(const MachineFunction& function, const AddressExpr& address,
std::ostream& os) {
switch (address.base_kind) {
case AddrBaseKind::FrameObject:
EmitFrameAddress(function, address.base_index, "x16", os);
break;
case AddrBaseKind::Global:
os << " adrp x16, " << address.symbol << "\n";
os << " add x16, x16, :lo12:" << address.symbol << "\n";
break;
case AddrBaseKind::VReg: {
const auto& alloc = function.GetAllocation(address.base_index);
if (alloc.kind == Allocation::Kind::PhysReg) {
EmitCopy(os, "x16", GetPhysRegName(alloc.phys, ValueType::Ptr), false);
} else {
EmitLoadSpill(function, alloc.stack_object, ValueType::Ptr, "x16", os);
}
break;
}
case AddrBaseKind::None:
throw std::runtime_error(FormatError("mir", "address expression has no base"));
}
if (address.const_offset != 0) {
EmitAdjustRegByImm(os, "x16", "x16", address.const_offset);
}
for (const auto& term : address.scaled_vregs) {
const auto index_reg = MaterializeGprUse(function, MachineOperand::VReg(term.first),
ValueType::I32, 10, os);
const std::int64_t stride = term.second;
if (stride == 0) {
continue;
}
if (IsPowerOfTwo(stride) && Log2(stride) <= 4) {
os << " add x16, x16, " << index_reg << ", sxtw #" << Log2(stride) << "\n";
continue;
}
os << " sxtw x17, " << index_reg << "\n";
EmitMoveImm(os, "x11", stride);
os << " mul x17, x17, x11\n";
os << " add x16, x16, x17\n";
}
}
const char* GetCondMnemonic(CondCode code) {
static const char* kCond[] = {"eq", "ne", "lt", "gt", "le", "ge"};
return kCond[static_cast<int>(code)];
}
bool TryEmitFusedCompareBranch(const MachineFunction& function, const MachineInstr& cmp,
const MachineInstr& branch,
const std::unordered_map<int, int>& use_counts,
std::ostream& os) {
if ((cmp.GetOpcode() != MachineInstr::Opcode::ICmp &&
cmp.GetOpcode() != MachineInstr::Opcode::FCmp) ||
branch.GetOpcode() != MachineInstr::Opcode::CondBr) {
return false;
}
const auto& cond = branch.GetOperands()[0];
if (cond.GetKind() != OperandKind::VReg) {
return false;
}
const int cond_vreg = cond.GetVReg();
if (cmp.GetOperands().empty() || cmp.GetOperands()[0].GetKind() != OperandKind::VReg ||
cmp.GetOperands()[0].GetVReg() != cond_vreg) {
return false;
}
auto it = use_counts.find(cond_vreg);
if (it == use_counts.end() || it->second != 1) {
return false;
}
if (cmp.GetOpcode() == MachineInstr::Opcode::ICmp) {
const auto lhs = MaterializeGprUse(function, cmp.GetOperands()[1], ValueType::I32, 10, os);
const auto& rhs_op = cmp.GetOperands()[2];
if (rhs_op.GetKind() == OperandKind::Imm && rhs_op.GetImm() >= 0 &&
IsAddSubImm(rhs_op.GetImm())) {
os << " cmp " << lhs << ", #" << rhs_op.GetImm() << "\n";
} else {
const auto rhs = MaterializeGprUse(function, rhs_op, ValueType::I32, 11, os);
os << " cmp " << lhs << ", " << rhs << "\n";
}
} else {
const auto lhs = MaterializeFprUse(function, cmp.GetOperands()[1], 16, 10, os);
const auto rhs = MaterializeFprUse(function, cmp.GetOperands()[2], 17, 11, os);
os << " fcmp " << lhs << ", " << rhs << "\n";
}
os << " b." << GetCondMnemonic(cmp.GetCondCode()) << " "
<< BlockLabel(function, branch.GetOperands()[1].GetText()) << "\n";
os << " b " << BlockLabel(function, branch.GetOperands()[2].GetText()) << "\n";
return true;
}
struct ArgLocation {
bool in_reg = false;
RegClass reg_class = RegClass::GPR;
int reg_index = -1;
int stack_offset = 0;
};
ArgLocation ComputeArgLocation(const std::vector<ValueType>& param_types, int target) {
int gpr = 0;
int fpr = 0;
int stack_offset = 0;
for (int i = 0; i <= target; ++i) {
const auto type = param_types[static_cast<size_t>(i)];
if (IsFPR(type)) {
if (fpr < 8) {
if (i == target) {
return {true, RegClass::FPR, fpr, 0};
}
++fpr;
} else {
if (i == target) {
return {false, RegClass::FPR, -1, stack_offset};
}
stack_offset += 8;
}
continue;
}
if (gpr < 8) {
if (i == target) {
return {true, RegClass::GPR, gpr, 0};
}
++gpr;
} else {
if (i == target) {
return {false, RegClass::GPR, -1, stack_offset};
}
stack_offset += 8;
}
}
throw std::runtime_error(FormatError("mir", "argument location computation failed"));
}
void EmitStackAdjust(std::ostream& os, const char* opcode, int bytes) {
if (bytes == 0) {
return;
}
const std::int64_t signed_bytes = opcode[0] == 's' ? -static_cast<std::int64_t>(bytes)
: static_cast<std::int64_t>(bytes);
EmitAdjustRegByImm(os, "sp", "sp", signed_bytes);
}
void EmitFunction(const MachineFunction& function, std::ostream& os) {
os << ".text\n";
os << " .align 2\n";
os << ".global " << function.GetName() << "\n";
os << ".type " << function.GetName() << ", %function\n";
os << function.GetName() << ":\n";
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
EmitStackAdjust(os, "sub", function.GetFrameSize());
for (int reg : function.GetUsedCalleeSavedGPRs()) {
const int slot = FindStackObject(function, "save.x" + std::to_string(reg));
if (slot >= 0) {
EmitFrameAddress(function, slot, "x16", os);
os << " str x" << reg << ", [x16]\n";
}
}
for (int reg : function.GetUsedCalleeSavedFPRs()) {
const int slot = FindStackObject(function, "save.v" + std::to_string(reg));
if (slot >= 0) {
EmitFrameAddress(function, slot, "x16", os);
os << " str " << GetDRegName(reg) << ", [x16]\n";
}
}
auto emit_epilogue = [&]() {
for (int reg : function.GetUsedCalleeSavedFPRs()) {
const int slot = FindStackObject(function, "save.v" + std::to_string(reg));
if (slot >= 0) {
EmitFrameAddress(function, slot, "x16", os);
os << " ldr " << GetDRegName(reg) << ", [x16]\n";
}
}
for (int reg : function.GetUsedCalleeSavedGPRs()) {
const int slot = FindStackObject(function, "save.x" + std::to_string(reg));
if (slot >= 0) {
EmitFrameAddress(function, slot, "x16", os);
os << " ldr x" << reg << ", [x16]\n";
}
}
os << " mov sp, x29\n";
os << " ldp x29, x30, [sp], #16\n";
os << " ret\n";
};
std::unordered_map<int, int> use_counts;
for (const auto& block : function.GetBlocks()) {
for (const auto& inst : block->GetInstructions()) {
for (int vreg : inst.GetUses()) {
++use_counts[vreg];
}
}
}
for (const auto& block : function.GetBlocks()) {
os << BlockLabel(function, *block) << ":\n";
const auto& instructions = block->GetInstructions();
for (std::size_t inst_index = 0; inst_index < instructions.size(); ++inst_index) {
const auto& inst = instructions[inst_index];
if ((inst.GetOpcode() == MachineInstr::Opcode::ICmp ||
inst.GetOpcode() == MachineInstr::Opcode::FCmp) &&
inst_index + 1 < instructions.size() &&
TryEmitFusedCompareBranch(function, inst, instructions[inst_index + 1], use_counts,
os)) {
++inst_index;
continue;
}
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::Arg: {
const int vreg = inst.GetOperands()[0].GetVReg();
const int arg_index = static_cast<int>(inst.GetOperands()[1].GetImm());
const auto type = function.GetVRegInfo(vreg).type;
const auto location = ComputeArgLocation(function.GetParamTypes(), arg_index);
if (IsFPR(type)) {
const auto def = PrepareFprDef(function, vreg, 16);
if (location.in_reg) {
EmitCopy(os, def.reg_name.c_str(), GetPhysRegName({RegClass::FPR, location.reg_index}, type),
true);
} else {
EmitIncomingStackAddress(location.stack_offset, "x16", os);
EmitLoadFromAddr(type, def.reg_name.c_str(), "x16", os);
}
FinalizeDef(function, vreg, def, os);
} else {
const auto def = PrepareGprDef(function, vreg, 9);
if (location.in_reg) {
EmitCopy(os, def.reg_name.c_str(),
GetPhysRegName({RegClass::GPR, location.reg_index}, type), false);
} else {
EmitIncomingStackAddress(location.stack_offset, "x16", os);
EmitLoadFromAddr(type, def.reg_name.c_str(), "x16", os);
}
FinalizeDef(function, vreg, def, os);
}
break;
}
case MachineInstr::Opcode::Copy:
case MachineInstr::Opcode::ZExt: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto type = function.GetVRegInfo(vreg).type;
if (IsFPR(type)) {
const auto def = PrepareFprDef(function, vreg, 16);
const auto src = MaterializeFprUse(function, inst.GetOperands()[1], 17, 9, os);
EmitCopy(os, def.reg_name.c_str(), src.c_str(), true);
FinalizeDef(function, vreg, def, os);
} else {
const auto def = PrepareGprDef(function, vreg, 9);
const auto src = MaterializeGprUse(function, inst.GetOperands()[1], type, 10, os);
EmitCopy(os, def.reg_name.c_str(), src.c_str(), false);
FinalizeDef(function, vreg, def, os);
}
break;
}
case MachineInstr::Opcode::Load: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto type = function.GetVRegInfo(vreg).type;
EmitAddressExpr(function, inst.GetAddress(), os);
if (IsFPR(type)) {
const auto def = PrepareFprDef(function, vreg, 16);
EmitLoadFromAddr(type, def.reg_name.c_str(), "x16", os);
FinalizeDef(function, vreg, def, os);
} else {
const auto def = PrepareGprDef(function, vreg, 9);
EmitLoadFromAddr(type, def.reg_name.c_str(), "x16", os);
FinalizeDef(function, vreg, def, os);
}
break;
}
case MachineInstr::Opcode::Store: {
const auto& src_op = inst.GetOperands()[0];
const ValueType type = src_op.GetKind() == OperandKind::VReg
? function.GetVRegInfo(src_op.GetVReg()).type
: inst.GetValueType();
EmitAddressExpr(function, inst.GetAddress(), os);
if (IsFPR(type)) {
const auto src = MaterializeFprUse(function, src_op, 16, 9, os);
EmitStoreToAddr(type, src.c_str(), "x16", os);
} else {
const auto src = MaterializeGprUse(function, src_op, type, 9, os);
EmitStoreToAddr(type, src.c_str(), "x16", os);
}
break;
}
case MachineInstr::Opcode::Lea: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareGprDef(function, vreg, 9);
EmitAddressExpr(function, inst.GetAddress(), os);
EmitCopy(os, def.reg_name.c_str(), "x16", false);
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::Add:
case MachineInstr::Opcode::Sub:
case MachineInstr::Opcode::Mul:
case MachineInstr::Opcode::Div:
case MachineInstr::Opcode::And:
case MachineInstr::Opcode::Or:
case MachineInstr::Opcode::Xor:
case MachineInstr::Opcode::Shl:
case MachineInstr::Opcode::AShr:
case MachineInstr::Opcode::LShr: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareGprDef(function, vreg, 9);
const auto& lhs_op = inst.GetOperands()[1];
const auto& rhs_op = inst.GetOperands()[2];
if (inst.GetOpcode() == MachineInstr::Opcode::Add ||
inst.GetOpcode() == MachineInstr::Opcode::Sub) {
auto emit_add_sub_imm = [&](const MachineOperand& reg_op, std::int64_t imm,
const char* pos_opcode,
const char* neg_opcode) -> bool {
if (reg_op.GetKind() == OperandKind::Imm) {
return false;
}
if (imm >= 0 && IsAddSubImm(imm)) {
const auto src = MaterializeGprUse(function, reg_op, ValueType::I32, 10, os);
EmitAddSubImm(os, pos_opcode, def.reg_name.c_str(), src.c_str(), imm);
FinalizeDef(function, vreg, def, os);
return true;
}
if (imm < 0 && IsAddSubImm(-imm)) {
const auto src = MaterializeGprUse(function, reg_op, ValueType::I32, 10, os);
EmitAddSubImm(os, neg_opcode, def.reg_name.c_str(), src.c_str(), -imm);
FinalizeDef(function, vreg, def, os);
return true;
}
return false;
};
if (rhs_op.GetKind() == OperandKind::Imm) {
if (emit_add_sub_imm(lhs_op, rhs_op.GetImm(),
inst.GetOpcode() == MachineInstr::Opcode::Add ? "add" : "sub",
inst.GetOpcode() == MachineInstr::Opcode::Add ? "sub" : "add")) {
break;
}
}
if (inst.GetOpcode() == MachineInstr::Opcode::Add &&
lhs_op.GetKind() == OperandKind::Imm) {
if (emit_add_sub_imm(rhs_op, lhs_op.GetImm(), "add", "sub")) {
break;
}
}
}
if (inst.GetOpcode() == MachineInstr::Opcode::Div &&
rhs_op.GetKind() == OperandKind::Imm &&
rhs_op.GetImm() > 0 && IsPowerOfTwo(rhs_op.GetImm())) {
const auto lhs = MaterializeGprUse(function, lhs_op, ValueType::I32, 10, os);
const int shift = Log2(rhs_op.GetImm());
if (shift == 0) {
EmitCopy(os, def.reg_name.c_str(), lhs.c_str(), false);
} else {
os << " asr w11, " << lhs << ", #31\n";
os << " and w11, w11, #" << ((1ll << shift) - 1) << "\n";
os << " add w11, " << lhs << ", w11\n";
os << " asr " << def.reg_name << ", w11, #" << shift << "\n";
}
FinalizeDef(function, vreg, def, os);
break;
}
const auto lhs = MaterializeGprUse(function, lhs_op, ValueType::I32, 10, os);
const auto rhs = MaterializeGprUse(function, rhs_op, ValueType::I32, 11, os);
const char* mnemonic = "add";
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::Add:
mnemonic = "add";
break;
case MachineInstr::Opcode::Sub:
mnemonic = "sub";
break;
case MachineInstr::Opcode::Mul:
mnemonic = "mul";
break;
case MachineInstr::Opcode::Div:
mnemonic = "sdiv";
break;
case MachineInstr::Opcode::And:
mnemonic = "and";
break;
case MachineInstr::Opcode::Or:
mnemonic = "orr";
break;
case MachineInstr::Opcode::Xor:
mnemonic = "eor";
break;
case MachineInstr::Opcode::Shl:
mnemonic = "lsl";
break;
case MachineInstr::Opcode::AShr:
mnemonic = "asr";
break;
case MachineInstr::Opcode::LShr:
mnemonic = "lsr";
break;
default:
break;
}
os << " " << mnemonic << " " << def.reg_name << ", " << lhs << ", " << rhs << "\n";
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::Rem: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareGprDef(function, vreg, 9);
const auto lhs = MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os);
const auto rhs = MaterializeGprUse(function, inst.GetOperands()[2], ValueType::I32, 11, os);
os << " sdiv w12, " << lhs << ", " << rhs << "\n";
os << " msub " << def.reg_name << ", w12, " << rhs << ", " << lhs << "\n";
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::FAdd:
case MachineInstr::Opcode::FSub:
case MachineInstr::Opcode::FMul:
case MachineInstr::Opcode::FDiv: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareFprDef(function, vreg, 16);
const auto lhs = MaterializeFprUse(function, inst.GetOperands()[1], 17, 9, os);
const auto rhs = MaterializeFprUse(function, inst.GetOperands()[2], 18, 10, os);
const char* mnemonic = "fadd";
switch (inst.GetOpcode()) {
case MachineInstr::Opcode::FAdd:
mnemonic = "fadd";
break;
case MachineInstr::Opcode::FSub:
mnemonic = "fsub";
break;
case MachineInstr::Opcode::FMul:
mnemonic = "fmul";
break;
case MachineInstr::Opcode::FDiv:
mnemonic = "fdiv";
break;
default:
break;
}
os << " " << mnemonic << " " << def.reg_name << ", " << lhs << ", " << rhs << "\n";
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::FNeg: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareFprDef(function, vreg, 16);
const auto src = MaterializeFprUse(function, inst.GetOperands()[1], 17, 9, os);
os << " fneg " << def.reg_name << ", " << src << "\n";
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::ICmp: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareGprDef(function, vreg, 9);
const auto lhs = MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os);
const auto rhs = MaterializeGprUse(function, inst.GetOperands()[2], ValueType::I32, 11, os);
os << " cmp " << lhs << ", " << rhs << "\n";
static const char* kCond[] = {"eq", "ne", "lt", "gt", "le", "ge"};
os << " cset " << def.reg_name << ", " << kCond[static_cast<int>(inst.GetCondCode())] << "\n";
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::FCmp: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareGprDef(function, vreg, 9);
const auto lhs = MaterializeFprUse(function, inst.GetOperands()[1], 16, 10, os);
const auto rhs = MaterializeFprUse(function, inst.GetOperands()[2], 17, 11, os);
os << " fcmp " << lhs << ", " << rhs << "\n";
static const char* kCond[] = {"eq", "ne", "lt", "gt", "le", "ge"};
os << " cset " << def.reg_name << ", " << kCond[static_cast<int>(inst.GetCondCode())] << "\n";
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::ItoF: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareFprDef(function, vreg, 16);
const auto src = MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os);
os << " scvtf " << def.reg_name << ", " << src << "\n";
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::FtoI: {
const int vreg = inst.GetOperands()[0].GetVReg();
const auto def = PrepareGprDef(function, vreg, 9);
const auto src = MaterializeFprUse(function, inst.GetOperands()[1], 16, 10, os);
os << " fcvtzs " << def.reg_name << ", " << src << "\n";
FinalizeDef(function, vreg, def, os);
break;
}
case MachineInstr::Opcode::Br:
os << " b " << BlockLabel(function, inst.GetOperands()[0].GetText()) << "\n";
break;
case MachineInstr::Opcode::CondBr: {
const auto& cond = inst.GetOperands()[0];
if (cond.GetKind() == OperandKind::Imm) {
os << " b " << BlockLabel(function,
cond.GetImm() != 0 ? inst.GetOperands()[1].GetText()
: inst.GetOperands()[2].GetText())
<< "\n";
break;
}
const auto cond_reg = MaterializeGprUse(function, cond, ValueType::I1, 9, os);
os << " cbnz " << cond_reg << ", "
<< BlockLabel(function, inst.GetOperands()[1].GetText()) << "\n";
os << " b " << BlockLabel(function, inst.GetOperands()[2].GetText()) << "\n";
break;
}
case MachineInstr::Opcode::Call: {
struct CallArgPlacement {
MachineOperand operand;
ValueType type = ValueType::Void;
bool on_stack = false;
int reg_index = -1;
int stack_offset = 0;
};
std::vector<CallArgPlacement> placements;
int gpr = 0;
int fpr = 0;
int stack = 0;
size_t arg_begin = inst.GetCallReturnType() == ValueType::Void ? 0 : 1;
for (size_t i = arg_begin; i < inst.GetOperands().size(); ++i) {
const auto type = inst.GetCallArgTypes()[i - arg_begin];
CallArgPlacement placement;
placement.operand = inst.GetOperands()[i];
placement.type = type;
if (IsFPR(type)) {
if (fpr < 8) {
placement.reg_index = fpr++;
} else {
placement.on_stack = true;
placement.stack_offset = stack;
stack += 8;
}
} else {
if (gpr < 8) {
placement.reg_index = gpr++;
} else {
placement.on_stack = true;
placement.stack_offset = stack;
stack += 8;
}
}
placements.push_back(placement);
}
const int stack_bytes = AlignTo(stack, 16);
EmitStackAdjust(os, "sub", stack_bytes);
for (const auto& placement : placements) {
if (!placement.on_stack) {
continue;
}
if (IsFPR(placement.type)) {
const auto src = MaterializeFprUse(function, placement.operand, 16, 9, os);
EmitMoveImm(os, "x11", placement.stack_offset);
os << " add x16, sp, x11\n";
EmitStoreToAddr(placement.type, src.c_str(), "x16", os);
} else {
const auto src = MaterializeGprUse(function, placement.operand, placement.type, 9, os);
EmitMoveImm(os, "x11", placement.stack_offset);
os << " add x16, sp, x11\n";
EmitStoreToAddr(placement.type, src.c_str(), "x16", os);
}
}
for (const auto& placement : placements) {
if (placement.on_stack) {
continue;
}
if (IsFPR(placement.type)) {
const auto src = MaterializeFprUse(function, placement.operand, 16, 9, os);
EmitCopy(os, GetPhysRegName({RegClass::FPR, placement.reg_index}, placement.type),
src.c_str(), true);
} else {
const auto src = MaterializeGprUse(function, placement.operand, placement.type, 9, os);
EmitCopy(os, GetPhysRegName({RegClass::GPR, placement.reg_index}, placement.type),
src.c_str(), false);
}
}
os << " bl " << inst.GetCallee() << "\n";
EmitStackAdjust(os, "add", stack_bytes);
if (inst.GetCallReturnType() != ValueType::Void) {
const int dest_vreg = inst.GetOperands()[0].GetVReg();
if (IsFPR(inst.GetCallReturnType())) {
const auto def = PrepareFprDef(function, dest_vreg, 16);
EmitCopy(os, def.reg_name.c_str(), "s0", true);
FinalizeDef(function, dest_vreg, def, os);
} else {
const auto def = PrepareGprDef(function, dest_vreg, 9);
EmitCopy(os, def.reg_name.c_str(),
GetPhysRegName({RegClass::GPR, 0}, inst.GetCallReturnType()), false);
FinalizeDef(function, dest_vreg, def, os);
}
}
break;
}
case MachineInstr::Opcode::Ret: {
if (!inst.GetOperands().empty()) {
const auto& value = inst.GetOperands()[0];
ValueType type = value.GetKind() == OperandKind::VReg
? function.GetVRegInfo(value.GetVReg()).type
: inst.GetValueType();
if (IsFPR(type)) {
const auto src = MaterializeFprUse(function, value, 16, 9, os);
EmitCopy(os, "s0", src.c_str(), true);
} else {
const auto src = MaterializeGprUse(function, value, type, 9, os);
EmitCopy(os, GetPhysRegName({RegClass::GPR, 0}, type), src.c_str(), false);
}
}
emit_epilogue();
break;
}
case MachineInstr::Opcode::Memset: {
EmitAddressExpr(function, inst.GetAddress(), os);
EmitCopy(os, "x0", "x16", false);
const auto value_reg = MaterializeGprUse(function, inst.GetOperands()[0], ValueType::I32, 9, os);
EmitCopy(os, "w1", value_reg.c_str(), false);
const auto len_reg = MaterializeGprUse(function, inst.GetOperands()[1], ValueType::I32, 10, os);
EmitCopy(os, "w2", len_reg.c_str(), false);
os << " bl memset\n";
break;
}
case MachineInstr::Opcode::Unreachable:
os << " brk #0\n";
break;
}
}
}
os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n";
}
} // namespace
void PrintAsm(const MachineModule& module, std::ostream& os) {
for (const auto& global : module.GetSourceModule().GetGlobalValues()) {
EmitGlobal(*global, os);
}
for (const auto& function : module.GetFunctions()) {
EmitFunction(*function, os);
}
}
} // namespace mir