From 2472624927e6ead17b64da2ff7ee0d7eb1d729e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AF=95=E6=81=A9=E5=87=AF?= <15609889+biankai001@user.noreply.gitee.com> Date: Mon, 1 Jun 2026 22:23:22 +0800 Subject: [PATCH] bnk --- .claude/settings.local.json | 61 ++++ include/mir/MIR.h | 284 ++++++++++++---- src/main.cpp | 2 +- src/mir/AsmPrinter.cpp | 383 +++++++++++++++++---- src/mir/CMakeLists.txt | 5 +- src/mir/FrameLowering.cpp | 58 ++-- src/mir/Lowering.cpp | 584 +++++++++++++++++++++++++++------ src/mir/MIRBasicBlock.cpp | 16 +- src/mir/MIRContext.cpp | 10 +- src/mir/MIRFunction.cpp | 33 +- src/mir/MIRInstr.cpp | 23 +- src/mir/RegAlloc.cpp | 358 ++++++++++++++++++-- src/mir/Register.cpp | 54 ++- src/mir/passes/PassManager.cpp | 8 +- src/mir/passes/Peephole.cpp | 86 ++++- 15 files changed, 1598 insertions(+), 367 deletions(-) create mode 100644 .claude/settings.local.json diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..3505354 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,61 @@ +{ + "permissions": { + "allow": [ + "Bash(cd \"\\\\\\\\wsl.localhost\\\\Ubuntu-24.04\\\\home\\\\bnk\\\\nudt-compiler-cpp\")", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'echo IN_WSL; which cmake g++ aarch64-linux-gnu-gcc qemu-aarch64 clang 2>/dev/null; pwd')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cmake -S . -B build -DCMAKE_BUILD_TYPE=Release > /tmp/cm.log 2>&1 && echo \"cfg ok\" || tail -20 /tmp/cm.log; cmake --build build -j \"$\\(nproc\\)\" > /tmp/build.log 2>&1 && echo \"build ok\" || tail -40 /tmp/build.log')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && rm -rf build && cmake -S . -B build -DCMAKE_BUILD_TYPE=Release > /tmp/cm.log 2>&1 && echo \"cfg ok\" || tail -20 /tmp/cm.log; cmake --build build -j \"$\\(nproc\\)\" > /tmp/build.log 2>&1 && echo \"build ok\" || tail -40 /tmp/build.log')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && ./build/bin/compiler --emit-ir test/test_case/functional/simple_add.sy 2>&1')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== simple_add no-opt ===\"; ./build/bin/compiler --emit-ir --no-opt test/test_case/functional/simple_add.sy 2>&1 | head -40')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && for f in 11_add2 13_sub2 29_break 36_op_priority2; do echo \"=== $f \\(opt\\) ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/$f.sy 2>&1 | sed -n \"/define/,/^}/p\"; done')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== 09_func_defn ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/09_func_defn.sy 2>&1 | sed -n \"1,60p\"')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== 25_scope3 ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/25_scope3.sy 2>&1 | sed -n \"/define/,/^}/p\" | head -80')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== 05_arr_defn4 ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/05_arr_defn4.sy 2>&1 | sed -n \"/define/,/^}/p\" | head -120')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== 22 globals+main head ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/22_matrix_multiply.sy 2>&1 | sed -n \"1,40p\"')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== gep with var index \\(22\\) ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/22_matrix_multiply.sy 2>&1 | grep -A2 getelementptr | head -20; echo \"=== 95 float head ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/95_float.sy 2>&1 | sed -n \"1,50p\"')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== CMakeLists \\(mir glob?\\) ===\"; grep -n -i \"mir\\\\|glob\\\\|GLOB\\\\|file\\(\" CMakeLists.txt | head -40')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -n \"src/\\\\|GLOB\\\\|add_executable\\\\|add_library\\\\|SOURCES\\\\|\\\\.cpp\" CMakeLists.txt | head -60')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -n \"compiler_core\\\\|compiler\\\\b\\\\|target_sources\\\\|GLOB.*SRC\\\\|PROJECT_SRC\\\\|set\\(SOURCES\\\\|\\\\.cpp\\\\\"\" CMakeLists.txt | head; echo \"---\"; sed -n \"82,160p\" CMakeLists.txt')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cat src/CMakeLists.txt')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== mir/CMakeLists ===\"; cat src/mir/CMakeLists.txt; echo \"=== ls src/mir ===\"; ls -R src/mir; echo \"=== ls include/mir ===\"; ls include/mir')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && tot=0; for f in test/test_case/functional/*.sy; do ./build/bin/compiler --emit-ir \"$f\" 2>/dev/null; done > /tmp/allir.txt; echo \"=== opcodes used ===\"; grep -oE \"= \\(add|sub|mul|sdiv|srem|fadd|fsub|fmul|fdiv|icmp|fcmp|call|getelementptr|load|alloca|phi|sitofp|fptosi|zext\\) \" /tmp/allir.txt | sort | uniq -c; echo \"=== bare ops ===\"; grep -oE \"^ \\(store|br|ret|call\\) \" /tmp/allir.txt | sort | uniq -c; echo \"=== phi count ===\"; grep -c \"phi\" /tmp/allir.txt; echo \"=== float present ===\"; grep -c \"float\" /tmp/allir.txt')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && ls test/test_case/functional/*.sy | wc -l; echo \"--- verify uses ---\"; grep -nE \"qemu|aarch64|gcc|clang|gcc-|--target\" scripts/verify_asm.sh | head; echo \"--- a sample .out ---\"; ls test/test_case/functional/*.out 2>/dev/null | head -3; echo \"--- runtime lib? ---\"; ls test/ ; find . -name \"*.a\" -path \"*runtime*\" 2>/dev/null | head')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== sylib.h ===\"; cat sylib/sylib.h 2>/dev/null | head -60; echo \"=== test list ===\"; ls test/test_case/functional/*.sy | xargs -n1 basename')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -nE \"int |void |float |#define|starttime|_sysy\" sylib/sylib.c | head -40')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && tail -5 include/mir/MIR.h')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -nE \"GetFunctions|GetGlobals|IsDeclaration|GetBlocks|GetEntry|GetSuccessors|GetNumArgs|GetArg\\\\b|class Argument|GetInstructions|GetOpcode|class Module\" include/ir/IR.h | head -40')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== Opcode enum ===\"; grep -nA40 \"enum class Opcode\" include/ir/IR.h | head -50')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -nE \"GetLhs|GetRhs|class BinaryInst|class ICmpInst|class FCmpInst|GetPredicate|class CastInst|GetValue\\\\b|class LoadInst|GetPtr|class StoreInst|class AllocaInst|GetAllocatedType|class GepInst|GetBasePtr|GetIndices|class CallInst|GetCallee|GetArgs|class BranchInst|GetDest|class CondBrInst|GetCond|GetTrueDest|GetFalseDest|class ReturnInst|HasReturnValue\" include/ir/IR.h')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== Type ===\"; grep -nE \"class Type|enum class Kind|GetKind|GetArraySize|GetElementType|IsFloat|IsPointer|IsArray|IsVoid|Int1|Int32\" include/ir/IR.h | head -30; echo \"=== Constants/Global ===\"; grep -nE \"class ConstantInt|class ConstantFloat|class ConstantArray|GetElements|class ConstantValue|class GlobalVariable|GetValueType|IsConst|GetInitializer|class Function\\\\b|GetName\" include/ir/IR.h | head -40')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -nE \"GetType\\\\\\(\\\\\\)|class Value\\\\b|class GlobalValue\" include/ir/IR.h | head; echo \"=== line 40-60 ===\"; sed -n \"36,60p\" include/ir/IR.h; echo \"=== 120-160 ===\"; sed -n \"118,160p\" include/ir/IR.h')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cmake --build build -j \"$\\(nproc\\)\" 2>&1 | grep -E \"error:|Error|错误\" | head -40; echo \"=== exit ${PIPESTATUS[0]} ===\"')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cmake --build build -j \"$\\(nproc\\)\" 2>&1 | grep -iE \"undefined|multiple|duplicate|reference to\" | head -30')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cat src/mir/CMakeLists.txt; echo \"=== grep mir in src/CMakeLists ===\"; grep -n mir src/CMakeLists.txt')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cat src/mir/passes/CMakeLists.txt')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cmake -S . -B build -DCMAKE_BUILD_TYPE=Release > /tmp/cm.log 2>&1 && echo cfgok || tail -20 /tmp/cm.log; cmake --build build -j \"$\\(nproc\\)\" 2>&1 | grep -iE \"error|undefined|错误\" | head -30; echo \"=== done ===\"')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && ./build/bin/compiler --emit-asm test/test_case/functional/simple_add.sy 2>&1 | head -40')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && pass=0; fail=0; fails=\"\"; for f in test/test_case/functional/*.sy; do b=$\\(basename \"$f\" .sy\\); if ./scripts/verify_asm.sh \"$f\" /tmp/asmout --run > /tmp/v.log 2>&1; then pass=$\\(\\(pass+1\\)\\); else fail=$\\(\\(fail+1\\)\\); fails=\"$fails $b\"; fi; done; echo \"PASS=$pass FAIL=$fail\"; echo \"FAILED:$fails\"')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && for b in 13_sub2 95_float; do echo \"===== $b =====\"; ./scripts/verify_asm.sh test/test_case/functional/$b.sy /tmp/asmout --run 2>&1 | tail -15; done')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cmake --build build -j \"$\\(nproc\\)\" 2>&1 | grep -iE \"error|错误\" | head; for b in 13_sub2 95_float; do echo \"===== $b =====\"; ./scripts/verify_asm.sh test/test_case/functional/$b.sy /tmp/asmout --run 2>&1 | tail -6; done')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && for b in 15_graph_coloring 22_matrix_multiply; do echo \"===== $b =====\"; ./scripts/verify_asm.sh test/test_case/functional/$b.sy /tmp/asmout --run 2>&1 | tail -8; done')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -nE \"no-opt|no_opt|emit-asm|compiler|OPT|opt\" scripts/verify_asm.sh | head')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && CC=aarch64-linux-gnu-gcc; QEMU=qemu-aarch64; SYS=$\\(ls sysroot 2>/dev/null\\); for b in 22_matrix_multiply 95_float 15_graph_coloring 25_scope3; do f=test/test_case/functional/$b.sy; ./build/bin/compiler --no-opt --emit-asm \"$f\" > /tmp/n.s 2>/tmp/n.err || { echo \"$b COMPILE FAIL\"; cat /tmp/n.err; continue; }; aarch64-linux-gnu-gcc /tmp/n.s sylib/libsysy.a -o /tmp/n.exe 2>/tmp/as.err || { echo \"$b ASM/LINK FAIL\"; head -5 /tmp/as.err; continue; }; in=test/test_case/functional/$b.in; if [ -f \"$in\" ]; then qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/n.exe < \"$in\" > /tmp/n.out 2>&1; else qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/n.exe > /tmp/n.out 2>&1; fi; ec=$?; echo \"$b: exit=$ec out=$\\(tr \"\\\\n\" \"|\" < /tmp/n.out | head -c 80\\)\"; done')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && sed -n \"58,110p\" scripts/verify_asm.sh')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && aarch64-linux-gnu-gcc -O2 -Wno-unused-result -c sylib/sylib.c -o /tmp/sylib.o 2>/dev/null; pass=0; fail=0; fails=\"\"; for f in test/test_case/functional/*.sy; do b=$\\(basename \"$f\" .sy\\); ./build/bin/compiler --no-opt --emit-asm \"$f\" > /tmp/n.s 2>/dev/null || { fails=\"$fails $b\\(cc\\)\"; fail=$\\(\\(fail+1\\)\\); continue; }; aarch64-linux-gnu-gcc /tmp/n.s /tmp/sylib.o -o /tmp/n.exe 2>/tmp/as.err || { fails=\"$fails $b\\(as\\)\"; fail=$\\(\\(fail+1\\)\\); continue; }; in=test/test_case/functional/$b.in; if [ -f \"$in\" ]; then qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/n.exe < \"$in\" > /tmp/n.out 2>&1; else qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/n.exe > /tmp/n.out 2>&1; fi; ec=$?; exp=test/test_case/functional/$b.out; { cat /tmp/n.out; [ -s /tmp/n.out ] && [ $\\(tail -c1 /tmp/n.out|wc -l\\) -eq 0 ] && printf \"\\\\n\"; printf \"%s\\\\n\" \"$ec\"; } > /tmp/n.act; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" \"$exp\" > /tmp/n.exp 2>/dev/null; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" /tmp/n.act > /tmp/n.actn; if diff -q /tmp/n.exp /tmp/n.actn >/dev/null 2>&1; then pass=$\\(\\(pass+1\\)\\); else fails=\"$fails $b\\(diff\\)\"; fail=$\\(\\(fail+1\\)\\); fi; done; echo \"NOOPT PASS=$pass FAIL=$fail\"; echo \"FAILS:$fails\"')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== check spill present in no-opt 22 ===\"; ./build/bin/compiler --no-opt --emit-asm test/test_case/functional/22_matrix_multiply.sy 2>/dev/null | grep -cE \"x29, #\\(2[0-9]|[3-9][0-9]\\)\"; echo \"\\(stack accesses above = spills/locals\\)\"; echo \"=== self-move check \\(should be 0\\) ===\"; ./build/bin/compiler --emit-asm test/test_case/functional/22_matrix_multiply.sy 2>/dev/null | grep -cE \"\\\\bmov\\\\t\\(w|x\\)\\([0-9]+\\), \\\\1?\\\\2$\"')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== total instrs across all tests \\(opt\\) ===\"; tot=0; for f in test/test_case/functional/*.sy; do n=$\\(./build/bin/compiler --emit-asm \"$f\" 2>/dev/null | grep -cE \"^\\\\t\\(mov|add|sub|mul|ldr|str|b|bl|cmp|cset|fmov|ret|sxtw|lsl|sdiv|msub|fadd|fsub|fmul|fdiv|fcmp|scvtf|fcvtzs|adrp|stp|ldp\\)\"\\); tot=$\\(\\(tot+n\\)\\); done; echo \"total=$tot\"')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && pass=0; fail=0; fails=\"\"; for f in test/test_case/functional/*.sy; do b=$\\(basename \"$f\" .sy\\); if ./scripts/verify_asm.sh \"$f\" test/test_result/function/asm --run > /tmp/v.log 2>&1; then pass=$\\(\\(pass+1\\)\\); else fail=$\\(\\(fail+1\\)\\); fails=\"$fails $b\"; fi; done; echo \"OFFICIAL SCRIPT PASS=$pass FAIL=$fail\"; echo \"FAILS:$fails\"')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && f=test/test_case/functional/22_matrix_multiply.sy; ./build/bin/compiler --emit-asm \"$f\" 2>/dev/null > /tmp/with.s; wc -l < /tmp/with.s | xargs echo \"lines with peephole:\"; grep -c \"\tmov\t\" /tmp/with.s | xargs echo \"mov count:\"')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && pass=0; fail=0; for f in test/test_case/functional/*.sy; do b=$\\(basename \"$f\" .sy\\); if ./scripts/verify_asm.sh \"$f\" test/test_result/function/asm --run > /tmp/v.log 2>&1; then echo \" PASS $b\"; pass=$\\(\\(pass+1\\)\\); else echo \" FAIL $b\"; fail=$\\(\\(fail+1\\)\\); fi; done; echo \"========================\"; echo \"PASS=$pass FAIL=$fail\"')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== test_case 目录结构 ===\"; find test/test_case -type d; echo \"=== 各目录 .sy 数量 ===\"; for d in $\\(find test/test_case -type d\\); do n=$\\(ls \"$d\"/*.sy 2>/dev/null | wc -l\\); [ \"$n\" -gt 0 ] && echo \"$n $d\"; done; echo \"=== 总数 ===\"; find test/test_case -name \"*.sy\" | wc -l')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && pass=0; fail=0; fails=\"\"; for f in test/test_case/performance/*.sy; do b=$\\(basename \"$f\" .sy\\); if ./scripts/verify_asm.sh \"$f\" test/test_result/perf/asm --run > /tmp/v.log 2>&1; then echo \" PASS $b\"; pass=$\\(\\(pass+1\\)\\); else echo \" FAIL $b\"; fail=$\\(\\(fail+1\\)\\); fails=\"$fails $b\"; fi; done; echo \"===== PERF PASS=$pass FAIL=$fail =====\"; echo \"FAILS:$fails\"')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && ls test/test_case/performance/*.sy | xargs -n1 basename')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && aarch64-linux-gnu-gcc -O2 -Wno-unused-result -c sylib/sylib.c -o /tmp/sylib.o 2>/dev/null; for f in test/test_case/performance/*.sy; do b=$\\(basename \"$f\" .sy\\); ./build/bin/compiler --emit-asm \"$f\" > /tmp/p.s 2>/tmp/p.err || { echo \"FAIL\\($b\\) compile\"; continue; }; aarch64-linux-gnu-gcc /tmp/p.s /tmp/sylib.o -o /tmp/p.exe 2>/tmp/p.as || { echo \"FAIL\\($b\\) assemble\"; head -3 /tmp/p.as; continue; }; in=test/test_case/performance/$b.in; if [ -f \"$in\" ]; then timeout 25 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/p.exe < \"$in\" > /tmp/p.out 2>&1; else timeout 25 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/p.exe > /tmp/p.out 2>&1; fi; ec=$?; if [ $ec -eq 124 ]; then echo \"TIMEOUT\\($b\\)\"; continue; fi; exp=test/test_case/performance/$b.out; { cat /tmp/p.out; [ -s /tmp/p.out ] && [ $\\(tail -c1 /tmp/p.out|wc -l\\) -eq 0 ] && printf \"\\\\n\"; printf \"%s\\\\n\" \"$ec\"; } > /tmp/p.act; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" \"$exp\" > /tmp/p.exp 2>/dev/null; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" /tmp/p.act > /tmp/p.actn; if diff -q /tmp/p.exp /tmp/p.actn >/dev/null 2>&1; then echo \"PASS\\($b\\)\"; else echo \"DIFF\\($b\\)\"; fi; done')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== if-combine3.sy ===\"; cat test/test_case/performance/if-combine3.sy; echo \"=== .out ===\"; cat test/test_case/performance/if-combine3.out; echo \"=== .in? ===\"; ls test/test_case/performance/if-combine3.in 2>/dev/null && cat test/test_case/performance/if-combine3.in')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== 循环尾部\\(最后30行\\) ===\"; tail -30 test/test_case/performance/if-combine3.sy; echo \"=== .in ===\"; cat test/test_case/performance/if-combine3.in 2>/dev/null || echo \"\\(无\\)\"; echo \"=== .out ===\"; cat test/test_case/performance/if-combine3.out')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && ./build/bin/compiler --emit-asm test/test_case/performance/if-combine3.sy > /tmp/ic.s 2>/dev/null && aarch64-linux-gnu-gcc /tmp/ic.s /tmp/sylib.o -o /tmp/ic.exe 2>/dev/null && echo \"5\" | timeout 20 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/ic.exe 2>/dev/null; echo \"exit=$? \\(小输入 n=5,验证逻辑\\)\"')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"50000000\" | timeout 280 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/ic.exe > /tmp/ic.out 2>&1; ec=$?; echo \"exit=$ec\"; cat /tmp/ic.out')", + "Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && for b in 2025-MYO-20 gameoflife-oscillator; do f=test/test_case/performance/$b.sy; ./build/bin/compiler --emit-asm \"$f\" > /tmp/x.s 2>/tmp/x.e || { echo \"$b COMPILE FAIL\"; head -3 /tmp/x.e; continue; }; aarch64-linux-gnu-gcc /tmp/x.s /tmp/sylib.o -o /tmp/x.exe 2>/tmp/x.a || { echo \"$b ASM FAIL\"; head -3 /tmp/x.a; continue; }; in=test/test_case/performance/$b.in; if [ -f \"$in\" ]; then timeout 280 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/x.exe < \"$in\" > /tmp/x.out 2>&1; else timeout 280 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/x.exe > /tmp/x.out 2>&1; fi; ec=$?; if [ $ec -eq 124 ]; then echo \"$b STILL TIMEOUT\\(>280s\\)\"; continue; fi; exp=test/test_case/performance/$b.out; { cat /tmp/x.out; [ -s /tmp/x.out ] && [ $\\(tail -c1 /tmp/x.out|wc -l\\) -eq 0 ] && printf \"\\\\n\"; printf \"%s\\\\n\" \"$ec\"; } > /tmp/x.act; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" \"$exp\" > /tmp/x.exp 2>/dev/null; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" /tmp/x.act > /tmp/x.actn; if diff -q /tmp/x.exp /tmp/x.actn >/dev/null 2>&1; then echo \"$b PASS\"; else echo \"$b DIFF:\"; diff /tmp/x.exp /tmp/x.actn | head; fi; done')" + ] + } +} diff --git a/include/mir/MIR.h b/include/mir/MIR.h index 2d91d22..71a49d7 100644 --- a/include/mir/MIR.h +++ b/include/mir/MIR.h @@ -1,6 +1,10 @@ +// Lab5 后端 MIR 表示: +// - 虚拟寄存器 + 物理寄存器,两类寄存器(GPR/FPR) +// - 多函数、多基本块、全局变量、栈对象 +// - 指令携带显式 def/use(约定:操作数前 num_defs 个为定值) #pragma once -#include +#include #include #include #include @@ -16,105 +20,273 @@ class MIRContext { public: MIRContext() = default; }; - MIRContext& DefaultContext(); -enum class PhysReg { W0, W8, W9, X29, X30, SP }; +enum class RegClass { GPR, FPR }; + +// 物理寄存器编号(按类内编号): +// GPR: 0..30 = x0..x30, 31 = sp, 32 = xzr +// FPR: 0..31 = s0..s31 +namespace preg { +constexpr int kSP = 31; +constexpr int kXZR = 32; +constexpr int kFP = 29; // x29 +constexpr int kLR = 30; // x30 +constexpr int kIP0 = 16; // x16 scratch +constexpr int kIP1 = 17; // x17 scratch +} // namespace preg -const char* PhysRegName(PhysReg reg); +enum class Cond { AL, EQ, NE, LT, LE, GT, GE, MI, LS, HI, HS }; enum class Opcode { - Prologue, - Epilogue, - MovImm, - LoadStack, - StoreStack, - AddRR, + Mov, // dst<-src (reg copy) + MovImm, // dst<-imm (materialize 32/64-bit) + Sxtw, // dst(64) = sign-extend src(32) + Add, // dst = a + b + Sub, // dst = a - b + Mul, // dst = a * b + SDiv, // dst = a / b + MSub, // dst = a - b*c + AddImm, // dst = a + imm + SubImm, // dst = a - imm + LslImm, // dst = a << imm + Cmp, // a ? b (sets flags) + CmpImm, // a ? imm + CSet, // dst = cond + FAdd, + FSub, + FMul, + FDiv, + FCmp, + FMov, // fpr<-fpr + FMovImm, // fpr <- 32-bit float bits (via scratch gpr) + SCvtF, // fpr = (float)gpr + FCvtZS, // gpr = (int)fpr + Ldr, // dst <- [base, #imm] + Str, // [base, #imm] <- src (def=0, ops: src,base,imm) + LdrStack, // dst <- [frame] + StrStack, // [frame] <- src (def=0, ops: src, frame) + AddrFrame, // dst = addr of frame slot + AddrGlobal, // dst = addr of global symbol + B, // branch label + BCond, // branch cond label (uses flags) + Bl, // call symbol (def=0; ops list arg pregs as uses) Ret, }; class Operand { public: - enum class Kind { Reg, Imm, FrameIndex }; + enum class Kind { None, VReg, PReg, Imm, Frame, Global, Label }; - static Operand Reg(PhysReg reg); - static Operand Imm(int value); - static Operand FrameIndex(int index); + static Operand VReg(int id, RegClass cls, int bytes) { + Operand o; + o.kind_ = Kind::VReg; + o.id_ = id; + o.cls_ = cls; + o.bytes_ = bytes; + return o; + } + static Operand PReg(int id, RegClass cls, int bytes) { + Operand o; + o.kind_ = Kind::PReg; + o.id_ = id; + o.cls_ = cls; + o.bytes_ = bytes; + return o; + } + static Operand Imm(long long v) { + Operand o; + o.kind_ = Kind::Imm; + o.imm_ = v; + return o; + } + static Operand Frame(int idx) { + Operand o; + o.kind_ = Kind::Frame; + o.id_ = idx; + return o; + } + static Operand Global(std::string name) { + Operand o; + o.kind_ = Kind::Global; + o.sym_ = std::move(name); + return o; + } + static Operand Label(std::string name) { + Operand o; + o.kind_ = Kind::Label; + o.sym_ = std::move(name); + return o; + } Kind GetKind() const { return kind_; } - PhysReg GetReg() const { return reg_; } - int GetImm() const { return imm_; } - int GetFrameIndex() const { return imm_; } + bool IsReg() const { return kind_ == Kind::VReg || kind_ == Kind::PReg; } + bool IsVReg() const { return kind_ == Kind::VReg; } + bool IsPReg() const { return kind_ == Kind::PReg; } + int GetId() const { return id_; } + RegClass GetClass() const { return cls_; } + int GetBytes() const { return bytes_; } + long long GetImm() const { return imm_; } + int GetFrame() const { return id_; } + const std::string& GetSym() const { return sym_; } + + void SetPReg(int id) { + kind_ = Kind::PReg; + id_ = id; + } + void SetVReg(int id) { + kind_ = Kind::VReg; + id_ = id; + } + void SetBytes(int b) { bytes_ = b; } private: - Operand(Kind kind, PhysReg reg, int imm); + Kind kind_ = Kind::None; + int id_ = 0; + RegClass cls_ = RegClass::GPR; + int bytes_ = 4; + long long imm_ = 0; + std::string sym_; +}; - Kind kind_; - PhysReg reg_; - int imm_; +struct MachineInstr { + Opcode op; + std::vector ops; + int num_defs = 0; + Cond cond = Cond::AL; + + MachineInstr(Opcode o, std::vector operands, int defs, + Cond c = Cond::AL) + : op(o), ops(std::move(operands)), num_defs(defs), cond(c) {} }; -class MachineInstr { +class MachineBasicBlock { public: - MachineInstr(Opcode opcode, std::vector operands = {}); + explicit MachineBasicBlock(std::string name) : name_(std::move(name)) {} + const std::string& GetName() const { return name_; } + std::vector& Instrs() { return instrs_; } + const std::vector& Instrs() const { return instrs_; } + std::vector& Succs() { return succs_; } + const std::vector& Succs() const { return succs_; } - Opcode GetOpcode() const { return opcode_; } - const std::vector& GetOperands() const { return operands_; } + void Add(MachineInstr mi) { instrs_.push_back(std::move(mi)); } private: - Opcode opcode_; - std::vector operands_; + std::string name_; + std::vector instrs_; + std::vector succs_; }; -struct FrameSlot { +struct VRegInfo { + RegClass cls = RegClass::GPR; + int bytes = 4; +}; + +struct StackObject { int index = 0; int size = 4; - int offset = 0; + int align = 4; + int offset = 0; // 相对 x29,负数 }; -class MachineBasicBlock { +class MachineFunction { public: - explicit MachineBasicBlock(std::string name); - + explicit MachineFunction(std::string name) : name_(std::move(name)) {} const std::string& GetName() const { return name_; } - std::vector& GetInstructions() { return instructions_; } - const std::vector& GetInstructions() const { return instructions_; } - MachineInstr& Append(Opcode opcode, - std::initializer_list operands = {}); + MachineBasicBlock* CreateBlock(const std::string& name) { + blocks_.push_back(std::make_unique(name)); + return blocks_.back().get(); + } + const std::vector>& Blocks() const { + return blocks_; + } + std::vector>& Blocks() { return blocks_; } - private: - std::string name_; - std::vector instructions_; -}; + int NewVReg(RegClass cls, int bytes) { + int id = static_cast(vregs_.size()); + vregs_.push_back(VRegInfo{cls, bytes}); + return id; + } + Operand NewVRegOp(RegClass cls, int bytes) { + return Operand::VReg(NewVReg(cls, bytes), cls, bytes); + } + int NumVRegs() const { return static_cast(vregs_.size()); } + const VRegInfo& VReg(int id) const { return vregs_[id]; } -class MachineFunction { - public: - explicit MachineFunction(std::string name); + int CreateStackObject(int size, int align) { + StackObject obj; + obj.index = static_cast(stack_.size()); + obj.size = size; + obj.align = align; + stack_.push_back(obj); + return obj.index; + } + std::vector& StackObjects() { return stack_; } + const std::vector& StackObjects() const { return stack_; } + StackObject& Stack(int idx) { return stack_[idx]; } - const std::string& GetName() const { return name_; } - MachineBasicBlock& GetEntry() { return entry_; } - const MachineBasicBlock& GetEntry() const { return entry_; } + int GetFrameSize() const { return frame_size_; } + void SetFrameSize(int s) { frame_size_ = s; } - int CreateFrameIndex(int size = 4); - FrameSlot& GetFrameSlot(int index); - const FrameSlot& GetFrameSlot(int index) const; - const std::vector& GetFrameSlots() const { return frame_slots_; } + // 寄存器分配产物:本函数用到、需要保存恢复的 callee-saved 物理寄存器 + std::vector& CalleeSavedGPR() { return callee_gpr_; } + std::vector& CalleeSavedFPR() { return callee_fpr_; } + const std::vector& CalleeSavedGPR() const { return callee_gpr_; } + const std::vector& CalleeSavedFPR() const { return callee_fpr_; } - int GetFrameSize() const { return frame_size_; } - void SetFrameSize(int size) { frame_size_ = size; } + int NumIntArgs() const { return num_int_args_; } + int NumFloatArgs() const { return num_float_args_; } + void SetArgCounts(int i, int f) { + num_int_args_ = i; + num_float_args_ = f; + } private: std::string name_; - MachineBasicBlock entry_; - std::vector frame_slots_; + std::vector> blocks_; + std::vector vregs_; + std::vector stack_; + std::vector callee_gpr_; + std::vector callee_fpr_; int frame_size_ = 0; + int num_int_args_ = 0; + int num_float_args_ = 0; +}; + +struct MachineGlobal { + std::string name; + int size = 0; // 总字节数 + int align = 4; + bool is_const = false; + bool zero_init = true; + std::vector words; // 非零初始化:按 4 字节小端存放的原始位 }; -std::unique_ptr LowerToMIR(const ir::Module& module); +class MachineModule { + public: + std::vector>& Functions() { return funcs_; } + const std::vector>& Functions() const { + return funcs_; + } + std::vector& Globals() { return globals_; } + const std::vector& Globals() const { return globals_; } + + private: + std::vector> funcs_; + std::vector globals_; +}; + +const char* GPRName(int id, int bytes); +const char* FPRName(int id, int bytes); + +std::unique_ptr LowerToMIR(const ir::Module& module); void RunRegAlloc(MachineFunction& function); void RunFrameLowering(MachineFunction& function); -void PrintAsm(const MachineFunction& function, std::ostream& os); +void RunPeephole(MachineFunction& function); +void RunBackendPipeline(MachineModule& module); +void PrintAsm(const MachineModule& module, std::ostream& os); +void PrintAArch64AsmFromMIR(const ir::Module& module, std::ostream& os); void PrintAArch64AsmFromIR(const ir::Module& module, std::ostream& os); } // namespace mir diff --git a/src/main.cpp b/src/main.cpp index da87ec5..d02678e 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -52,7 +52,7 @@ int main(int argc, char** argv) { if (need_blank_line) { std::cout << "\n"; } - mir::PrintAArch64AsmFromIR(*module, std::cout); + mir::PrintAArch64AsmFromMIR(*module, std::cout); } #else if (opts.emit_ir || opts.emit_asm) { diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index 4d1f65f..9ab26c3 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -1,78 +1,351 @@ +// AArch64 汇编发射(Lab5)。 +// 输入为寄存器分配 + 栈帧布局后的 MachineModule(操作数均为物理寄存器/栈对象)。 #include "mir/MIR.h" #include -#include - -#include "utils/Log.h" +#include +#include namespace mir { namespace { -const FrameSlot& GetFrameSlot(const MachineFunction& function, - const Operand& operand) { - if (operand.GetKind() != Operand::Kind::FrameIndex) { - throw std::runtime_error(FormatError("mir", "期望 FrameIndex 操作数")); +int CalleeAreaBytes(const MachineFunction& f) { + return ((int)f.CalleeSavedGPR().size() + (int)f.CalleeSavedFPR().size()) * 8; +} + +class Printer { + public: + Printer(const MachineModule& m, std::ostream& os) : m_(m), os_(os) {} + void Run(); + + private: + const MachineModule& m_; + std::ostream& os_; + const MachineFunction* mf_ = nullptr; + + void EmitGlobals(); + void EmitFunction(const MachineFunction& f); + void EmitProlog(const MachineFunction& f); + void EmitEpilog(const MachineFunction& f); + void EmitInstr(const MachineInstr& mi); + + std::string R(const Operand& op) { + if (op.GetClass() == RegClass::FPR) + return FPRName(op.GetId(), op.GetBytes()); + return GPRName(op.GetId(), op.GetBytes()); + } + // 把任意立即数加载进暂存寄存器(x16 / x17)。 + void LoadImm(const char* dst, long long v); + // [x29 + offset] 形式访存,offset 过大时用 x16 计算地址。 + void MemAccess(const char* mnem, const std::string& reg, int offset); + int FrameOffset(int stack_index) { + return mf_->StackObjects()[stack_index].offset; + } +}; + +const char* CondStr(Cond c) { + switch (c) { + case Cond::EQ: return "eq"; + case Cond::NE: return "ne"; + case Cond::LT: return "lt"; + case Cond::LE: return "le"; + case Cond::GT: return "gt"; + case Cond::GE: return "ge"; + case Cond::MI: return "mi"; + case Cond::LS: return "ls"; + case Cond::HI: return "hi"; + case Cond::HS: return "hs"; + default: return "al"; } - return function.GetFrameSlot(operand.GetFrameIndex()); } -void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg, - int offset) { - os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset - << "]\n"; +void Printer::LoadImm(const char* dst, long long v) { + bool is_w = dst[0] == 'w'; + unsigned long long uv = (unsigned long long)v; + int hi = is_w ? 32 : 64; // w 寄存器只填低 32 位 + if (is_w) uv &= 0xffffffffULL; + os_ << "\tmov\t" << dst << ", #" << (uv & 0xffff) << "\n"; + for (int sh = 16; sh < hi; sh += 16) { + unsigned chunk = (uv >> sh) & 0xffff; + if (chunk) + os_ << "\tmovk\t" << dst << ", #" << chunk << ", lsl #" << sh << "\n"; + } } -} // namespace +void Printer::MemAccess(const char* mnem, const std::string& reg, int offset) { + if (offset >= -256 && offset <= 4095) { + os_ << "\t" << mnem << "\t" << reg << ", [x29, #" << offset << "]\n"; + } else { + LoadImm("x16", offset); + os_ << "\tadd\tx16, x29, x16\n"; + os_ << "\t" << mnem << "\t" << reg << ", [x16]\n"; + } +} -void PrintAsm(const MachineFunction& function, std::ostream& os) { - os << ".text\n"; - os << ".global " << function.GetName() << "\n"; - os << ".type " << function.GetName() << ", %function\n"; - os << function.GetName() << ":\n"; - - for (const auto& inst : function.GetEntry().GetInstructions()) { - const auto& ops = inst.GetOperands(); - switch (inst.GetOpcode()) { - case Opcode::Prologue: - os << " stp x29, x30, [sp, #-16]!\n"; - os << " mov x29, sp\n"; - if (function.GetFrameSize() > 0) { - os << " sub sp, sp, #" << function.GetFrameSize() << "\n"; - } - break; - case Opcode::Epilogue: - if (function.GetFrameSize() > 0) { - os << " add sp, sp, #" << function.GetFrameSize() << "\n"; - } - os << " ldp x29, x30, [sp], #16\n"; - break; - case Opcode::MovImm: - os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #" - << ops.at(1).GetImm() << "\n"; - break; - case Opcode::LoadStack: { - const auto& slot = GetFrameSlot(function, ops.at(1)); - PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset); +void Printer::Run() { + EmitGlobals(); + os_ << "\t.text\n"; + for (const auto& f : m_.Functions()) EmitFunction(*f); +} + +void Printer::EmitGlobals() { + if (m_.Globals().empty()) return; + for (const auto& g : m_.Globals()) { + if (g.zero_init) { + os_ << "\t.bss\n"; + os_ << "\t.align\t" << (g.align == 16 ? 4 : 2) << "\n"; + os_ << "\t.globl\t" << g.name << "\n"; + os_ << g.name << ":\n"; + os_ << "\t.zero\t" << g.size << "\n"; + } else { + os_ << "\t.data\n"; + os_ << "\t.align\t" << (g.align == 16 ? 4 : 2) << "\n"; + os_ << "\t.globl\t" << g.name << "\n"; + os_ << g.name << ":\n"; + for (unsigned w : g.words) os_ << "\t.word\t" << w << "\n"; + } + } +} + +void Printer::EmitFunction(const MachineFunction& f) { + mf_ = &f; + os_ << "\t.globl\t" << f.GetName() << "\n"; + os_ << "\t.type\t" << f.GetName() << ", %function\n"; + os_ << f.GetName() << ":\n"; + EmitProlog(f); + for (const auto& bb : f.Blocks()) { + os_ << ".L." << f.GetName() << "." << bb->GetName() << ":\n"; + for (const auto& mi : bb->Instrs()) EmitInstr(mi); + } + os_ << "\t.size\t" << f.GetName() << ", .-" << f.GetName() << "\n"; +} + +void Printer::EmitProlog(const MachineFunction& f) { + int frame = f.GetFrameSize(); + // sub sp, sp, #frame ; 保存 fp/lr ; mov x29, sp + if (frame <= 4095) { + os_ << "\tsub\tsp, sp, #" << frame << "\n"; + } else { + LoadImm("x16", frame); + os_ << "\tsub\tsp, sp, x16\n"; + } + os_ << "\tstp\tx29, x30, [sp]\n"; + os_ << "\tmov\tx29, sp\n"; + // 保存 callee-saved(相对 x29 偏移,从 16 开始)。 + int off = 16; + for (int r : f.CalleeSavedGPR()) { + MemAccess("str", GPRName(r, 8), off); + off += 8; + } + for (int r : f.CalleeSavedFPR()) { + MemAccess("str", FPRName(r, 8), off); + off += 8; + } +} + +void Printer::EmitEpilog(const MachineFunction& f) { + int frame = f.GetFrameSize(); + int off = 16; + for (int r : f.CalleeSavedGPR()) { + MemAccess("ldr", GPRName(r, 8), off); + off += 8; + } + for (int r : f.CalleeSavedFPR()) { + MemAccess("ldr", FPRName(r, 8), off); + off += 8; + } + os_ << "\tldp\tx29, x30, [sp]\n"; + if (frame <= 4095) { + os_ << "\tadd\tsp, sp, #" << frame << "\n"; + } else { + LoadImm("x16", frame); + os_ << "\tadd\tsp, sp, x16\n"; + } + os_ << "\tret\n"; +} + +void Printer::EmitInstr(const MachineInstr& mi) { + auto label = [&](const Operand& o) { + return ".L." + mf_->GetName() + "." + o.GetSym(); + }; + switch (mi.op) { + case Opcode::Mov: { + // 同寄存器拷贝可省略(peephole 已处理,这里再兜底)。 + if (mi.ops[0].IsPReg() && mi.ops[1].IsPReg() && + mi.ops[0].GetId() == mi.ops[1].GetId()) break; + os_ << "\tmov\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n"; + break; + } + case Opcode::MovImm: + LoadImm(R(mi.ops[0]).c_str(), mi.ops[1].GetImm()); + break; + case Opcode::Sxtw: + os_ << "\tsxtw\t" << GPRName(mi.ops[0].GetId(), 8) << ", " + << GPRName(mi.ops[1].GetId(), 4) << "\n"; + break; + case Opcode::Add: + os_ << "\tadd\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", " + << R(mi.ops[2]) << "\n"; + break; + case Opcode::Sub: + os_ << "\tsub\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", " + << R(mi.ops[2]) << "\n"; + break; + case Opcode::Mul: + os_ << "\tmul\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", " + << R(mi.ops[2]) << "\n"; + break; + case Opcode::SDiv: + os_ << "\tsdiv\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", " + << R(mi.ops[2]) << "\n"; + break; + case Opcode::MSub: + os_ << "\tmsub\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", " + << R(mi.ops[2]) << ", " << R(mi.ops[3]) << "\n"; + break; + case Opcode::AddImm: { + long long v = mi.ops[2].GetImm(); + if (v >= 0 && v <= 4095) { + os_ << "\tadd\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", #" << v + << "\n"; + } else { + LoadImm("x16", v); + os_ << "\tadd\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", x16\n"; } - case Opcode::StoreStack: { - const auto& slot = GetFrameSlot(function, ops.at(1)); - PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset); + break; + } + case Opcode::SubImm: + os_ << "\tsub\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", #" + << mi.ops[2].GetImm() << "\n"; + break; + case Opcode::LslImm: + os_ << "\tlsl\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", #" + << mi.ops[2].GetImm() << "\n"; + break; + case Opcode::Cmp: + os_ << "\tcmp\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n"; + break; + case Opcode::CmpImm: + os_ << "\tcmp\t" << R(mi.ops[0]) << ", #" << mi.ops[1].GetImm() << "\n"; + break; + case Opcode::CSet: + os_ << "\tcset\t" << R(mi.ops[0]) << ", " << CondStr(mi.cond) << "\n"; + break; + case Opcode::FAdd: + os_ << "\tfadd\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", " + << R(mi.ops[2]) << "\n"; + break; + case Opcode::FSub: + os_ << "\tfsub\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", " + << R(mi.ops[2]) << "\n"; + break; + case Opcode::FMul: + os_ << "\tfmul\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", " + << R(mi.ops[2]) << "\n"; + break; + case Opcode::FDiv: + os_ << "\tfdiv\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", " + << R(mi.ops[2]) << "\n"; + break; + case Opcode::FCmp: + os_ << "\tfcmp\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n"; + break; + case Opcode::FMov: { + if (mi.ops[0].IsPReg() && mi.ops[1].IsPReg() && + mi.ops[0].GetId() == mi.ops[1].GetId()) break; + os_ << "\tfmov\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n"; + break; + } + case Opcode::FMovImm: + // 通过 x16 装入 32 位浮点位模式,再 fmov 到 s 寄存器。 + LoadImm("w16", (long long)(unsigned)mi.ops[1].GetImm()); + os_ << "\tfmov\t" << R(mi.ops[0]) << ", w16\n"; + break; + case Opcode::SCvtF: + os_ << "\tscvtf\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n"; + break; + case Opcode::FCvtZS: + os_ << "\tfcvtzs\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n"; + break; + case Opcode::Ldr: { + long long off = mi.ops[2].GetImm(); + if (off >= -256 && off <= 4095) + os_ << "\tldr\t" << R(mi.ops[0]) << ", [" << R(mi.ops[1]) << ", #" + << off << "]\n"; + else { + LoadImm("x16", off); + os_ << "\tldr\t" << R(mi.ops[0]) << ", [" << R(mi.ops[1]) << ", x16]\n"; } - case Opcode::AddRR: - os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::Ret: - os << " ret\n"; - break; + break; + } + case Opcode::Str: { + long long off = mi.ops[2].GetImm(); + if (off >= -256 && off <= 4095) + os_ << "\tstr\t" << R(mi.ops[0]) << ", [" << R(mi.ops[1]) << ", #" + << off << "]\n"; + else { + LoadImm("x16", off); + os_ << "\tstr\t" << R(mi.ops[0]) << ", [" << R(mi.ops[1]) << ", x16]\n"; + } + break; } + case Opcode::LdrStack: + MemAccess("ldr", R(mi.ops[0]), FrameOffset(mi.ops[1].GetFrame())); + break; + case Opcode::StrStack: + MemAccess("str", R(mi.ops[0]), FrameOffset(mi.ops[1].GetFrame())); + break; + case Opcode::AddrFrame: { + int off = FrameOffset(mi.ops[1].GetFrame()); + if (off >= 0 && off <= 4095) + os_ << "\tadd\t" << R(mi.ops[0]) << ", x29, #" << off << "\n"; + else { + LoadImm("x16", off); + os_ << "\tadd\t" << R(mi.ops[0]) << ", x29, x16\n"; + } + break; + } + case Opcode::AddrGlobal: + os_ << "\tadrp\t" << R(mi.ops[0]) << ", " << mi.ops[1].GetSym() << "\n"; + os_ << "\tadd\t" << R(mi.ops[0]) << ", " << R(mi.ops[0]) << ", :lo12:" + << mi.ops[1].GetSym() << "\n"; + break; + case Opcode::B: + os_ << "\tb\t" << label(mi.ops[0]) << "\n"; + break; + case Opcode::BCond: + os_ << "\tb." << CondStr(mi.cond) << "\t" << label(mi.ops[0]) << "\n"; + break; + case Opcode::Bl: + os_ << "\tbl\t" << mi.ops[0].GetSym() << "\n"; + break; + case Opcode::Ret: + EmitEpilog(*mf_); + break; } +} + +} // namespace + +void PrintAsm(const MachineModule& module, std::ostream& os) { + Printer p(module, os); + p.Run(); +} + +void RunBackendPipeline(MachineModule& module) { + for (auto& f : module.Functions()) { + RunRegAlloc(*f); + RunFrameLowering(*f); + RunPeephole(*f); + } +} - os << ".size " << function.GetName() << ", .-" << function.GetName() - << "\n"; +void PrintAArch64AsmFromMIR(const ir::Module& module, std::ostream& os) { + auto mm = LowerToMIR(module); + RunBackendPipeline(*mm); + PrintAsm(*mm, os); } } // namespace mir diff --git a/src/mir/CMakeLists.txt b/src/mir/CMakeLists.txt index 78a4659..5eeb438 100644 --- a/src/mir/CMakeLists.txt +++ b/src/mir/CMakeLists.txt @@ -9,6 +9,8 @@ add_library(mir_core STATIC FrameLowering.cpp AsmPrinter.cpp LLVMAsmBackend.cpp + passes/PassManager.cpp + passes/Peephole.cpp ) target_link_libraries(mir_core PUBLIC @@ -16,10 +18,7 @@ target_link_libraries(mir_core PUBLIC ir ) -add_subdirectory(passes) - add_library(mir INTERFACE) target_link_libraries(mir INTERFACE mir_core - mir_passes ) diff --git a/src/mir/FrameLowering.cpp b/src/mir/FrameLowering.cpp index 679ab68..c1777de 100644 --- a/src/mir/FrameLowering.cpp +++ b/src/mir/FrameLowering.cpp @@ -1,45 +1,35 @@ +// 栈帧布局(Lab5): +// 在寄存器分配产出 spill 槽与 callee-saved 使用集合后,确定每个栈对象 +// 相对 x29 的偏移与总帧大小。实际的 prologue/epilogue 由 AsmPrinter 按 +// 同一套布局公式发射。 +// +// 帧布局(x29 指向帧底,sp == x29): +// [x29 + 0] 保存的 x29 +// [x29 + 8] 保存的 x30(lr) +// [x29 + 16 ...] callee-saved GPR、callee-saved FPR(各 8 字节) +// [其后] 局部/spill 栈对象(按声明顺序,按对齐摆放) +// 总大小对齐到 16 字节。 #include "mir/MIR.h" -#include -#include - -#include "utils/Log.h" - namespace mir { -namespace { -int AlignTo(int value, int align) { - return ((value + align - 1) / align) * align; +int CalleeSavedAreaBytes(const MachineFunction& f) { + int n = (int)f.CalleeSavedGPR().size() + (int)f.CalleeSavedFPR().size(); + return n * 8; } -} // namespace - void RunFrameLowering(MachineFunction& function) { - int cursor = 0; - for (const auto& slot : function.GetFrameSlots()) { - cursor += slot.size; - if (-cursor < -256) { - throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧")); - } - } - - cursor = 0; - for (const auto& slot : function.GetFrameSlots()) { - cursor += slot.size; - function.GetFrameSlot(slot.index).offset = -cursor; - } - function.SetFrameSize(AlignTo(cursor, 16)); - - auto& insts = function.GetEntry().GetInstructions(); - std::vector lowered; - lowered.emplace_back(Opcode::Prologue); - for (const auto& inst : insts) { - if (inst.GetOpcode() == Opcode::Ret) { - lowered.emplace_back(Opcode::Epilogue); - } - lowered.push_back(inst); + int base = 16 + CalleeSavedAreaBytes(function); // fp/lr + callee-saved + int off = base; + for (auto& obj : function.StackObjects()) { + int align = obj.align < 4 ? 4 : obj.align; + off = (off + align - 1) / align * align; + obj.offset = off; // 相对 x29 的正偏移 + off += obj.size; } - insts = std::move(lowered); + int frame = (off + 15) / 16 * 16; + if (frame < 16) frame = 16; + function.SetFrameSize(frame); } } // namespace mir diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index 843890c..f5b751d 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -1,7 +1,15 @@ +// IR -> MIR 指令选择(Lab5): +// - 为每个 IR 值分配虚拟寄存器(GPR / FPR 两类) +// - alloca -> 栈对象;gep/global -> 地址计算 +// - 完整覆盖算术、比较、分支、调用、访存、类型转换、浮点 #include "mir/MIR.h" +#include +#include #include +#include #include +#include #include "ir/IR.h" #include "utils/Log.h" @@ -9,123 +17,509 @@ namespace mir { namespace { -using ValueSlotMap = std::unordered_map; +int TypeSize(const ir::Type& t) { + switch (t.GetKind()) { + case ir::Type::Kind::Int1: + case ir::Type::Kind::Int32: + case ir::Type::Kind::Float: + return 4; + case ir::Type::Kind::Pointer: + return 8; + case ir::Type::Kind::Array: + return static_cast(t.GetArraySize()) * + TypeSize(*t.GetElementType()); + default: + return 8; + } +} + +RegClass ClassOf(const ir::Type& t) { + return t.IsFloat() ? RegClass::FPR : RegClass::GPR; +} +int BytesOf(const ir::Type& t) { return t.IsPointer() ? 8 : 4; } + +bool IsPow2(long long v) { return v > 0 && (v & (v - 1)) == 0; } +int Log2(long long v) { + int n = 0; + while ((1LL << n) < v) ++n; + return n; +} + +class Lowerer { + public: + Lowerer(const ir::Module& m, MachineModule& out) : ir_(m), out_(out) {} + void Run(); + + private: + const ir::Module& ir_; + MachineModule& out_; + MachineFunction* mf_ = nullptr; + MachineBasicBlock* mbb_ = nullptr; + std::unordered_map vmap_; + std::unordered_map bmap_; + std::unordered_map allocas_; + int label_id_ = 0; -void EmitValueToReg(const ir::Value* value, PhysReg target, - const ValueSlotMap& slots, MachineBasicBlock& block) { - if (auto* constant = dynamic_cast(value)) { - block.Append(Opcode::MovImm, - {Operand::Reg(target), Operand::Imm(constant->GetValue())}); - return; + void Emit(Opcode op, std::vector ops, int defs, Cond c = Cond::AL) { + mbb_->Add(MachineInstr(op, std::move(ops), defs, c)); } + Operand NewG(int bytes = 4) { return mf_->NewVRegOp(RegClass::GPR, bytes); } + Operand NewF() { return mf_->NewVRegOp(RegClass::FPR, 4); } - auto it = slots.find(value); - if (it == slots.end()) { - throw std::runtime_error( - FormatError("mir", "找不到值对应的栈槽: " + value->GetName())); + void LowerFunction(const ir::Function& f); + void LowerBlock(const ir::BasicBlock& bb); + void LowerInst(const ir::Instruction& inst); + void LowerInstMem(const ir::Instruction& inst); + + Operand GetReg(const ir::Value* v); // 取值(必要时物化常量/地址) + Operand MaterializeInt(int v); + Operand AddressOf(const ir::Value* ptr); + long long GepConst(const ir::GepInst& gep, Operand* out_base); + void LowerGlobals(); + Cond ICmpCond(ir::ICmpPredicate p); + Cond FCmpCond(ir::FCmpPredicate p); +}; + +void Lowerer::Run() { + LowerGlobals(); + for (const auto& f : ir_.GetFunctions()) { + if (f->IsDeclaration()) continue; + LowerFunction(*f); } +} - block.Append(Opcode::LoadStack, - {Operand::Reg(target), Operand::FrameIndex(it->second)}); +void Lowerer::LowerGlobals() { + for (const auto& g : ir_.GetGlobals()) { + MachineGlobal mg; + mg.name = g->GetName(); + const ir::Type& vt = *g->GetValueType(); + mg.size = TypeSize(vt); + mg.align = vt.IsArray() ? 16 : 4; + mg.is_const = g->IsConst(); + ir::ConstantValue* init = g->GetInitializer(); + mg.zero_init = true; + int nwords = (mg.size + 3) / 4; + mg.words.assign(nwords, 0u); + // 收集初始化位(递归展开数组)。 + std::vector flat; + std::function walk = [&](ir::ConstantValue* c) { + if (!c) return; + if (auto* ci = dynamic_cast(c)) { + flat.push_back(static_cast(ci->GetValue())); + } else if (auto* cf = dynamic_cast(c)) { + float v = cf->GetValue(); + unsigned bits; + std::memcpy(&bits, &v, 4); + flat.push_back(bits); + } else if (auto* ca = dynamic_cast(c)) { + for (auto* e : ca->GetElements()) walk(e); + } + }; + walk(init); + for (size_t i = 0; i < flat.size() && i < mg.words.size(); ++i) { + mg.words[i] = flat[i]; + if (flat[i] != 0) mg.zero_init = false; + } + out_.Globals().push_back(std::move(mg)); + } } -void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, - ValueSlotMap& slots) { - auto& block = function.GetEntry(); +void Lowerer::LowerFunction(const ir::Function& f) { + out_.Functions().push_back(std::make_unique(f.GetName())); + mf_ = out_.Functions().back().get(); + vmap_.clear(); + bmap_.clear(); + allocas_.clear(); - switch (inst.GetOpcode()) { - case ir::Opcode::Alloca: { - slots.emplace(&inst, function.CreateFrameIndex()); - return; - } - case ir::Opcode::Store: { - auto& store = static_cast(inst); - auto dst = slots.find(store.GetPtr()); - if (dst == slots.end()) { - throw std::runtime_error( - FormatError("mir", "暂不支持对非栈变量地址进行写入")); - } - EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)}); - return; - } - case ir::Opcode::Load: { - auto& load = static_cast(inst); - auto src = slots.find(load.GetPtr()); - if (src == slots.end()) { - throw std::runtime_error( - FormatError("mir", "暂不支持对非栈变量地址进行读取")); - } - int dst_slot = function.CreateFrameIndex(); - block.Append(Opcode::LoadStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); - return; - } - case ir::Opcode::Add: { - auto& bin = static_cast(inst); - int dst_slot = function.CreateFrameIndex(); - EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); - EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); - block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W9)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); - slots.emplace(&inst, dst_slot); - return; - } - case ir::Opcode::Ret: { - auto& ret = static_cast(inst); - EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block); - block.Append(Opcode::Ret); - return; - } - case ir::Opcode::Sub: - case ir::Opcode::Mul: - case ir::Opcode::SDiv: - case ir::Opcode::SRem: - case ir::Opcode::FAdd: - case ir::Opcode::FSub: - case ir::Opcode::FMul: - case ir::Opcode::FDiv: - throw std::runtime_error(FormatError("mir", "暂不支持该二元运算")); - default: - break; + for (const auto& bb : f.GetBlocks()) { + bmap_[bb.get()] = mf_->CreateBlock(bb->GetName()); + } + // 记录后继,便于活跃性分析。 + for (const auto& bb : f.GetBlocks()) { + MachineBasicBlock* mb = bmap_[bb.get()]; + for (auto* s : bb->GetSuccessors()) mb->Succs().push_back(bmap_[s]); + } + + mbb_ = bmap_[f.GetEntry()]; + + // 形参:整型走 x0.., 浮点走 s0..,超过 8 个的从栈读取(测试未用,简化)。 + int ig = 0, fg = 0; + std::vector arg_copies; + for (size_t i = 0; i < f.GetNumArgs(); ++i) { + ir::Argument* a = const_cast(f).GetArg(i); + const ir::Type& at = *a->GetType(); + if (at.IsFloat()) { + Operand dst = NewF(); + arg_copies.push_back(MachineInstr( + Opcode::FMov, + {dst, Operand::PReg(fg++, RegClass::FPR, 4)}, 1)); + vmap_[a] = dst; + } else { + int bytes = at.IsPointer() ? 8 : 4; + Operand dst = NewG(bytes); + arg_copies.push_back(MachineInstr( + Opcode::Mov, {dst, Operand::PReg(ig++, RegClass::GPR, bytes)}, 1)); + vmap_[a] = dst; + } } + mf_->SetArgCounts(ig, fg); + for (auto& mi : arg_copies) mbb_->Add(std::move(mi)); - throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令")); + for (const auto& bb : f.GetBlocks()) { + mbb_ = bmap_[bb.get()]; + LowerBlock(*bb); + } } -} // namespace +void Lowerer::LowerBlock(const ir::BasicBlock& bb) { + for (const auto& inst : bb.GetInstructions()) { + LowerInst(*inst); + } +} + +Operand Lowerer::MaterializeInt(int v) { + Operand d = NewG(4); + Emit(Opcode::MovImm, {d, Operand::Imm(v)}, 1); + return d; +} + +Operand Lowerer::GetReg(const ir::Value* v) { + auto it = vmap_.find(v); + if (it != vmap_.end()) return it->second; + if (auto* ci = dynamic_cast(v)) { + return MaterializeInt(ci->GetValue()); + } + if (auto* cf = dynamic_cast(v)) { + float f = cf->GetValue(); + unsigned bits; + std::memcpy(&bits, &f, 4); + Operand d = NewF(); + Emit(Opcode::FMovImm, {d, Operand::Imm(bits)}, 1); + return d; + } + // 兜底:未知值视为 0。 + return MaterializeInt(0); +} -std::unique_ptr LowerToMIR(const ir::Module& module) { - DefaultContext(); +// 计算 gep 的常量字节偏移;返回偏移并把基址写入 *out_base。 +// 若存在变量下标,直接生成地址计算指令并把最终地址放入 *out_base、返回 0。 +long long Lowerer::GepConst(const ir::GepInst& gep, Operand* out_base) { + Operand base = AddressOf(gep.GetBasePtr()); + // 推断逐层元素类型:基址指针指向的类型。 + std::shared_ptr cur = gep.GetBasePtr()->GetType()->GetElementType(); + long long const_off = 0; + Operand addr = base; + bool addr_dirty = false; + const auto& idxs = gep.GetIndices(); + for (size_t i = 0; i < idxs.size(); ++i) { + int elem_size = cur ? TypeSize(*cur) : 4; + ir::Value* iv = idxs[i]; + if (auto* ci = dynamic_cast(iv)) { + const_off += static_cast(ci->GetValue()) * elem_size; + } else { + // addr += index * elem_size + Operand idx = GetReg(iv); + Operand idx64 = NewG(8); + Emit(Opcode::Sxtw, {idx64, idx}, 1); + Operand scaled = NewG(8); + if (IsPow2(elem_size)) { + Emit(Opcode::LslImm, {scaled, idx64, Operand::Imm(Log2(elem_size))}, 1); + } else { + Operand sz = NewG(8); + Emit(Opcode::MovImm, {sz, Operand::Imm(elem_size)}, 1); + Emit(Opcode::Mul, {scaled, idx64, sz}, 1); + } + Operand na = NewG(8); + Emit(Opcode::Add, {na, addr, scaled}, 1); + addr = na; + addr_dirty = true; + } + if (cur && cur->IsArray()) cur = cur->GetElementType(); + } + if (addr_dirty) { + *out_base = addr; + return const_off; + } + *out_base = base; + return const_off; +} + +// 返回某个指针型 IR 值对应的“地址”寄存器。 +Operand Lowerer::AddressOf(const ir::Value* ptr) { + auto it = vmap_.find(ptr); + if (it != vmap_.end()) return it->second; + if (auto* a = dynamic_cast(ptr)) { + int idx; + auto ai = allocas_.find(a); + if (ai == allocas_.end()) { + idx = mf_->CreateStackObject(TypeSize(*a->GetAllocatedType()), + a->GetAllocatedType()->IsArray() ? 16 : 4); + allocas_[a] = idx; + } else { + idx = ai->second; + } + Operand d = NewG(8); + Emit(Opcode::AddrFrame, {d, Operand::Frame(idx)}, 1); + vmap_[ptr] = d; + return d; + } + // 全局变量 + Operand d = NewG(8); + Emit(Opcode::AddrGlobal, {d, Operand::Global(ptr->GetName())}, 1); + return d; +} - if (module.GetFunctions().size() != 1) { - throw std::runtime_error(FormatError("mir", "暂不支持多个函数")); +Cond Lowerer::ICmpCond(ir::ICmpPredicate p) { + switch (p) { + case ir::ICmpPredicate::Eq: return Cond::EQ; + case ir::ICmpPredicate::Ne: return Cond::NE; + case ir::ICmpPredicate::Slt: return Cond::LT; + case ir::ICmpPredicate::Sle: return Cond::LE; + case ir::ICmpPredicate::Sgt: return Cond::GT; + case ir::ICmpPredicate::Sge: return Cond::GE; } + return Cond::EQ; +} - const auto& func = *module.GetFunctions().front(); - if (func.GetName() != "main") { - throw std::runtime_error(FormatError("mir", "暂不支持非 main 函数")); +Cond Lowerer::FCmpCond(ir::FCmpPredicate p) { + switch (p) { + case ir::FCmpPredicate::Oeq: return Cond::EQ; + case ir::FCmpPredicate::One: return Cond::NE; + case ir::FCmpPredicate::Olt: return Cond::MI; + case ir::FCmpPredicate::Ole: return Cond::LS; + case ir::FCmpPredicate::Ogt: return Cond::GT; + case ir::FCmpPredicate::Oge: return Cond::GE; } + return Cond::EQ; +} - auto machine_func = std::make_unique(func.GetName()); - ValueSlotMap slots; - const auto* entry = func.GetEntry(); - if (!entry) { - throw std::runtime_error(FormatError("mir", "IR 函数缺少入口基本块")); +void Lowerer::LowerInst(const ir::Instruction& inst) { + using ir::Opcode; + switch (inst.GetOpcode()) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::SDiv: + case Opcode::SRem: { + auto& b = static_cast(inst); + Operand l = GetReg(b.GetLhs()); + Operand r = GetReg(b.GetRhs()); + Operand d = NewG(4); + mir::Opcode mop = mir::Opcode::Add; + if (inst.GetOpcode() == Opcode::Add) mop = mir::Opcode::Add; + else if (inst.GetOpcode() == Opcode::Sub) mop = mir::Opcode::Sub; + else if (inst.GetOpcode() == Opcode::Mul) mop = mir::Opcode::Mul; + else mop = mir::Opcode::SDiv; + if (inst.GetOpcode() == Opcode::SRem) { + Operand q = NewG(4); + Emit(mir::Opcode::SDiv, {q, l, r}, 1); + Emit(mir::Opcode::MSub, {d, q, r, l}, 1); // d = l - q*r + } else { + Emit(mop, {d, l, r}, 1); + } + vmap_[&inst] = d; + break; + } + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: { + auto& b = static_cast(inst); + Operand l = GetReg(b.GetLhs()); + Operand r = GetReg(b.GetRhs()); + Operand d = NewF(); + mir::Opcode mop = mir::Opcode::FAdd; + if (inst.GetOpcode() == Opcode::FSub) mop = mir::Opcode::FSub; + else if (inst.GetOpcode() == Opcode::FMul) mop = mir::Opcode::FMul; + else if (inst.GetOpcode() == Opcode::FDiv) mop = mir::Opcode::FDiv; + Emit(mop, {d, l, r}, 1); + vmap_[&inst] = d; + break; + } + case Opcode::SIToFP: { + auto& c = static_cast(inst); + Operand s = GetReg(c.GetValue()); + Operand d = NewF(); + Emit(mir::Opcode::SCvtF, {d, s}, 1); + vmap_[&inst] = d; + break; + } + case Opcode::FPToSI: { + auto& c = static_cast(inst); + Operand s = GetReg(c.GetValue()); + Operand d = NewG(4); + Emit(mir::Opcode::FCvtZS, {d, s}, 1); + vmap_[&inst] = d; + break; + } + case Opcode::ZExt: { + auto& c = static_cast(inst); + Operand s = GetReg(c.GetValue()); + Operand d = NewG(4); + Emit(mir::Opcode::Mov, {d, s}, 1); // i1->i32:cset 已产出 0/1 + vmap_[&inst] = d; + break; + } + case Opcode::ICmp: { + auto& c = static_cast(inst); + Operand l = GetReg(c.GetLhs()); + Operand r = GetReg(c.GetRhs()); + Emit(mir::Opcode::Cmp, {l, r}, 0); + Operand d = NewG(4); + Emit(mir::Opcode::CSet, {d}, 1, ICmpCond(c.GetPredicate())); + vmap_[&inst] = d; + break; + } + case Opcode::FCmp: { + auto& c = static_cast(inst); + Operand l = GetReg(c.GetLhs()); + Operand r = GetReg(c.GetRhs()); + Emit(mir::Opcode::FCmp, {l, r}, 0); + Operand d = NewG(4); + Emit(mir::Opcode::CSet, {d}, 1, FCmpCond(c.GetPredicate())); + vmap_[&inst] = d; + break; + } + default: + LowerInstMem(inst); + break; } +} - for (const auto& inst : entry->GetInstructions()) { - LowerInstruction(*inst, *machine_func, slots); +void Lowerer::LowerInstMem(const ir::Instruction& inst) { + using ir::Opcode; + switch (inst.GetOpcode()) { + case Opcode::Alloca: { + auto& a = static_cast(inst); + int idx = mf_->CreateStackObject(TypeSize(*a.GetAllocatedType()), + a.GetAllocatedType()->IsArray() ? 16 : 4); + allocas_[&a] = idx; + Operand d = NewG(8); + Emit(mir::Opcode::AddrFrame, {d, Operand::Frame(idx)}, 1); + vmap_[&inst] = d; + break; + } + case Opcode::Load: { + auto& ld = static_cast(inst); + Operand base; + long long off = 0; + if (auto* gep = dynamic_cast(ld.GetPtr())) { + off = GepConst(*gep, &base); + } else { + base = AddressOf(ld.GetPtr()); + } + bool is_f = inst.GetType()->IsFloat(); + Operand d = is_f ? NewF() : NewG(BytesOf(*inst.GetType())); + Emit(mir::Opcode::Ldr, {d, base, Operand::Imm(off)}, 1); + vmap_[&inst] = d; + break; + } + case Opcode::Store: { + auto& st = static_cast(inst); + Operand val = GetReg(st.GetValue()); + Operand base; + long long off = 0; + if (auto* gep = dynamic_cast(st.GetPtr())) { + off = GepConst(*gep, &base); + } else { + base = AddressOf(st.GetPtr()); + } + Emit(mir::Opcode::Str, {val, base, Operand::Imm(off)}, 0); + break; + } + case Opcode::Gep: { + auto& gep = static_cast(inst); + Operand base; + long long off = GepConst(gep, &base); + Operand d = NewG(8); + if (off == 0) { + Emit(mir::Opcode::Mov, {d, base}, 1); + } else { + Emit(mir::Opcode::AddImm, {d, base, Operand::Imm(off)}, 1); + } + vmap_[&inst] = d; + break; + } + case Opcode::Call: { + auto& call = static_cast(inst); + // 先把所有实参算入虚拟寄存器,再连续搬入物理参数寄存器, + // 避免计算后续实参时分配器复用 x0..x7 破坏已就绪的参数。 + std::vector vals; + for (auto* arg : call.GetArgs()) vals.push_back(GetReg(arg)); + int ig = 0, fg = 0; + std::vector arg_uses; + for (size_t i = 0; i < call.GetArgs().size(); ++i) { + ir::Value* arg = call.GetArgs()[i]; + if (arg->GetType()->IsFloat()) { + Operand p = Operand::PReg(fg++, RegClass::FPR, 4); + Emit(mir::Opcode::FMov, {p, vals[i]}, 1); + arg_uses.push_back(p); + } else { + int bytes = arg->GetType()->IsPointer() ? 8 : 4; + Operand p = Operand::PReg(ig++, RegClass::GPR, bytes); + Emit(mir::Opcode::Mov, {p, vals[i]}, 1); + arg_uses.push_back(p); + } + } + std::vector ops; + ops.push_back(Operand::Global(call.GetCallee()->GetName())); + for (auto& u : arg_uses) ops.push_back(u); + Emit(mir::Opcode::Bl, ops, 0); + if (!inst.GetType()->IsVoid()) { + if (inst.GetType()->IsFloat()) { + Operand d = NewF(); + Emit(mir::Opcode::FMov, {d, Operand::PReg(0, RegClass::FPR, 4)}, 1); + vmap_[&inst] = d; + } else { + Operand d = NewG(BytesOf(*inst.GetType())); + Emit(mir::Opcode::Mov, + {d, Operand::PReg(0, RegClass::GPR, d.GetBytes())}, 1); + vmap_[&inst] = d; + } + } + break; + } + case Opcode::Br: { + auto& br = static_cast(inst); + Emit(mir::Opcode::B, {Operand::Label(br.GetDest()->GetName())}, 0); + break; + } + case Opcode::CondBr: { + auto& cbr = static_cast(inst); + Operand c = GetReg(cbr.GetCond()); + Emit(mir::Opcode::CmpImm, {c, Operand::Imm(0)}, 0); + Emit(mir::Opcode::BCond, {Operand::Label(cbr.GetTrueDest()->GetName())}, 0, + Cond::NE); + Emit(mir::Opcode::B, {Operand::Label(cbr.GetFalseDest()->GetName())}, 0); + break; + } + case Opcode::Ret: { + auto& ret = static_cast(inst); + if (ret.HasReturnValue()) { + Operand v = GetReg(ret.GetValue()); + if (ret.GetValue()->GetType()->IsFloat()) { + Emit(mir::Opcode::FMov, {Operand::PReg(0, RegClass::FPR, 4), v}, 1); + } else { + Emit(mir::Opcode::Mov, + {Operand::PReg(0, RegClass::GPR, v.GetBytes()), v}, 1); + } + } + Emit(mir::Opcode::Ret, {}, 0); + break; + } + default: + break; } +} - return machine_func; +} // namespace + +std::unique_ptr LowerToMIR(const ir::Module& module) { + auto out = std::make_unique(); + Lowerer lo(module, *out); + lo.Run(); + return out; } } // namespace mir + diff --git a/src/mir/MIRBasicBlock.cpp b/src/mir/MIRBasicBlock.cpp index d42b4b3..b534dfc 100644 --- a/src/mir/MIRBasicBlock.cpp +++ b/src/mir/MIRBasicBlock.cpp @@ -1,16 +1,2 @@ +// 机器基本块:实现已并入头文件,本文件仅保留 TU 占位。 #include "mir/MIR.h" - -#include - -namespace mir { - -MachineBasicBlock::MachineBasicBlock(std::string name) - : name_(std::move(name)) {} - -MachineInstr& MachineBasicBlock::Append(Opcode opcode, - std::initializer_list operands) { - instructions_.emplace_back(opcode, std::vector(operands)); - return instructions_.back(); -} - -} // namespace mir diff --git a/src/mir/MIRContext.cpp b/src/mir/MIRContext.cpp index 30c75c8..3db6ec4 100644 --- a/src/mir/MIRContext.cpp +++ b/src/mir/MIRContext.cpp @@ -1,10 +1,2 @@ +// 机器上下文:实现已并入头文件,本文件仅保留 TU 占位。 #include "mir/MIR.h" - -namespace mir { - -MIRContext& DefaultContext() { - static MIRContext ctx; - return ctx; -} - -} // namespace mir diff --git a/src/mir/MIRFunction.cpp b/src/mir/MIRFunction.cpp index 334f8cc..bc656e8 100644 --- a/src/mir/MIRFunction.cpp +++ b/src/mir/MIRFunction.cpp @@ -1,33 +1,2 @@ +// 机器函数:实现已并入头文件,本文件仅保留 TU 占位。 #include "mir/MIR.h" - -#include -#include - -#include "utils/Log.h" - -namespace mir { - -MachineFunction::MachineFunction(std::string name) - : name_(std::move(name)), entry_("entry") {} - -int MachineFunction::CreateFrameIndex(int size) { - int index = static_cast(frame_slots_.size()); - frame_slots_.push_back(FrameSlot{index, size, 0}); - return index; -} - -FrameSlot& MachineFunction::GetFrameSlot(int index) { - if (index < 0 || index >= static_cast(frame_slots_.size())) { - throw std::runtime_error(FormatError("mir", "非法 FrameIndex")); - } - return frame_slots_[index]; -} - -const FrameSlot& MachineFunction::GetFrameSlot(int index) const { - if (index < 0 || index >= static_cast(frame_slots_.size())) { - throw std::runtime_error(FormatError("mir", "非法 FrameIndex")); - } - return frame_slots_[index]; -} - -} // namespace mir diff --git a/src/mir/MIRInstr.cpp b/src/mir/MIRInstr.cpp index 0a21a03..ee76289 100644 --- a/src/mir/MIRInstr.cpp +++ b/src/mir/MIRInstr.cpp @@ -1,23 +1,2 @@ +// 机器指令:实现已并入头文件,本文件仅保留 TU 占位。 #include "mir/MIR.h" - -#include - -namespace mir { - -Operand::Operand(Kind kind, PhysReg reg, int imm) - : kind_(kind), reg_(reg), imm_(imm) {} - -Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); } - -Operand Operand::Imm(int value) { - return Operand(Kind::Imm, PhysReg::W0, value); -} - -Operand Operand::FrameIndex(int index) { - return Operand(Kind::FrameIndex, PhysReg::W0, index); -} - -MachineInstr::MachineInstr(Opcode opcode, std::vector operands) - : opcode_(opcode), operands_(std::move(operands)) {} - -} // namespace mir diff --git a/src/mir/RegAlloc.cpp b/src/mir/RegAlloc.cpp index 5dc5d2b..dae406c 100644 --- a/src/mir/RegAlloc.cpp +++ b/src/mir/RegAlloc.cpp @@ -1,36 +1,354 @@ +// 寄存器分配(Lab5):线性扫描 + 活跃区间。 +// +// 物理寄存器约定: +// GPR: x0-x8 参数/返回(不参与分配),x9-x12 可分配(caller-saved), +// x13-x15 spill 暂存,x16-x17 汇编寻址暂存,x18 平台保留, +// x19-x28 可分配(callee-saved),x29/x30 fp/lr,x31 sp。 +// FPR: s0-s7 参数/返回,s8-s15 可分配(callee-saved), +// s16-s28 可分配(caller-saved),s29-s31 spill 暂存。 +// +// 跨调用活跃的虚拟寄存器只能落在 callee-saved 寄存器或被 spill。 #include "mir/MIR.h" -#include - -#include "utils/Log.h" +#include +#include +#include +#include namespace mir { namespace { -bool IsAllowedReg(PhysReg reg) { - switch (reg) { - case PhysReg::W0: - case PhysReg::W8: - case PhysReg::W9: - case PhysReg::X29: - case PhysReg::X30: - case PhysReg::SP: - return true; +const std::vector& CallerGPR() { + static const std::vector v{9, 10, 11}; + return v; +} +const std::vector& CalleeGPR() { + static const std::vector v{19, 20, 21, 22, 23, 24, 25, 26, 27, 28}; + return v; +} +const std::vector& CalleeFPR() { + static const std::vector v{8, 9, 10, 11, 12, 13, 14, 15}; + return v; +} +const std::vector& CallerFPR() { + static const std::vector v{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27}; + return v; +} +// 4 个暂存寄存器:覆盖单条指令最多 3 use + 1 def 同时为 spill 的情形。 +const int kGprSpillScratch[4] = {12, 13, 14, 15}; +const int kFprSpillScratch[4] = {28, 29, 30, 31}; + +bool IsCalleeSavedGPR(int id) { return id >= 19 && id <= 28; } +bool IsCalleeSavedFPR(int id) { return id >= 8 && id <= 15; } +bool IsCalleeSavedGPRorFPR(int id, RegClass cls) { + return cls == RegClass::GPR ? IsCalleeSavedGPR(id) : IsCalleeSavedFPR(id); +} + +// 提取指令的 def/use 寄存器(仅虚拟寄存器,按操作数顺序)。 +void Defs(const MachineInstr& mi, std::vector* out) { + for (int i = 0; i < mi.num_defs && i < (int)mi.ops.size(); ++i) + if (mi.ops[i].IsVReg()) out->push_back(mi.ops[i].GetId()); +} +void Uses(const MachineInstr& mi, std::vector* out) { + for (int i = mi.num_defs; i < (int)mi.ops.size(); ++i) + if (mi.ops[i].IsVReg()) out->push_back(mi.ops[i].GetId()); +} + +struct Interval { + int vreg; + int start; + int end; + RegClass cls; + bool cross_call = false; + int preg = -1; // 分配到的物理寄存器;-1 表示 spill + int spill_slot = -1; +}; + +// 把所有基本块按顺序展开成全局指令编号,并记录每块的 [lo,hi]。 +struct Numbering { + std::vector instrs; + std::vector owner; + std::unordered_map> range; + std::vector call_sites; // Bl 指令的全局编号 +}; + +Numbering NumberInstrs(MachineFunction& f) { + Numbering n; + for (auto& bb : f.Blocks()) { + int lo = (int)n.instrs.size(); + for (auto& mi : bb->Instrs()) { + if (mi.op == Opcode::Bl) n.call_sites.push_back((int)n.instrs.size()); + n.instrs.push_back(&mi); + n.owner.push_back(bb.get()); + } + int hi = (int)n.instrs.size() - 1; + if (hi < lo) hi = lo; // 空块 + n.range[bb.get()] = {lo, hi}; + } + return n; +} + +// 经典反向数据流活跃性分析(按块)。 +void ComputeLiveness( + MachineFunction& f, + std::unordered_map>* live_in, + std::unordered_map>* live_out) { + std::unordered_map> use, def; + for (auto& bb : f.Blocks()) { + std::unordered_set u, d; + for (auto& mi : bb->Instrs()) { + std::vector us, ds; + Uses(mi, &us); + Defs(mi, &ds); + for (int r : us) + if (!d.count(r)) u.insert(r); + for (int r : ds) d.insert(r); + } + use[bb.get()] = std::move(u); + def[bb.get()] = std::move(d); + (*live_in)[bb.get()]; + (*live_out)[bb.get()]; + } + bool changed = true; + while (changed) { + changed = false; + for (auto it = f.Blocks().rbegin(); it != f.Blocks().rend(); ++it) { + MachineBasicBlock* b = it->get(); + std::unordered_set out; + for (auto* s : b->Succs()) + for (int r : (*live_in)[s]) out.insert(r); + std::unordered_set in = use[b]; + for (int r : (*live_out)[b]) + if (!def[b].count(r)) in.insert(r); + if (out.size() != (*live_out)[b].size() || + in.size() != (*live_in)[b].size()) { + changed = true; + } + (*live_out)[b] = std::move(out); + (*live_in)[b] = std::move(in); + } } - return false; } -} // namespace +// 用块级活跃信息 + 块内精确编号构造每个 vreg 的活跃区间。 +std::vector BuildIntervals(MachineFunction& f, const Numbering& num) { + std::unordered_map> live_in, + live_out; + ComputeLiveness(f, &live_in, &live_out); -void RunRegAlloc(MachineFunction& function) { - for (const auto& inst : function.GetEntry().GetInstructions()) { - for (const auto& operand : inst.GetOperands()) { - if (operand.GetKind() == Operand::Kind::Reg && - !IsAllowedReg(operand.GetReg())) { - throw std::runtime_error(FormatError("mir", "寄存器分配失败")); + int nv = f.NumVRegs(); + std::vector start(nv, -1), end(nv, -1); + auto extend = [&](int v, int p) { + if (v < 0) return; + if (start[v] == -1 || p < start[v]) start[v] = p; + if (p > end[v]) end[v] = p; + }; + + for (auto& bb : f.Blocks()) { + auto [lo, hi] = num.range.at(bb.get()); + // live-in 在块入口活跃。 + for (int v : live_in[bb.get()]) extend(v, lo); + // live-out 在块出口活跃。 + for (int v : live_out[bb.get()]) extend(v, hi); + int p = lo; + for (auto& mi : bb->Instrs()) { + std::vector us, ds; + Uses(mi, &us); + Defs(mi, &ds); + for (int v : us) extend(v, p); + for (int v : ds) extend(v, p); + ++p; + } + } + + std::vector ivs; + for (int v = 0; v < nv; ++v) { + if (start[v] == -1) continue; + Interval iv; + iv.vreg = v; + iv.start = start[v]; + iv.end = end[v]; + iv.cls = f.VReg(v).cls; + for (int cs : num.call_sites) + if (cs > iv.start && cs <= iv.end) { // 调用点严格落在区间内 + iv.cross_call = true; + break; + } + ivs.push_back(iv); + } + std::sort(ivs.begin(), ivs.end(), + [](const Interval& a, const Interval& b) { + return a.start < b.start; + }); + return ivs; +} + +// 线性扫描分配;对每个区间设定 preg(>=0) 或标记 spill(preg==-1)。 +// 返回 true 表示无需额外 spill 重试(本实现一次性完成,spill 直接落槽)。 +void LinearScan(MachineFunction& f, std::vector& ivs) { + // 为两类寄存器分别准备“优先 caller-saved,再 callee-saved”的池; + // 跨调用区间则只用 callee-saved。 + auto run = [&](RegClass cls) { + const std::vector& caller = + cls == RegClass::GPR ? CallerGPR() : CallerFPR(); + const std::vector& callee = + cls == RegClass::GPR ? CalleeGPR() : CalleeFPR(); + + // active:已分配且尚未结束的区间,按 end 升序。 + std::vector active; + std::unordered_set free_regs; + for (int r : caller) free_regs.insert(r); + for (int r : callee) free_regs.insert(r); + + auto expire = [&](int point) { + std::vector keep; + for (Interval* a : active) { + if (a->end < point) { + free_regs.insert(a->preg); + } else { + keep.push_back(a); + } + } + active = std::move(keep); + }; + + auto pick = [&](bool need_callee) -> int { + // 跨调用:仅 callee-saved;否则优先 caller-saved。 + if (need_callee) { + for (int r : callee) + if (free_regs.count(r)) return r; + return -1; + } + for (int r : caller) + if (free_regs.count(r)) return r; + for (int r : callee) + if (free_regs.count(r)) return r; + return -1; + }; + + for (auto& iv : ivs) { + if (iv.cls != cls) continue; + expire(iv.start); + int r = pick(iv.cross_call); + if (r == -1) { + // spill:在当前区间和 active 里结束最晚者之间选择。 + Interval* victim = nullptr; + for (Interval* a : active) + if (a->cls == cls && + (iv.cross_call ? IsCalleeSavedGPRorFPR(a->preg, cls) : true)) { + if (!victim || a->end > victim->end) victim = a; + } + if (victim && victim->end > iv.end) { + iv.preg = victim->preg; + victim->preg = -1; + victim->spill_slot = + f.CreateStackObject(cls == RegClass::GPR ? 8 : 8, 8); + // 从 active 移除 victim,加入当前。 + active.erase(std::remove(active.begin(), active.end(), victim), + active.end()); + active.push_back(&iv); + } else { + iv.preg = -1; + iv.spill_slot = f.CreateStackObject(8, 8); + } + } else { + free_regs.erase(r); + iv.preg = r; + active.push_back(&iv); + } + std::sort(active.begin(), active.end(), + [](Interval* a, Interval* b) { return a->end < b->end; }); + } + }; + run(RegClass::GPR); + run(RegClass::FPR); +} + +// 把分配结果写回指令;spill 的虚拟寄存器用暂存寄存器搬运并落槽。 +void Rewrite(MachineFunction& f, const std::vector& ivs) { + int nv = f.NumVRegs(); + std::vector preg(nv, -1); + std::vector slot(nv, -1); + for (const auto& iv : ivs) { + preg[iv.vreg] = iv.preg; + slot[iv.vreg] = iv.spill_slot; + } + + std::unordered_set used_callee_gpr, used_callee_fpr; + for (const auto& iv : ivs) { + if (iv.preg < 0) continue; + if (iv.cls == RegClass::GPR && IsCalleeSavedGPR(iv.preg)) + used_callee_gpr.insert(iv.preg); + if (iv.cls == RegClass::FPR && IsCalleeSavedFPR(iv.preg)) + used_callee_fpr.insert(iv.preg); + } + + for (auto& bb : f.Blocks()) { + std::vector out; + for (auto& mi : bb->Instrs()) { + std::vector pre, post; + int gscr = 0, fscr = 0; + std::unordered_map assigned; // vreg -> scratch preg + + auto scratchFor = [&](const Operand& op) -> int { + int v = op.GetId(); + auto it = assigned.find(v); + if (it != assigned.end()) return it->second; + int s = op.GetClass() == RegClass::GPR ? kGprSpillScratch[gscr++] + : kFprSpillScratch[fscr++]; + assigned[v] = s; + return s; + }; + + for (int i = 0; i < (int)mi.ops.size(); ++i) { + Operand& op = mi.ops[i]; + if (!op.IsVReg()) continue; + int v = op.GetId(); + if (preg[v] >= 0) { + op.SetPReg(preg[v]); + continue; + } + // spilled + int s = scratchFor(op); + bool is_def = i < mi.num_defs; + if (is_def) { + MachineInstr st(Opcode::StrStack, + {Operand::PReg(s, op.GetClass(), op.GetBytes()), + Operand::Frame(slot[v])}, + 0); + post.push_back(st); + } else { + MachineInstr ld(Opcode::LdrStack, + {Operand::PReg(s, op.GetClass(), op.GetBytes()), + Operand::Frame(slot[v])}, + 1); + pre.push_back(ld); + } + op.SetPReg(s); } + for (auto& p : pre) out.push_back(std::move(p)); + out.push_back(mi); + for (auto& p : post) out.push_back(std::move(p)); } + bb->Instrs() = std::move(out); } + + for (int r : used_callee_gpr) f.CalleeSavedGPR().push_back(r); + for (int r : used_callee_fpr) f.CalleeSavedFPR().push_back(r); + std::sort(f.CalleeSavedGPR().begin(), f.CalleeSavedGPR().end()); + std::sort(f.CalleeSavedFPR().begin(), f.CalleeSavedFPR().end()); +} + +} // namespace + +void RunRegAlloc(MachineFunction& function) { + Numbering num = NumberInstrs(function); + std::vector ivs = BuildIntervals(function, num); + LinearScan(function, ivs); + Rewrite(function, ivs); } } // namespace mir + + + diff --git a/src/mir/Register.cpp b/src/mir/Register.cpp index 7530470..de9fa06 100644 --- a/src/mir/Register.cpp +++ b/src/mir/Register.cpp @@ -1,27 +1,45 @@ #include "mir/MIR.h" -#include +namespace mir { -#include "utils/Log.h" +MIRContext& DefaultContext() { + static MIRContext ctx; + return ctx; +} -namespace mir { +namespace { +// 64 位通用寄存器名(x0..x30, sp)。 +const char* kX[33] = { + "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", + "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x18", "x19", "x20", "x21", + "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x29", "x30", "sp", "xzr"}; +const char* kW[33] = { + "w0", "w1", "w2", "w3", "w4", "w5", "w6", "w7", "w8", "w9", "w10", + "w11", "w12", "w13", "w14", "w15", "w16", "w17", "w18", "w19", "w20", "w21", + "w22", "w23", "w24", "w25", "w26", "w27", "w28", "w29", "w30", "wsp", "wzr"}; +} // namespace + +const char* GPRName(int id, int bytes) { + if (id < 0 || id > 32) return "x0"; + return bytes == 8 ? kX[id] : kW[id]; +} -const char* PhysRegName(PhysReg reg) { - switch (reg) { - case PhysReg::W0: - return "w0"; - case PhysReg::W8: - return "w8"; - case PhysReg::W9: - return "w9"; - case PhysReg::X29: - return "x29"; - case PhysReg::X30: - return "x30"; - case PhysReg::SP: - return "sp"; +const char* FPRName(int id, int bytes) { + static char buf[8][8]; + static int slot = 0; + char* b = buf[slot]; + slot = (slot + 1) & 7; + b[0] = (bytes == 8) ? 'd' : 's'; + int n = id; + if (n < 10) { + b[1] = static_cast('0' + n); + b[2] = '\0'; + } else { + b[1] = static_cast('0' + n / 10); + b[2] = static_cast('0' + n % 10); + b[3] = '\0'; } - throw std::runtime_error(FormatError("mir", "未知物理寄存器")); + return b; } } // namespace mir diff --git a/src/mir/passes/PassManager.cpp b/src/mir/passes/PassManager.cpp index c510460..019ce3b 100644 --- a/src/mir/passes/PassManager.cpp +++ b/src/mir/passes/PassManager.cpp @@ -1,4 +1,4 @@ -// MIR Pass 管理: -// - 组织后端 pass 的运行顺序(PreRA/PostRA/PEI 等阶段) -// - 统一运行 pass 与调试输出(按需要扩展) - +// 后端 Pass 管理:当前后端流水线直接在 RunBackendPipeline 中按 +// RegAlloc -> FrameLowering -> Peephole 顺序驱动(见 AsmPrinter.cpp)。 +// 本文件保留占位,便于后续扩展更细粒度的 Pass 调度。 +#include "mir/MIR.h" diff --git a/src/mir/passes/Peephole.cpp b/src/mir/passes/Peephole.cpp index c6d9ab7..56c2973 100644 --- a/src/mir/passes/Peephole.cpp +++ b/src/mir/passes/Peephole.cpp @@ -1,4 +1,84 @@ -// 窥孔优化(Peephole): -// - 删除冗余 move、合并常见指令模式 -// - 提升最终汇编质量(按实现范围裁剪) +// 后端局部窥孔优化(Lab5)。 +// 在寄存器分配 + 物理寄存器落地之后运行,针对最终机器指令序列做局部清理: +// 1. 删除自拷贝 mov xN, xN / fmov sN, sN; +// 2. 删除恒等运算 add/sub xD, xS, #0(必要时退化为 mov); +// 3. 删除相邻冗余访存:str R,[B,#o] 紧跟 ldr R,[B,#o] 时删除 ldr; +// 4. 删除写入同一目标且中间无使用的连续 mov(保留最后一条)。 +#include "mir/MIR.h" +#include + +namespace mir { +namespace { + +bool SameReg(const Operand& a, const Operand& b) { + return a.IsPReg() && b.IsPReg() && a.GetId() == b.GetId() && + a.GetClass() == b.GetClass(); +} + +bool IsSelfMove(const MachineInstr& mi) { + if (mi.op == Opcode::Mov || mi.op == Opcode::FMov) + return mi.ops.size() == 2 && SameReg(mi.ops[0], mi.ops[1]); + return false; +} + +bool IsIdentityAddSub(const MachineInstr& mi) { + if ((mi.op == Opcode::AddImm || mi.op == Opcode::SubImm) && + mi.ops.size() == 3 && mi.ops[2].GetImm() == 0) + return true; + return false; +} + +// str R,[B,#o] 之后立即 ldr R,[B,#o](同寄存器、同基址、同偏移):ldr 冗余。 +bool RedundantLoadAfterStore(const MachineInstr& st, const MachineInstr& ld) { + if (st.op != Opcode::Str || ld.op != Opcode::Ldr) return false; + return SameReg(st.ops[0], ld.ops[0]) && SameReg(st.ops[1], ld.ops[1]) && + st.ops[2].GetImm() == ld.ops[2].GetImm(); +} +bool RedundantStackLoadAfterStore(const MachineInstr& st, + const MachineInstr& ld) { + if (st.op != Opcode::StrStack || ld.op != Opcode::LdrStack) return false; + return SameReg(st.ops[0], ld.ops[0]) && + st.ops[1].GetFrame() == ld.ops[1].GetFrame(); +} + +void OptimizeBlock(MachineBasicBlock& bb) { + std::vector out; + out.reserve(bb.Instrs().size()); + for (auto& mi : bb.Instrs()) { + if (IsSelfMove(mi)) continue; + if (IsIdentityAddSub(mi)) { + // add/sub D, S, #0 -> mov D, S(若 D==S 直接删)。 + if (SameReg(mi.ops[0], mi.ops[1])) continue; + MachineInstr mv(Opcode::Mov, {mi.ops[0], mi.ops[1]}, 1); + out.push_back(std::move(mv)); + continue; + } + if (!out.empty()) { + const MachineInstr& prev = out.back(); + if (RedundantLoadAfterStore(prev, mi) || + RedundantStackLoadAfterStore(prev, mi)) { + continue; // 删除冗余 load + } + } + out.push_back(mi); + } + bb.Instrs() = std::move(out); +} + +} // namespace + +void RunPeephole(MachineFunction& function) { + // 迭代到不动点,确保删除后新暴露的模式也被处理。 + bool again = true; + int guard = 0; + while (again && guard++ < 8) { + size_t before = 0, after = 0; + for (auto& bb : function.Blocks()) before += bb->Instrs().size(); + for (auto& bb : function.Blocks()) OptimizeBlock(*bb); + for (auto& bb : function.Blocks()) after += bb->Instrs().size(); + again = (after < before); + } +} + +} // namespace mir