You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/mir/Lowering.cpp

1037 lines
38 KiB

#include "mir/MIR.h"
#include <algorithm>
#include <cstring>
#include <memory>
#include <stdexcept>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "ir/IR.h"
#include "ir/passes/MathIdiomUtils.h"
#include "utils/Log.h"
namespace mir {
namespace {
enum class LoweredKind { Invalid, VReg, StackObject, Global };
std::vector<ir::BasicBlock*> CollectLoweringOrder(ir::Function& function) {
std::vector<ir::BasicBlock*> order;
auto* entry = function.GetEntryBlock();
if (!entry) {
return order;
}
std::unordered_set<ir::BasicBlock*> visited;
std::vector<ir::BasicBlock*> stack{entry};
while (!stack.empty()) {
auto* block = stack.back();
stack.pop_back();
if (!block || !visited.insert(block).second) {
continue;
}
order.push_back(block);
const auto& succs = block->GetSuccessors();
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
if (*it != nullptr) {
stack.push_back(*it);
}
}
}
for (const auto& block : function.GetBlocks()) {
if (block && visited.insert(block.get()).second) {
order.push_back(block.get());
}
}
return order;
}
struct LoweredValue {
LoweredKind kind = LoweredKind::Invalid;
ValueType type = ValueType::Void;
int index = -1;
std::string symbol;
};
ValueType LowerType(const std::shared_ptr<ir::Type>& type) {
if (!type || type->IsVoid()) {
return ValueType::Void;
}
if (type->IsInt1()) {
return ValueType::I1;
}
if (type->IsInt32()) {
return ValueType::I32;
}
if (type->IsFloat()) {
return ValueType::F32;
}
if (type->IsPointer()) {
return ValueType::Ptr;
}
throw std::runtime_error(FormatError("mir", "unsupported IR type in backend lowering"));
}
int GetIRTypeAlign(const std::shared_ptr<ir::Type>& type) {
if (!type) {
return 1;
}
if (type->IsArray()) {
return GetIRTypeAlign(type->GetElementType());
}
return GetValueAlign(LowerType(type));
}
bool ShouldMaterializeAllocaBase(const std::shared_ptr<ir::Type>& type) {
return type && type->IsArray() && type->GetSize() >= 256;
}
CondCode LowerIntCond(ir::Opcode opcode) {
switch (opcode) {
case ir::Opcode::ICmpEQ:
return CondCode::EQ;
case ir::Opcode::ICmpNE:
return CondCode::NE;
case ir::Opcode::ICmpLT:
return CondCode::LT;
case ir::Opcode::ICmpGT:
return CondCode::GT;
case ir::Opcode::ICmpLE:
return CondCode::LE;
case ir::Opcode::ICmpGE:
return CondCode::GE;
default:
throw std::runtime_error(FormatError("mir", "invalid integer compare opcode"));
}
}
CondCode LowerFloatCond(ir::Opcode opcode) {
switch (opcode) {
case ir::Opcode::FCmpEQ:
return CondCode::EQ;
case ir::Opcode::FCmpNE:
return CondCode::NE;
case ir::Opcode::FCmpLT:
return CondCode::LT;
case ir::Opcode::FCmpGT:
return CondCode::GT;
case ir::Opcode::FCmpLE:
return CondCode::LE;
case ir::Opcode::FCmpGE:
return CondCode::GE;
default:
throw std::runtime_error(FormatError("mir", "invalid float compare opcode"));
}
}
std::int64_t FloatBits(float value) {
std::uint32_t bits = 0;
std::memcpy(&bits, &value, sizeof(bits));
return static_cast<std::int64_t>(bits);
}
class Lowerer {
public:
explicit Lowerer(const ir::Module& module)
: module_(module), machine_module_(std::make_unique<MachineModule>(module)) {}
std::unique_ptr<MachineModule> Run() {
for (const auto& func : module_.GetFunctions()) {
if (func && !func->IsExternal()) {
LowerFunction(*func);
}
}
return std::move(machine_module_);
}
private:
using OperandMap = std::unordered_map<const ir::Value*, MachineOperand>;
MachineOperand ResolveScalarOperand(ir::Value* value,
const OperandMap* inline_values = nullptr) {
if (auto* ci = ir::dyncast<ir::ConstantInt>(value)) {
return MachineOperand::Imm(ci->GetValue());
}
if (auto* cb = ir::dyncast<ir::ConstantI1>(value)) {
return MachineOperand::Imm(cb->GetValue() ? 1 : 0);
}
if (auto* cf = ir::dyncast<ir::ConstantFloat>(value)) {
return MachineOperand::Imm(FloatBits(cf->GetValue()));
}
if (inline_values != nullptr) {
auto inline_it = inline_values->find(value);
if (inline_it != inline_values->end()) {
return inline_it->second;
}
}
auto it = values_.find(value);
if (it == values_.end() || it->second.kind != LoweredKind::VReg) {
throw std::runtime_error(
FormatError("mir", "value is not materialized as a virtual register: " +
value->GetName()));
}
return MachineOperand::VReg(it->second.index);
}
MachineOperand LowerScalarOperand(ir::Value* value) {
return ResolveScalarOperand(value, nullptr);
}
AddressExpr LowerAddress(ir::Value* value) {
if (auto* global = ir::dyncast<ir::GlobalValue>(value)) {
AddressExpr address;
address.base_kind = AddrBaseKind::Global;
address.symbol = global->GetName();
return address;
}
auto it = values_.find(value);
if (it == values_.end()) {
throw std::runtime_error(FormatError("mir", "missing lowered address value"));
}
AddressExpr address;
switch (it->second.kind) {
case LoweredKind::StackObject:
address.base_kind = AddrBaseKind::FrameObject;
address.base_index = it->second.index;
return address;
case LoweredKind::Global:
address.base_kind = AddrBaseKind::Global;
address.symbol = it->second.symbol;
return address;
case LoweredKind::VReg:
address.base_kind = AddrBaseKind::VReg;
address.base_index = it->second.index;
return address;
case LoweredKind::Invalid:
break;
}
throw std::runtime_error(FormatError("mir", "invalid address lowering"));
}
MachineInstr::Opcode LowerBinaryOpcode(ir::Opcode opcode) {
switch (opcode) {
case ir::Opcode::Add:
return MachineInstr::Opcode::Add;
case ir::Opcode::Sub:
return MachineInstr::Opcode::Sub;
case ir::Opcode::Mul:
return MachineInstr::Opcode::Mul;
case ir::Opcode::Div:
return MachineInstr::Opcode::Div;
case ir::Opcode::Rem:
return MachineInstr::Opcode::Rem;
case ir::Opcode::And:
return MachineInstr::Opcode::And;
case ir::Opcode::Or:
return MachineInstr::Opcode::Or;
case ir::Opcode::Xor:
return MachineInstr::Opcode::Xor;
case ir::Opcode::Shl:
return MachineInstr::Opcode::Shl;
case ir::Opcode::AShr:
return MachineInstr::Opcode::AShr;
case ir::Opcode::LShr:
return MachineInstr::Opcode::LShr;
case ir::Opcode::FAdd:
return MachineInstr::Opcode::FAdd;
case ir::Opcode::FSub:
return MachineInstr::Opcode::FSub;
case ir::Opcode::FMul:
return MachineInstr::Opcode::FMul;
case ir::Opcode::FDiv:
return MachineInstr::Opcode::FDiv;
default:
throw std::runtime_error(FormatError("mir", "unsupported binary opcode"));
}
}
LoweredValue NewVRegValue(ValueType type) {
return {LoweredKind::VReg, type, current_function_->NewVReg(type), ""};
}
LoweredValue MaterializeOperandAsValue(const MachineOperand& operand, ValueType type) {
if (operand.GetKind() == OperandKind::VReg) {
return {LoweredKind::VReg, type, operand.GetVReg(), ""};
}
auto lowered = NewVRegValue(type);
current_block_->Append(MachineInstr::Opcode::Copy,
{MachineOperand::VReg(lowered.index), operand});
return lowered;
}
void InsertBeforeTerminator(MachineBasicBlock* block, MachineInstr instr) {
auto& instructions = block->GetInstructions();
auto insert_pos = instructions.end();
if (!instructions.empty() && instructions.back().IsTerminator()) {
insert_pos = instructions.end() - 1;
}
instructions.insert(insert_pos, std::move(instr));
}
struct PhiCopy {
int dst_vreg = -1;
MachineOperand src;
};
void EmitResolvedPhiCopies(MachineBasicBlock* block, std::vector<PhiCopy> copies) {
copies.erase(std::remove_if(copies.begin(), copies.end(),
[](const PhiCopy& copy) {
return copy.src.GetKind() == OperandKind::VReg &&
copy.src.GetVReg() == copy.dst_vreg;
}),
copies.end());
while (!copies.empty()) {
bool progress = false;
for (auto it = copies.begin(); it != copies.end(); ++it) {
const bool dst_is_still_needed_as_source =
std::any_of(copies.begin(), copies.end(), [&](const PhiCopy& other) {
return other.src.GetKind() == OperandKind::VReg &&
other.src.GetVReg() == it->dst_vreg;
});
if (dst_is_still_needed_as_source) {
continue;
}
InsertBeforeTerminator(
block, MachineInstr(MachineInstr::Opcode::Copy,
{MachineOperand::VReg(it->dst_vreg), it->src}));
copies.erase(it);
progress = true;
break;
}
if (progress) {
continue;
}
auto& cycle = copies.front();
if (cycle.src.GetKind() != OperandKind::VReg) {
throw std::runtime_error(FormatError("mir", "invalid phi copy cycle"));
}
const int src_vreg = cycle.src.GetVReg();
const auto temp_type = current_function_->GetVRegInfo(src_vreg).type;
const int temp_vreg = current_function_->NewVReg(temp_type);
InsertBeforeTerminator(
block, MachineInstr(MachineInstr::Opcode::Copy,
{MachineOperand::VReg(temp_vreg), MachineOperand::VReg(src_vreg)}));
for (auto& copy : copies) {
if (copy.src.GetKind() == OperandKind::VReg && copy.src.GetVReg() == src_vreg) {
copy.src = MachineOperand::VReg(temp_vreg);
}
}
}
}
void RedirectEdgeToPhiBlock(MachineBasicBlock* pred_block,
const std::string& succ_name,
const std::string& phi_block_name) {
auto& instructions = pred_block->GetInstructions();
if (instructions.empty() || !instructions.back().IsTerminator()) {
throw std::runtime_error(FormatError("mir", "phi predecessor has no terminator"));
}
auto& term = instructions.back();
auto& operands = term.GetOperands();
switch (term.GetOpcode()) {
case MachineInstr::Opcode::Br:
if (!operands.empty() && operands[0].GetKind() == OperandKind::Block &&
operands[0].GetText() == succ_name) {
operands[0] = MachineOperand::Block(phi_block_name);
return;
}
break;
case MachineInstr::Opcode::CondBr:
for (size_t i = 1; i < operands.size(); ++i) {
if (operands[i].GetKind() == OperandKind::Block &&
operands[i].GetText() == succ_name) {
operands[i] = MachineOperand::Block(phi_block_name);
return;
}
}
break;
default:
break;
}
throw std::runtime_error(FormatError("mir", "failed to redirect phi edge"));
}
void PreparePhiResults(ir::Function& function) {
for (const auto& block : function.GetBlocks()) {
for (const auto& inst : block->GetInstructions()) {
if (inst->GetOpcode() != ir::Opcode::Phi) {
break;
}
auto lowered = NewVRegValue(LowerType(inst->GetType()));
values_[inst.get()] = lowered;
}
}
}
void EmitPhiCopies(ir::Function& function) {
struct EdgeCopies {
MachineBasicBlock* succ_block = nullptr;
std::vector<PhiCopy> copies;
};
std::unordered_map<MachineBasicBlock*, std::vector<EdgeCopies>> pending;
for (const auto& block : function.GetBlocks()) {
for (const auto& inst : block->GetInstructions()) {
if (inst->GetOpcode() != ir::Opcode::Phi) {
break;
}
auto* phi = static_cast<ir::PhiInst*>(inst.get());
const int dest_vreg = values_.at(phi).index;
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
auto* pred_block = blocks_.at(phi->GetIncomingBlock(i));
auto* succ_block = blocks_.at(block.get());
auto& edges = pending[pred_block];
auto edge_it = std::find_if(edges.begin(), edges.end(), [&](const EdgeCopies& edge) {
return edge.succ_block == succ_block;
});
if (edge_it == edges.end()) {
edges.push_back({succ_block, {}});
edge_it = std::prev(edges.end());
}
edge_it->copies.push_back(
{dest_vreg, LowerScalarOperand(phi->GetIncomingValue(i))});
}
}
}
int phi_block_index = 0;
for (auto& item : pending) {
auto* pred_block = item.first;
auto& pred_instructions = pred_block->GetInstructions();
if (pred_instructions.empty() || !pred_instructions.back().IsTerminator()) {
throw std::runtime_error(FormatError("mir", "phi predecessor has no terminator"));
}
const auto terminator_opcode = pred_instructions.back().GetOpcode();
for (auto& edge : item.second) {
if (terminator_opcode == MachineInstr::Opcode::Br) {
EmitResolvedPhiCopies(pred_block, std::move(edge.copies));
continue;
}
if (terminator_opcode != MachineInstr::Opcode::CondBr) {
throw std::runtime_error(
FormatError("mir", "unsupported terminator for phi lowering"));
}
auto* phi_block = current_function_->CreateBlock(
"phi.edge." + std::to_string(phi_block_index++));
EmitResolvedPhiCopies(phi_block, std::move(edge.copies));
phi_block->Append(MachineInstr::Opcode::Br,
{MachineOperand::Block(edge.succ_block->GetName())});
RedirectEdgeToPhiBlock(pred_block, edge.succ_block->GetName(), phi_block->GetName());
}
}
}
bool CanInlineDirectCall(const ir::Function& function) const {
if (function.IsExternal() || function.GetBlocks().size() != 1) {
return false;
}
for (const auto& block : function.GetBlocks()) {
for (const auto& inst : block->GetInstructions()) {
switch (inst->GetOpcode()) {
case ir::Opcode::Add:
case ir::Opcode::Sub:
case ir::Opcode::Mul:
case ir::Opcode::Div:
case ir::Opcode::Rem:
case ir::Opcode::And:
case ir::Opcode::Or:
case ir::Opcode::Xor:
case ir::Opcode::Shl:
case ir::Opcode::AShr:
case ir::Opcode::LShr:
case ir::Opcode::FAdd:
case ir::Opcode::FSub:
case ir::Opcode::FMul:
case ir::Opcode::FDiv:
case ir::Opcode::FNeg:
case ir::Opcode::ICmpEQ:
case ir::Opcode::ICmpNE:
case ir::Opcode::ICmpLT:
case ir::Opcode::ICmpGT:
case ir::Opcode::ICmpLE:
case ir::Opcode::ICmpGE:
case ir::Opcode::FCmpEQ:
case ir::Opcode::FCmpNE:
case ir::Opcode::FCmpLT:
case ir::Opcode::FCmpGT:
case ir::Opcode::FCmpLE:
case ir::Opcode::FCmpGE:
case ir::Opcode::Zext:
case ir::Opcode::IToF:
case ir::Opcode::FtoI:
case ir::Opcode::Call:
case ir::Opcode::Return:
break;
default:
return false;
}
}
}
return true;
}
bool TryInlineFunctionBody(const ir::Function& callee, OperandMap* inline_values,
MachineOperand* return_operand, bool* has_return,
int inline_depth) {
if (inline_depth > 2) {
return false;
}
for (const auto& inst : callee.GetBlocks().front()->GetInstructions()) {
switch (inst->GetOpcode()) {
case ir::Opcode::Add:
case ir::Opcode::Sub:
case ir::Opcode::Mul:
case ir::Opcode::Div:
case ir::Opcode::Rem:
case ir::Opcode::And:
case ir::Opcode::Or:
case ir::Opcode::Xor:
case ir::Opcode::Shl:
case ir::Opcode::AShr:
case ir::Opcode::LShr:
case ir::Opcode::FAdd:
case ir::Opcode::FSub:
case ir::Opcode::FMul:
case ir::Opcode::FDiv: {
auto* binary = static_cast<ir::BinaryInst*>(inst.get());
auto lowered = NewVRegValue(LowerType(binary->GetType()));
current_block_->Append(LowerBinaryOpcode(inst->GetOpcode()),
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(binary->GetLhs(), inline_values),
ResolveScalarOperand(binary->GetRhs(), inline_values)});
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
break;
}
case ir::Opcode::FNeg: {
auto* unary = static_cast<ir::UnaryInst*>(inst.get());
auto lowered = NewVRegValue(ValueType::F32);
current_block_->Append(MachineInstr::Opcode::FNeg,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(unary->GetOprd(), inline_values)});
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
break;
}
case ir::Opcode::ICmpEQ:
case ir::Opcode::ICmpNE:
case ir::Opcode::ICmpLT:
case ir::Opcode::ICmpGT:
case ir::Opcode::ICmpLE:
case ir::Opcode::ICmpGE: {
auto* binary = static_cast<ir::BinaryInst*>(inst.get());
auto lowered = NewVRegValue(ValueType::I1);
MachineInstr instr(MachineInstr::Opcode::ICmp,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(binary->GetLhs(), inline_values),
ResolveScalarOperand(binary->GetRhs(), inline_values)});
instr.SetCondCode(LowerIntCond(inst->GetOpcode()));
current_block_->Append(std::move(instr));
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
break;
}
case ir::Opcode::FCmpEQ:
case ir::Opcode::FCmpNE:
case ir::Opcode::FCmpLT:
case ir::Opcode::FCmpGT:
case ir::Opcode::FCmpLE:
case ir::Opcode::FCmpGE: {
auto* binary = static_cast<ir::BinaryInst*>(inst.get());
auto lowered = NewVRegValue(ValueType::I1);
MachineInstr instr(MachineInstr::Opcode::FCmp,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(binary->GetLhs(), inline_values),
ResolveScalarOperand(binary->GetRhs(), inline_values)});
instr.SetCondCode(LowerFloatCond(inst->GetOpcode()));
current_block_->Append(std::move(instr));
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
break;
}
case ir::Opcode::Zext: {
auto* zext = static_cast<ir::ZextInst*>(inst.get());
auto lowered = NewVRegValue(LowerType(zext->GetType()));
current_block_->Append(MachineInstr::Opcode::ZExt,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(zext->GetValue(), inline_values)});
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
break;
}
case ir::Opcode::IToF: {
auto* unary = static_cast<ir::UnaryInst*>(inst.get());
auto lowered = NewVRegValue(ValueType::F32);
current_block_->Append(MachineInstr::Opcode::ItoF,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(unary->GetOprd(), inline_values)});
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
break;
}
case ir::Opcode::FtoI: {
auto* unary = static_cast<ir::UnaryInst*>(inst.get());
auto lowered = NewVRegValue(ValueType::I32);
current_block_->Append(MachineInstr::Opcode::FtoI,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(unary->GetOprd(), inline_values)});
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
break;
}
case ir::Opcode::Call: {
auto* nested_call = static_cast<ir::CallInst*>(inst.get());
auto* nested_callee = nested_call->GetCallee();
if (nested_callee == nullptr || nested_callee == current_ir_function_) {
return false;
}
MachineOperand math_idiom_result;
if (TryEmitMathIdiomCall(nested_call, inline_values, &math_idiom_result)) {
(*inline_values)[inst.get()] = math_idiom_result;
break;
}
if (CanInlineDirectCall(*nested_callee)) {
MachineOperand nested_return_operand;
bool nested_has_return = false;
OperandMap nested_values;
const auto& nested_args = nested_callee->GetArguments();
const auto& nested_call_args = nested_call->GetArguments();
if (nested_args.size() != nested_call_args.size()) {
return false;
}
for (size_t i = 0; i < nested_call_args.size(); ++i) {
nested_values[nested_args[i].get()] =
ResolveScalarOperand(nested_call_args[i], inline_values);
}
if (!TryInlineFunctionBody(*nested_callee, &nested_values, &nested_return_operand,
&nested_has_return, inline_depth + 1)) {
return false;
}
if (!nested_call->GetType()->IsVoid()) {
if (!nested_has_return) {
throw std::runtime_error(
FormatError("mir", "inlined nested call is missing return value"));
}
auto nested_value =
MaterializeOperandAsValue(nested_return_operand, LowerType(nested_call->GetType()));
(*inline_values)[inst.get()] = MachineOperand::VReg(nested_value.index);
}
break;
}
std::vector<MachineOperand> operands;
if (!nested_call->GetType()->IsVoid()) {
auto lowered = NewVRegValue(LowerType(nested_call->GetType()));
operands.push_back(MachineOperand::VReg(lowered.index));
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
}
std::vector<ValueType> arg_types;
for (auto* arg : nested_call->GetArguments()) {
operands.push_back(ResolveScalarOperand(arg, inline_values));
arg_types.push_back(LowerType(arg->GetType()));
}
MachineInstr instr(MachineInstr::Opcode::Call, std::move(operands));
instr.SetCallInfo(nested_callee->GetName(), std::move(arg_types),
LowerType(nested_call->GetType()));
current_block_->Append(std::move(instr));
break;
}
case ir::Opcode::Return: {
auto* ret = static_cast<ir::ReturnInst*>(inst.get());
if (ret->HasReturnValue()) {
*return_operand = ResolveScalarOperand(ret->GetReturnValue(), inline_values);
*has_return = true;
}
break;
}
default:
return false;
}
}
return true;
}
bool TryEmitMathIdiomCall(ir::CallInst* call, const OperandMap* inline_values,
MachineOperand* result_operand) {
auto* callee = call == nullptr ? nullptr : call->GetCallee();
const ir::GlobalValue* sqrt_state = nullptr;
if (callee == nullptr || call->GetType() == nullptr || !call->GetType()->IsFloat() ||
call->GetArguments().size() != 1 ||
!call->GetArguments()[0]->GetType()->IsFloat() ||
!ir::mathidiom::IsPrivateToleranceNewtonSqrt(*callee, &sqrt_state)) {
return false;
}
auto lowered = NewVRegValue(ValueType::F32);
MachineInstr instr(MachineInstr::Opcode::FSqrt,
{MachineOperand::VReg(lowered.index),
ResolveScalarOperand(call->GetArguments()[0], inline_values)});
if (sqrt_state != nullptr) {
AddressExpr address;
address.base_kind = AddrBaseKind::Global;
address.symbol = sqrt_state->GetName();
instr.SetAddress(std::move(address));
}
current_block_->Append(std::move(instr));
if (result_operand != nullptr) {
*result_operand = MachineOperand::VReg(lowered.index);
} else {
values_[call] = lowered;
}
return true;
}
bool TryInlineDirectCall(ir::CallInst* call) {
auto* callee = call->GetCallee();
if (callee == nullptr || callee == current_ir_function_ || !CanInlineDirectCall(*callee)) {
return false;
}
const auto& callee_args = callee->GetArguments();
const auto& call_args = call->GetArguments();
if (callee_args.size() != call_args.size()) {
return false;
}
OperandMap inline_values;
for (size_t i = 0; i < call_args.size(); ++i) {
inline_values[callee_args[i].get()] = ResolveScalarOperand(call_args[i], nullptr);
}
MachineOperand return_operand;
bool has_return = false;
if (!TryInlineFunctionBody(*callee, &inline_values, &return_operand, &has_return, 0)) {
return false;
}
if (!call->GetType()->IsVoid()) {
if (!has_return) {
throw std::runtime_error(FormatError("mir", "inlined call is missing return value"));
}
values_[call] = MaterializeOperandAsValue(return_operand, LowerType(call->GetType()));
}
return true;
}
void LowerInstruction(ir::Instruction& inst) {
switch (inst.GetOpcode()) {
case ir::Opcode::Alloca: {
auto* alloca_inst = static_cast<ir::AllocaInst*>(&inst);
const auto allocated_type = alloca_inst->GetAllocatedType();
const int object = current_function_->CreateStackObject(
allocated_type->GetSize(), GetIRTypeAlign(allocated_type),
StackObjectKind::Local, inst.GetName());
if (ShouldMaterializeAllocaBase(allocated_type)) {
auto lowered = NewVRegValue(ValueType::Ptr);
MachineInstr lea(MachineInstr::Opcode::Lea,
{MachineOperand::VReg(lowered.index)});
AddressExpr address;
address.base_kind = AddrBaseKind::FrameObject;
address.base_index = object;
lea.SetAddress(std::move(address));
current_block_->Append(std::move(lea));
values_[&inst] = lowered;
} else {
values_[&inst] = {LoweredKind::StackObject, ValueType::Ptr, object, ""};
}
return;
}
case ir::Opcode::Load: {
auto* load = static_cast<ir::LoadInst*>(&inst);
auto lowered = NewVRegValue(LowerType(load->GetType()));
MachineInstr instr(MachineInstr::Opcode::Load,
{MachineOperand::VReg(lowered.index)});
instr.SetAddress(LowerAddress(load->GetPtr()));
current_block_->Append(std::move(instr));
values_[&inst] = lowered;
return;
}
case ir::Opcode::Store: {
auto* store = static_cast<ir::StoreInst*>(&inst);
MachineInstr instr(MachineInstr::Opcode::Store,
{LowerScalarOperand(store->GetValue())});
instr.SetValueType(LowerType(store->GetValue()->GetType()));
instr.SetAddress(LowerAddress(store->GetPtr()));
current_block_->Append(std::move(instr));
return;
}
case ir::Opcode::Add:
case ir::Opcode::Sub:
case ir::Opcode::Mul:
case ir::Opcode::Div:
case ir::Opcode::Rem:
case ir::Opcode::And:
case ir::Opcode::Or:
case ir::Opcode::Xor:
case ir::Opcode::Shl:
case ir::Opcode::AShr:
case ir::Opcode::LShr:
case ir::Opcode::FAdd:
case ir::Opcode::FSub:
case ir::Opcode::FMul:
case ir::Opcode::FDiv: {
auto* binary = static_cast<ir::BinaryInst*>(&inst);
auto lowered = NewVRegValue(LowerType(binary->GetType()));
current_block_->Append(LowerBinaryOpcode(inst.GetOpcode()),
{MachineOperand::VReg(lowered.index),
LowerScalarOperand(binary->GetLhs()),
LowerScalarOperand(binary->GetRhs())});
values_[&inst] = lowered;
return;
}
case ir::Opcode::FNeg: {
auto* unary = static_cast<ir::UnaryInst*>(&inst);
auto lowered = NewVRegValue(ValueType::F32);
current_block_->Append(MachineInstr::Opcode::FNeg,
{MachineOperand::VReg(lowered.index),
LowerScalarOperand(unary->GetOprd())});
values_[&inst] = lowered;
return;
}
case ir::Opcode::ICmpEQ:
case ir::Opcode::ICmpNE:
case ir::Opcode::ICmpLT:
case ir::Opcode::ICmpGT:
case ir::Opcode::ICmpLE:
case ir::Opcode::ICmpGE: {
auto* binary = static_cast<ir::BinaryInst*>(&inst);
auto lowered = NewVRegValue(ValueType::I1);
MachineInstr instr(MachineInstr::Opcode::ICmp,
{MachineOperand::VReg(lowered.index),
LowerScalarOperand(binary->GetLhs()),
LowerScalarOperand(binary->GetRhs())});
instr.SetCondCode(LowerIntCond(inst.GetOpcode()));
current_block_->Append(std::move(instr));
values_[&inst] = lowered;
return;
}
case ir::Opcode::FCmpEQ:
case ir::Opcode::FCmpNE:
case ir::Opcode::FCmpLT:
case ir::Opcode::FCmpGT:
case ir::Opcode::FCmpLE:
case ir::Opcode::FCmpGE: {
auto* binary = static_cast<ir::BinaryInst*>(&inst);
auto lowered = NewVRegValue(ValueType::I1);
MachineInstr instr(MachineInstr::Opcode::FCmp,
{MachineOperand::VReg(lowered.index),
LowerScalarOperand(binary->GetLhs()),
LowerScalarOperand(binary->GetRhs())});
instr.SetCondCode(LowerFloatCond(inst.GetOpcode()));
current_block_->Append(std::move(instr));
values_[&inst] = lowered;
return;
}
case ir::Opcode::Zext: {
auto* zext = static_cast<ir::ZextInst*>(&inst);
auto lowered = NewVRegValue(LowerType(zext->GetType()));
current_block_->Append(MachineInstr::Opcode::ZExt,
{MachineOperand::VReg(lowered.index),
LowerScalarOperand(zext->GetValue())});
values_[&inst] = lowered;
return;
}
case ir::Opcode::IToF: {
auto* unary = static_cast<ir::UnaryInst*>(&inst);
auto lowered = NewVRegValue(ValueType::F32);
current_block_->Append(MachineInstr::Opcode::ItoF,
{MachineOperand::VReg(lowered.index),
LowerScalarOperand(unary->GetOprd())});
values_[&inst] = lowered;
return;
}
case ir::Opcode::FtoI: {
auto* unary = static_cast<ir::UnaryInst*>(&inst);
auto lowered = NewVRegValue(ValueType::I32);
current_block_->Append(MachineInstr::Opcode::FtoI,
{MachineOperand::VReg(lowered.index),
LowerScalarOperand(unary->GetOprd())});
values_[&inst] = lowered;
return;
}
case ir::Opcode::GetElementPtr: {
auto* gep = static_cast<ir::GetElementPtrInst*>(&inst);
auto lowered = NewVRegValue(ValueType::Ptr);
AddressExpr address = LowerAddress(gep->GetPointer());
auto current_type = gep->GetSourceType();
for (size_t i = 0; i < gep->GetNumIndices(); ++i) {
auto* index = gep->GetIndex(i);
const std::int64_t stride = current_type ? current_type->GetSize() : 0;
if (auto* ci = ir::dyncast<ir::ConstantInt>(index)) {
address.const_offset += static_cast<std::int64_t>(ci->GetValue()) * stride;
} else if (auto* cb = ir::dyncast<ir::ConstantI1>(index)) {
address.const_offset +=
static_cast<std::int64_t>(cb->GetValue() ? 1 : 0) * stride;
} else {
address.scaled_vregs.push_back({LowerScalarOperand(index).GetVReg(), stride});
}
if (current_type && current_type->IsArray()) {
current_type = current_type->GetElementType();
}
}
MachineInstr instr(MachineInstr::Opcode::Lea,
{MachineOperand::VReg(lowered.index)});
instr.SetAddress(std::move(address));
current_block_->Append(std::move(instr));
values_[&inst] = lowered;
return;
}
case ir::Opcode::Call: {
auto* call = static_cast<ir::CallInst*>(&inst);
if (TryEmitMathIdiomCall(call, nullptr, nullptr)) {
return;
}
if (TryInlineDirectCall(call)) {
return;
}
std::vector<MachineOperand> operands;
if (!call->GetType()->IsVoid()) {
auto lowered = NewVRegValue(LowerType(call->GetType()));
operands.push_back(MachineOperand::VReg(lowered.index));
values_[&inst] = lowered;
}
std::vector<ValueType> arg_types;
for (auto* arg : call->GetArguments()) {
operands.push_back(LowerScalarOperand(arg));
arg_types.push_back(LowerType(arg->GetType()));
}
MachineInstr instr(MachineInstr::Opcode::Call, std::move(operands));
instr.SetCallInfo(call->GetCallee()->GetName(), std::move(arg_types),
LowerType(call->GetType()));
current_block_->Append(std::move(instr));
return;
}
case ir::Opcode::Br: {
auto* br = static_cast<ir::UncondBrInst*>(&inst);
current_block_->Append(MachineInstr::Opcode::Br,
{MachineOperand::Block(blocks_.at(br->GetDest())->GetName())});
return;
}
case ir::Opcode::CondBr: {
auto* br = static_cast<ir::CondBrInst*>(&inst);
current_block_->Append(MachineInstr::Opcode::CondBr,
{LowerScalarOperand(br->GetCondition()),
MachineOperand::Block(blocks_.at(br->GetThenBlock())->GetName()),
MachineOperand::Block(blocks_.at(br->GetElseBlock())->GetName())});
return;
}
case ir::Opcode::Return: {
auto* ret = static_cast<ir::ReturnInst*>(&inst);
if (ret->HasReturnValue()) {
MachineInstr instr(MachineInstr::Opcode::Ret,
{LowerScalarOperand(ret->GetReturnValue())});
instr.SetValueType(LowerType(ret->GetReturnValue()->GetType()));
current_block_->Append(std::move(instr));
} else {
current_block_->Append(MachineInstr::Opcode::Ret);
}
return;
}
case ir::Opcode::Memset: {
auto* memset_inst = static_cast<ir::MemsetInst*>(&inst);
MachineInstr instr(MachineInstr::Opcode::Memset,
{LowerScalarOperand(memset_inst->GetValue()),
LowerScalarOperand(memset_inst->GetLength())});
instr.SetAddress(LowerAddress(memset_inst->GetDest()));
current_block_->Append(std::move(instr));
return;
}
case ir::Opcode::Unreachable:
current_block_->Append(MachineInstr::Opcode::Unreachable);
return;
case ir::Opcode::Phi:
return;
case ir::Opcode::FRem:
case ir::Opcode::Neg:
case ir::Opcode::Not:
throw std::runtime_error(
FormatError("mir", "unsupported instruction in backend lowering"));
}
throw std::runtime_error(FormatError("mir", "unsupported IR opcode in backend lowering"));
}
void LowerFunction(ir::Function& function) {
values_.clear();
blocks_.clear();
std::vector<ValueType> param_types;
for (const auto& type : function.GetParamTypes()) {
param_types.push_back(LowerType(type));
}
auto machine_function = std::make_unique<MachineFunction>(
function.GetName(), LowerType(function.GetReturnType()), std::move(param_types));
current_ir_function_ = &function;
current_function_ = machine_function.get();
const auto ordered_blocks = CollectLoweringOrder(function);
for (const auto& block : function.GetBlocks()) {
blocks_[block.get()] = current_function_->CreateBlock(block->GetName());
}
if (!function.GetBlocks().empty()) {
auto* entry = blocks_.at(function.GetBlocks().front().get());
for (const auto& argument : function.GetArguments()) {
auto lowered = NewVRegValue(LowerType(argument->GetType()));
entry->Append(MachineInstr::Opcode::Arg,
{MachineOperand::VReg(lowered.index),
MachineOperand::Imm(static_cast<std::int64_t>(argument->GetIndex()))});
values_[argument.get()] = lowered;
}
}
PreparePhiResults(function);
for (auto* block : ordered_blocks) {
current_block_ = blocks_.at(block);
for (const auto& inst : block->GetInstructions()) {
LowerInstruction(*inst);
}
}
EmitPhiCopies(function);
machine_module_->AddFunction(std::move(machine_function));
current_ir_function_ = nullptr;
current_function_ = nullptr;
current_block_ = nullptr;
}
const ir::Module& module_;
std::unique_ptr<MachineModule> machine_module_;
ir::Function* current_ir_function_ = nullptr;
MachineFunction* current_function_ = nullptr;
MachineBasicBlock* current_block_ = nullptr;
std::unordered_map<const ir::Value*, LoweredValue> values_;
std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> blocks_;
};
} // namespace
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) {
DefaultContext();
return Lowerer(module).Run();
}
} // namespace mir