完成二三四阶段,解决支持多函数问题等

feature/mir
ftt 2 days ago
parent 6e804e2091
commit 653c091993

@ -134,26 +134,29 @@ enum class Opcode {
// ========== 操作数类 ==========
class Operand {
public:
enum class Kind { Reg, Imm, FrameIndex, Cond };
enum class Kind { Reg, Imm, FrameIndex, Cond, Label };
static Operand Reg(PhysReg reg);
static Operand Imm(int value);
static Operand FrameIndex(int index);
static Operand Cond(CondCode cc);
static Operand Label(const std::string& label);
Kind GetKind() const { return kind_; }
PhysReg GetReg() const { return reg_; }
int GetImm() const { return imm_; }
int GetFrameIndex() const { return imm_; }
CondCode GetCondCode() const { return cc_; }
const std::string& GetLabel() const { return label_; }
private:
Operand(Kind kind, PhysReg reg, int imm, CondCode cc);
Operand(Kind kind, PhysReg reg, int imm, CondCode cc, const std::string& label);
Kind kind_;
PhysReg reg_;
int imm_;
CondCode cc_;
std::string label_;
};
// ========== MIR 指令类 ==========
@ -226,6 +229,7 @@ class MachineFunction {
int CreateFrameIndex(int size = 4);
FrameSlot& GetFrameSlot(int index);
const FrameSlot& GetFrameSlot(int index) const;
std::vector<FrameSlot>& GetFrameSlots() { return frame_slots_; }
const std::vector<FrameSlot>& GetFrameSlots() const { return frame_slots_; }
// 栈帧大小
@ -240,10 +244,56 @@ class MachineFunction {
int frame_size_ = 0;
};
// ========== MIR 模块 ==========
class MachineModule {
public:
MachineModule() = default;
// 添加 MachineFunction
void AddFunction(std::unique_ptr<MachineFunction> func) {
functions_.push_back(std::move(func));
}
// 获取所有函数
const std::vector<std::unique_ptr<MachineFunction>>& GetFunctions() const {
return functions_;
}
std::vector<std::unique_ptr<MachineFunction>>& GetFunctions() {
return functions_;
}
// 根据名称查找函数
MachineFunction* GetFunction(const std::string& name) {
for (auto& func : functions_) {
if (func->GetName() == name) {
return func.get();
}
}
return nullptr;
}
const MachineFunction* GetFunction(const std::string& name) const {
for (const auto& func : functions_) {
if (func->GetName() == name) {
return func.get();
}
}
return nullptr;
}
private:
std::vector<std::unique_ptr<MachineFunction>> functions_;
};
// ========== 后端流程函数 ==========
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
/* std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineFunction& function);
void RunFrameLowering(MachineFunction& function);
void PrintAsm(const MachineFunction& function, std::ostream& os);
void PrintAsm(const MachineFunction& function, std::ostream& os); */
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineModule& module);
void RunFrameLowering(MachineModule& module);
void PrintAsm(const MachineModule& module, std::ostream& os);
} // namespace mir

