#include "mir/MIR.h" #include #include #include #include #include #include #include #include #include "ir/IR.h" #include "ir/passes/MathIdiomUtils.h" #include "utils/Log.h" namespace mir { namespace { enum class LoweredKind { Invalid, VReg, StackObject, Global }; std::vector CollectLoweringOrder(ir::Function& function) { std::vector order; auto* entry = function.GetEntryBlock(); if (!entry) { return order; } std::unordered_set visited; std::vector stack{entry}; while (!stack.empty()) { auto* block = stack.back(); stack.pop_back(); if (!block || !visited.insert(block).second) { continue; } order.push_back(block); const auto& succs = block->GetSuccessors(); for (auto it = succs.rbegin(); it != succs.rend(); ++it) { if (*it != nullptr) { stack.push_back(*it); } } } for (const auto& block : function.GetBlocks()) { if (block && visited.insert(block.get()).second) { order.push_back(block.get()); } } return order; } struct LoweredValue { LoweredKind kind = LoweredKind::Invalid; ValueType type = ValueType::Void; int index = -1; std::string symbol; }; ValueType LowerType(const std::shared_ptr& type) { if (!type || type->IsVoid()) { return ValueType::Void; } if (type->IsInt1()) { return ValueType::I1; } if (type->IsInt32()) { return ValueType::I32; } if (type->IsFloat()) { return ValueType::F32; } if (type->IsPointer()) { return ValueType::Ptr; } throw std::runtime_error(FormatError("mir", "unsupported IR type in backend lowering")); } int GetIRTypeAlign(const std::shared_ptr& type) { if (!type) { return 1; } if (type->IsArray()) { return GetIRTypeAlign(type->GetElementType()); } return GetValueAlign(LowerType(type)); } bool ShouldMaterializeAllocaBase(const std::shared_ptr& type) { return type && type->IsArray() && type->GetSize() >= 256; } CondCode LowerIntCond(ir::Opcode opcode) { switch (opcode) { case ir::Opcode::ICmpEQ: return CondCode::EQ; case ir::Opcode::ICmpNE: return CondCode::NE; case ir::Opcode::ICmpLT: return CondCode::LT; case ir::Opcode::ICmpGT: return CondCode::GT; case ir::Opcode::ICmpLE: return CondCode::LE; case ir::Opcode::ICmpGE: return CondCode::GE; default: throw std::runtime_error(FormatError("mir", "invalid integer compare opcode")); } } CondCode LowerFloatCond(ir::Opcode opcode) { switch (opcode) { case ir::Opcode::FCmpEQ: return CondCode::EQ; case ir::Opcode::FCmpNE: return CondCode::NE; case ir::Opcode::FCmpLT: return CondCode::LT; case ir::Opcode::FCmpGT: return CondCode::GT; case ir::Opcode::FCmpLE: return CondCode::LE; case ir::Opcode::FCmpGE: return CondCode::GE; default: throw std::runtime_error(FormatError("mir", "invalid float compare opcode")); } } std::int64_t FloatBits(float value) { std::uint32_t bits = 0; std::memcpy(&bits, &value, sizeof(bits)); return static_cast(bits); } class Lowerer { public: explicit Lowerer(const ir::Module& module) : module_(module), machine_module_(std::make_unique(module)) {} std::unique_ptr Run() { for (const auto& func : module_.GetFunctions()) { if (func && !func->IsExternal()) { LowerFunction(*func); } } return std::move(machine_module_); } private: using OperandMap = std::unordered_map; MachineOperand ResolveScalarOperand(ir::Value* value, const OperandMap* inline_values = nullptr) { if (auto* ci = ir::dyncast(value)) { return MachineOperand::Imm(ci->GetValue()); } if (auto* cb = ir::dyncast(value)) { return MachineOperand::Imm(cb->GetValue() ? 1 : 0); } if (auto* cf = ir::dyncast(value)) { return MachineOperand::Imm(FloatBits(cf->GetValue())); } if (inline_values != nullptr) { auto inline_it = inline_values->find(value); if (inline_it != inline_values->end()) { return inline_it->second; } } auto it = values_.find(value); if (it == values_.end() || it->second.kind != LoweredKind::VReg) { throw std::runtime_error( FormatError("mir", "value is not materialized as a virtual register: " + value->GetName())); } return MachineOperand::VReg(it->second.index); } MachineOperand LowerScalarOperand(ir::Value* value) { return ResolveScalarOperand(value, nullptr); } AddressExpr LowerAddress(ir::Value* value) { if (auto* global = ir::dyncast(value)) { AddressExpr address; address.base_kind = AddrBaseKind::Global; address.symbol = global->GetName(); return address; } auto it = values_.find(value); if (it == values_.end()) { throw std::runtime_error(FormatError("mir", "missing lowered address value")); } AddressExpr address; switch (it->second.kind) { case LoweredKind::StackObject: address.base_kind = AddrBaseKind::FrameObject; address.base_index = it->second.index; return address; case LoweredKind::Global: address.base_kind = AddrBaseKind::Global; address.symbol = it->second.symbol; return address; case LoweredKind::VReg: address.base_kind = AddrBaseKind::VReg; address.base_index = it->second.index; return address; case LoweredKind::Invalid: break; } throw std::runtime_error(FormatError("mir", "invalid address lowering")); } MachineInstr::Opcode LowerBinaryOpcode(ir::Opcode opcode) { switch (opcode) { case ir::Opcode::Add: return MachineInstr::Opcode::Add; case ir::Opcode::Sub: return MachineInstr::Opcode::Sub; case ir::Opcode::Mul: return MachineInstr::Opcode::Mul; case ir::Opcode::Div: return MachineInstr::Opcode::Div; case ir::Opcode::Rem: return MachineInstr::Opcode::Rem; case ir::Opcode::And: return MachineInstr::Opcode::And; case ir::Opcode::Or: return MachineInstr::Opcode::Or; case ir::Opcode::Xor: return MachineInstr::Opcode::Xor; case ir::Opcode::Shl: return MachineInstr::Opcode::Shl; case ir::Opcode::AShr: return MachineInstr::Opcode::AShr; case ir::Opcode::LShr: return MachineInstr::Opcode::LShr; case ir::Opcode::FAdd: return MachineInstr::Opcode::FAdd; case ir::Opcode::FSub: return MachineInstr::Opcode::FSub; case ir::Opcode::FMul: return MachineInstr::Opcode::FMul; case ir::Opcode::FDiv: return MachineInstr::Opcode::FDiv; default: throw std::runtime_error(FormatError("mir", "unsupported binary opcode")); } } LoweredValue NewVRegValue(ValueType type) { return {LoweredKind::VReg, type, current_function_->NewVReg(type), ""}; } LoweredValue MaterializeOperandAsValue(const MachineOperand& operand, ValueType type) { if (operand.GetKind() == OperandKind::VReg) { return {LoweredKind::VReg, type, operand.GetVReg(), ""}; } auto lowered = NewVRegValue(type); current_block_->Append(MachineInstr::Opcode::Copy, {MachineOperand::VReg(lowered.index), operand}); return lowered; } void InsertBeforeTerminator(MachineBasicBlock* block, MachineInstr instr) { auto& instructions = block->GetInstructions(); auto insert_pos = instructions.end(); if (!instructions.empty() && instructions.back().IsTerminator()) { insert_pos = instructions.end() - 1; } instructions.insert(insert_pos, std::move(instr)); } struct PhiCopy { int dst_vreg = -1; MachineOperand src; }; void EmitResolvedPhiCopies(MachineBasicBlock* block, std::vector copies) { copies.erase(std::remove_if(copies.begin(), copies.end(), [](const PhiCopy& copy) { return copy.src.GetKind() == OperandKind::VReg && copy.src.GetVReg() == copy.dst_vreg; }), copies.end()); while (!copies.empty()) { bool progress = false; for (auto it = copies.begin(); it != copies.end(); ++it) { const bool dst_is_still_needed_as_source = std::any_of(copies.begin(), copies.end(), [&](const PhiCopy& other) { return other.src.GetKind() == OperandKind::VReg && other.src.GetVReg() == it->dst_vreg; }); if (dst_is_still_needed_as_source) { continue; } InsertBeforeTerminator( block, MachineInstr(MachineInstr::Opcode::Copy, {MachineOperand::VReg(it->dst_vreg), it->src})); copies.erase(it); progress = true; break; } if (progress) { continue; } auto& cycle = copies.front(); if (cycle.src.GetKind() != OperandKind::VReg) { throw std::runtime_error(FormatError("mir", "invalid phi copy cycle")); } const int src_vreg = cycle.src.GetVReg(); const auto temp_type = current_function_->GetVRegInfo(src_vreg).type; const int temp_vreg = current_function_->NewVReg(temp_type); InsertBeforeTerminator( block, MachineInstr(MachineInstr::Opcode::Copy, {MachineOperand::VReg(temp_vreg), MachineOperand::VReg(src_vreg)})); for (auto& copy : copies) { if (copy.src.GetKind() == OperandKind::VReg && copy.src.GetVReg() == src_vreg) { copy.src = MachineOperand::VReg(temp_vreg); } } } } void RedirectEdgeToPhiBlock(MachineBasicBlock* pred_block, const std::string& succ_name, const std::string& phi_block_name) { auto& instructions = pred_block->GetInstructions(); if (instructions.empty() || !instructions.back().IsTerminator()) { throw std::runtime_error(FormatError("mir", "phi predecessor has no terminator")); } auto& term = instructions.back(); auto& operands = term.GetOperands(); switch (term.GetOpcode()) { case MachineInstr::Opcode::Br: if (!operands.empty() && operands[0].GetKind() == OperandKind::Block && operands[0].GetText() == succ_name) { operands[0] = MachineOperand::Block(phi_block_name); return; } break; case MachineInstr::Opcode::CondBr: for (size_t i = 1; i < operands.size(); ++i) { if (operands[i].GetKind() == OperandKind::Block && operands[i].GetText() == succ_name) { operands[i] = MachineOperand::Block(phi_block_name); return; } } break; default: break; } throw std::runtime_error(FormatError("mir", "failed to redirect phi edge")); } void PreparePhiResults(ir::Function& function) { for (const auto& block : function.GetBlocks()) { for (const auto& inst : block->GetInstructions()) { if (inst->GetOpcode() != ir::Opcode::Phi) { break; } auto lowered = NewVRegValue(LowerType(inst->GetType())); values_[inst.get()] = lowered; } } } void EmitPhiCopies(ir::Function& function) { struct EdgeCopies { MachineBasicBlock* succ_block = nullptr; std::vector copies; }; std::unordered_map> pending; for (const auto& block : function.GetBlocks()) { for (const auto& inst : block->GetInstructions()) { if (inst->GetOpcode() != ir::Opcode::Phi) { break; } auto* phi = static_cast(inst.get()); const int dest_vreg = values_.at(phi).index; for (int i = 0; i < phi->GetNumIncomings(); ++i) { auto* pred_block = blocks_.at(phi->GetIncomingBlock(i)); auto* succ_block = blocks_.at(block.get()); auto& edges = pending[pred_block]; auto edge_it = std::find_if(edges.begin(), edges.end(), [&](const EdgeCopies& edge) { return edge.succ_block == succ_block; }); if (edge_it == edges.end()) { edges.push_back({succ_block, {}}); edge_it = std::prev(edges.end()); } edge_it->copies.push_back( {dest_vreg, LowerScalarOperand(phi->GetIncomingValue(i))}); } } } int phi_block_index = 0; for (auto& item : pending) { auto* pred_block = item.first; auto& pred_instructions = pred_block->GetInstructions(); if (pred_instructions.empty() || !pred_instructions.back().IsTerminator()) { throw std::runtime_error(FormatError("mir", "phi predecessor has no terminator")); } const auto terminator_opcode = pred_instructions.back().GetOpcode(); for (auto& edge : item.second) { if (terminator_opcode == MachineInstr::Opcode::Br) { EmitResolvedPhiCopies(pred_block, std::move(edge.copies)); continue; } if (terminator_opcode != MachineInstr::Opcode::CondBr) { throw std::runtime_error( FormatError("mir", "unsupported terminator for phi lowering")); } auto* phi_block = current_function_->CreateBlock( "phi.edge." + std::to_string(phi_block_index++)); EmitResolvedPhiCopies(phi_block, std::move(edge.copies)); phi_block->Append(MachineInstr::Opcode::Br, {MachineOperand::Block(edge.succ_block->GetName())}); RedirectEdgeToPhiBlock(pred_block, edge.succ_block->GetName(), phi_block->GetName()); } } } bool CanInlineDirectCall(const ir::Function& function) const { if (function.IsExternal() || function.GetBlocks().size() != 1) { return false; } for (const auto& block : function.GetBlocks()) { for (const auto& inst : block->GetInstructions()) { switch (inst->GetOpcode()) { case ir::Opcode::Add: case ir::Opcode::Sub: case ir::Opcode::Mul: case ir::Opcode::Div: case ir::Opcode::Rem: case ir::Opcode::And: case ir::Opcode::Or: case ir::Opcode::Xor: case ir::Opcode::Shl: case ir::Opcode::AShr: case ir::Opcode::LShr: case ir::Opcode::FAdd: case ir::Opcode::FSub: case ir::Opcode::FMul: case ir::Opcode::FDiv: case ir::Opcode::FNeg: case ir::Opcode::ICmpEQ: case ir::Opcode::ICmpNE: case ir::Opcode::ICmpLT: case ir::Opcode::ICmpGT: case ir::Opcode::ICmpLE: case ir::Opcode::ICmpGE: case ir::Opcode::FCmpEQ: case ir::Opcode::FCmpNE: case ir::Opcode::FCmpLT: case ir::Opcode::FCmpGT: case ir::Opcode::FCmpLE: case ir::Opcode::FCmpGE: case ir::Opcode::Zext: case ir::Opcode::IToF: case ir::Opcode::FtoI: case ir::Opcode::Call: case ir::Opcode::Return: break; default: return false; } } } return true; } bool TryInlineFunctionBody(const ir::Function& callee, OperandMap* inline_values, MachineOperand* return_operand, bool* has_return, int inline_depth) { if (inline_depth > 2) { return false; } for (const auto& inst : callee.GetBlocks().front()->GetInstructions()) { switch (inst->GetOpcode()) { case ir::Opcode::Add: case ir::Opcode::Sub: case ir::Opcode::Mul: case ir::Opcode::Div: case ir::Opcode::Rem: case ir::Opcode::And: case ir::Opcode::Or: case ir::Opcode::Xor: case ir::Opcode::Shl: case ir::Opcode::AShr: case ir::Opcode::LShr: case ir::Opcode::FAdd: case ir::Opcode::FSub: case ir::Opcode::FMul: case ir::Opcode::FDiv: { auto* binary = static_cast(inst.get()); auto lowered = NewVRegValue(LowerType(binary->GetType())); current_block_->Append(LowerBinaryOpcode(inst->GetOpcode()), {MachineOperand::VReg(lowered.index), ResolveScalarOperand(binary->GetLhs(), inline_values), ResolveScalarOperand(binary->GetRhs(), inline_values)}); (*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index); break; } case ir::Opcode::FNeg: { auto* unary = static_cast(inst.get()); auto lowered = NewVRegValue(ValueType::F32); current_block_->Append(MachineInstr::Opcode::FNeg, {MachineOperand::VReg(lowered.index), ResolveScalarOperand(unary->GetOprd(), inline_values)}); (*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index); break; } case ir::Opcode::ICmpEQ: case ir::Opcode::ICmpNE: case ir::Opcode::ICmpLT: case ir::Opcode::ICmpGT: case ir::Opcode::ICmpLE: case ir::Opcode::ICmpGE: { auto* binary = static_cast(inst.get()); auto lowered = NewVRegValue(ValueType::I1); MachineInstr instr(MachineInstr::Opcode::ICmp, {MachineOperand::VReg(lowered.index), ResolveScalarOperand(binary->GetLhs(), inline_values), ResolveScalarOperand(binary->GetRhs(), inline_values)}); instr.SetCondCode(LowerIntCond(inst->GetOpcode())); current_block_->Append(std::move(instr)); (*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index); break; } case ir::Opcode::FCmpEQ: case ir::Opcode::FCmpNE: case ir::Opcode::FCmpLT: case ir::Opcode::FCmpGT: case ir::Opcode::FCmpLE: case ir::Opcode::FCmpGE: { auto* binary = static_cast(inst.get()); auto lowered = NewVRegValue(ValueType::I1); MachineInstr instr(MachineInstr::Opcode::FCmp, {MachineOperand::VReg(lowered.index), ResolveScalarOperand(binary->GetLhs(), inline_values), ResolveScalarOperand(binary->GetRhs(), inline_values)}); instr.SetCondCode(LowerFloatCond(inst->GetOpcode())); current_block_->Append(std::move(instr)); (*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index); break; } case ir::Opcode::Zext: { auto* zext = static_cast(inst.get()); auto lowered = NewVRegValue(LowerType(zext->GetType())); current_block_->Append(MachineInstr::Opcode::ZExt, {MachineOperand::VReg(lowered.index), ResolveScalarOperand(zext->GetValue(), inline_values)}); (*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index); break; } case ir::Opcode::IToF: { auto* unary = static_cast(inst.get()); auto lowered = NewVRegValue(ValueType::F32); current_block_->Append(MachineInstr::Opcode::ItoF, {MachineOperand::VReg(lowered.index), ResolveScalarOperand(unary->GetOprd(), inline_values)}); (*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index); break; } case ir::Opcode::FtoI: { auto* unary = static_cast(inst.get()); auto lowered = NewVRegValue(ValueType::I32); current_block_->Append(MachineInstr::Opcode::FtoI, {MachineOperand::VReg(lowered.index), ResolveScalarOperand(unary->GetOprd(), inline_values)}); (*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index); break; } case ir::Opcode::Call: { auto* nested_call = static_cast(inst.get()); auto* nested_callee = nested_call->GetCallee(); if (nested_callee == nullptr || nested_callee == current_ir_function_) { return false; } MachineOperand math_idiom_result; if (TryEmitMathIdiomCall(nested_call, inline_values, &math_idiom_result)) { (*inline_values)[inst.get()] = math_idiom_result; break; } if (CanInlineDirectCall(*nested_callee)) { MachineOperand nested_return_operand; bool nested_has_return = false; OperandMap nested_values; const auto& nested_args = nested_callee->GetArguments(); const auto& nested_call_args = nested_call->GetArguments(); if (nested_args.size() != nested_call_args.size()) { return false; } for (size_t i = 0; i < nested_call_args.size(); ++i) { nested_values[nested_args[i].get()] = ResolveScalarOperand(nested_call_args[i], inline_values); } if (!TryInlineFunctionBody(*nested_callee, &nested_values, &nested_return_operand, &nested_has_return, inline_depth + 1)) { return false; } if (!nested_call->GetType()->IsVoid()) { if (!nested_has_return) { throw std::runtime_error( FormatError("mir", "inlined nested call is missing return value")); } auto nested_value = MaterializeOperandAsValue(nested_return_operand, LowerType(nested_call->GetType())); (*inline_values)[inst.get()] = MachineOperand::VReg(nested_value.index); } break; } std::vector operands; if (!nested_call->GetType()->IsVoid()) { auto lowered = NewVRegValue(LowerType(nested_call->GetType())); operands.push_back(MachineOperand::VReg(lowered.index)); (*inline_values)[inst.get()] = MachineOperand::VReg(lowered.index); } std::vector arg_types; for (auto* arg : nested_call->GetArguments()) { operands.push_back(ResolveScalarOperand(arg, inline_values)); arg_types.push_back(LowerType(arg->GetType())); } MachineInstr instr(MachineInstr::Opcode::Call, std::move(operands)); instr.SetCallInfo(nested_callee->GetName(), std::move(arg_types), LowerType(nested_call->GetType())); current_block_->Append(std::move(instr)); break; } case ir::Opcode::Return: { auto* ret = static_cast(inst.get()); if (ret->HasReturnValue()) { *return_operand = ResolveScalarOperand(ret->GetReturnValue(), inline_values); *has_return = true; } break; } default: return false; } } return true; } bool TryEmitMathIdiomCall(ir::CallInst* call, const OperandMap* inline_values, MachineOperand* result_operand) { auto* callee = call == nullptr ? nullptr : call->GetCallee(); const ir::GlobalValue* sqrt_state = nullptr; if (callee == nullptr || call->GetType() == nullptr || !call->GetType()->IsFloat() || call->GetArguments().size() != 1 || !call->GetArguments()[0]->GetType()->IsFloat() || !ir::mathidiom::IsPrivateToleranceNewtonSqrt(*callee, &sqrt_state)) { return false; } auto lowered = NewVRegValue(ValueType::F32); MachineInstr instr(MachineInstr::Opcode::FSqrt, {MachineOperand::VReg(lowered.index), ResolveScalarOperand(call->GetArguments()[0], inline_values)}); if (sqrt_state != nullptr) { AddressExpr address; address.base_kind = AddrBaseKind::Global; address.symbol = sqrt_state->GetName(); instr.SetAddress(std::move(address)); } current_block_->Append(std::move(instr)); if (result_operand != nullptr) { *result_operand = MachineOperand::VReg(lowered.index); } else { values_[call] = lowered; } return true; } bool TryInlineDirectCall(ir::CallInst* call) { auto* callee = call->GetCallee(); if (callee == nullptr || callee == current_ir_function_ || !CanInlineDirectCall(*callee)) { return false; } const auto& callee_args = callee->GetArguments(); const auto& call_args = call->GetArguments(); if (callee_args.size() != call_args.size()) { return false; } OperandMap inline_values; for (size_t i = 0; i < call_args.size(); ++i) { inline_values[callee_args[i].get()] = ResolveScalarOperand(call_args[i], nullptr); } MachineOperand return_operand; bool has_return = false; if (!TryInlineFunctionBody(*callee, &inline_values, &return_operand, &has_return, 0)) { return false; } if (!call->GetType()->IsVoid()) { if (!has_return) { throw std::runtime_error(FormatError("mir", "inlined call is missing return value")); } values_[call] = MaterializeOperandAsValue(return_operand, LowerType(call->GetType())); } return true; } void LowerInstruction(ir::Instruction& inst) { switch (inst.GetOpcode()) { case ir::Opcode::Alloca: { auto* alloca_inst = static_cast(&inst); const auto allocated_type = alloca_inst->GetAllocatedType(); const int object = current_function_->CreateStackObject( allocated_type->GetSize(), GetIRTypeAlign(allocated_type), StackObjectKind::Local, inst.GetName()); if (ShouldMaterializeAllocaBase(allocated_type)) { auto lowered = NewVRegValue(ValueType::Ptr); MachineInstr lea(MachineInstr::Opcode::Lea, {MachineOperand::VReg(lowered.index)}); AddressExpr address; address.base_kind = AddrBaseKind::FrameObject; address.base_index = object; lea.SetAddress(std::move(address)); current_block_->Append(std::move(lea)); values_[&inst] = lowered; } else { values_[&inst] = {LoweredKind::StackObject, ValueType::Ptr, object, ""}; } return; } case ir::Opcode::Load: { auto* load = static_cast(&inst); auto lowered = NewVRegValue(LowerType(load->GetType())); MachineInstr instr(MachineInstr::Opcode::Load, {MachineOperand::VReg(lowered.index)}); instr.SetAddress(LowerAddress(load->GetPtr())); current_block_->Append(std::move(instr)); values_[&inst] = lowered; return; } case ir::Opcode::Store: { auto* store = static_cast(&inst); MachineInstr instr(MachineInstr::Opcode::Store, {LowerScalarOperand(store->GetValue())}); instr.SetValueType(LowerType(store->GetValue()->GetType())); instr.SetAddress(LowerAddress(store->GetPtr())); current_block_->Append(std::move(instr)); return; } case ir::Opcode::Add: case ir::Opcode::Sub: case ir::Opcode::Mul: case ir::Opcode::Div: case ir::Opcode::Rem: case ir::Opcode::And: case ir::Opcode::Or: case ir::Opcode::Xor: case ir::Opcode::Shl: case ir::Opcode::AShr: case ir::Opcode::LShr: case ir::Opcode::FAdd: case ir::Opcode::FSub: case ir::Opcode::FMul: case ir::Opcode::FDiv: { auto* binary = static_cast(&inst); auto lowered = NewVRegValue(LowerType(binary->GetType())); current_block_->Append(LowerBinaryOpcode(inst.GetOpcode()), {MachineOperand::VReg(lowered.index), LowerScalarOperand(binary->GetLhs()), LowerScalarOperand(binary->GetRhs())}); values_[&inst] = lowered; return; } case ir::Opcode::FNeg: { auto* unary = static_cast(&inst); auto lowered = NewVRegValue(ValueType::F32); current_block_->Append(MachineInstr::Opcode::FNeg, {MachineOperand::VReg(lowered.index), LowerScalarOperand(unary->GetOprd())}); values_[&inst] = lowered; return; } case ir::Opcode::ICmpEQ: case ir::Opcode::ICmpNE: case ir::Opcode::ICmpLT: case ir::Opcode::ICmpGT: case ir::Opcode::ICmpLE: case ir::Opcode::ICmpGE: { auto* binary = static_cast(&inst); auto lowered = NewVRegValue(ValueType::I1); MachineInstr instr(MachineInstr::Opcode::ICmp, {MachineOperand::VReg(lowered.index), LowerScalarOperand(binary->GetLhs()), LowerScalarOperand(binary->GetRhs())}); instr.SetCondCode(LowerIntCond(inst.GetOpcode())); current_block_->Append(std::move(instr)); values_[&inst] = lowered; return; } case ir::Opcode::FCmpEQ: case ir::Opcode::FCmpNE: case ir::Opcode::FCmpLT: case ir::Opcode::FCmpGT: case ir::Opcode::FCmpLE: case ir::Opcode::FCmpGE: { auto* binary = static_cast(&inst); auto lowered = NewVRegValue(ValueType::I1); MachineInstr instr(MachineInstr::Opcode::FCmp, {MachineOperand::VReg(lowered.index), LowerScalarOperand(binary->GetLhs()), LowerScalarOperand(binary->GetRhs())}); instr.SetCondCode(LowerFloatCond(inst.GetOpcode())); current_block_->Append(std::move(instr)); values_[&inst] = lowered; return; } case ir::Opcode::Zext: { auto* zext = static_cast(&inst); auto lowered = NewVRegValue(LowerType(zext->GetType())); current_block_->Append(MachineInstr::Opcode::ZExt, {MachineOperand::VReg(lowered.index), LowerScalarOperand(zext->GetValue())}); values_[&inst] = lowered; return; } case ir::Opcode::IToF: { auto* unary = static_cast(&inst); auto lowered = NewVRegValue(ValueType::F32); current_block_->Append(MachineInstr::Opcode::ItoF, {MachineOperand::VReg(lowered.index), LowerScalarOperand(unary->GetOprd())}); values_[&inst] = lowered; return; } case ir::Opcode::FtoI: { auto* unary = static_cast(&inst); auto lowered = NewVRegValue(ValueType::I32); current_block_->Append(MachineInstr::Opcode::FtoI, {MachineOperand::VReg(lowered.index), LowerScalarOperand(unary->GetOprd())}); values_[&inst] = lowered; return; } case ir::Opcode::GetElementPtr: { auto* gep = static_cast(&inst); auto lowered = NewVRegValue(ValueType::Ptr); AddressExpr address = LowerAddress(gep->GetPointer()); auto current_type = gep->GetSourceType(); for (size_t i = 0; i < gep->GetNumIndices(); ++i) { auto* index = gep->GetIndex(i); const std::int64_t stride = current_type ? current_type->GetSize() : 0; if (auto* ci = ir::dyncast(index)) { address.const_offset += static_cast(ci->GetValue()) * stride; } else if (auto* cb = ir::dyncast(index)) { address.const_offset += static_cast(cb->GetValue() ? 1 : 0) * stride; } else { address.scaled_vregs.push_back({LowerScalarOperand(index).GetVReg(), stride}); } if (current_type && current_type->IsArray()) { current_type = current_type->GetElementType(); } } MachineInstr instr(MachineInstr::Opcode::Lea, {MachineOperand::VReg(lowered.index)}); instr.SetAddress(std::move(address)); current_block_->Append(std::move(instr)); values_[&inst] = lowered; return; } case ir::Opcode::Call: { auto* call = static_cast(&inst); if (TryEmitMathIdiomCall(call, nullptr, nullptr)) { return; } if (TryInlineDirectCall(call)) { return; } std::vector operands; if (!call->GetType()->IsVoid()) { auto lowered = NewVRegValue(LowerType(call->GetType())); operands.push_back(MachineOperand::VReg(lowered.index)); values_[&inst] = lowered; } std::vector arg_types; for (auto* arg : call->GetArguments()) { operands.push_back(LowerScalarOperand(arg)); arg_types.push_back(LowerType(arg->GetType())); } MachineInstr instr(MachineInstr::Opcode::Call, std::move(operands)); instr.SetCallInfo(call->GetCallee()->GetName(), std::move(arg_types), LowerType(call->GetType())); current_block_->Append(std::move(instr)); return; } case ir::Opcode::Br: { auto* br = static_cast(&inst); current_block_->Append(MachineInstr::Opcode::Br, {MachineOperand::Block(blocks_.at(br->GetDest())->GetName())}); return; } case ir::Opcode::CondBr: { auto* br = static_cast(&inst); current_block_->Append(MachineInstr::Opcode::CondBr, {LowerScalarOperand(br->GetCondition()), MachineOperand::Block(blocks_.at(br->GetThenBlock())->GetName()), MachineOperand::Block(blocks_.at(br->GetElseBlock())->GetName())}); return; } case ir::Opcode::Return: { auto* ret = static_cast(&inst); if (ret->HasReturnValue()) { MachineInstr instr(MachineInstr::Opcode::Ret, {LowerScalarOperand(ret->GetReturnValue())}); instr.SetValueType(LowerType(ret->GetReturnValue()->GetType())); current_block_->Append(std::move(instr)); } else { current_block_->Append(MachineInstr::Opcode::Ret); } return; } case ir::Opcode::Memset: { auto* memset_inst = static_cast(&inst); MachineInstr instr(MachineInstr::Opcode::Memset, {LowerScalarOperand(memset_inst->GetValue()), LowerScalarOperand(memset_inst->GetLength())}); instr.SetAddress(LowerAddress(memset_inst->GetDest())); current_block_->Append(std::move(instr)); return; } case ir::Opcode::Unreachable: current_block_->Append(MachineInstr::Opcode::Unreachable); return; case ir::Opcode::Phi: return; case ir::Opcode::FRem: case ir::Opcode::Neg: case ir::Opcode::Not: throw std::runtime_error( FormatError("mir", "unsupported instruction in backend lowering")); } throw std::runtime_error(FormatError("mir", "unsupported IR opcode in backend lowering")); } void LowerFunction(ir::Function& function) { values_.clear(); blocks_.clear(); std::vector param_types; for (const auto& type : function.GetParamTypes()) { param_types.push_back(LowerType(type)); } auto machine_function = std::make_unique( function.GetName(), LowerType(function.GetReturnType()), std::move(param_types)); current_ir_function_ = &function; current_function_ = machine_function.get(); const auto ordered_blocks = CollectLoweringOrder(function); for (const auto& block : function.GetBlocks()) { blocks_[block.get()] = current_function_->CreateBlock(block->GetName()); } if (!function.GetBlocks().empty()) { auto* entry = blocks_.at(function.GetBlocks().front().get()); for (const auto& argument : function.GetArguments()) { auto lowered = NewVRegValue(LowerType(argument->GetType())); entry->Append(MachineInstr::Opcode::Arg, {MachineOperand::VReg(lowered.index), MachineOperand::Imm(static_cast(argument->GetIndex()))}); values_[argument.get()] = lowered; } } PreparePhiResults(function); for (auto* block : ordered_blocks) { current_block_ = blocks_.at(block); for (const auto& inst : block->GetInstructions()) { LowerInstruction(*inst); } } EmitPhiCopies(function); machine_module_->AddFunction(std::move(machine_function)); current_ir_function_ = nullptr; current_function_ = nullptr; current_block_ = nullptr; } const ir::Module& module_; std::unique_ptr machine_module_; ir::Function* current_ir_function_ = nullptr; MachineFunction* current_function_ = nullptr; MachineBasicBlock* current_block_ = nullptr; std::unordered_map values_; std::unordered_map blocks_; }; } // namespace std::unique_ptr LowerToMIR(const ir::Module& module) { DefaultContext(); return Lowerer(module).Run(); } } // namespace mir