lab3代码实现

lc 2 months ago
parent 8f807adb08
commit 3dda941176

@ -19,7 +19,17 @@ class MIRContext {
MIRContext& DefaultContext();
enum class PhysReg { W0, W8, W9, X29, X30, SP };
// AArch64 physical registers
enum class PhysReg {
W0, W1, W2, W3, W4, W5, W6, W7,
W8, W9, W10, W11, W12, W13, W14, W15,
X0, X1, X2, X3, X4, X5, X6, X7,
X8, X9, X10, X11, X12, X13, X14, X15,
X16, X17,
S0, S1, S2, S3, S4, S5, S6, S7,
S8, S9, S10, S11, S12, S13, S14, S15,
X29, X30, SP, WZR, XZR
};
const char* PhysRegName(PhysReg reg);
@ -27,31 +37,67 @@ enum class Opcode {
Prologue,
Epilogue,
MovImm,
MovRR,
LoadStack,
StoreStack,
AddrStack,
LoadGlobal,
StoreGlobal,
AddRR,
AddRRI,
AddRRR_LSL,
SubRR,
MulRR,
SDivRR,
MSubRRR,
Sxtw,
NegR,
CmpRR,
CSet,
FAdd,
FSub,
FMUL,
FDiv,
FNeg,
FCmp,
FCvtSI2FP,
FCvtFP2SI,
LoadR,
StoreR,
Call,
B,
BCond,
Ret,
};
enum class CondCode { EQ, NE, LT, LE, GT, GE };
class Operand {
public:
enum class Kind { Reg, Imm, FrameIndex };
enum class Kind { Reg, Imm, FrameIndex, Label, Global, Cond };
static Operand Reg(PhysReg reg);
static Operand Imm(int value);
static Operand FrameIndex(int index);
static Operand Label(const std::string& name);
static Operand Global(const std::string& name);
static Operand Cond(CondCode cc);
Kind GetKind() const { return kind_; }
PhysReg GetReg() const { return reg_; }
int GetImm() const { return imm_; }
int GetFrameIndex() const { return imm_; }
const std::string& GetLabel() const { return label_; }
const std::string& GetGlobal() const { return label_; }
CondCode GetCond() const { return static_cast<CondCode>(imm_); }
private:
Operand(Kind kind, PhysReg reg, int imm);
Operand(Kind kind, PhysReg reg, int imm, std::string label = "");
Kind kind_;
PhysReg reg_;
int imm_;
std::string label_;
};
class MachineInstr {
@ -93,8 +139,10 @@ class MachineFunction {
explicit MachineFunction(std::string name);
const std::string& GetName() const { return name_; }
MachineBasicBlock& GetEntry() { return entry_; }
const MachineBasicBlock& GetEntry() const { return entry_; }
MachineBasicBlock& CreateBlock(const std::string& name);
std::vector<std::unique_ptr<MachineBasicBlock>>& GetBlocks() { return blocks_; }
const std::vector<std::unique_ptr<MachineBasicBlock>>& GetBlocks() const { return blocks_; }
int CreateFrameIndex(int size = 4);
FrameSlot& GetFrameSlot(int index);
@ -106,14 +154,35 @@ class MachineFunction {
private:
std::string name_;
MachineBasicBlock entry_;
std::vector<std::unique_ptr<MachineBasicBlock>> blocks_;
std::vector<FrameSlot> frame_slots_;
int frame_size_ = 0;
};
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
struct GlobalVariable {
std::string name;
int init_value = 0;
size_t size = 4;
bool is_const = false;
};
class MachineModule {
public:
MachineModule() = default;
std::vector<std::unique_ptr<MachineFunction>>& GetFunctions() { return functions_; }
const std::vector<std::unique_ptr<MachineFunction>>& GetFunctions() const { return functions_; }
std::vector<GlobalVariable>& GetGlobals() { return globals_; }
const std::vector<GlobalVariable>& GetGlobals() const { return globals_; }
private:
std::vector<std::unique_ptr<MachineFunction>> functions_;
std::vector<GlobalVariable> globals_;
};
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineFunction& function);
void RunFrameLowering(MachineFunction& function);
void PrintAsm(const MachineFunction& function, std::ostream& os);
void PrintAsm(const MachineModule& module, std::ostream& os);
} // namespace mir

@ -46,13 +46,15 @@ int main(int argc, char** argv) {
}
if (opts.emit_asm) {
auto machine_func = mir::LowerToMIR(*module);
mir::RunRegAlloc(*machine_func);
mir::RunFrameLowering(*machine_func);
auto machine_module = mir::LowerToMIR(*module);
for (auto& func : machine_module->GetFunctions()) {
mir::RunRegAlloc(*func);
mir::RunFrameLowering(*func);
}
if (need_blank_line) {
std::cout << "\n";
}
mir::PrintAsm(*machine_func, std::cout);
mir::PrintAsm(*machine_module, std::cout);
}
#else
if (opts.emit_ir || opts.emit_asm) {

@ -16,63 +16,290 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function,
return function.GetFrameSlot(operand.GetFrameIndex());
}
void PrintMovImm(std::ostream& os, PhysReg reg, int imm) {
const char* reg_name = PhysRegName(reg);
if (imm >= -32768 && imm <= 65535) {
os << " mov " << reg_name << ", #" << imm << "\n";
} else {
uint32_t uimm = static_cast<uint32_t>(imm);
os << " mov " << reg_name << ", #" << (uimm & 0xFFFF) << "\n";
os << " movk " << reg_name << ", #" << ((uimm >> 16) & 0xFFFF) << ", lsl #16\n";
}
}
void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
int offset) {
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset
<< "]\n";
if (offset >= -256 && offset <= 255) {
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset
<< "]\n";
} else {
// Offset out of range for ldur/stur
if (offset < 0) {
PrintMovImm(os, PhysReg::X16, -offset);
os << " sub x16, x29, x16\n";
} else {
PrintMovImm(os, PhysReg::X16, offset);
os << " add x16, x29, x16\n";
}
if (mnemonic[0] == 'l') { // load
os << " ldr " << PhysRegName(reg) << ", [x16]\n";
} else { // store
os << " str " << PhysRegName(reg) << ", [x16]\n";
}
}
}
const char* CondCodeName(CondCode cc) {
switch (cc) {
case CondCode::EQ: return "eq";
case CondCode::NE: return "ne";
case CondCode::LT: return "lt";
case CondCode::LE: return "le";
case CondCode::GT: return "gt";
case CondCode::GE: return "ge";
}
return "??";
}
} // namespace
void PrintAsm(const MachineFunction& function, std::ostream& os) {
void PrintAsm(const MachineModule& module, std::ostream& os) {
// Print global variables
if (!module.GetGlobals().empty()) {
os << ".data\n";
for (const auto& gv : module.GetGlobals()) {
os << ".global " << gv.name << "\n";
os << ".align 4\n";
os << gv.name << ":\n";
if (gv.size > 4 || gv.init_value == 0) {
os << " .zero " << gv.size << "\n";
} else {
os << " .word " << gv.init_value << "\n";
}
}
os << "\n";
}
os << ".text\n";
os << ".global " << function.GetName() << "\n";
os << ".type " << function.GetName() << ", %function\n";
os << function.GetName() << ":\n";
for (const auto& function : module.GetFunctions()) {
os << ".global " << function->GetName() << "\n";
os << ".type " << function->GetName() << ", %function\n";
os << function->GetName() << ":\n";
for (const auto& inst : function.GetEntry().GetInstructions()) {
const auto& ops = inst.GetOperands();
switch (inst.GetOpcode()) {
case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) {
os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
}
break;
case Opcode::Epilogue:
if (function.GetFrameSize() > 0) {
os << " add sp, sp, #" << function.GetFrameSize() << "\n";
for (const auto& block : function->GetBlocks()) {
os << ".L" << function->GetName() << "_" << block->GetName() << ":\n";
for (const auto& inst : block->GetInstructions()) {
const auto& ops = inst.GetOperands();
switch (inst.GetOpcode()) {
case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function->GetFrameSize() > 0) {
if (function->GetFrameSize() <= 4095) {
os << " sub sp, sp, #" << function->GetFrameSize() << "\n";
} else {
PrintMovImm(os, PhysReg::X11, function->GetFrameSize());
os << " sub sp, sp, x11\n";
}
}
break;
case Opcode::Epilogue:
if (function->GetFrameSize() > 0) {
if (function->GetFrameSize() <= 4095) {
os << " add sp, sp, #" << function->GetFrameSize() << "\n";
} else {
PrintMovImm(os, PhysReg::X11, function->GetFrameSize());
os << " add sp, sp, x11\n";
}
}
os << " ldp x29, x30, [sp], #16\n";
break;
case Opcode::MovImm:
if (ops.at(1).GetKind() == Operand::Kind::Global) {
os << " adrp " << PhysRegName(ops.at(0).GetReg()) << ", " << ops.at(1).GetGlobal() << "\n";
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(0).GetReg())
<< ", :lo12:" << ops.at(1).GetGlobal() << "\n";
} else {
PrintMovImm(os, ops.at(0).GetReg(), ops.at(1).GetImm());
}
break;
case Opcode::MovRR: {
const char* dst = PhysRegName(ops.at(0).GetReg());
const char* src = PhysRegName(ops.at(1).GetReg());
if (dst[0] == 's' && src[0] == 'w') {
os << " fmov " << dst << ", " << src << "\n";
} else if (dst[0] == 'w' && src[0] == 's') {
os << " fmov " << dst << ", " << src << "\n";
} else if (dst[0] == 's' && src[0] == 's') {
os << " fmov " << dst << ", " << src << "\n";
} else {
os << " mov " << dst << ", " << src << "\n";
}
break;
}
case Opcode::LoadStack: {
const auto& slot = GetFrameSlot(*function, ops.at(1));
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::StoreStack: {
const auto& slot = GetFrameSlot(*function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::AddrStack: {
const auto& slot = GetFrameSlot(*function, ops.at(1));
int offset = slot.offset;
if (offset >= 0) {
if (offset <= 4095) {
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", x29, #" << offset << "\n";
} else {
PrintMovImm(os, PhysReg::X16, offset);
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", x29, x16\n";
}
} else {
int abs_offset = -offset;
if (abs_offset <= 4095) {
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", x29, #" << abs_offset << "\n";
} else {
PrintMovImm(os, PhysReg::X16, abs_offset);
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", x29, x16\n";
}
}
break;
}
case Opcode::LoadGlobal:
os << " adrp x16, " << ops.at(1).GetGlobal() << "\n";
os << " add x16, x16, :lo12:" << ops.at(1).GetGlobal() << "\n";
os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", [x16]\n";
break;
case Opcode::StoreGlobal:
os << " adrp x16, " << ops.at(1).GetGlobal() << "\n";
os << " add x16, x16, :lo12:" << ops.at(1).GetGlobal() << "\n";
os << " str " << PhysRegName(ops.at(0).GetReg()) << ", [x16]\n";
break;
case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::AddRRI:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", #" << ops.at(2).GetImm() << "\n";
break;
case Opcode::AddRRR_LSL: {
const char* reg2_name = PhysRegName(ops.at(2).GetReg());
std::string reg2_str = reg2_name;
std::string extension = "lsl";
if (reg2_name[0] == 'w') {
extension = "sxtw";
}
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< reg2_str << ", " << extension << " #" << ops.at(3).GetImm() << "\n";
break;
}
case Opcode::SubRR:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::MulRR:
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::SDivRR:
os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::MSubRRR:
os << " msub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << ", "
<< PhysRegName(ops.at(3).GetReg()) << "\n";
break;
case Opcode::Sxtw:
os << " sxtw " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::NegR:
os << " neg " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::CmpRR:
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::CSet:
os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", "
<< CondCodeName(ops.at(1).GetCond()) << "\n";
break;
case Opcode::FAdd:
os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FSub:
os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FMUL:
os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FDiv:
os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FNeg:
os << " fneg " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FCmp:
os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FCvtSI2FP:
os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FCvtFP2SI:
os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::LoadR:
os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", ["
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
break;
case Opcode::StoreR:
os << " str " << PhysRegName(ops.at(0).GetReg()) << ", ["
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
break;
case Opcode::Call:
os << " bl " << ops.at(0).GetLabel() << "\n";
break;
case Opcode::B:
os << " b .L" << function->GetName() << "_" << ops.at(0).GetLabel() << "\n";
break;
case Opcode::BCond:
os << " cmp " << PhysRegName(ops.at(1).GetReg()) << ", #0\n";
os << " b." << CondCodeName(ops.at(0).GetCond()) << " .L" << function->GetName() << "_" << ops.at(2).GetLabel() << "\n";
break;
case Opcode::Ret:
os << " ret\n";
break;
}
os << " ldp x29, x30, [sp], #16\n";
break;
case Opcode::MovImm:
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< ops.at(1).GetImm() << "\n";
break;
case Opcode::LoadStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::StoreStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Ret:
os << " ret\n";
break;
}
os << ".size " << function->GetName() << ", .-" << function->GetName() << "\n\n";
}
os << ".size " << function.GetName() << ", .-" << function.GetName()
<< "\n";
}
} // namespace mir

@ -19,7 +19,8 @@ void RunFrameLowering(MachineFunction& function) {
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
if (-cursor < -256) {
throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
// For now, keep the 256-byte limit for simplicity (ldur/stur range)
// throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
}
}
@ -30,16 +31,24 @@ void RunFrameLowering(MachineFunction& function) {
}
function.SetFrameSize(AlignTo(cursor, 16));
auto& insts = function.GetEntry().GetInstructions();
std::vector<MachineInstr> lowered;
lowered.emplace_back(Opcode::Prologue);
for (const auto& inst : insts) {
if (inst.GetOpcode() == Opcode::Ret) {
lowered.emplace_back(Opcode::Epilogue);
// Add Prologue to the first block
if (!function.GetBlocks().empty()) {
auto& entry_insts = function.GetBlocks().front()->GetInstructions();
entry_insts.insert(entry_insts.begin(), MachineInstr(Opcode::Prologue));
}
// Add Epilogue before every Ret
for (auto& block : function.GetBlocks()) {
auto& insts = block->GetInstructions();
std::vector<MachineInstr> lowered;
for (const auto& inst : insts) {
if (inst.GetOpcode() == Opcode::Ret) {
lowered.emplace_back(Opcode::Epilogue);
}
lowered.push_back(inst);
}
lowered.push_back(inst);
insts = std::move(lowered);
}
insts = std::move(lowered);
}
} // namespace mir

Loading…
Cancel
Save