@ -46,13 +46,17 @@ int main(int argc, char** argv) {
}
if (opts.emit_asm) {
auto machine_func = mir::LowerToMIR(*module);
mir::RunRegAlloc(*machine_func);
mir::RunFrameLowering(*machine_func);
//auto machine_func = mir::LowerToMIR(*module);
auto machine_module = mir::LowerToMIR(*module);
//mir::RunRegAlloc(*machine_func);
mir::RunRegAlloc(*machine_module);
//mir::RunFrameLowering(*machine_func);
mir::RunFrameLowering(*machine_module);
if (need_blank_line) {
std::cout << "\n";
}
mir::PrintAsm(*machine_func, std::cout);
//mir::PrintAsm(*machine_func, std::cout);
mir::PrintAsm(*machine_module, std::cout);
}
#else
if (opts.emit_ir || opts.emit_asm) {

@ -24,188 +24,284 @@ void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
void PrintStackPairAccess(std::ostream& os, const char* mnemonic, PhysReg reg0, PhysReg reg1,
int offset) {
os << " " << mnemonic << " " << PhysRegName(reg0) << " " << PhysRegName(reg1) << ", [x29, #" << offset
os << " " << mnemonic << " " << PhysRegName(reg0) << ", " << PhysRegName(reg1) << ", [x29, #" << offset
<< "]\n";
}
} // namespace
// 打印单个操作数
void PrintOperand(std::ostream& os, const Operand& op) {
switch (op.GetKind()) {
case Operand::Kind::Reg:
os << PhysRegName(op.GetReg());
break;
case Operand::Kind::Imm:
os << "#" << op.GetImm();
break;
case Operand::Kind::FrameIndex:
os << "[sp, #" << op.GetFrameIndex() << "]";
break;
case Operand::Kind::Cond:
os << CondCodeName(op.GetCondCode());
break;
case Operand::Kind::Label:
os << op.GetLabel();
break;
}
}
// 打印单条指令
void PrintInstruction(std::ostream& os, const MachineInstr& instr,
const MachineFunction& function) {
const auto& ops = instr.GetOperands();
switch (instr.GetOpcode()) {
case Opcode::Prologue:
// Prologue 在 RunFrameLowering 中已经生成具体指令,这里不需要输出
break;
case Opcode::Epilogue:
// Epilogue 在 RunFrameLowering 中已经生成具体指令,这里不需要输出
break;
case Opcode::MovImm:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::MovReg:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::LoadStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::StoreStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::StoreStackPair:
// stp x29, x30, [sp, #-16]!
os << " stp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", [sp";
if (ops.size() > 2 && ops.at(2).GetKind() == Operand::Kind::Imm) {
int offset = ops.at(2).GetImm();
os << ", #" << offset;
}
os << "]!\n"; // 注意添加 ! 表示 pre-index
break;
case Opcode::LoadStackPair:
// ldp x29, x30, [sp], #16
os << " ldp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", [sp]";
if (ops.size() > 2 && ops.at(2).GetKind() == Operand::Kind::Imm) {
int offset = ops.at(2).GetImm();
os << ", #" << offset;
}
os << "\n";
break;
case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::AddRI:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #"
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::SubRR:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::SubRI:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #"
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::MulRR:
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::SDivRR:
os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::UDivRR:
os << " udiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FAddRR:
os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FSubRR:
os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FMulRR:
os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FDivRR:
os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::CmpRR:
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::CmpRI:
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::FCmpRR:
os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::SIToFP:
os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FPToSI:
os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::ZExt:
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #1\n";
break;
case Opcode::AndRR:
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::OrRR:
os << " orr " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::EorRR:
os << " eor " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::LslRR:
os << " lsl " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::LsrRR:
os << " lsr " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::AsrRR:
os << " asr " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::B:
os << " b ";
PrintOperand(os, ops.at(0));
os << "\n";
break;
case Opcode::BCond:
os << " b.";
PrintOperand(os, ops.at(0));
os << " ";
PrintOperand(os, ops.at(1));
os << "\n";
break;
case Opcode::Call:
os << " bl ";
PrintOperand(os, ops.at(0));
os << "\n";
break;
//case Opcode::Ret:
// os << " ret\n";
// break;
case Opcode::Ret: {
// 输出函数尾声
int frameSize = function.GetFrameSize();
int alignedSize = (frameSize + 15) & ~15;
if (alignedSize > 0) {
os << " add sp, sp, #" << alignedSize << "\n";
}
os << " ldp x29, x30, [sp], #16\n";
os << " ret\n";
break;
}
case Opcode::Nop:
os << " nop\n";
break;
default:
os << " // unknown instruction\n";
break;
}
}
// 打印单个函数(单函数版本)
void PrintAsm(const MachineFunction& function, std::ostream& os) {
// 输出函数标签
os << ".text\n";
os << ".global " << function.GetName() << "\n";
os << ".type " << function.GetName() << ", %function\n";
os << function.GetName() << ":\n";
for (const auto& inst : function.GetEntry().GetInstructions()) {
const auto& ops = inst.GetOperands();
switch (inst.GetOpcode()) {
case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) {
os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
}
break;
case Opcode::Epilogue:
if (function.GetFrameSize() > 0) {
os << " add sp, sp, #" << function.GetFrameSize() << "\n";
}
os << " ldp x29, x30, [sp], #16\n";
break;
case Opcode::MovImm:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::MovReg:
os << " mov " << PhysRegName(ops.at(0).GetReg())
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::LoadStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::StoreStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::LoadStackPair: {
PrintStackPairAccess(os, "ldp", ops.at(0).GetReg(), ops.at(1).GetReg(), 16);
break;
}
case Opcode::StoreStackPair: {
PrintStackPairAccess(os, "stp", ops.at(0).GetReg(), ops.at(1).GetReg(), -16);
break;
}
case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::AddRI:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #"
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::SubRR:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::SubRI:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #"
<< ops.at(2).GetImm() << "\n";
break;
case Opcode::MulRR:
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::SDivRR:
os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::UDivRR:
os << " udiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FAddRR:
os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FSubRR:
os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FMulRR:
os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FDivRR:
os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::CmpRR:
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::CmpRI:
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::FCmpRR:
os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::SIToFP:
os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FPToSI:
os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::ZExt:
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #1\n";
break;
case Opcode::AndRR:
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::OrRR:
os << " orr " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::EorRR:
os << " eor " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::LslRR:
os << " lsl " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::LsrRR:
os << " lsr " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::AsrRR:
os << " asr " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Nop:
os << " nop \n";
break;
// TODO: 控制流
case Opcode::B:
break;
case Opcode::BCond:
break;
case Opcode::Call:
break;
case Opcode::Ret:
os << " ret\n";
break;
// 计算栈帧大小
int frameSize = function.GetFrameSize();
int alignedSize = (frameSize + 15) & ~15; // 16 字节对齐
// ========== 函数序言 ==========
os << " stp x29, x30, [sp, #-16]!\n"; // 保存 FP/LRsp -= 16
os << " mov x29, sp\n"; // 设置 FP
if (alignedSize > 0) {
os << " sub sp, sp, #" << alignedSize << "\n"; // 分配局部变量空间
}
// 输出每个基本块
const auto& blocks = function.GetBasicBlocks();
bool firstBlock = true;
for (const auto& bb : blocks) {
// 输出基本块标签(非第一个基本块)
if (!firstBlock) {
os << bb->GetName() << ":\n";
}
firstBlock = false;
// 输出基本块中的指令
for (const auto& inst : bb->GetInstructions()) {
PrintInstruction(os, inst, function);
}
}
os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n";
}
} // namespace
os << ".size " << function.GetName() << ", .-" << function.GetName()
<< "\n";
// 打印模块(模块版本)
void PrintAsm(const MachineModule& module, std::ostream& os) {
// 输出文件头
os << ".arch armv8-a\n";
os << ".text\n";
// 遍历所有函数,输出汇编
for (const auto& func : module.GetFunctions()) {
PrintAsm(*func, os);
os << "\n";
}
}
} // namespace mir
} // namespace mir

