forked from NUDT-compiler/nudt-compiler-cpp
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1037 lines
38 KiB
1037 lines
38 KiB
#include "mir/MIR.h"
|
|
|
|
#include <algorithm>
|
|
#include <cstring>
|
|
#include <memory>
|
|
#include <stdexcept>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "ir/IR.h"
|
|
#include "ir/passes/MathIdiomUtils.h"
|
|
#include "utils/Log.h"
|
|
|
|
namespace mir {
|
|
namespace {
|
|
|
|
enum class LoweredKind { Invalid, VReg, StackObject, Global };
|
|
|
|
std::vector<ir::BasicBlock*> CollectLoweringOrder(ir::Function& function) {
|
|
std::vector<ir::BasicBlock*> order;
|
|
auto* entry = function.GetEntryBlock();
|
|
if (!entry) {
|
|
return order;
|
|
}
|
|
|
|
std::unordered_set<ir::BasicBlock*> visited;
|
|
std::vector<ir::BasicBlock*> stack{entry};
|
|
while (!stack.empty()) {
|
|
auto* block = stack.back();
|
|
stack.pop_back();
|
|
if (!block || !visited.insert(block).second) {
|
|
continue;
|
|
}
|
|
order.push_back(block);
|
|
|
|
const auto& succs = block->GetSuccessors();
|
|
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
|
|
if (*it != nullptr) {
|
|
stack.push_back(*it);
|
|
}
|
|
}
|
|
}
|
|
|
|
for (const auto& block : function.GetBlocks()) {
|
|
if (block && visited.insert(block.get()).second) {
|
|
order.push_back(block.get());
|
|
}
|
|
}
|
|
return order;
|
|
}
|
|
|
|
struct LoweredValue {
|
|
LoweredKind kind = LoweredKind::Invalid;
|
|
ValueType type = ValueType::Void;
|
|
int index = -1;
|
|
std::string symbol;
|
|
};
|
|
|
|
ValueType LowerType(const std::shared_ptr<ir::Type>& type) {
|
|
if (!type || type->IsVoid()) {
|
|
return ValueType::Void;
|
|
}
|
|
if (type->IsInt1()) {
|
|
return ValueType::I1;
|
|
}
|
|
if (type->IsInt32()) {
|
|
return ValueType::I32;
|
|
}
|
|
if (type->IsFloat()) {
|
|
return ValueType::F32;
|
|
}
|
|
if (type->IsPointer()) {
|
|
return ValueType::Ptr;
|
|
}
|
|
throw std::runtime_error(FormatError("mir", "unsupported IR type in backend lowering"));
|
|
}
|
|
|
|
int GetIRTypeAlign(const std::shared_ptr<ir::Type>& type) {
|
|
if (!type) {
|
|
return 1;
|
|
}
|
|
if (type->IsArray()) {
|
|
return GetIRTypeAlign(type->GetElementType());
|
|
}
|
|
return GetValueAlign(LowerType(type));
|
|
}
|
|
|
|
bool ShouldMaterializeAllocaBase(const std::shared_ptr<ir::Type>& type) {
|
|
return type && type->IsArray() && type->GetSize() >= 256;
|
|
}
|
|
|
|
CondCode LowerIntCond(ir::Opcode opcode) {
|
|
switch (opcode) {
|
|
case ir::Opcode::ICmpEQ:
|
|
return CondCode::EQ;
|
|
case ir::Opcode::ICmpNE:
|
|
return CondCode::NE;
|
|
case ir::Opcode::ICmpLT:
|
|
return CondCode::LT;
|
|
case ir::Opcode::ICmpGT:
|
|
return CondCode::GT;
|
|
case ir::Opcode::ICmpLE:
|
|
return CondCode::LE;
|
|
case ir::Opcode::ICmpGE:
|
|
return CondCode::GE;
|
|
default:
|
|
throw std::runtime_error(FormatError("mir", "invalid integer compare opcode"));
|
|
}
|
|
}
|
|
|
|
CondCode LowerFloatCond(ir::Opcode opcode) {
|
|
switch (opcode) {
|
|
case ir::Opcode::FCmpEQ:
|
|
return CondCode::EQ;
|
|
case ir::Opcode::FCmpNE:
|
|
return CondCode::NE;
|
|
case ir::Opcode::FCmpLT:
|
|
return CondCode::LT;
|
|
case ir::Opcode::FCmpGT:
|
|
return CondCode::GT;
|
|
case ir::Opcode::FCmpLE:
|
|
return CondCode::LE;
|
|
case ir::Opcode::FCmpGE:
|
|
return CondCode::GE;
|
|
default:
|
|
throw std::runtime_error(FormatError("mir", "invalid float compare opcode"));
|
|
}
|
|
}
|
|
|
|
std::int64_t FloatBits(float value) {
|
|
std::uint32_t bits = 0;
|
|
std::memcpy(&bits, &value, sizeof(bits));
|
|
return static_cast<std::int64_t>(bits);
|
|
}
|
|
|
|
class Lowerer {
|
|
public:
|
|
explicit Lowerer(const ir::Module& module)
|
|
: module_(module), machine_module_(std::make_unique<MachineModule>(module)) {}
|
|
|
|
std::unique_ptr<MachineModule> Run() {
|
|
for (const auto& func : module_.GetFunctions()) {
|
|
if (func && !func->IsExternal()) {
|
|
LowerFunction(*func);
|
|
}
|
|
}
|
|
return std::move(machine_module_);
|
|
}
|
|
|
|
private:
|
|
using OperandMap = std::unordered_map<const ir::Value*, MachineOperand>;
|
|
|
|
MachineOperand ResolveScalarOperand(ir::Value* value,
|
|
const OperandMap* inline_values = nullptr) {
|
|
if (auto* ci = ir::dyncast<ir::ConstantInt>(value)) {
|
|
return MachineOperand::Imm(ci->GetValue());
|
|
}
|
|
if (auto* cb = ir::dyncast<ir::ConstantI1>(value)) {
|
|
return MachineOperand::Imm(cb->GetValue() ? 1 : 0);
|
|
}
|
|
if (auto* cf = ir::dyncast<ir::ConstantFloat>(value)) {
|
|
return MachineOperand::Imm(FloatBits(cf->GetValue()));
|
|
}
|
|
|
|
if (inline_values != nullptr) {
|
|
auto inline_it = inline_values->find(value);
|
|
if (inline_it != inline_values->end()) {
|
|
return inline_it->second;
|
|
}
|
|
}
|
|
|
|
auto it = values_.find(value);
|
|
if (it == values_.end() || it->second.kind != LoweredKind::VReg) {
|
|
throw std::runtime_error(
|
|
FormatError("mir", "value is not materialized as a virtual register: " +
|
|
value->GetName()));
|
|
}
|
|
return MachineOperand::VReg(it->second.index);
|
|
}
|
|
|
|
MachineOperand LowerScalarOperand(ir::Value* value) {
|
|
return ResolveScalarOperand(value, nullptr);
|
|
}
|
|
|
|
AddressExpr LowerAddress(ir::Value* value) {
|
|
if (auto* global = ir::dyncast<ir::GlobalValue>(value)) {
|
|
AddressExpr address;
|
|
address.base_kind = AddrBaseKind::Global;
|
|
address.symbol = global->GetName();
|
|
return address;
|
|
}
|
|
|
|
auto it = values_.find(value);
|
|
if (it == values_.end()) {
|
|
throw std::runtime_error(FormatError("mir", "missing lowered address value"));
|
|
}
|
|
|
|
AddressExpr address;
|
|
switch (it->second.kind) {
|
|
case LoweredKind::StackObject:
|
|
address.base_kind = AddrBaseKind::FrameObject;
|
|
address.base_index = it->second.index;
|
|
return address;
|
|
case LoweredKind::Global:
|
|
address.base_kind = AddrBaseKind::Global;
|
|
address.symbol = it->second.symbol;
|
|
return address;
|
|
case LoweredKind::VReg:
|
|
address.base_kind = AddrBaseKind::VReg;
|
|
address.base_index = it->second.index;
|
|
return address;
|
|
case LoweredKind::Invalid:
|
|
break;
|
|
}
|
|
|
|
throw std::runtime_error(FormatError("mir", "invalid address lowering"));
|
|
}
|
|
|
|
MachineInstr::Opcode LowerBinaryOpcode(ir::Opcode opcode) {
|
|
switch (opcode) {
|
|
case ir::Opcode::Add:
|
|
return MachineInstr::Opcode::Add;
|
|
case ir::Opcode::Sub:
|
|
return MachineInstr::Opcode::Sub;
|
|
case ir::Opcode::Mul:
|
|
return MachineInstr::Opcode::Mul;
|
|
case ir::Opcode::Div:
|
|
return MachineInstr::Opcode::Div;
|
|
case ir::Opcode::Rem:
|
|
return MachineInstr::Opcode::Rem;
|
|
case ir::Opcode::And:
|
|
return MachineInstr::Opcode::And;
|
|
case ir::Opcode::Or:
|
|
return MachineInstr::Opcode::Or;
|
|
case ir::Opcode::Xor:
|
|
return MachineInstr::Opcode::Xor;
|
|
case ir::Opcode::Shl:
|
|
return MachineInstr::Opcode::Shl;
|
|
case ir::Opcode::AShr:
|
|
return MachineInstr::Opcode::AShr;
|
|
case ir::Opcode::LShr:
|
|
return MachineInstr::Opcode::LShr;
|
|
case ir::Opcode::FAdd:
|
|
return MachineInstr::Opcode::FAdd;
|
|
case ir::Opcode::FSub:
|
|
return MachineInstr::Opcode::FSub;
|
|
case ir::Opcode::FMul:
|
|
return MachineInstr::Opcode::FMul;
|
|
case ir::Opcode::FDiv:
|
|
return MachineInstr::Opcode::FDiv;
|
|
default:
|
|
throw std::runtime_error(FormatError("mir", "unsupported binary opcode"));
|
|
}
|
|
}
|
|
|
|
LoweredValue NewVRegValue(ValueType type) {
|
|
return {LoweredKind::VReg, type, current_function_->NewVReg(type), ""};
|
|
}
|
|
|
|
LoweredValue MaterializeOperandAsValue(const MachineOperand& operand, ValueType type) {
|
|
if (operand.GetKind() == OperandKind::VReg) {
|
|
return {LoweredKind::VReg, type, operand.GetVReg(), ""};
|
|
}
|
|
|
|
auto lowered = NewVRegValue(type);
|
|
current_block_->Append(MachineInstr::Opcode::Copy,
|
|
{MachineOperand::VReg(lowered.index), operand});
|
|
return lowered;
|
|
}
|
|
|
|
void InsertBeforeTerminator(MachineBasicBlock* block, MachineInstr instr) {
|
|
auto& instructions = block->GetInstructions();
|
|
auto insert_pos = instructions.end();
|
|
if (!instructions.empty() && instructions.back().IsTerminator()) {
|
|
insert_pos = instructions.end() - 1;
|
|
}
|
|
instructions.insert(insert_pos, std::move(instr));
|
|
}
|
|
|
|
struct PhiCopy {
|
|
int dst_vreg = -1;
|
|
MachineOperand src;
|
|
};
|
|
|
|
void EmitResolvedPhiCopies(MachineBasicBlock* block, std::vector<PhiCopy> copies) {
|
|
copies.erase(std::remove_if(copies.begin(), copies.end(),
|
|
[](const PhiCopy& copy) {
|
|
return copy.src.GetKind() == OperandKind::VReg &&
|
|
copy.src.GetVReg() == copy.dst_vreg;
|
|
}),
|
|
copies.end());
|
|
|
|
while (!copies.empty()) {
|
|
bool progress = false;
|
|
for (auto it = copies.begin(); it != copies.end(); ++it) {
|
|
const bool dst_is_still_needed_as_source =
|
|
std::any_of(copies.begin(), copies.end(), [&](const PhiCopy& other) {
|
|
return other.src.GetKind() == OperandKind::VReg &&
|
|
other.src.GetVReg() == it->dst_vreg;
|
|
});
|
|
if (dst_is_still_needed_as_source) {
|
|
continue;
|
|
}
|
|
|
|
InsertBeforeTerminator(
|
|
block, MachineInstr(MachineInstr::Opcode::Copy,
|
|
{MachineOperand::VReg(it->dst_vreg), it->src}));
|
|
copies.erase(it);
|
|
progress = true;
|
|
break;
|
|
}
|
|
|
|
if (progress) {
|
|
continue;
|
|
}
|
|
|
|
auto& cycle = copies.front();
|
|
if (cycle.src.GetKind() != OperandKind::VReg) {
|
|
throw std::runtime_error(FormatError("mir", "invalid phi copy cycle"));
|
|
}
|
|
|
|
const int src_vreg = cycle.src.GetVReg();
|
|
const auto temp_type = current_function_->GetVRegInfo(src_vreg).type;
|
|
const int temp_vreg = current_function_->NewVReg(temp_type);
|
|
InsertBeforeTerminator(
|
|
block, MachineInstr(MachineInstr::Opcode::Copy,
|
|
{MachineOperand::VReg(temp_vreg), MachineOperand::VReg(src_vreg)}));
|
|
for (auto& copy : copies) {
|
|
if (copy.src.GetKind() == OperandKind::VReg && copy.src.GetVReg() == src_vreg) {
|
|
copy.src = MachineOperand::VReg(temp_vreg);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void RedirectEdgeToPhiBlock(MachineBasicBlock* pred_block,
|
|
const std::string& succ_name,
|
|
const std::string& phi_block_name) {
|
|
auto& instructions = pred_block->GetInstructions();
|
|
if (instructions.empty() || !instructions.back().IsTerminator()) {
|
|
throw std::runtime_error(FormatError("mir", "phi predecessor has no terminator"));
|
|
}
|
|
|
|
auto& term = instructions.back();
|
|
auto& operands = term.GetOperands();
|
|
switch (term.GetOpcode()) {
|
|
case MachineInstr::Opcode::Br:
|
|
if (!operands.empty() && operands[0].GetKind() == OperandKind::Block &&
|
|
operands[0].GetText() == succ_name) {
|
|
operands[0] = MachineOperand::Block(phi_block_name);
|
|
return;
|
|
}
|
|
break;
|
|
case MachineInstr::Opcode::CondBr:
|
|
for (size_t i = 1; i < operands.size(); ++i) {
|
|
if (operands[i].GetKind() == OperandKind::Block &&
|
|
operands[i].GetText() == succ_name) {
|
|
operands[i] = MachineOperand::Block(phi_block_name);
|
|
return;
|
|
}
|
|
}
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
|
|
throw std::runtime_error(FormatError("mir", "failed to redirect phi edge"));
|
|
}
|
|
|
|
void PreparePhiResults(ir::Function& function) {
|
|
for (const auto& block : function.GetBlocks()) {
|
|
for (const auto& inst : block->GetInstructions()) {
|
|
if (inst->GetOpcode() != ir::Opcode::Phi) {
|
|
break;
|
|
}
|
|
auto lowered = NewVRegValue(LowerType(inst->GetType()));
|
|
values_[inst.get()] = lowered;
|
|
}
|
|
}
|
|
}
|
|
|
|
void EmitPhiCopies(ir::Function& function) {
|
|
struct EdgeCopies {
|
|
MachineBasicBlock* succ_block = nullptr;
|
|
std::vector<PhiCopy> copies;
|
|
};
|
|
|
|
std::unordered_map<MachineBasicBlock*, std::vector<EdgeCopies>> pending;
|
|
|
|
for (const auto& block : function.GetBlocks()) {
|
|
for (const auto& inst : block->GetInstructions()) {
|
|
if (inst->GetOpcode() != ir::Opcode::Phi) {
|
|
break;
|
|
}
|
|
auto* phi = static_cast<ir::PhiInst*>(inst.get());
|
|
const int dest_vreg = values_.at(phi).index;
|
|
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
|
|
auto* pred_block = blocks_.at(phi->GetIncomingBlock(i));
|
|
auto* succ_block = blocks_.at(block.get());
|
|
auto& edges = pending[pred_block];
|
|
auto edge_it = std::find_if(edges.begin(), edges.end(), [&](const EdgeCopies& edge) {
|
|
return edge.succ_block == succ_block;
|
|
});
|
|
if (edge_it == edges.end()) {
|
|
edges.push_back({succ_block, {}});
|
|
edge_it = std::prev(edges.end());
|
|
}
|
|
edge_it->copies.push_back(
|
|
{dest_vreg, LowerScalarOperand(phi->GetIncomingValue(i))});
|
|
}
|
|
}
|
|
}
|
|
|
|
int phi_block_index = 0;
|
|
for (auto& item : pending) {
|
|
auto* pred_block = item.first;
|
|
auto& pred_instructions = pred_block->GetInstructions();
|
|
if (pred_instructions.empty() || !pred_instructions.back().IsTerminator()) {
|
|
throw std::runtime_error(FormatError("mir", "phi predecessor has no terminator"));
|
|
}
|
|
|
|
const auto terminator_opcode = pred_instructions.back().GetOpcode();
|
|
for (auto& edge : item.second) {
|
|
if (terminator_opcode == MachineInstr::Opcode::Br) {
|
|
EmitResolvedPhiCopies(pred_block, std::move(edge.copies));
|
|
continue;
|
|
}
|
|
|
|
if (terminator_opcode != MachineInstr::Opcode::CondBr) {
|
|
throw std::runtime_error(
|
|
FormatError("mir", "unsupported terminator for phi lowering"));
|
|
}
|
|
|
|
auto* phi_block = current_function_->CreateBlock(
|
|
"phi.edge." + std::to_string(phi_block_index++));
|
|
EmitResolvedPhiCopies(phi_block, std::move(edge.copies));
|
|
phi_block->Append(MachineInstr::Opcode::Br,
|
|
{MachineOperand::Block(edge.succ_block->GetName())});
|
|
RedirectEdgeToPhiBlock(pred_block, edge.succ_block->GetName(), phi_block->GetName());
|
|
}
|
|
}
|
|
}
|
|
|
|
bool CanInlineDirectCall(const ir::Function& function) const {
|
|
if (function.IsExternal() || function.GetBlocks().size() != 1) {
|
|
return false;
|
|
}
|
|
|
|
for (const auto& block : function.GetBlocks()) {
|
|
for (const auto& inst : block->GetInstructions()) {
|
|
switch (inst->GetOpcode()) {
|
|
case ir::Opcode::Add:
|
|
case ir::Opcode::Sub:
|
|
case ir::Opcode::Mul:
|
|
case ir::Opcode::Div:
|
|
case ir::Opcode::Rem:
|
|
case ir::Opcode::And:
|
|
case ir::Opcode::Or:
|
|
case ir::Opcode::Xor:
|
|
case ir::Opcode::Shl:
|
|
case ir::Opcode::AShr:
|
|
case ir::Opcode::LShr:
|
|
case ir::Opcode::FAdd:
|
|
case ir::Opcode::FSub:
|
|
case ir::Opcode::FMul:
|
|
case ir::Opcode::FDiv:
|
|
case ir::Opcode::FNeg:
|
|
case ir::Opcode::ICmpEQ:
|
|
case ir::Opcode::ICmpNE:
|
|
case ir::Opcode::ICmpLT:
|
|
case ir::Opcode::ICmpGT:
|
|
case ir::Opcode::ICmpLE:
|
|
case ir::Opcode::ICmpGE:
|
|
case ir::Opcode::FCmpEQ:
|
|
case ir::Opcode::FCmpNE:
|
|
case ir::Opcode::FCmpLT:
|
|
case ir::Opcode::FCmpGT:
|
|
case ir::Opcode::FCmpLE:
|
|
case ir::Opcode::FCmpGE:
|
|
case ir::Opcode::Zext:
|
|
case ir::Opcode::IToF:
|
|
case ir::Opcode::FtoI:
|
|
case ir::Opcode::Call:
|
|
case ir::Opcode::Return:
|
|
break;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool TryInlineFunctionBody(const ir::Function& callee, OperandMap* inline_values,
|
|
MachineOperand* return_operand, bool* has_return,
|
|
int inline_depth) {
|
|
if (inline_depth > 2) {
|
|
return false;
|
|
}
|
|
|
|
for (const auto& inst : callee.GetBlocks().front()->GetInstructions()) {
|
|
switch (inst->GetOpcode()) {
|
|
case ir::Opcode::Add:
|
|
case ir::Opcode::Sub:
|
|
case ir::Opcode::Mul:
|
|
case ir::Opcode::Div:
|
|
case ir::Opcode::Rem:
|
|
case ir::Opcode::And:
|
|
case ir::Opcode::Or:
|
|
case ir::Opcode::Xor:
|
|
case ir::Opcode::Shl:
|
|
case ir::Opcode::AShr:
|
|
case ir::Opcode::LShr:
|
|
case ir::Opcode::FAdd:
|
|
case ir::Opcode::FSub:
|
|
case ir::Opcode::FMul:
|
|
case ir::Opcode::FDiv: {
|
|
auto* binary = static_cast<ir::BinaryInst*>(inst.get());
|
|
auto lowered = NewVRegValue(LowerType(binary->GetType()));
|
|
current_block_->Append(LowerBinaryOpcode(inst->GetOpcode()),
|
|
{MachineOperand::VReg(lowered.index),
|
|
ResolveScalarOperand(binary->GetLhs(), inline_values),
|
|
ResolveScalarOperand(binary->GetRhs(), inline_values)});
|
|
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
|
|
break;
|
|
}
|
|
case ir::Opcode::FNeg: {
|
|
auto* unary = static_cast<ir::UnaryInst*>(inst.get());
|
|
auto lowered = NewVRegValue(ValueType::F32);
|
|
current_block_->Append(MachineInstr::Opcode::FNeg,
|
|
{MachineOperand::VReg(lowered.index),
|
|
ResolveScalarOperand(unary->GetOprd(), inline_values)});
|
|
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
|
|
break;
|
|
}
|
|
case ir::Opcode::ICmpEQ:
|
|
case ir::Opcode::ICmpNE:
|
|
case ir::Opcode::ICmpLT:
|
|
case ir::Opcode::ICmpGT:
|
|
case ir::Opcode::ICmpLE:
|
|
case ir::Opcode::ICmpGE: {
|
|
auto* binary = static_cast<ir::BinaryInst*>(inst.get());
|
|
auto lowered = NewVRegValue(ValueType::I1);
|
|
MachineInstr instr(MachineInstr::Opcode::ICmp,
|
|
{MachineOperand::VReg(lowered.index),
|
|
ResolveScalarOperand(binary->GetLhs(), inline_values),
|
|
ResolveScalarOperand(binary->GetRhs(), inline_values)});
|
|
instr.SetCondCode(LowerIntCond(inst->GetOpcode()));
|
|
current_block_->Append(std::move(instr));
|
|
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
|
|
break;
|
|
}
|
|
case ir::Opcode::FCmpEQ:
|
|
case ir::Opcode::FCmpNE:
|
|
case ir::Opcode::FCmpLT:
|
|
case ir::Opcode::FCmpGT:
|
|
case ir::Opcode::FCmpLE:
|
|
case ir::Opcode::FCmpGE: {
|
|
auto* binary = static_cast<ir::BinaryInst*>(inst.get());
|
|
auto lowered = NewVRegValue(ValueType::I1);
|
|
MachineInstr instr(MachineInstr::Opcode::FCmp,
|
|
{MachineOperand::VReg(lowered.index),
|
|
ResolveScalarOperand(binary->GetLhs(), inline_values),
|
|
ResolveScalarOperand(binary->GetRhs(), inline_values)});
|
|
instr.SetCondCode(LowerFloatCond(inst->GetOpcode()));
|
|
current_block_->Append(std::move(instr));
|
|
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
|
|
break;
|
|
}
|
|
case ir::Opcode::Zext: {
|
|
auto* zext = static_cast<ir::ZextInst*>(inst.get());
|
|
auto lowered = NewVRegValue(LowerType(zext->GetType()));
|
|
current_block_->Append(MachineInstr::Opcode::ZExt,
|
|
{MachineOperand::VReg(lowered.index),
|
|
ResolveScalarOperand(zext->GetValue(), inline_values)});
|
|
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
|
|
break;
|
|
}
|
|
case ir::Opcode::IToF: {
|
|
auto* unary = static_cast<ir::UnaryInst*>(inst.get());
|
|
auto lowered = NewVRegValue(ValueType::F32);
|
|
current_block_->Append(MachineInstr::Opcode::ItoF,
|
|
{MachineOperand::VReg(lowered.index),
|
|
ResolveScalarOperand(unary->GetOprd(), inline_values)});
|
|
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
|
|
break;
|
|
}
|
|
case ir::Opcode::FtoI: {
|
|
auto* unary = static_cast<ir::UnaryInst*>(inst.get());
|
|
auto lowered = NewVRegValue(ValueType::I32);
|
|
current_block_->Append(MachineInstr::Opcode::FtoI,
|
|
{MachineOperand::VReg(lowered.index),
|
|
ResolveScalarOperand(unary->GetOprd(), inline_values)});
|
|
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
|
|
break;
|
|
}
|
|
case ir::Opcode::Call: {
|
|
auto* nested_call = static_cast<ir::CallInst*>(inst.get());
|
|
auto* nested_callee = nested_call->GetCallee();
|
|
if (nested_callee == nullptr || nested_callee == current_ir_function_) {
|
|
return false;
|
|
}
|
|
|
|
MachineOperand math_idiom_result;
|
|
if (TryEmitMathIdiomCall(nested_call, inline_values, &math_idiom_result)) {
|
|
(*inline_values)[inst.get()] = math_idiom_result;
|
|
break;
|
|
}
|
|
|
|
if (CanInlineDirectCall(*nested_callee)) {
|
|
MachineOperand nested_return_operand;
|
|
bool nested_has_return = false;
|
|
OperandMap nested_values;
|
|
const auto& nested_args = nested_callee->GetArguments();
|
|
const auto& nested_call_args = nested_call->GetArguments();
|
|
if (nested_args.size() != nested_call_args.size()) {
|
|
return false;
|
|
}
|
|
for (size_t i = 0; i < nested_call_args.size(); ++i) {
|
|
nested_values[nested_args[i].get()] =
|
|
ResolveScalarOperand(nested_call_args[i], inline_values);
|
|
}
|
|
if (!TryInlineFunctionBody(*nested_callee, &nested_values, &nested_return_operand,
|
|
&nested_has_return, inline_depth + 1)) {
|
|
return false;
|
|
}
|
|
if (!nested_call->GetType()->IsVoid()) {
|
|
if (!nested_has_return) {
|
|
throw std::runtime_error(
|
|
FormatError("mir", "inlined nested call is missing return value"));
|
|
}
|
|
auto nested_value =
|
|
MaterializeOperandAsValue(nested_return_operand, LowerType(nested_call->GetType()));
|
|
(*inline_values)[inst.get()] = MachineOperand::VReg(nested_value.index);
|
|
}
|
|
break;
|
|
}
|
|
|
|
std::vector<MachineOperand> operands;
|
|
if (!nested_call->GetType()->IsVoid()) {
|
|
auto lowered = NewVRegValue(LowerType(nested_call->GetType()));
|
|
operands.push_back(MachineOperand::VReg(lowered.index));
|
|
(*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index);
|
|
}
|
|
std::vector<ValueType> arg_types;
|
|
for (auto* arg : nested_call->GetArguments()) {
|
|
operands.push_back(ResolveScalarOperand(arg, inline_values));
|
|
arg_types.push_back(LowerType(arg->GetType()));
|
|
}
|
|
MachineInstr instr(MachineInstr::Opcode::Call, std::move(operands));
|
|
instr.SetCallInfo(nested_callee->GetName(), std::move(arg_types),
|
|
LowerType(nested_call->GetType()));
|
|
current_block_->Append(std::move(instr));
|
|
break;
|
|
}
|
|
case ir::Opcode::Return: {
|
|
auto* ret = static_cast<ir::ReturnInst*>(inst.get());
|
|
if (ret->HasReturnValue()) {
|
|
*return_operand = ResolveScalarOperand(ret->GetReturnValue(), inline_values);
|
|
*has_return = true;
|
|
}
|
|
break;
|
|
}
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool TryEmitMathIdiomCall(ir::CallInst* call, const OperandMap* inline_values,
|
|
MachineOperand* result_operand) {
|
|
auto* callee = call == nullptr ? nullptr : call->GetCallee();
|
|
const ir::GlobalValue* sqrt_state = nullptr;
|
|
if (callee == nullptr || call->GetType() == nullptr || !call->GetType()->IsFloat() ||
|
|
call->GetArguments().size() != 1 ||
|
|
!call->GetArguments()[0]->GetType()->IsFloat() ||
|
|
!ir::mathidiom::IsPrivateToleranceNewtonSqrt(*callee, &sqrt_state)) {
|
|
return false;
|
|
}
|
|
|
|
auto lowered = NewVRegValue(ValueType::F32);
|
|
MachineInstr instr(MachineInstr::Opcode::FSqrt,
|
|
{MachineOperand::VReg(lowered.index),
|
|
ResolveScalarOperand(call->GetArguments()[0], inline_values)});
|
|
if (sqrt_state != nullptr) {
|
|
AddressExpr address;
|
|
address.base_kind = AddrBaseKind::Global;
|
|
address.symbol = sqrt_state->GetName();
|
|
instr.SetAddress(std::move(address));
|
|
}
|
|
current_block_->Append(std::move(instr));
|
|
if (result_operand != nullptr) {
|
|
*result_operand = MachineOperand::VReg(lowered.index);
|
|
} else {
|
|
values_[call] = lowered;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool TryInlineDirectCall(ir::CallInst* call) {
|
|
auto* callee = call->GetCallee();
|
|
if (callee == nullptr || callee == current_ir_function_ || !CanInlineDirectCall(*callee)) {
|
|
return false;
|
|
}
|
|
|
|
const auto& callee_args = callee->GetArguments();
|
|
const auto& call_args = call->GetArguments();
|
|
if (callee_args.size() != call_args.size()) {
|
|
return false;
|
|
}
|
|
|
|
OperandMap inline_values;
|
|
for (size_t i = 0; i < call_args.size(); ++i) {
|
|
inline_values[callee_args[i].get()] = ResolveScalarOperand(call_args[i], nullptr);
|
|
}
|
|
|
|
MachineOperand return_operand;
|
|
bool has_return = false;
|
|
if (!TryInlineFunctionBody(*callee, &inline_values, &return_operand, &has_return, 0)) {
|
|
return false;
|
|
}
|
|
|
|
if (!call->GetType()->IsVoid()) {
|
|
if (!has_return) {
|
|
throw std::runtime_error(FormatError("mir", "inlined call is missing return value"));
|
|
}
|
|
values_[call] = MaterializeOperandAsValue(return_operand, LowerType(call->GetType()));
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void LowerInstruction(ir::Instruction& inst) {
|
|
switch (inst.GetOpcode()) {
|
|
case ir::Opcode::Alloca: {
|
|
auto* alloca_inst = static_cast<ir::AllocaInst*>(&inst);
|
|
const auto allocated_type = alloca_inst->GetAllocatedType();
|
|
const int object = current_function_->CreateStackObject(
|
|
allocated_type->GetSize(), GetIRTypeAlign(allocated_type),
|
|
StackObjectKind::Local, inst.GetName());
|
|
if (ShouldMaterializeAllocaBase(allocated_type)) {
|
|
auto lowered = NewVRegValue(ValueType::Ptr);
|
|
MachineInstr lea(MachineInstr::Opcode::Lea,
|
|
{MachineOperand::VReg(lowered.index)});
|
|
AddressExpr address;
|
|
address.base_kind = AddrBaseKind::FrameObject;
|
|
address.base_index = object;
|
|
lea.SetAddress(std::move(address));
|
|
current_block_->Append(std::move(lea));
|
|
values_[&inst] = lowered;
|
|
} else {
|
|
values_[&inst] = {LoweredKind::StackObject, ValueType::Ptr, object, ""};
|
|
}
|
|
return;
|
|
}
|
|
case ir::Opcode::Load: {
|
|
auto* load = static_cast<ir::LoadInst*>(&inst);
|
|
auto lowered = NewVRegValue(LowerType(load->GetType()));
|
|
MachineInstr instr(MachineInstr::Opcode::Load,
|
|
{MachineOperand::VReg(lowered.index)});
|
|
instr.SetAddress(LowerAddress(load->GetPtr()));
|
|
current_block_->Append(std::move(instr));
|
|
values_[&inst] = lowered;
|
|
return;
|
|
}
|
|
case ir::Opcode::Store: {
|
|
auto* store = static_cast<ir::StoreInst*>(&inst);
|
|
MachineInstr instr(MachineInstr::Opcode::Store,
|
|
{LowerScalarOperand(store->GetValue())});
|
|
instr.SetValueType(LowerType(store->GetValue()->GetType()));
|
|
instr.SetAddress(LowerAddress(store->GetPtr()));
|
|
current_block_->Append(std::move(instr));
|
|
return;
|
|
}
|
|
case ir::Opcode::Add:
|
|
case ir::Opcode::Sub:
|
|
case ir::Opcode::Mul:
|
|
case ir::Opcode::Div:
|
|
case ir::Opcode::Rem:
|
|
case ir::Opcode::And:
|
|
case ir::Opcode::Or:
|
|
case ir::Opcode::Xor:
|
|
case ir::Opcode::Shl:
|
|
case ir::Opcode::AShr:
|
|
case ir::Opcode::LShr:
|
|
case ir::Opcode::FAdd:
|
|
case ir::Opcode::FSub:
|
|
case ir::Opcode::FMul:
|
|
case ir::Opcode::FDiv: {
|
|
auto* binary = static_cast<ir::BinaryInst*>(&inst);
|
|
auto lowered = NewVRegValue(LowerType(binary->GetType()));
|
|
current_block_->Append(LowerBinaryOpcode(inst.GetOpcode()),
|
|
{MachineOperand::VReg(lowered.index),
|
|
LowerScalarOperand(binary->GetLhs()),
|
|
LowerScalarOperand(binary->GetRhs())});
|
|
values_[&inst] = lowered;
|
|
return;
|
|
}
|
|
case ir::Opcode::FNeg: {
|
|
auto* unary = static_cast<ir::UnaryInst*>(&inst);
|
|
auto lowered = NewVRegValue(ValueType::F32);
|
|
current_block_->Append(MachineInstr::Opcode::FNeg,
|
|
{MachineOperand::VReg(lowered.index),
|
|
LowerScalarOperand(unary->GetOprd())});
|
|
values_[&inst] = lowered;
|
|
return;
|
|
}
|
|
case ir::Opcode::ICmpEQ:
|
|
case ir::Opcode::ICmpNE:
|
|
case ir::Opcode::ICmpLT:
|
|
case ir::Opcode::ICmpGT:
|
|
case ir::Opcode::ICmpLE:
|
|
case ir::Opcode::ICmpGE: {
|
|
auto* binary = static_cast<ir::BinaryInst*>(&inst);
|
|
auto lowered = NewVRegValue(ValueType::I1);
|
|
MachineInstr instr(MachineInstr::Opcode::ICmp,
|
|
{MachineOperand::VReg(lowered.index),
|
|
LowerScalarOperand(binary->GetLhs()),
|
|
LowerScalarOperand(binary->GetRhs())});
|
|
instr.SetCondCode(LowerIntCond(inst.GetOpcode()));
|
|
current_block_->Append(std::move(instr));
|
|
values_[&inst] = lowered;
|
|
return;
|
|
}
|
|
case ir::Opcode::FCmpEQ:
|
|
case ir::Opcode::FCmpNE:
|
|
case ir::Opcode::FCmpLT:
|
|
case ir::Opcode::FCmpGT:
|
|
case ir::Opcode::FCmpLE:
|
|
case ir::Opcode::FCmpGE: {
|
|
auto* binary = static_cast<ir::BinaryInst*>(&inst);
|
|
auto lowered = NewVRegValue(ValueType::I1);
|
|
MachineInstr instr(MachineInstr::Opcode::FCmp,
|
|
{MachineOperand::VReg(lowered.index),
|
|
LowerScalarOperand(binary->GetLhs()),
|
|
LowerScalarOperand(binary->GetRhs())});
|
|
instr.SetCondCode(LowerFloatCond(inst.GetOpcode()));
|
|
current_block_->Append(std::move(instr));
|
|
values_[&inst] = lowered;
|
|
return;
|
|
}
|
|
case ir::Opcode::Zext: {
|
|
auto* zext = static_cast<ir::ZextInst*>(&inst);
|
|
auto lowered = NewVRegValue(LowerType(zext->GetType()));
|
|
current_block_->Append(MachineInstr::Opcode::ZExt,
|
|
{MachineOperand::VReg(lowered.index),
|
|
LowerScalarOperand(zext->GetValue())});
|
|
values_[&inst] = lowered;
|
|
return;
|
|
}
|
|
case ir::Opcode::IToF: {
|
|
auto* unary = static_cast<ir::UnaryInst*>(&inst);
|
|
auto lowered = NewVRegValue(ValueType::F32);
|
|
current_block_->Append(MachineInstr::Opcode::ItoF,
|
|
{MachineOperand::VReg(lowered.index),
|
|
LowerScalarOperand(unary->GetOprd())});
|
|
values_[&inst] = lowered;
|
|
return;
|
|
}
|
|
case ir::Opcode::FtoI: {
|
|
auto* unary = static_cast<ir::UnaryInst*>(&inst);
|
|
auto lowered = NewVRegValue(ValueType::I32);
|
|
current_block_->Append(MachineInstr::Opcode::FtoI,
|
|
{MachineOperand::VReg(lowered.index),
|
|
LowerScalarOperand(unary->GetOprd())});
|
|
values_[&inst] = lowered;
|
|
return;
|
|
}
|
|
case ir::Opcode::GetElementPtr: {
|
|
auto* gep = static_cast<ir::GetElementPtrInst*>(&inst);
|
|
auto lowered = NewVRegValue(ValueType::Ptr);
|
|
AddressExpr address = LowerAddress(gep->GetPointer());
|
|
auto current_type = gep->GetSourceType();
|
|
for (size_t i = 0; i < gep->GetNumIndices(); ++i) {
|
|
auto* index = gep->GetIndex(i);
|
|
const std::int64_t stride = current_type ? current_type->GetSize() : 0;
|
|
if (auto* ci = ir::dyncast<ir::ConstantInt>(index)) {
|
|
address.const_offset += static_cast<std::int64_t>(ci->GetValue()) * stride;
|
|
} else if (auto* cb = ir::dyncast<ir::ConstantI1>(index)) {
|
|
address.const_offset +=
|
|
static_cast<std::int64_t>(cb->GetValue() ? 1 : 0) * stride;
|
|
} else {
|
|
address.scaled_vregs.push_back({LowerScalarOperand(index).GetVReg(), stride});
|
|
}
|
|
if (current_type && current_type->IsArray()) {
|
|
current_type = current_type->GetElementType();
|
|
}
|
|
}
|
|
MachineInstr instr(MachineInstr::Opcode::Lea,
|
|
{MachineOperand::VReg(lowered.index)});
|
|
instr.SetAddress(std::move(address));
|
|
current_block_->Append(std::move(instr));
|
|
values_[&inst] = lowered;
|
|
return;
|
|
}
|
|
case ir::Opcode::Call: {
|
|
auto* call = static_cast<ir::CallInst*>(&inst);
|
|
if (TryEmitMathIdiomCall(call, nullptr, nullptr)) {
|
|
return;
|
|
}
|
|
if (TryInlineDirectCall(call)) {
|
|
return;
|
|
}
|
|
std::vector<MachineOperand> operands;
|
|
if (!call->GetType()->IsVoid()) {
|
|
auto lowered = NewVRegValue(LowerType(call->GetType()));
|
|
operands.push_back(MachineOperand::VReg(lowered.index));
|
|
values_[&inst] = lowered;
|
|
}
|
|
std::vector<ValueType> arg_types;
|
|
for (auto* arg : call->GetArguments()) {
|
|
operands.push_back(LowerScalarOperand(arg));
|
|
arg_types.push_back(LowerType(arg->GetType()));
|
|
}
|
|
MachineInstr instr(MachineInstr::Opcode::Call, std::move(operands));
|
|
instr.SetCallInfo(call->GetCallee()->GetName(), std::move(arg_types),
|
|
LowerType(call->GetType()));
|
|
current_block_->Append(std::move(instr));
|
|
return;
|
|
}
|
|
case ir::Opcode::Br: {
|
|
auto* br = static_cast<ir::UncondBrInst*>(&inst);
|
|
current_block_->Append(MachineInstr::Opcode::Br,
|
|
{MachineOperand::Block(blocks_.at(br->GetDest())->GetName())});
|
|
return;
|
|
}
|
|
case ir::Opcode::CondBr: {
|
|
auto* br = static_cast<ir::CondBrInst*>(&inst);
|
|
current_block_->Append(MachineInstr::Opcode::CondBr,
|
|
{LowerScalarOperand(br->GetCondition()),
|
|
MachineOperand::Block(blocks_.at(br->GetThenBlock())->GetName()),
|
|
MachineOperand::Block(blocks_.at(br->GetElseBlock())->GetName())});
|
|
return;
|
|
}
|
|
case ir::Opcode::Return: {
|
|
auto* ret = static_cast<ir::ReturnInst*>(&inst);
|
|
if (ret->HasReturnValue()) {
|
|
MachineInstr instr(MachineInstr::Opcode::Ret,
|
|
{LowerScalarOperand(ret->GetReturnValue())});
|
|
instr.SetValueType(LowerType(ret->GetReturnValue()->GetType()));
|
|
current_block_->Append(std::move(instr));
|
|
} else {
|
|
current_block_->Append(MachineInstr::Opcode::Ret);
|
|
}
|
|
return;
|
|
}
|
|
case ir::Opcode::Memset: {
|
|
auto* memset_inst = static_cast<ir::MemsetInst*>(&inst);
|
|
MachineInstr instr(MachineInstr::Opcode::Memset,
|
|
{LowerScalarOperand(memset_inst->GetValue()),
|
|
LowerScalarOperand(memset_inst->GetLength())});
|
|
instr.SetAddress(LowerAddress(memset_inst->GetDest()));
|
|
current_block_->Append(std::move(instr));
|
|
return;
|
|
}
|
|
case ir::Opcode::Unreachable:
|
|
current_block_->Append(MachineInstr::Opcode::Unreachable);
|
|
return;
|
|
case ir::Opcode::Phi:
|
|
return;
|
|
case ir::Opcode::FRem:
|
|
case ir::Opcode::Neg:
|
|
case ir::Opcode::Not:
|
|
throw std::runtime_error(
|
|
FormatError("mir", "unsupported instruction in backend lowering"));
|
|
}
|
|
|
|
throw std::runtime_error(FormatError("mir", "unsupported IR opcode in backend lowering"));
|
|
}
|
|
|
|
void LowerFunction(ir::Function& function) {
|
|
values_.clear();
|
|
blocks_.clear();
|
|
|
|
std::vector<ValueType> param_types;
|
|
for (const auto& type : function.GetParamTypes()) {
|
|
param_types.push_back(LowerType(type));
|
|
}
|
|
|
|
auto machine_function = std::make_unique<MachineFunction>(
|
|
function.GetName(), LowerType(function.GetReturnType()), std::move(param_types));
|
|
current_ir_function_ = &function;
|
|
current_function_ = machine_function.get();
|
|
const auto ordered_blocks = CollectLoweringOrder(function);
|
|
|
|
for (const auto& block : function.GetBlocks()) {
|
|
blocks_[block.get()] = current_function_->CreateBlock(block->GetName());
|
|
}
|
|
|
|
if (!function.GetBlocks().empty()) {
|
|
auto* entry = blocks_.at(function.GetBlocks().front().get());
|
|
for (const auto& argument : function.GetArguments()) {
|
|
auto lowered = NewVRegValue(LowerType(argument->GetType()));
|
|
entry->Append(MachineInstr::Opcode::Arg,
|
|
{MachineOperand::VReg(lowered.index),
|
|
MachineOperand::Imm(static_cast<std::int64_t>(argument->GetIndex()))});
|
|
values_[argument.get()] = lowered;
|
|
}
|
|
}
|
|
|
|
PreparePhiResults(function);
|
|
|
|
for (auto* block : ordered_blocks) {
|
|
current_block_ = blocks_.at(block);
|
|
for (const auto& inst : block->GetInstructions()) {
|
|
LowerInstruction(*inst);
|
|
}
|
|
}
|
|
|
|
EmitPhiCopies(function);
|
|
|
|
machine_module_->AddFunction(std::move(machine_function));
|
|
current_ir_function_ = nullptr;
|
|
current_function_ = nullptr;
|
|
current_block_ = nullptr;
|
|
}
|
|
|
|
const ir::Module& module_;
|
|
std::unique_ptr<MachineModule> machine_module_;
|
|
ir::Function* current_ir_function_ = nullptr;
|
|
MachineFunction* current_function_ = nullptr;
|
|
MachineBasicBlock* current_block_ = nullptr;
|
|
std::unordered_map<const ir::Value*, LoweredValue> values_;
|
|
std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> blocks_;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) {
|
|
DefaultContext();
|
|
return Lowerer(module).Run();
|
|
}
|
|
|
|
} // namespace mir
|