From d0ecc5c2e939a767f74684fe2c5bedef18b1386d Mon Sep 17 00:00:00 2001 From: mxr <> Date: Sat, 18 Apr 2026 23:54:34 +0800 Subject: [PATCH] =?UTF-8?q?(mir)=E4=BF=AE=E5=A4=8D=E4=BC=A0=E9=80=92?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E6=98=AF=E6=8C=87=E9=92=88=E6=97=B6=E5=AD=98?= =?UTF-8?q?=E5=9C=A8=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mir/Lowering.cpp | 187 ++++++++++++++++++++++++++++--------------- 1 file changed, 122 insertions(+), 65 deletions(-) diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index a509514..1c54a86 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -7,7 +7,7 @@ #include "ir/IR.h" #include "utils/Log.h" -#define DEBUG_Lower +//#define DEBUG_Lower #ifdef DEBUG_Lower #include @@ -208,28 +208,16 @@ void EmitValueToReg(const ir::Value* value, PhysReg target, // ========== 处理全局变量 ========== if (auto* global = dynamic_cast(value)) { - // 加载全局变量的地址到目标寄存器 - // 注意:地址计算需要使用 64 位寄存器 - PhysReg addrReg = target; - bool needMove = false; - - // 如果目标是 32 位寄存器,改用 X8 临时寄存器 - if (target == PhysReg::W0 || target == PhysReg::W1 || target == PhysReg::W2 || - target == PhysReg::W3 || target == PhysReg::W4 || target == PhysReg::W5 || - target == PhysReg::W6 || target == PhysReg::W7 || target == PhysReg::W8 || - target == PhysReg::W9) { - addrReg = PhysReg::X8; - needMove = true; - } - - // 使用 ADRP + ADD 加载全局变量地址 - block.Append(Opcode::Adrp, {Operand::Reg(addrReg), Operand::Label(global->GetName())}); - block.Append(Opcode::AddLabel, {Operand::Reg(addrReg), Operand::Reg(addrReg), Operand::Label(global->GetName())}); - - // 如果需要,将地址移动到目标寄存器 - if (needMove) { - block.Append(Opcode::MovReg, {Operand::Reg(target), Operand::Reg(addrReg)}); + // 如果目标是 32 位寄存器,升级为对应的 64 位寄存器 + PhysReg addrTarget = target; + if (target >= PhysReg::W0 && target <= PhysReg::W30) { + // 映射 Wn → Xn + addrTarget = static_cast( + static_cast(target) - static_cast(PhysReg::W0) + static_cast(PhysReg::X0)); } + // 现在 addrTarget 一定是 64 位寄存器 + block.Append(Opcode::Adrp, {Operand::Reg(addrTarget), Operand::Label(global->GetName())}); + block.Append(Opcode::AddLabel, {Operand::Reg(addrTarget), Operand::Reg(addrTarget), Operand::Label(global->GetName())}); return; } @@ -254,8 +242,19 @@ void EmitValueToReg(const ir::Value* value, PhysReg target, FormatError("mir", "找不到值对应的栈槽: " + valueName)); } - block.Append(Opcode::LoadStack, - {Operand::Reg(target), Operand::FrameIndex(it->second)}); + PhysReg actualTarget = target; + const ir::Type* ty = value->GetType().get(); + bool isPointer = ty->IsPtrInt32() || ty->IsPtrFloat() || ty->IsPtrInt1() + || ty->IsArray(); // 数组类型在地址上下文中视为指针 + + // 若非指针类型且目标是 64 位寄存器,降级为对应的 32 位寄存器(自动零扩展) + if (!isPointer && target >= PhysReg::X0 && target <= PhysReg::X30) { + actualTarget = static_cast( + static_cast(target) - static_cast(PhysReg::X0) + static_cast(PhysReg::W0)); +} + +block.Append(Opcode::LoadStack, + {Operand::Reg(actualTarget), Operand::FrameIndex(it->second)}); } void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, @@ -363,53 +362,85 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, } case ir::Opcode::Add: { auto& bin = static_cast(inst); - int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); - EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block, function); - EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block, function); - block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W9)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + const ir::Type* lhsTy = bin.GetLhs()->GetType().get(); + const ir::Type* rhsTy = bin.GetRhs()->GetType().get(); + + // 指针判断:指令结果类型是指针,或者任一操作数是指针(指针算术) + bool isPointer = (lhsTy->IsPtrInt32() || lhsTy->IsPtrFloat() || lhsTy->IsPtrInt1() || lhsTy->IsArray()) || + (rhsTy->IsPtrInt32() || rhsTy->IsPtrFloat() || rhsTy->IsPtrInt1() || rhsTy->IsArray()) || + inst.GetType()->IsPtrInt32() || inst.GetType()->IsPtrFloat() || inst.GetType()->IsPtrInt1() || inst.GetType()->IsArray(); + int slotSize = isPointer ? 8 : 4; + PhysReg lhsReg = isPointer ? PhysReg::X8 : PhysReg::W8; + PhysReg rhsReg = isPointer ? PhysReg::X9 : PhysReg::W9; + PhysReg dstReg = isPointer ? PhysReg::X8 : PhysReg::W8; + + int dst_slot = function.CreateFrameIndex(slotSize); // 使用计算出的 slotSize + EmitValueToReg(bin.GetLhs(), lhsReg, slots, block, function); + EmitValueToReg(bin.GetRhs(), rhsReg, slots, block, function); + block.Append(Opcode::AddRR, {Operand::Reg(dstReg), Operand::Reg(lhsReg), Operand::Reg(rhsReg)}); + block.Append(Opcode::StoreStack, {Operand::Reg(dstReg), Operand::FrameIndex(dst_slot)}); slots.emplace(&inst, dst_slot); return; } case ir::Opcode::Sub: { auto& bin = static_cast(inst); - int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); // 分配结果栈槽 - EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block, function); - EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block, function); - block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W9)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + const ir::Type* lhsTy = bin.GetLhs()->GetType().get(); + const ir::Type* rhsTy = bin.GetRhs()->GetType().get(); + + // 指针判断:指令结果类型是指针,或者任一操作数是指针(指针算术) + bool isPointer = (lhsTy->IsPtrInt32() || lhsTy->IsPtrFloat() || lhsTy->IsPtrInt1() || lhsTy->IsArray()) || + (rhsTy->IsPtrInt32() || rhsTy->IsPtrFloat() || rhsTy->IsPtrInt1() || rhsTy->IsArray()) || + inst.GetType()->IsPtrInt32() || inst.GetType()->IsPtrFloat() || inst.GetType()->IsPtrInt1() || inst.GetType()->IsArray(); + int slotSize = isPointer ? 8 : 4; + PhysReg lhsReg = isPointer ? PhysReg::X8 : PhysReg::W8; + PhysReg rhsReg = isPointer ? PhysReg::X9 : PhysReg::W9; + PhysReg dstReg = isPointer ? PhysReg::X8 : PhysReg::W8; + + int dst_slot = function.CreateFrameIndex(slotSize); // 使用计算出的 slotSize + EmitValueToReg(bin.GetLhs(), lhsReg, slots, block, function); + EmitValueToReg(bin.GetRhs(), rhsReg, slots, block, function); + block.Append(Opcode::SubRR, {Operand::Reg(dstReg), Operand::Reg(lhsReg), Operand::Reg(rhsReg)}); + block.Append(Opcode::StoreStack, {Operand::Reg(dstReg), Operand::FrameIndex(dst_slot)}); slots.emplace(&inst, dst_slot); return; } case ir::Opcode::Mul: { auto& bin = static_cast(inst); + const ir::Type* ty = inst.GetType().get(); + bool isPointer = ty->IsPtrInt32() || ty->IsPtrFloat() || ty->IsPtrInt1(); + int slotSize = isPointer ? 8 : 4; + PhysReg lhsReg = isPointer ? PhysReg::X8 : PhysReg::W8; + PhysReg rhsReg = isPointer ? PhysReg::X9 : PhysReg::W9; + PhysReg dstReg = isPointer ? PhysReg::X8 : PhysReg::W8; // 可复用 lhsReg + int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); // 分配结果栈槽 - EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block, function); - EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block, function); - block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W9)}); + EmitValueToReg(bin.GetLhs(), lhsReg, slots, block, function); + EmitValueToReg(bin.GetRhs(), rhsReg, slots, block, function); + block.Append(Opcode::MulRR, {Operand::Reg(dstReg), + Operand::Reg(lhsReg), + Operand::Reg(rhsReg)}); block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + {Operand::Reg(dstReg), Operand::FrameIndex(dst_slot)}); slots.emplace(&inst, dst_slot); return; } case ir::Opcode::Div: { auto& bin = static_cast(inst); + const ir::Type* ty = inst.GetType().get(); + bool isPointer = ty->IsPtrInt32() || ty->IsPtrFloat() || ty->IsPtrInt1(); + int slotSize = isPointer ? 8 : 4; + PhysReg lhsReg = isPointer ? PhysReg::X8 : PhysReg::W8; + PhysReg rhsReg = isPointer ? PhysReg::X9 : PhysReg::W9; + PhysReg dstReg = isPointer ? PhysReg::X8 : PhysReg::W8; // 可复用 lhsReg + int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); // 分配结果栈槽 - EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block, function); - EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block, function); - block.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W9)}); + EmitValueToReg(bin.GetLhs(), lhsReg, slots, block, function); + EmitValueToReg(bin.GetRhs(), rhsReg, slots, block, function); + block.Append(Opcode::SDivRR, {Operand::Reg(dstReg), + Operand::Reg(lhsReg), + Operand::Reg(rhsReg)}); block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + {Operand::Reg(dstReg), Operand::FrameIndex(dst_slot)}); slots.emplace(&inst, dst_slot); return; } @@ -417,7 +448,16 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, auto& ret = static_cast(inst); const ir::Value* retVal = ret.GetValue(); if (retVal != nullptr) { - EmitValueToReg(retVal, PhysReg::W0, slots, block, function); + const ir::Type* retType = retVal->GetType().get(); + PhysReg retReg = PhysReg::W0; // 默认整数返回值 + if (retType->IsFloat()) { + retReg = PhysReg::S0; + } else if (retType->IsPtrInt32() || retType->IsPtrFloat() || retType->IsPtrInt1()) { + retReg = PhysReg::X0; + } else { + retReg = PhysReg::W0; + } + EmitValueToReg(retVal, retReg, slots, block, function); } block.Append(Opcode::Ret); return; @@ -589,8 +629,13 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, PhysReg reg = static_cast(static_cast(PhysReg::S0) + fpArgCount); EmitValueToReg(arg, reg, slots, block, function); fpArgCount++; + } else if (argType->IsPtrInt32() || argType->IsPtrFloat() || argType->IsPtrInt1()) { + // 指针参数 → X 寄存器(占用一个整数参数槽) + PhysReg reg = static_cast(static_cast(PhysReg::X0) + intArgCount); + EmitValueToReg(arg, reg, slots, block, function); + intArgCount++; } else { - // 整数参数 + // 普通整数 → W 寄存器 PhysReg reg = static_cast(static_cast(PhysReg::W0) + intArgCount); EmitValueToReg(arg, reg, slots, block, function); intArgCount++; @@ -602,13 +647,17 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, block.Append(Opcode::Call, {Operand::Label(calleeName)}); // 保存返回值 if (dst_slot != -1) { - if (inst.GetType()->IsFloat()) { - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); + const ir::Type* retType = inst.GetType().get(); + PhysReg srcReg = PhysReg::W0; + if (retType->IsFloat()) { + srcReg = PhysReg::S0; + } else if (retType->IsPtrInt32() || retType->IsPtrFloat() || retType->IsPtrInt1()) { + srcReg = PhysReg::X0; } else { - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W0), Operand::FrameIndex(dst_slot)}); + srcReg = PhysReg::W0; } + block.Append(Opcode::StoreStack, + {Operand::Reg(srcReg), Operand::FrameIndex(dst_slot)}); slots.emplace(&inst, dst_slot); } return; @@ -765,13 +814,13 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, } // 加载当前索引到 x10 - EmitValueToReg(indices[idx_pos], PhysReg::W10, slots, block, function); + EmitValueToReg(indices[idx_pos], PhysReg::X10, slots, block, function); // 乘以步长 - block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W11), Operand::Imm(strides[i])}); - block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W10), - Operand::Reg(PhysReg::W10), - Operand::Reg(PhysReg::W11)}); + block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X11), Operand::Imm(strides[i])}); + block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X10), + Operand::Reg(PhysReg::X10), + Operand::Reg(PhysReg::X11)}); // 累加到偏移量 block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9), @@ -999,6 +1048,7 @@ std::unique_ptr LowerFunction(const ir::Function& func) { const ir::Value* arg; int slot; bool isFloat; + bool isPointer; }; std::vector paramInfos; @@ -1007,7 +1057,8 @@ std::unique_ptr LowerFunction(const ir::Function& func) { int slot = machine_func->CreateFrameIndex(GetTypeSize(arg->GetType().get())); slots.emplace(arg.get(), slot); bool isFloat = arg->GetType()->IsFloat(); - paramInfos.push_back({arg.get(), slot, isFloat}); + bool isPointer = arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat() || arg->GetType()->IsPtrInt1(); + paramInfos.push_back({arg.get(), slot, isFloat, isPointer}); } // IR 基本块到 MIR 基本块的映射 @@ -1037,6 +1088,12 @@ std::unique_ptr LowerFunction(const ir::Function& func) { {Operand::Reg(reg), Operand::FrameIndex(param.slot)}); } fpArgIdx++; + } else if (param.isPointer) { + if (intArgIdx < 8) { + PhysReg reg = static_cast(static_cast(PhysReg::X0) + intArgIdx); + entryBB->Append(Opcode::StoreStack, {Operand::Reg(reg), Operand::FrameIndex(param.slot)}); + } + intArgIdx++; } else { if (intArgIdx < 8) { PhysReg reg = static_cast(static_cast(PhysReg::W0) + intArgIdx);