@ -14,7 +14,112 @@ int AlignTo(int value, int align) {
} // namespace
// 计算栈帧中所有栈槽的总大小
static int CalculateStackFrameSize(MachineFunction& function) {
int totalSize = 0;
// 计算所有栈槽的总大小
for (auto& slot : function.GetFrameSlots()) {
// 栈槽偏移从 -4, -8, -12 开始(负偏移)
int offset = -(totalSize + slot.size);
slot.offset = offset;
totalSize += slot.size;
}
// 返回局部变量总大小(不需要额外加 16因为 sp 已经减了 16
return totalSize;
}
// 单函数版本的栈帧布局(原有逻辑扩展)
void RunFrameLowering(MachineFunction& function) {
// 计算栈帧大小
int frameSize = CalculateStackFrameSize(function);
function.SetFrameSize(frameSize);
// 获取入口基本块
if (function.GetBasicBlocks().empty()) {
throw std::runtime_error(FormatError("framelowering", "函数没有基本块"));
}
MachineBasicBlock& entry = function.GetEntry();
// 获取指令列表
auto& instrs = entry.GetInstructions();
// 创建新的指令列表,先放序言,再放原有指令
std::vector<MachineInstr> newInstrs;
// ========== 函数序言 ==========
// 1. stp x29, x30, [sp, #-16]! (保存 FP 和 LR同时 sp -= 16)
/* newInstrs.emplace_back(Opcode::StoreStackPair,
std::vector<Operand>{Operand::Reg(PhysReg::X29),
Operand::Reg(PhysReg::X30),
Operand::Imm(-16)});
// 2. mov x29, sp (设置 FP)
newInstrs.emplace_back(Opcode::MovReg,
std::vector<Operand>{Operand::Reg(PhysReg::X29),
Operand::Reg(PhysReg::SP)});
// 3. sub sp, sp, #frameSize (分配局部变量空间)
if (frameSize > 0) {
int alignedSize = (frameSize + 15) & ~15; // 16 字节对齐
newInstrs.emplace_back(Opcode::SubRI,
std::vector<Operand>{Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::SP),
Operand::Imm(alignedSize)});
}
// 复制原有指令
for (auto& instr : instrs) {
newInstrs.push_back(std::move(instr));
}
// 替换指令列表
instrs = std::move(newInstrs); */
// ========== 处理尾声 ==========
/* for (auto& bb : function.GetBasicBlocks()) {
auto& instructions = bb->GetInstructions();
// 查找 Ret 指令
for (auto it = instructions.begin(); it != instructions.end(); ++it) {
if (it->GetOpcode() == Opcode::Ret) {
// 创建尾声指令
std::vector<MachineInstr> epilogue;
// 1. add sp, sp, #frameSize (释放局部变量空间)
if (frameSize > 0) {
int alignedSize = (frameSize + 15) & ~15;
epilogue.emplace_back(Opcode::AddRI,
std::vector<Operand>{Operand::Reg(PhysReg::SP),
Operand::Reg(PhysReg::SP),
Operand::Imm(alignedSize)});
}
// 2. ldp x29, x30, [sp], #16 (恢复 FP 和 LR同时 sp += 16)
epilogue.emplace_back(Opcode::LoadStackPair,
std::vector<Operand>{Operand::Reg(PhysReg::X29),
Operand::Reg(PhysReg::X30),
Operand::Imm(16)});
// 在 Ret 指令前插入尾声
instructions.insert(it, epilogue.begin(), epilogue.end());
break; // 每个函数只处理第一个 Ret实际应该处理所有
}
}
} */
}
// 模块版本的栈帧布局
void RunFrameLowering(MachineModule& module) {
// 对模块中的每个函数执行栈帧布局
for (auto& func : module.GetFunctions()) {
RunFrameLowering(*func);
}
}
/* void RunFrameLowering(MachineFunction& function) {
int cursor = 0;
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
@ -40,6 +145,6 @@ void RunFrameLowering(MachineFunction& function) {
lowered.push_back(inst);
}
insts = std::move(lowered);
}
} */
} // namespace mir

