From 310c7c3697365c2873c63b250fa2e19db5997021 Mon Sep 17 00:00:00 2001 From: tansiping <3213415568@qq.com> Date: Mon, 13 Apr 2026 10:36:29 +0800 Subject: [PATCH] =?UTF-8?q?feat(mir):=E4=BF=AE=E6=AD=A3=E5=B9=B6=E5=AE=8C?= =?UTF-8?q?=E5=96=84=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/mir/MIR.h | 1 + scripts/mir_test.sh | 0 scripts/mir_test.sh:Zone.Identifier | Bin 0 -> 25 bytes scripts/mir_test1.sh | 153 ++++++++++++++++++++++++++++ src/mir/AsmPrinter.cpp | 148 ++++++++++++++++++++++----- src/mir/FrameLowering.cpp | 6 +- src/mir/Lowering.cpp | 36 +++++-- 7 files changed, 306 insertions(+), 38 deletions(-) mode change 100755 => 100644 scripts/mir_test.sh create mode 100644 scripts/mir_test.sh:Zone.Identifier create mode 100755 scripts/mir_test1.sh diff --git a/include/mir/MIR.h b/include/mir/MIR.h index 9b5721f..4dfe142 100644 --- a/include/mir/MIR.h +++ b/include/mir/MIR.h @@ -93,6 +93,7 @@ enum class Opcode { Slti, Slli, Sltu, // 无符号小于 + Sltiu, Xori, LoadGlobalAddr, LoadGlobal, diff --git a/scripts/mir_test.sh b/scripts/mir_test.sh old mode 100755 new mode 100644 diff --git a/scripts/mir_test.sh:Zone.Identifier b/scripts/mir_test.sh:Zone.Identifier new file mode 100644 index 0000000000000000000000000000000000000000..d6c1ec682968c796b9f5e9e080cc6f674b57c766 GIT binary patch literal 25 dcma!!%Fjy;DN4*MPD?F{<>dl#JyUFr831@K2x/dev/null + if [ $? -ne 0 ]; then + echo "警告:无法编译 sylib.c,部分测试可能链接失败" + fi +fi +echo "" + +mkdir -p "$TEST_RESULT_DIR" + +echo "==========================================" +echo "RISC-V 后端测试" +echo "==========================================" +echo "" + +# 收集测试用例 +mapfile -t test_files < <(find "$TEST_CASE_DIR" -name "*.sy" -not -path '*/*performance*/*' | sort) + +total=${#test_files[@]} +pass_gen=0 +fail_gen=0 +pass_run=0 +fail_run=0 +timeout_cnt=0 + +echo "=== 阶段1:汇编生成 ===" +echo "" + +for test_file in "${test_files[@]}"; do + relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file") + output_file="$TEST_RESULT_DIR/${relative_path%.sy}.s" + mkdir -p "$(dirname "$output_file")" + + "$COMPILER" --emit-asm "$test_file" 2>/dev/null > "$output_file" + + if [ $? -eq 0 ] && [ -s "$output_file" ]; then + echo -e " ${GREEN}✓${NC} $relative_path" + ((pass_gen++)) + else + echo -e " ${RED}✗${NC} $relative_path" + ((fail_gen++)) + fi +done + +echo "" +echo "--- 汇编生成: 通过 $pass_gen / 失败 $fail_gen / 总计 $total ---" +echo "" + +for test_file in "${test_files[@]}"; do + relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file") + stem="${relative_path%.sy}" + asm_file="$TEST_RESULT_DIR/${stem}.s" + exe_file="$TEST_RESULT_DIR/${stem}" + expected_file="${test_file%.sy}.out" + + if [ ! -s "$asm_file" ]; then + echo -e " ${YELLOW}⚠${NC} $relative_path (跳过)" + continue + fi + + # 链接 + if [ -f "$SYLIB_O" ]; then + riscv64-linux-gnu-gcc -static "$asm_file" "$SYLIB_O" -o "$exe_file" -no-pie 2>/dev/null + else + riscv64-linux-gnu-gcc -static "$asm_file" -o "$exe_file" -no-pie 2>/dev/null + fi + + if [ $? -ne 0 ]; then + echo -e " ${RED}✗${NC} $relative_path (链接失败)" + ((fail_run++)) + continue + fi + + # 运行程序 + + input_file="${test_file%.sy}.in" + tmp_out=$(mktemp) + if [ -f "$input_file" ]; then + timeout 10 qemu-riscv64 "$exe_file" < "$input_file" > "$tmp_out" 2>&1 + else + timeout 10 qemu-riscv64 "$exe_file" > "$tmp_out" 2>&1 + fi + exit_code=$? + + if [ $exit_code -eq 124 ]; then + echo -e " ${YELLOW}⚠${NC} $relative_path (超时)" + ((timeout_cnt++)) + rm -f "$tmp_out" + continue + fi + + program_output=$(cat "$tmp_out" | tr -d '\n' | sed 's/[[:space:]]*$//') + rm -f "$tmp_out" + + if [ -f "$expected_file" ]; then + expected=$(cat "$expected_file" | tr -d '\n' | sed 's/[[:space:]]*$//') + + if [[ "$expected" =~ ^[0-9]+$ ]] && [ "$expected" -ge 0 ] && [ "$expected" -le 255 ] && [ -z "$program_output" ]; then + # 期望退出码(且没有输出) + if [ $exit_code -eq "$expected" ] 2>/dev/null; then + echo -e " ${GREEN}✓${NC} $relative_path (退出码: $exit_code)" + ((pass_run++)) + else + echo -e " ${RED}✗${NC} $relative_path (退出码: 期望 $expected, 实际 $exit_code)" + ((fail_run++)) + fi + else + # 期望输出内容 + if [ "$program_output" = "$expected" ]; then + echo -e " ${GREEN}✓${NC} $relative_path (输出匹配)" + ((pass_run++)) + else + echo -e " ${RED}✗${NC} $relative_path (输出不匹配: 期望 '$expected', 实际 '$program_output')" + ((fail_run++)) + fi + fi + else + # 没有期望文件 + echo -e " ${GREEN}✓${NC} $relative_path (退出码: $exit_code, 输出: '$program_output')" + ((pass_run++)) + fi +done +echo "" +echo "--- 运行验证: 通过 $pass_run / 失败 $fail_run / 超时 $timeout_cnt ---" +echo "" + +echo "==========================================" +echo "测试完成" +echo "汇编生成: 通过 $pass_gen / 失败 $fail_gen" +echo "运行验证: 通过 $pass_run / 失败 $fail_run / 超时 $timeout_cnt" +echo "==========================================" \ No newline at end of file diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index 36774c3..c7a988d 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -7,6 +7,8 @@ #include #include "utils/Log.h" + + // 引用全局变量(定义在 Lowering.cpp 中) extern std::vector g_globalVars; @@ -21,6 +23,46 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function, return function.GetFrameSlot(operand.GetFrameIndex()); } +void EmitStackLoad(std::ostream& os, PhysReg dst, int offset, PhysReg base = PhysReg::SP) { + if (offset >= -2048 && offset <= 2047) { + os << " lw " << PhysRegName(dst) << ", " << offset << "(" << PhysRegName(base) << ")\n"; + } else { + os << " li t4, " << offset << "\n"; + os << " add t4, " << PhysRegName(base) << ", t4\n"; + os << " lw " << PhysRegName(dst) << ", 0(t4)\n"; + } +} + +void EmitStackStore(std::ostream& os, PhysReg src, int offset, PhysReg base = PhysReg::SP) { + if (offset >= -2048 && offset <= 2047) { + os << " sw " << PhysRegName(src) << ", " << offset << "(" << PhysRegName(base) << ")\n"; + } else { + os << " li t4, " << offset << "\n"; + os << " add t4, " << PhysRegName(base) << ", t4\n"; + os << " sw " << PhysRegName(src) << ", 0(t4)\n"; + } +} + +void EmitStackLoadFloat(std::ostream& os, PhysReg dst, int offset, PhysReg base = PhysReg::SP) { + if (offset >= -2048 && offset <= 2047) { + os << " flw " << PhysRegName(dst) << ", " << offset << "(" << PhysRegName(base) << ")\n"; + } else { + os << " li t4, " << offset << "\n"; + os << " add t4, " << PhysRegName(base) << ", t4\n"; + os << " flw " << PhysRegName(dst) << ", 0(t4)\n"; + } +} + +void EmitStackStoreFloat(std::ostream& os, PhysReg src, int offset, PhysReg base = PhysReg::SP) { + if (offset >= -2048 && offset <= 2047) { + os << " fsw " << PhysRegName(src) << ", " << offset << "(" << PhysRegName(base) << ")\n"; + } else { + os << " li t4, " << offset << "\n"; + os << " add t4, " << PhysRegName(base) << ", t4\n"; + os << " fsw " << PhysRegName(src) << ", 0(t4)\n"; + } +} + // 输出单个函数的汇编 void PrintAsmFunction(const MachineFunction& function, std::ostream& os) { // 收集所有基本块名称 @@ -45,9 +87,34 @@ void PrintAsmFunction(const MachineFunction& function, std::ostream& os) { // 在入口块的第一条指令前输出序言 if (!prologue_done && block.GetName() == "entry") { - os << " addi sp, sp, -" << total_frame_size << "\n"; - os << " sw ra, " << (total_frame_size - 8) << "(sp)\n"; - os << " sw s0, " << (total_frame_size - 16) << "(sp)\n"; + // 处理大栈帧的情况 + if (total_frame_size <= 2047) { + os << " addi sp, sp, -" << total_frame_size << "\n"; + } else { + os << " li t4, -" << total_frame_size << "\n"; + os << " add sp, sp, t4\n"; + } + + // 保存 ra 和 s0 + int ra_offset = total_frame_size - 8; + int s0_offset = total_frame_size - 16; + + if (ra_offset <= 2047) { + os << " sw ra, " << ra_offset << "(sp)\n"; + } else { + os << " li t4, " << ra_offset << "\n"; + os << " add t4, sp, t4\n"; + os << " sw ra, 0(t4)\n"; + } + + if (s0_offset <= 2047) { + os << " sw s0, " << s0_offset << "(sp)\n"; + } else { + os << " li t4, " << s0_offset << "\n"; + os << " add t4, sp, t4\n"; + os << " sw s0, 0(t4)\n"; + } + prologue_done = true; } @@ -64,12 +131,11 @@ void PrintAsmFunction(const MachineFunction& function, std::ostream& os) { case Opcode::Load: { if (ops.size() == 2 && ops.at(1).GetKind() == Operand::Kind::Reg) { os << " lw " << PhysRegName(ops.at(0).GetReg()) << ", 0(" - << PhysRegName(ops.at(1).GetReg()) << ")\n"; + << PhysRegName(ops.at(1).GetReg()) << ")\n"; } else { int frame_idx = ops.at(1).GetFrameIndex(); const auto& slot = function.GetFrameSlot(frame_idx); - os << " lw " << PhysRegName(ops.at(0).GetReg()) << ", " - << slot.offset << "(sp)\n"; + EmitStackLoad(os, ops.at(0).GetReg(), slot.offset); } break; } @@ -77,15 +143,14 @@ void PrintAsmFunction(const MachineFunction& function, std::ostream& os) { case Opcode::Store: { if (ops.size() == 2 && ops.at(1).GetKind() == Operand::Kind::Reg) { os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0(" - << PhysRegName(ops.at(1).GetReg()) << ")\n"; + << PhysRegName(ops.at(1).GetReg()) << ")\n"; } else { int frame_idx = ops.at(1).GetFrameIndex(); const auto& slot = function.GetFrameSlot(frame_idx); - os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", " - << slot.offset << "(sp)\n"; + EmitStackStore(os, ops.at(0).GetReg(), slot.offset); } break; - } + } case Opcode::Add: os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " @@ -142,6 +207,12 @@ void PrintAsmFunction(const MachineFunction& function, std::ostream& os) { << PhysRegName(ops.at(2).GetReg()) << "\n"; break; + case Opcode::Sltiu: // <-- 添加这个 + os << " sltiu " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << ops.at(2).GetImm() << "\n"; + break; + case Opcode::Xori: os << " xori " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " @@ -161,8 +232,8 @@ void PrintAsmFunction(const MachineFunction& function, std::ostream& os) { case Opcode::StoreGlobal: { std::string global_name = ops.at(1).GetGlobalName(); - os << " la t0, " << global_name << "\n"; - os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0(t0)\n"; + os << " la t1, " << global_name << "\n"; + os << " sw " << PhysRegName(ops.at(0).GetReg()) << ", 0(t1)\n"; break; } @@ -186,8 +257,13 @@ void PrintAsmFunction(const MachineFunction& function, std::ostream& os) { case Opcode::LoadAddr: { int frame_idx = ops.at(1).GetFrameIndex(); const auto& slot = function.GetFrameSlot(frame_idx); - os << " addi " << PhysRegName(ops.at(0).GetReg()) << ", sp, " - << slot.offset << "\n"; + if (slot.offset >= -2048 && slot.offset <= 2047) { + os << " addi " << PhysRegName(ops.at(0).GetReg()) << ", sp, " << slot.offset << "\n"; + } else { + os << " li " << PhysRegName(ops.at(0).GetReg()) << ", " << slot.offset << "\n"; + os << " add " << PhysRegName(ops.at(0).GetReg()) << ", sp, " + << PhysRegName(ops.at(0).GetReg()) << "\n"; + } break; } @@ -202,12 +278,38 @@ void PrintAsmFunction(const MachineFunction& function, std::ostream& os) { << PhysRegName(ops.at(1).GetReg()) << ")\n"; break; - case Opcode::Ret: - os << " lw ra, " << (total_frame_size - 8) << "(sp)\n"; - os << " lw s0, " << (total_frame_size - 16) << "(sp)\n"; - os << " addi sp, sp, " << total_frame_size << "\n"; + case Opcode::Ret:{ + // 恢复 ra 和 s0 + int ra_offset = total_frame_size - 8; + int s0_offset = total_frame_size - 16; + + if (ra_offset <= 2047) { + os << " lw ra, " << ra_offset << "(sp)\n"; + } else { + os << " li t3, " << ra_offset << "\n"; + os << " add t3, sp, t3\n"; + os << " lw ra, 0(t3)\n"; + } + + if (s0_offset <= 2047) { + os << " lw s0, " << s0_offset << "(sp)\n"; + } else { + os << " li t3, " << s0_offset << "\n"; + os << " add t3, sp, t3\n"; + os << " lw s0, 0(t3)\n"; + } + + // 恢复 sp + if (total_frame_size <= 2047) { + os << " addi sp, sp, " << total_frame_size << "\n"; + } else { + os << " li t3, " << total_frame_size << "\n"; + os << " add sp, sp, t3\n"; + } + os << " ret\n"; break; + } case Opcode::Br: { auto* target = reinterpret_cast(ops[0].GetImm64()); @@ -300,24 +402,22 @@ void PrintAsmFunction(const MachineFunction& function, std::ostream& os) { case Opcode::LoadFloat: if (ops.size() == 2 && ops[1].GetKind() == Operand::Kind::Reg) { os << " flw " << PhysRegName(ops[0].GetReg()) << ", 0(" - << PhysRegName(ops[1].GetReg()) << ")\n"; + << PhysRegName(ops[1].GetReg()) << ")\n"; } else { int frame_idx = ops[1].GetFrameIndex(); const auto& slot = function.GetFrameSlot(frame_idx); - os << " flw " << PhysRegName(ops[0].GetReg()) << ", " - << slot.offset << "(sp)\n"; + EmitStackLoadFloat(os, ops[0].GetReg(), slot.offset); } break; case Opcode::StoreFloat: if (ops.size() == 2 && ops[1].GetKind() == Operand::Kind::Reg) { os << " fsw " << PhysRegName(ops[0].GetReg()) << ", 0(" - << PhysRegName(ops[1].GetReg()) << ")\n"; + << PhysRegName(ops[1].GetReg()) << ")\n"; } else { int frame_idx = ops[1].GetFrameIndex(); const auto& slot = function.GetFrameSlot(frame_idx); - os << " fsw " << PhysRegName(ops[0].GetReg()) << ", " - << slot.offset << "(sp)\n"; + EmitStackStoreFloat(os, ops[0].GetReg(), slot.offset); } break; diff --git a/src/mir/FrameLowering.cpp b/src/mir/FrameLowering.cpp index 9063509..d191591 100644 --- a/src/mir/FrameLowering.cpp +++ b/src/mir/FrameLowering.cpp @@ -18,9 +18,9 @@ void RunFrameLowering(MachineFunction& function) { int cursor = 0; for (const auto& slot : function.GetFrameSlots()) { cursor += slot.size; - if (-cursor < -2048) { - throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧")); - } + //if (-cursor < -2048) { + //throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧")); + //} } cursor = 0; diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index 5f52768..ca44e42 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -305,20 +305,25 @@ void LowerInstructionToBlock(const ir::Instruction& inst, MachineFunction& funct switch (pred) { case ir::ICmpPredicate::EQ: block.Append(Opcode::Sub, {Operand::Reg(PhysReg::T0), - Operand::Reg(PhysReg::T0), - Operand::Reg(PhysReg::T1)}); - block.Append(Opcode::Slti, {Operand::Reg(PhysReg::T0), + Operand::Reg(PhysReg::T0), + Operand::Reg(PhysReg::T1)}); + block.Append(Opcode::Sltiu, {Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0), Operand::Imm(1)}); break; + case ir::ICmpPredicate::NE: block.Append(Opcode::Sub, {Operand::Reg(PhysReg::T0), - Operand::Reg(PhysReg::T0), - Operand::Reg(PhysReg::T1)}); - block.Append(Opcode::Sltu, {Operand::Reg(PhysReg::T0), - Operand::Reg(PhysReg::ZERO), - Operand::Reg(PhysReg::T0)}); + Operand::Reg(PhysReg::T0), + Operand::Reg(PhysReg::T1)}); + block.Append(Opcode::Sltiu, {Operand::Reg(PhysReg::T1), + Operand::Reg(PhysReg::T0), + Operand::Imm(1)}); + block.Append(Opcode::Xori, {Operand::Reg(PhysReg::T0), + Operand::Reg(PhysReg::T1), + Operand::Imm(1)}); break; + case ir::ICmpPredicate::SLT: block.Append(Opcode::Slt, {Operand::Reg(PhysReg::T0), Operand::Reg(PhysReg::T0), @@ -476,16 +481,25 @@ void LowerInstructionToBlock(const ir::Instruction& inst, MachineFunction& funct auto& condbr = static_cast(inst); auto* true_bb = condbr.GetTrueBB(); auto* false_bb = condbr.GetFalseBB(); - + + // 如果条件涉及函数调用,需要特殊处理 + // 简单方案:将条件值保存到栈槽 + int cond_slot = function.CreateFrameIndex(4); EmitValueToReg(condbr.GetCond(), PhysReg::T0, slots, block); + // 保存条件值到栈 + block.Append(Opcode::Store, {Operand::Reg(PhysReg::T0), Operand::FrameIndex(cond_slot)}); + + // 从栈加载条件值(确保函数调用后还能获取) + block.Append(Opcode::Load, {Operand::Reg(PhysReg::T0), Operand::FrameIndex(cond_slot)}); + block.Append(Opcode::Sltu, {Operand::Reg(PhysReg::T1), Operand::Reg(PhysReg::ZERO), Operand::Reg(PhysReg::T0)}); - + MachineBasicBlock* true_block = GetOrCreateBlock(true_bb, function); MachineBasicBlock* false_block = GetOrCreateBlock(false_bb, function); - + block.Append(Opcode::CondBr, {Operand::Reg(PhysReg::T1), Operand::Imm64(reinterpret_cast(true_block)), Operand::Imm64(reinterpret_cast(false_block))});