@ -12,13 +12,113 @@ namespace {
using ValueSlotMap = std::unordered_map<const ir::Value*, int>;
// 获取类型大小(字节)
int GetTypeSize(const ir::Type* type) {
if (!type) return 4;
size_t size = type->Size();
return size > 0 ? static_cast<int>(size) : 4;
}
// 将 IR 整数比较谓词转换为 ARMv8 条件码
CondCode IcmpToCondCode(ir::IcmpInst::Predicate pred) {
switch (pred) {
case ir::IcmpInst::Predicate::EQ: return CondCode::EQ;
case ir::IcmpInst::Predicate::NE: return CondCode::NE;
case ir::IcmpInst::Predicate::LT: return CondCode::LT;
case ir::IcmpInst::Predicate::GT: return CondCode::GT;
case ir::IcmpInst::Predicate::LE: return CondCode::LE;
case ir::IcmpInst::Predicate::GE: return CondCode::GE;
default: return CondCode::AL;
}
}
// 将 IR 浮点比较谓词转换为 ARMv8 条件码
CondCode FcmpToCondCode(ir::FcmpInst::Predicate pred, bool& isOrdered) {
isOrdered = true;
switch (pred) {
case ir::FcmpInst::Predicate::OEQ: return CondCode::EQ;
case ir::FcmpInst::Predicate::ONE: return CondCode::NE;
case ir::FcmpInst::Predicate::OLT: return CondCode::LT;
case ir::FcmpInst::Predicate::OGT: return CondCode::GT;
case ir::FcmpInst::Predicate::OLE: return CondCode::LE;
case ir::FcmpInst::Predicate::OGE: return CondCode::GE;
case ir::FcmpInst::Predicate::UEQ: isOrdered = false; return CondCode::EQ;
case ir::FcmpInst::Predicate::UNE: isOrdered = false; return CondCode::NE;
case ir::FcmpInst::Predicate::ULT: isOrdered = false; return CondCode::LT;
case ir::FcmpInst::Predicate::UGT: isOrdered = false; return CondCode::GT;
case ir::FcmpInst::Predicate::ULE: isOrdered = false; return CondCode::LE;
case ir::FcmpInst::Predicate::UGE: isOrdered = false; return CondCode::GE;
default: return CondCode::AL;
}
}
// 获取基本块的标签名(用于汇编输出)
std::string GetBlockLabel(const ir::BasicBlock* bb) {
if (!bb || !bb->GetParent()) {
return ".Lunknown";
}
// 格式:.L函数名_基本块名
std::string funcName = bb->GetParent()->GetName();
std::string blockName = bb->GetName();
// 如果基本块没有名字,使用地址作为标识
if (blockName.empty()) {
blockName = std::to_string(reinterpret_cast<uintptr_t>(bb));
}
return ".L" + funcName + "_" + blockName;
}
// 获取数组类型的维度信息
static const ir::ArrayType* GetArrayType(const ir::Type* type) {
if (type->IsArray()) {
return static_cast<const ir::ArrayType*>(type);
}
return nullptr;
}
static std::vector<int> GetArrayStrides(const ir::ArrayType* arrayType) {
std::vector<int> strides;
const std::vector<int>& dims = arrayType->GetDimensions();
int stride = 4; // 元素大小int/float 是 4 字节)
// 从最后一维向前计算步长
for (int i = dims.size() - 1; i >= 0; --i) {
strides.insert(strides.begin(), stride);
stride *= dims[i];
}
return strides;
}
void EmitValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) {
const ValueSlotMap& slots, MachineBasicBlock& block,
MachineFunction& function) {
// 处理整数常量
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(constant->GetValue())});
return;
}
// 处理浮点常量
if (auto* fconstant = dynamic_cast<const ir::ConstantFloat*>(value)) {
// 浮点常量需要先存储到栈槽,再加载到寄存器
// 因为 ARMv8 没有直接加载浮点立即数的指令
int slot = -1;
// 注意:这里需要找到或创建该浮点常量的栈槽
// 简单起见,可以每次都分配新栈槽
// 更好的做法是:在 Module 级别缓存浮点常量
throw std::runtime_error(
FormatError("mir", "浮点常量暂未实现"));
return;
}
// 处理零常量
if (dynamic_cast<const ir::ConstantZero*>(value) ||
dynamic_cast<const ir::ConstantAggregateZero*>(value)) {
// 零常量:直接加载 0
block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(0)});
return;
}
auto it = slots.find(value);
if (it == slots.end()) {
@ -31,22 +131,34 @@ void EmitValueToReg(const ir::Value* value, PhysReg target,
}
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
ValueSlotMap& slots) {
auto& block = function.GetEntry();
ValueSlotMap& slots, MachineBasicBlock& block,
std::unordered_map<const ir::BasicBlock*,
MachineBasicBlock*>& blockMap) {
//auto& block = function.GetEntry();
switch (inst.GetOpcode()) {
case ir::Opcode::Alloca: {
slots.emplace(&inst, function.CreateFrameIndex());
slots.emplace(&inst, function.CreateFrameIndex(GetTypeSize(inst.GetType().get())));
return;
}
case ir::Opcode::Store: {
auto& store = static_cast<const ir::StoreInst&>(inst);
auto dst = slots.find(store.GetPtr());
if (dst == slots.end()) {
throw std::runtime_error(
FormatError("mir", "暂不支持对非栈变量地址进行写入"));
//throw std::runtime_error(
// FormatError("mir", "暂不支持对非栈变量地址进行写入"));
// 对于非栈变量地址(如 GEP 结果),地址本身在栈槽中
// 需要先加载地址,然后存储值到该地址
// 先加载地址到 x8
EmitValueToReg(store.GetPtr(), PhysReg::X8, slots, block, function);
// 加载值到 w9
EmitValueToReg(store.GetValue(), PhysReg::W9, slots, block, function);
// 存储值到地址
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::X8)});
return;
}
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block);
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block, function);
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)});
return;
@ -55,10 +167,23 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
auto& load = static_cast<const ir::LoadInst&>(inst);
auto src = slots.find(load.GetPtr());
if (src == slots.end()) {
throw std::runtime_error(
FormatError("mir", "暂不支持对非栈变量地址进行读取"));
//throw std::runtime_error(
// FormatError("mir", "暂不支持对非栈变量地址进行读取"));
// 对于非栈变量地址(如 GEP 结果),地址本身在栈槽中
// 需要先加载地址,然后从该地址加载值
int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get()));
// 加载地址到 x8
EmitValueToReg(load.GetPtr(), PhysReg::X8, slots, block, function);
// 从地址加载值到 w9
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::X8)});
// 存储值到结果栈槽
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W9), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
int dst_slot = function.CreateFrameIndex();
int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); // 分配结果槽
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)});
block.Append(Opcode::StoreStack,
@ -68,9 +193,9 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
}
case ir::Opcode::Add: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get()));
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block, function);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block, function);
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
@ -79,46 +204,523 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Sub: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); // 分配结果栈槽
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block, function);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block, function);
block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Mul: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); // 分配结果栈槽
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block, function);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block, function);
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Div: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); // 分配结果栈槽
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block, function);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block, function);
block.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block);
EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block, function);
block.Append(Opcode::Ret);
return;
}
case ir::Opcode::Sub:
case ir::Opcode::Mul:
throw std::runtime_error(FormatError("mir", "暂不支持该二元运算"));
case ir::Opcode::FAdd: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); // 分配结果栈槽
// 浮点值加载到 S0, S1使用浮点寄存器
EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block, function);
EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block, function);
block.Append(Opcode::FAddRR, {Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S1)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::FSub: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); // 分配结果栈槽
// 浮点值加载到 S0, S1使用浮点寄存器
EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block, function);
EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block, function);
block.Append(Opcode::FSubRR, {Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S1)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::FMul: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); // 分配结果栈槽
// 浮点值加载到 S0, S1使用浮点寄存器
EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block, function);
EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block, function);
block.Append(Opcode::FMulRR, {Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S1)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::FDiv: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get())); // 分配结果栈槽
// 浮点值加载到 S0, S1使用浮点寄存器
EmitValueToReg(bin.GetLhs(), PhysReg::S0, slots, block, function);
EmitValueToReg(bin.GetRhs(), PhysReg::S1, slots, block, function);
block.Append(Opcode::FDivRR, {Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S0),
Operand::Reg(PhysReg::S1)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
// ========== 整数比较指令(修正版)==========
case ir::Opcode::Icmp: {
auto& icmp = static_cast<const ir::IcmpInst&>(inst);
int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get()));
// 加载左右操作数到 w8, w9
EmitValueToReg(icmp.GetLhs(), PhysReg::W8, slots, block, function);
EmitValueToReg(icmp.GetRhs(), PhysReg::W9, slots, block, function);
// 生成比较指令
block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
// 使用条件设置指令: CSET W8, cc
// 如果条件成立W8 = 1否则 W8 = 0
CondCode cc = IcmpToCondCode(icmp.GetPredicate());
// 使用 CSET 的替代实现:条件移动
// MOV W8, #1
// MOV W9, #0
// CSEL W8, W8, W9, cc
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W8), Operand::Imm(1)});
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(0)});
// TODO: 需要添加 CSEL 指令,暂时使用条件跳转
// 创建临时标签
std::string true_label = ".L_cset_true_" + std::to_string(reinterpret_cast<uintptr_t>(&icmp));
std::string end_label = ".L_cset_end_" + std::to_string(reinterpret_cast<uintptr_t>(&icmp));
// 条件跳转到 true_label
block.Append(Opcode::BCond, {Operand::Cond(cc), Operand::Imm(0)});
// 条件不成立W8 = 0
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W8), Operand::Imm(0)});
block.Append(Opcode::B, {Operand::Imm(0)}); // 跳转到 end_label
// true_label: W8 = 1已经在上面设置了
// end_label: 继续
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
// ========== 浮点比较指令 ==========
case ir::Opcode::FCmp: {
auto& fcmp = static_cast<const ir::FcmpInst&>(inst);
int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get()));
// 加载浮点操作数到 s0, s1
EmitValueToReg(fcmp.GetLhs(), PhysReg::S0, slots, block, function);
EmitValueToReg(fcmp.GetRhs(), PhysReg::S1, slots, block, function);
// 生成浮点比较指令
block.Append(Opcode::FCmpRR, {Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::S1)});
// 简化实现:存储 1 作为结果
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W8), Operand::Imm(1)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
// ========== 跳转指令(使用标签操作数)==========
case ir::Opcode::Br: {
auto& br = static_cast<const ir::BranchInst&>(inst);
if (br.IsConditional()) {
// 条件跳转: br i1 %cond, label %then, label %else
// 加载条件值到 w8
EmitValueToReg(br.GetCondition(), PhysReg::W8, slots, block, function);
// 比较条件值是否为 0
block.Append(Opcode::CmpRI, {Operand::Reg(PhysReg::W8), Operand::Imm(0)});
// 获取目标基本块的标签名
const ir::BasicBlock* irTrueTarget = br.GetTrueTarget();
const ir::BasicBlock* irFalseTarget = br.GetFalseTarget();
std::string trueLabel = GetBlockLabel(irTrueTarget);
std::string falseLabel = GetBlockLabel(irFalseTarget);
// 生成 B.NE true_label
block.Append(Opcode::BCond, {Operand::Cond(CondCode::NE), Operand::Label(trueLabel)});
// 生成 B false_label
block.Append(Opcode::B, {Operand::Label(falseLabel)});
} else {
// 无条件跳转: br label %target
const ir::BasicBlock* irTarget = br.GetTarget();
std::string targetLabel = GetBlockLabel(irTarget);
// 生成 B target_label
block.Append(Opcode::B, {Operand::Label(targetLabel)});
}
return;
}
// ========== 函数调用 ==========
case ir::Opcode::Call: {
auto& call = static_cast<const ir::CallInst&>(inst);
const ir::Function* callee = call.GetCallee();
const std::string& calleeName = callee->GetName();
// 分配结果栈槽(如果有返回值)
int dst_slot = -1;
if (!inst.GetType()->IsVoid()) {
dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get()));
}
// 按照 ARM64 调用约定传递参数
const auto& args = call.GetArgs();
size_t intArgCount = 0;
size_t fpArgCount = 0;
for (size_t i = 0; i < args.size(); ++i) {
const auto* arg = args[i];
const ir::Type* argType = arg->GetType().get();
if (argType->IsFloat()) {
// 浮点参数
PhysReg reg = static_cast<PhysReg>(static_cast<int>(PhysReg::S0) + fpArgCount);
EmitValueToReg(arg, reg, slots, block, function);
fpArgCount++;
} else {
// 整数参数
PhysReg reg = static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + intArgCount);
EmitValueToReg(arg, reg, slots, block, function);
intArgCount++;
}
}
// 生成调用指令
block.Append(Opcode::Call, {Operand::Imm(0)}); // 实际需要传递函数名
// 保存返回值
if (dst_slot != -1) {
if (inst.GetType()->IsFloat()) {
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
} else {
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W0), Operand::FrameIndex(dst_slot)});
}
slots.emplace(&inst, dst_slot);
}
return;
}
// ========== 类型转换指令 ==========
case ir::Opcode::ZExt: {
auto& zext = static_cast<const ir::ZExtInst&>(inst);
int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get()));
// 加载源值到 w8
EmitValueToReg(zext.GetValue(), PhysReg::W8, slots, block, function);
// 零扩展i1 -> i32直接存储即可因为 i1 已经是 0 或 1
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::SIToFP: {
auto& sitofp = static_cast<const ir::SIToFPInst&>(inst);
int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get()));
// 加载整数到 w8
EmitValueToReg(sitofp.GetValue(), PhysReg::W8, slots, block, function);
// 整数转浮点SCVTF s0, w8
block.Append(Opcode::SIToFP, {Operand::Reg(PhysReg::S0), Operand::Reg(PhysReg::W8)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::FPToSI: {
auto& fptosi = static_cast<const ir::FPToSIInst&>(inst);
int dst_slot = function.CreateFrameIndex(GetTypeSize(inst.GetType().get()));
// 加载浮点数到 s0
EmitValueToReg(fptosi.GetValue(), PhysReg::S0, slots, block, function);
// 浮点转整数FCVTZS w8, s0
block.Append(Opcode::FPToSI, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::S0)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
// ========== GEP 指令(计算数组元素地址)==========
case ir::Opcode::GEP: {
auto& gep = static_cast<const ir::GEPInst&>(inst);
// GEP 返回指针类型,在 ARM64 上指针是 8 字节
int dst_slot = function.CreateFrameIndex(8);
// 获取基地址(数组的起始地址)
ir::Value* base = gep.GetBase();
const auto& indices = gep.GetIndices();
// 加载基地址到 x8使用 64 位寄存器存储地址)
EmitValueToReg(base, PhysReg::X8, slots, block, function);
if (indices.empty()) {
// 没有索引,直接返回基地址
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
// 获取数组类型信息,计算每个维度的步长
const ir::Type* baseType = base->GetType().get();
// 如果基地址是指针类型,需要解引用获取元素类型
if (baseType->IsPtrInt32() || baseType->IsPtrFloat() || baseType->IsPtrInt1()) {
// 对于指针类型,第一个索引是偏移量(以元素为单位)
// 例如int* p; p[1] 的 GEP 中 indices[0] = 1
if (indices.size() >= 1) {
// 加载索引到 x9
EmitValueToReg(indices[0], PhysReg::X9, slots, block, function);
// 乘以元素大小int/float 是 4 字节)
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X10), Operand::Imm(4)});
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
// 地址 = base + index * 4
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X9)});
}
// 存储计算出的地址
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
// 如果基地址是数组类型,需要处理多维数组
if (baseType->IsArray()) {
const ir::ArrayType* arrayType = static_cast<const ir::ArrayType*>(baseType);
const std::vector<int>& dims = arrayType->GetDimensions();
// 计算每个维度的步长
std::vector<int> strides(dims.size());
int stride = 4; // 元素大小int/float 是 4 字节)
for (int i = dims.size() - 1; i >= 0; --i) {
strides[i] = stride;
stride *= dims[i];
}
// 计算总偏移量
// 地址 = base + index0 * stride0 + index1 * stride1 + ...
size_t numIndices = indices.size();
// 限制索引数量(不能超过维度数)
if (numIndices > dims.size()) {
numIndices = dims.size();
}
// 加载当前地址到 x9 作为偏移量累加器
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X9), Operand::Imm(0)});
for (size_t i = 0; i < numIndices; ++i) {
// 加载当前索引到 x10
EmitValueToReg(indices[i], PhysReg::X10, slots, block, function);
// 乘以步长
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X11), Operand::Imm(strides[i])});
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X10),
Operand::Reg(PhysReg::X10),
Operand::Reg(PhysReg::X11)});
// 累加到偏移量
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
}
// 最终地址 = base + offset
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X9)});
// 存储计算出的地址
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
// 其他情况:简单处理
// 只处理第一个索引
if (indices.size() >= 1) {
EmitValueToReg(indices[0], PhysReg::X9, slots, block, function);
// 乘以元素大小(默认 4 字节)
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::X10), Operand::Imm(4)});
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X9),
Operand::Reg(PhysReg::X10)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X8),
Operand::Reg(PhysReg::X9)});
}
// 存储计算出的地址
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::X8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
//throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令"));
throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令opcode: "
+ std::to_string(static_cast<int>(inst.GetOpcode()))));
}
throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令"));
}
} // namespace
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module) {
DefaultContext();
if (module.GetFunctions().size() != 1) {
throw std::runtime_error(FormatError("mir", "暂不支持多个函数"));
}
const auto& func = *module.GetFunctions().front();
if (func.GetName() != "main") {
throw std::runtime_error(FormatError("mir", "暂不支持非 main 函数"));
}
} // namespace
// 辅助函数,将单个 IR 函数转换为 MachineFunction
std::unique_ptr<MachineFunction> LowerFunction(const ir::Function& func) {
auto machine_func = std::make_unique<MachineFunction>(func.GetName());
ValueSlotMap slots;
const auto* entry = func.GetEntry();
if (!entry) {
throw std::runtime_error(FormatError("mir", "IR 函数缺少入口基本块"));
// 为函数参数分配栈槽
for (const auto& arg : func.GetArguments()) {
// 为每个参数分配栈槽
int slot = machine_func->CreateFrameIndex(GetTypeSize(arg->GetType().get()));
slots.emplace(arg.get(), slot);
// 注意:参数的值需要从寄存器加载到栈槽
// 在 ARM64 调用约定中,前 8 个整数参数在 w0-w7前 8 个浮点参数在 s0-s7
// 这里需要生成指令将参数从寄存器存储到栈槽
// 但 MachineFunction 还没有基本块,所以需要延迟处理
}
for (const auto& inst : entry->GetInstructions()) {
LowerInstruction(*inst, *machine_func, slots);
// IR 基本块到 MIR 基本块的映射
std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> blockMap;
// 第一遍:为每个 IR 基本块创建 MIR 基本块
for (const auto& bb : func.GetBlocks()) {
auto mirBB = std::make_unique<MachineBasicBlock>(bb->GetName());
blockMap[bb.get()] = mirBB.get();
machine_func->AddBasicBlock(std::move(mirBB));
}
// 在入口基本块的开头添加参数加载指令
if (!func.GetBlocks().empty()) {
MachineBasicBlock* entryBB = blockMap[func.GetEntry()];
if (entryBB) {
// 为每个参数生成从寄存器到栈槽的存储指令
size_t intArgIdx = 0;
size_t fpArgIdx = 0;
for (const auto& arg : func.GetArguments()) {
int slot = slots[arg.get()];
const ir::Type* argType = arg->GetType().get();
if (argType->IsFloat()) {
// 浮点参数从 s0, s1, ... 读取
if (fpArgIdx < 8) {
PhysReg reg = static_cast<PhysReg>(static_cast<int>(PhysReg::S0) + fpArgIdx);
entryBB->Append(Opcode::StoreStack,
{Operand::Reg(reg), Operand::FrameIndex(slot)});
}
fpArgIdx++;
} else {
// 整数参数从 w0, w1, ... 读取
if (intArgIdx < 8) {
PhysReg reg = static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + intArgIdx);
entryBB->Append(Opcode::StoreStack,
{Operand::Reg(reg), Operand::FrameIndex(slot)});
}
intArgIdx++;
}
}
}
}
// 第二遍:遍历每个基本块,转换指令
for (const auto& bb : func.GetBlocks()) {
MachineBasicBlock* mirBB = blockMap[bb.get()];
if (!mirBB) {
throw std::runtime_error(FormatError("mir", "找不到基本块对应的 MIR 基本块"));
}
for (const auto& inst : bb->GetInstructions()) {
LowerInstruction(*inst, *machine_func, slots, *mirBB, blockMap);
}
}
return machine_func;
}
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) {
DefaultContext();
auto machine_module = std::make_unique<MachineModule>();
// 遍历模块中的所有函数
for (const auto& func : module.GetFunctions()) {
try {
auto machine_func = LowerFunction(*func);
machine_module->AddFunction(std::move(machine_func));
} catch (const std::runtime_error& e) {
// 记录错误但继续处理其他函数
throw std::runtime_error(FormatError("mir", "转换函数失败: " + func->GetName() + " - " + e.what()));
}
}
if (machine_module->GetFunctions().empty()) {
throw std::runtime_error(FormatError("mir", "模块中没有成功转换的函数"));
}
return machine_module;
}
} // namespace mir

@ -4,21 +4,25 @@
namespace mir {
Operand::Operand(Kind kind, PhysReg reg, int imm, CondCode cc)
: kind_(kind), reg_(reg), imm_(imm), cc_(cc) {}
Operand::Operand(Kind kind, PhysReg reg, int imm, CondCode cc, const std::string& label)
: kind_(kind), reg_(reg), imm_(imm), cc_(cc), label_(label) {}
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0, CondCode::EQ); }
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0, CondCode::EQ, ""); }
Operand Operand::Imm(int value) {
return Operand(Kind::Imm, PhysReg::W0, value, CondCode::EQ);
return Operand(Kind::Imm, PhysReg::W0, value, CondCode::EQ, "");
}
Operand Operand::FrameIndex(int index) {
return Operand(Kind::FrameIndex, PhysReg::W0, index, CondCode::EQ);
return Operand(Kind::FrameIndex, PhysReg::W0, index, CondCode::EQ, "");
}
Operand Operand::Cond(CondCode cc) {
return Operand(Kind::Cond, PhysReg::W0, 0, cc);
return Operand(Kind::Cond, PhysReg::W0, 0, cc, "");
}
Operand Operand::Label(const std::string& label) {
return Operand(Kind::Label, PhysReg::W0, 0, CondCode::EQ, label);
}
MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands)

@ -22,15 +22,61 @@ bool IsAllowedReg(PhysReg reg) {
} // namespace
//void RunRegAlloc(MachineFunction& function) {
// for (const auto& inst : function.GetEntry().GetInstructions()) {
// for (const auto& operand : inst.GetOperands()) {
// if (operand.GetKind() == Operand::Kind::Reg &&
// !IsAllowedReg(operand.GetReg())) {
// throw std::runtime_error(FormatError("mir", "寄存器分配失败"));
// }
// }
// }
//}
// 单函数版本的寄存器分配(原有逻辑)
void RunRegAlloc(MachineFunction& function) {
for (const auto& inst : function.GetEntry().GetInstructions()) {
for (const auto& operand : inst.GetOperands()) {
if (operand.GetKind() == Operand::Kind::Reg &&
!IsAllowedReg(operand.GetReg())) {
throw std::runtime_error(FormatError("mir", "寄存器分配失败"));
// 当前仅执行最小一致性检查,不实现真实寄存器分配
// Lab3 阶段保持栈槽模型,不需要真实寄存器分配
// 检查每个基本块中的指令
for (auto& bb : function.GetBasicBlocks()) {
for (auto& instr : bb->GetInstructions()) {
// 检查指令的操作数是否有效
for (const auto& operand : instr.GetOperands()) {
switch (operand.GetKind()) {
case Operand::Kind::Reg:
// 寄存器操作数:检查是否在允许的范围内
// 当前使用固定寄存器 w0, w8, w9, s0, s1 等
break;
case Operand::Kind::FrameIndex:
// 栈槽索引:检查是否有效
if (operand.GetFrameIndex() < 0 ||
operand.GetFrameIndex() >= static_cast<int>(function.GetFrameSlots().size())) {
throw std::runtime_error(
FormatError("regalloc", "无效的栈槽索引: " +
std::to_string(operand.GetFrameIndex())));
}
break;
case Operand::Kind::Imm:
case Operand::Kind::Cond:
case Operand::Kind::Label:
// 立即数、条件码、标签不需要检查
break;
}
}
}
}
// 注意Lab3 阶段不实现真实寄存器分配
// 所有值仍然使用栈槽模型,寄存器仅作为临时计算使用
}
// 模块版本的寄存器分配
void RunRegAlloc(MachineModule& module) {
// 对模块中的每个函数执行寄存器分配
for (auto& func : module.GetFunctions()) {
RunRegAlloc(*func);
}
}
} // namespace mir

Loading…
Cancel
Save