Compare commits

..

No commits in common. 'master' and 'master' have entirely different histories.

@ -1,61 +0,0 @@
{
"permissions": {
"allow": [
"Bash(cd \"\\\\\\\\wsl.localhost\\\\Ubuntu-24.04\\\\home\\\\bnk\\\\nudt-compiler-cpp\")",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'echo IN_WSL; which cmake g++ aarch64-linux-gnu-gcc qemu-aarch64 clang 2>/dev/null; pwd')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cmake -S . -B build -DCMAKE_BUILD_TYPE=Release > /tmp/cm.log 2>&1 && echo \"cfg ok\" || tail -20 /tmp/cm.log; cmake --build build -j \"$\\(nproc\\)\" > /tmp/build.log 2>&1 && echo \"build ok\" || tail -40 /tmp/build.log')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && rm -rf build && cmake -S . -B build -DCMAKE_BUILD_TYPE=Release > /tmp/cm.log 2>&1 && echo \"cfg ok\" || tail -20 /tmp/cm.log; cmake --build build -j \"$\\(nproc\\)\" > /tmp/build.log 2>&1 && echo \"build ok\" || tail -40 /tmp/build.log')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && ./build/bin/compiler --emit-ir test/test_case/functional/simple_add.sy 2>&1')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== simple_add no-opt ===\"; ./build/bin/compiler --emit-ir --no-opt test/test_case/functional/simple_add.sy 2>&1 | head -40')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && for f in 11_add2 13_sub2 29_break 36_op_priority2; do echo \"=== $f \\(opt\\) ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/$f.sy 2>&1 | sed -n \"/define/,/^}/p\"; done')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== 09_func_defn ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/09_func_defn.sy 2>&1 | sed -n \"1,60p\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== 25_scope3 ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/25_scope3.sy 2>&1 | sed -n \"/define/,/^}/p\" | head -80')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== 05_arr_defn4 ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/05_arr_defn4.sy 2>&1 | sed -n \"/define/,/^}/p\" | head -120')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== 22 globals+main head ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/22_matrix_multiply.sy 2>&1 | sed -n \"1,40p\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== gep with var index \\(22\\) ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/22_matrix_multiply.sy 2>&1 | grep -A2 getelementptr | head -20; echo \"=== 95 float head ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/95_float.sy 2>&1 | sed -n \"1,50p\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== CMakeLists \\(mir glob?\\) ===\"; grep -n -i \"mir\\\\|glob\\\\|GLOB\\\\|file\\(\" CMakeLists.txt | head -40')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -n \"src/\\\\|GLOB\\\\|add_executable\\\\|add_library\\\\|SOURCES\\\\|\\\\.cpp\" CMakeLists.txt | head -60')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -n \"compiler_core\\\\|compiler\\\\b\\\\|target_sources\\\\|GLOB.*SRC\\\\|PROJECT_SRC\\\\|set\\(SOURCES\\\\|\\\\.cpp\\\\\"\" CMakeLists.txt | head; echo \"---\"; sed -n \"82,160p\" CMakeLists.txt')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cat src/CMakeLists.txt')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== mir/CMakeLists ===\"; cat src/mir/CMakeLists.txt; echo \"=== ls src/mir ===\"; ls -R src/mir; echo \"=== ls include/mir ===\"; ls include/mir')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && tot=0; for f in test/test_case/functional/*.sy; do ./build/bin/compiler --emit-ir \"$f\" 2>/dev/null; done > /tmp/allir.txt; echo \"=== opcodes used ===\"; grep -oE \"= \\(add|sub|mul|sdiv|srem|fadd|fsub|fmul|fdiv|icmp|fcmp|call|getelementptr|load|alloca|phi|sitofp|fptosi|zext\\) \" /tmp/allir.txt | sort | uniq -c; echo \"=== bare ops ===\"; grep -oE \"^ \\(store|br|ret|call\\) \" /tmp/allir.txt | sort | uniq -c; echo \"=== phi count ===\"; grep -c \"phi\" /tmp/allir.txt; echo \"=== float present ===\"; grep -c \"float\" /tmp/allir.txt')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && ls test/test_case/functional/*.sy | wc -l; echo \"--- verify uses ---\"; grep -nE \"qemu|aarch64|gcc|clang|gcc-|--target\" scripts/verify_asm.sh | head; echo \"--- a sample .out ---\"; ls test/test_case/functional/*.out 2>/dev/null | head -3; echo \"--- runtime lib? ---\"; ls test/ ; find . -name \"*.a\" -path \"*runtime*\" 2>/dev/null | head')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== sylib.h ===\"; cat sylib/sylib.h 2>/dev/null | head -60; echo \"=== test list ===\"; ls test/test_case/functional/*.sy | xargs -n1 basename')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -nE \"int |void |float |#define|starttime|_sysy\" sylib/sylib.c | head -40')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && tail -5 include/mir/MIR.h')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -nE \"GetFunctions|GetGlobals|IsDeclaration|GetBlocks|GetEntry|GetSuccessors|GetNumArgs|GetArg\\\\b|class Argument|GetInstructions|GetOpcode|class Module\" include/ir/IR.h | head -40')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== Opcode enum ===\"; grep -nA40 \"enum class Opcode\" include/ir/IR.h | head -50')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -nE \"GetLhs|GetRhs|class BinaryInst|class ICmpInst|class FCmpInst|GetPredicate|class CastInst|GetValue\\\\b|class LoadInst|GetPtr|class StoreInst|class AllocaInst|GetAllocatedType|class GepInst|GetBasePtr|GetIndices|class CallInst|GetCallee|GetArgs|class BranchInst|GetDest|class CondBrInst|GetCond|GetTrueDest|GetFalseDest|class ReturnInst|HasReturnValue\" include/ir/IR.h')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== Type ===\"; grep -nE \"class Type|enum class Kind|GetKind|GetArraySize|GetElementType|IsFloat|IsPointer|IsArray|IsVoid|Int1|Int32\" include/ir/IR.h | head -30; echo \"=== Constants/Global ===\"; grep -nE \"class ConstantInt|class ConstantFloat|class ConstantArray|GetElements|class ConstantValue|class GlobalVariable|GetValueType|IsConst|GetInitializer|class Function\\\\b|GetName\" include/ir/IR.h | head -40')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -nE \"GetType\\\\\\(\\\\\\)|class Value\\\\b|class GlobalValue\" include/ir/IR.h | head; echo \"=== line 40-60 ===\"; sed -n \"36,60p\" include/ir/IR.h; echo \"=== 120-160 ===\"; sed -n \"118,160p\" include/ir/IR.h')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cmake --build build -j \"$\\(nproc\\)\" 2>&1 | grep -E \"error:|Error|错误\" | head -40; echo \"=== exit ${PIPESTATUS[0]} ===\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cmake --build build -j \"$\\(nproc\\)\" 2>&1 | grep -iE \"undefined|multiple|duplicate|reference to\" | head -30')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cat src/mir/CMakeLists.txt; echo \"=== grep mir in src/CMakeLists ===\"; grep -n mir src/CMakeLists.txt')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cat src/mir/passes/CMakeLists.txt')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cmake -S . -B build -DCMAKE_BUILD_TYPE=Release > /tmp/cm.log 2>&1 && echo cfgok || tail -20 /tmp/cm.log; cmake --build build -j \"$\\(nproc\\)\" 2>&1 | grep -iE \"error|undefined|错误\" | head -30; echo \"=== done ===\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && ./build/bin/compiler --emit-asm test/test_case/functional/simple_add.sy 2>&1 | head -40')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && pass=0; fail=0; fails=\"\"; for f in test/test_case/functional/*.sy; do b=$\\(basename \"$f\" .sy\\); if ./scripts/verify_asm.sh \"$f\" /tmp/asmout --run > /tmp/v.log 2>&1; then pass=$\\(\\(pass+1\\)\\); else fail=$\\(\\(fail+1\\)\\); fails=\"$fails $b\"; fi; done; echo \"PASS=$pass FAIL=$fail\"; echo \"FAILED:$fails\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && for b in 13_sub2 95_float; do echo \"===== $b =====\"; ./scripts/verify_asm.sh test/test_case/functional/$b.sy /tmp/asmout --run 2>&1 | tail -15; done')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cmake --build build -j \"$\\(nproc\\)\" 2>&1 | grep -iE \"error|错误\" | head; for b in 13_sub2 95_float; do echo \"===== $b =====\"; ./scripts/verify_asm.sh test/test_case/functional/$b.sy /tmp/asmout --run 2>&1 | tail -6; done')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && for b in 15_graph_coloring 22_matrix_multiply; do echo \"===== $b =====\"; ./scripts/verify_asm.sh test/test_case/functional/$b.sy /tmp/asmout --run 2>&1 | tail -8; done')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -nE \"no-opt|no_opt|emit-asm|compiler|OPT|opt\" scripts/verify_asm.sh | head')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && CC=aarch64-linux-gnu-gcc; QEMU=qemu-aarch64; SYS=$\\(ls sysroot 2>/dev/null\\); for b in 22_matrix_multiply 95_float 15_graph_coloring 25_scope3; do f=test/test_case/functional/$b.sy; ./build/bin/compiler --no-opt --emit-asm \"$f\" > /tmp/n.s 2>/tmp/n.err || { echo \"$b COMPILE FAIL\"; cat /tmp/n.err; continue; }; aarch64-linux-gnu-gcc /tmp/n.s sylib/libsysy.a -o /tmp/n.exe 2>/tmp/as.err || { echo \"$b ASM/LINK FAIL\"; head -5 /tmp/as.err; continue; }; in=test/test_case/functional/$b.in; if [ -f \"$in\" ]; then qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/n.exe < \"$in\" > /tmp/n.out 2>&1; else qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/n.exe > /tmp/n.out 2>&1; fi; ec=$?; echo \"$b: exit=$ec out=$\\(tr \"\\\\n\" \"|\" < /tmp/n.out | head -c 80\\)\"; done')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && sed -n \"58,110p\" scripts/verify_asm.sh')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && aarch64-linux-gnu-gcc -O2 -Wno-unused-result -c sylib/sylib.c -o /tmp/sylib.o 2>/dev/null; pass=0; fail=0; fails=\"\"; for f in test/test_case/functional/*.sy; do b=$\\(basename \"$f\" .sy\\); ./build/bin/compiler --no-opt --emit-asm \"$f\" > /tmp/n.s 2>/dev/null || { fails=\"$fails $b\\(cc\\)\"; fail=$\\(\\(fail+1\\)\\); continue; }; aarch64-linux-gnu-gcc /tmp/n.s /tmp/sylib.o -o /tmp/n.exe 2>/tmp/as.err || { fails=\"$fails $b\\(as\\)\"; fail=$\\(\\(fail+1\\)\\); continue; }; in=test/test_case/functional/$b.in; if [ -f \"$in\" ]; then qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/n.exe < \"$in\" > /tmp/n.out 2>&1; else qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/n.exe > /tmp/n.out 2>&1; fi; ec=$?; exp=test/test_case/functional/$b.out; { cat /tmp/n.out; [ -s /tmp/n.out ] && [ $\\(tail -c1 /tmp/n.out|wc -l\\) -eq 0 ] && printf \"\\\\n\"; printf \"%s\\\\n\" \"$ec\"; } > /tmp/n.act; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" \"$exp\" > /tmp/n.exp 2>/dev/null; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" /tmp/n.act > /tmp/n.actn; if diff -q /tmp/n.exp /tmp/n.actn >/dev/null 2>&1; then pass=$\\(\\(pass+1\\)\\); else fails=\"$fails $b\\(diff\\)\"; fail=$\\(\\(fail+1\\)\\); fi; done; echo \"NOOPT PASS=$pass FAIL=$fail\"; echo \"FAILS:$fails\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== check spill present in no-opt 22 ===\"; ./build/bin/compiler --no-opt --emit-asm test/test_case/functional/22_matrix_multiply.sy 2>/dev/null | grep -cE \"x29, #\\(2[0-9]|[3-9][0-9]\\)\"; echo \"\\(stack accesses above = spills/locals\\)\"; echo \"=== self-move check \\(should be 0\\) ===\"; ./build/bin/compiler --emit-asm test/test_case/functional/22_matrix_multiply.sy 2>/dev/null | grep -cE \"\\\\bmov\\\\t\\(w|x\\)\\([0-9]+\\), \\\\1?\\\\2$\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== total instrs across all tests \\(opt\\) ===\"; tot=0; for f in test/test_case/functional/*.sy; do n=$\\(./build/bin/compiler --emit-asm \"$f\" 2>/dev/null | grep -cE \"^\\\\t\\(mov|add|sub|mul|ldr|str|b|bl|cmp|cset|fmov|ret|sxtw|lsl|sdiv|msub|fadd|fsub|fmul|fdiv|fcmp|scvtf|fcvtzs|adrp|stp|ldp\\)\"\\); tot=$\\(\\(tot+n\\)\\); done; echo \"total=$tot\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && pass=0; fail=0; fails=\"\"; for f in test/test_case/functional/*.sy; do b=$\\(basename \"$f\" .sy\\); if ./scripts/verify_asm.sh \"$f\" test/test_result/function/asm --run > /tmp/v.log 2>&1; then pass=$\\(\\(pass+1\\)\\); else fail=$\\(\\(fail+1\\)\\); fails=\"$fails $b\"; fi; done; echo \"OFFICIAL SCRIPT PASS=$pass FAIL=$fail\"; echo \"FAILS:$fails\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && f=test/test_case/functional/22_matrix_multiply.sy; ./build/bin/compiler --emit-asm \"$f\" 2>/dev/null > /tmp/with.s; wc -l < /tmp/with.s | xargs echo \"lines with peephole:\"; grep -c \"\tmov\t\" /tmp/with.s | xargs echo \"mov count:\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && pass=0; fail=0; for f in test/test_case/functional/*.sy; do b=$\\(basename \"$f\" .sy\\); if ./scripts/verify_asm.sh \"$f\" test/test_result/function/asm --run > /tmp/v.log 2>&1; then echo \" PASS $b\"; pass=$\\(\\(pass+1\\)\\); else echo \" FAIL $b\"; fail=$\\(\\(fail+1\\)\\); fi; done; echo \"========================\"; echo \"PASS=$pass FAIL=$fail\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== test_case 目录结构 ===\"; find test/test_case -type d; echo \"=== 各目录 .sy 数量 ===\"; for d in $\\(find test/test_case -type d\\); do n=$\\(ls \"$d\"/*.sy 2>/dev/null | wc -l\\); [ \"$n\" -gt 0 ] && echo \"$n $d\"; done; echo \"=== 总数 ===\"; find test/test_case -name \"*.sy\" | wc -l')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && pass=0; fail=0; fails=\"\"; for f in test/test_case/performance/*.sy; do b=$\\(basename \"$f\" .sy\\); if ./scripts/verify_asm.sh \"$f\" test/test_result/perf/asm --run > /tmp/v.log 2>&1; then echo \" PASS $b\"; pass=$\\(\\(pass+1\\)\\); else echo \" FAIL $b\"; fail=$\\(\\(fail+1\\)\\); fails=\"$fails $b\"; fi; done; echo \"===== PERF PASS=$pass FAIL=$fail =====\"; echo \"FAILS:$fails\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && ls test/test_case/performance/*.sy | xargs -n1 basename')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && aarch64-linux-gnu-gcc -O2 -Wno-unused-result -c sylib/sylib.c -o /tmp/sylib.o 2>/dev/null; for f in test/test_case/performance/*.sy; do b=$\\(basename \"$f\" .sy\\); ./build/bin/compiler --emit-asm \"$f\" > /tmp/p.s 2>/tmp/p.err || { echo \"FAIL\\($b\\) compile\"; continue; }; aarch64-linux-gnu-gcc /tmp/p.s /tmp/sylib.o -o /tmp/p.exe 2>/tmp/p.as || { echo \"FAIL\\($b\\) assemble\"; head -3 /tmp/p.as; continue; }; in=test/test_case/performance/$b.in; if [ -f \"$in\" ]; then timeout 25 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/p.exe < \"$in\" > /tmp/p.out 2>&1; else timeout 25 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/p.exe > /tmp/p.out 2>&1; fi; ec=$?; if [ $ec -eq 124 ]; then echo \"TIMEOUT\\($b\\)\"; continue; fi; exp=test/test_case/performance/$b.out; { cat /tmp/p.out; [ -s /tmp/p.out ] && [ $\\(tail -c1 /tmp/p.out|wc -l\\) -eq 0 ] && printf \"\\\\n\"; printf \"%s\\\\n\" \"$ec\"; } > /tmp/p.act; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" \"$exp\" > /tmp/p.exp 2>/dev/null; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" /tmp/p.act > /tmp/p.actn; if diff -q /tmp/p.exp /tmp/p.actn >/dev/null 2>&1; then echo \"PASS\\($b\\)\"; else echo \"DIFF\\($b\\)\"; fi; done')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== if-combine3.sy ===\"; cat test/test_case/performance/if-combine3.sy; echo \"=== .out ===\"; cat test/test_case/performance/if-combine3.out; echo \"=== .in? ===\"; ls test/test_case/performance/if-combine3.in 2>/dev/null && cat test/test_case/performance/if-combine3.in')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== 循环尾部\\(最后30行\\) ===\"; tail -30 test/test_case/performance/if-combine3.sy; echo \"=== .in ===\"; cat test/test_case/performance/if-combine3.in 2>/dev/null || echo \"\\(无\\)\"; echo \"=== .out ===\"; cat test/test_case/performance/if-combine3.out')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && ./build/bin/compiler --emit-asm test/test_case/performance/if-combine3.sy > /tmp/ic.s 2>/dev/null && aarch64-linux-gnu-gcc /tmp/ic.s /tmp/sylib.o -o /tmp/ic.exe 2>/dev/null && echo \"5\" | timeout 20 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/ic.exe 2>/dev/null; echo \"exit=$? \\(小输入 n=5验证逻辑\\)\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"50000000\" | timeout 280 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/ic.exe > /tmp/ic.out 2>&1; ec=$?; echo \"exit=$ec\"; cat /tmp/ic.out')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && for b in 2025-MYO-20 gameoflife-oscillator; do f=test/test_case/performance/$b.sy; ./build/bin/compiler --emit-asm \"$f\" > /tmp/x.s 2>/tmp/x.e || { echo \"$b COMPILE FAIL\"; head -3 /tmp/x.e; continue; }; aarch64-linux-gnu-gcc /tmp/x.s /tmp/sylib.o -o /tmp/x.exe 2>/tmp/x.a || { echo \"$b ASM FAIL\"; head -3 /tmp/x.a; continue; }; in=test/test_case/performance/$b.in; if [ -f \"$in\" ]; then timeout 280 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/x.exe < \"$in\" > /tmp/x.out 2>&1; else timeout 280 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/x.exe > /tmp/x.out 2>&1; fi; ec=$?; if [ $ec -eq 124 ]; then echo \"$b STILL TIMEOUT\\(>280s\\)\"; continue; fi; exp=test/test_case/performance/$b.out; { cat /tmp/x.out; [ -s /tmp/x.out ] && [ $\\(tail -c1 /tmp/x.out|wc -l\\) -eq 0 ] && printf \"\\\\n\"; printf \"%s\\\\n\" \"$ec\"; } > /tmp/x.act; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" \"$exp\" > /tmp/x.exp 2>/dev/null; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" /tmp/x.act > /tmp/x.actn; if diff -q /tmp/x.exp /tmp/x.actn >/dev/null 2>&1; then echo \"$b PASS\"; else echo \"$b DIFF:\"; diff /tmp/x.exp /tmp/x.actn | head; fi; done')"
]
}
}

5
.gitignore vendored

@ -68,8 +68,3 @@ Thumbs.db
# Project outputs # Project outputs
# ========================= # =========================
test/test_result/ test/test_result/
# Added by cargo
/target

@ -2,8 +2,6 @@ cmake_minimum_required(VERSION 3.20)
project(compiler LANGUAGES C CXX) project(compiler LANGUAGES C CXX)
find_package(Java REQUIRED COMPONENTS Runtime)
# C++ # C++
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
@ -33,7 +31,7 @@ target_include_directories(build_options INTERFACE
option(COMPILER_ENABLE_WARNINGS "Enable common compiler warnings" ON) option(COMPILER_ENABLE_WARNINGS "Enable common compiler warnings" ON)
if(COMPILER_ENABLE_WARNINGS) if(COMPILER_ENABLE_WARNINGS)
if(MSVC) if(MSVC)
target_compile_options(build_options INTERFACE /W4 /utf-8) target_compile_options(build_options INTERFACE /W4)
else() else()
target_compile_options(build_options INTERFACE -Wall -Wextra -Wpedantic) target_compile_options(build_options INTERFACE -Wall -Wextra -Wpedantic)
endif() endif()
@ -41,18 +39,12 @@ endif()
option(COMPILER_PARSE_ONLY "Build only the frontend parser pipeline" OFF) option(COMPILER_PARSE_ONLY "Build only the frontend parser pipeline" OFF)
set(ANTLR4_JAR "${PROJECT_SOURCE_DIR}/third_party/antlr-4.13.2-complete.jar")
if(NOT EXISTS "${ANTLR4_JAR}")
message(FATAL_ERROR "ANTLR jar not found: ${ANTLR4_JAR}")
endif()
# 使 third_party ANTLR4 C++ runtime # 使 third_party ANTLR4 C++ runtime
# third_party runtime third_party/antlr4-runtime-4.13.2/runtime/src # third_party runtime third_party/antlr4-runtime-4.13.2/runtime/src
set(ANTLR4_RUNTIME_SRC_DIR "${PROJECT_SOURCE_DIR}/third_party/antlr4-runtime-4.13.2/runtime/src") set(ANTLR4_RUNTIME_SRC_DIR "${PROJECT_SOURCE_DIR}/third_party/antlr4-runtime-4.13.2/runtime/src")
add_library(antlr4_runtime STATIC) add_library(antlr4_runtime STATIC)
target_compile_features(antlr4_runtime PUBLIC cxx_std_17) target_compile_features(antlr4_runtime PUBLIC cxx_std_17)
target_compile_definitions(antlr4_runtime PUBLIC ANTLR4CPP_STATIC)
target_include_directories(antlr4_runtime PUBLIC target_include_directories(antlr4_runtime PUBLIC
"${ANTLR4_RUNTIME_SRC_DIR}" "${ANTLR4_RUNTIME_SRC_DIR}"

7
Cargo.lock generated

@ -1,7 +0,0 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "nudt-compiler-cpp"
version = "0.1.0"

@ -1,6 +0,0 @@
[package]
name = "nudt-compiler-cpp"
version = "0.1.0"
edition = "2024"
[dependencies]

@ -109,8 +109,3 @@ cmake --build build -j "$(nproc)"
目标:脚本自动读取同名 `.in`,并将程序输出与退出码和同名 `.out` 比对,确保优化后程序行为与优化前保持一致。 目标:脚本自动读取同名 `.in`,并将程序输出与退出码和同名 `.out` 比对,确保优化后程序行为与优化前保持一致。
完成 Lab4 后,应对 `test/test_case` 下全部测试用例逐个回归;如有需要,也可以自行编写批量测试脚本统一执行。 完成 Lab4 后,应对 `test/test_case` 下全部测试用例逐个回归;如有需要,也可以自行编写批量测试脚本统一执行。
批量测试脚本:
bash test/test_result/lab4_batch/run_all.sh

@ -1,15 +1,37 @@
// 扩展后的 IR 库: // 当前只支撑 i32、i32*、void 以及最小的内存/算术指令,演示用。
// - 完整基础类型void/i1/i32/float/ptr/array/function/label //
// - 指令算术、比较、分支、调用、phi、gep、类型转换等 // 当前已经实现:
// - 常量int/float/array // 1. 基础类型系统void / i32 / i32*
// - 基本块/函数/模块/IRBuilder 的完整接口 // 2. Value 体系Value / ConstantValue / ConstantInt / Function / BasicBlock / User / GlobalValue / Instruction
// 3. 最小指令集Add / Alloca / Load / Store / Ret
// 4. BasicBlock / Function / Module 三层组织结构
// 5. IRBuilder便捷创建常量和最小指令
// 6. def-use 关系的轻量实现:
// - Instruction 保存 operand 列表
// - Value 保存 uses
// - 支持 ReplaceAllUsesWith 的简化实现
//
// 当前尚未实现或只做了最小占位:
// 1. 完整类型系统数组、函数类型、label 类型等
// 2. 更完整的指令系统br / condbr / call / phi / gep 等
// 3. 更成熟的 Use 管理(例如 LLVM 风格的双向链式结构)
// 4. 更完整的 IR verifier 和优化基础设施
//
// 当前需要特别说明的两个简化点:
// 1. BasicBlock 虽然已经纳入 Value 体系,但其类型目前仍用 void 作为占位,
// 后续如果补 label type可以再改成更合理的块标签类型。
// 2. ConstantValue 体系目前只实现了 ConstantInt后续可以继续补 ConstantFloat、
// ConstantArray等更完整的常量种类。
//
// 建议的扩展顺序:
// 1. 先补更多指令和类型
// 2. 再补控制流相关 IR
// 3. 最后再考虑把 Value/User/Use 进一步抽象成更完整的框架
#pragma once #pragma once
#include <cstdint>
#include <iosfwd> #include <iosfwd>
#include <memory> #include <memory>
#include <optional>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
@ -23,14 +45,10 @@ class Value;
class User; class User;
class ConstantValue; class ConstantValue;
class ConstantInt; class ConstantInt;
class ConstantFloat;
class ConstantArray;
class GlobalValue; class GlobalValue;
class GlobalVariable;
class Instruction; class Instruction;
class BasicBlock; class BasicBlock;
class Function; class Function;
class Argument;
// Use 表示一个 Value 的一次使用记录。 // Use 表示一个 Value 的一次使用记录。
// 当前实现设计: // 当前实现设计:
@ -65,65 +83,31 @@ class Context {
~Context(); ~Context();
// 去重创建 i32 常量。 // 去重创建 i32 常量。
ConstantInt* GetConstInt(int v); ConstantInt* GetConstInt(int v);
ConstantInt* GetConstBool(bool v);
ConstantFloat* GetConstFloat(float v);
ConstantArray* CreateConstArray(std::shared_ptr<Type> array_ty,
std::vector<ConstantValue*> elements);
std::string NextTemp(); std::string NextTemp();
private: private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_; std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_bools_;
std::unordered_map<uint32_t, std::unique_ptr<ConstantFloat>> const_floats_;
std::vector<std::unique_ptr<ConstantArray>> const_arrays_;
int temp_index_ = -1; int temp_index_ = -1;
}; };
class Type { class Type {
public: public:
enum class Kind { Void, Int1, Int32, Float, Pointer, Array, Function, Label }; enum class Kind { Void, Int32, PtrInt32 };
explicit Type(Kind k); explicit Type(Kind k);
Type(Kind k, std::shared_ptr<Type> elem, size_t count);
Type(Kind k, std::shared_ptr<Type> ret, std::vector<std::shared_ptr<Type>> params,
bool is_vararg);
// 使用静态共享对象获取类型。 // 使用静态共享对象获取类型。
// 同一类型可直接比较返回值是否相等,例如: // 同一类型可直接比较返回值是否相等,例如:
// Type::GetInt32Type() == Type::GetInt32Type() // Type::GetInt32Type() == Type::GetInt32Type()
static const std::shared_ptr<Type>& GetVoidType(); static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt1Type();
static const std::shared_ptr<Type>& GetInt32Type(); static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetFloatType(); static const std::shared_ptr<Type>& GetPtrInt32Type();
static const std::shared_ptr<Type>& GetLabelType();
static std::shared_ptr<Type> GetPointerType(std::shared_ptr<Type> elem);
static std::shared_ptr<Type> GetArrayType(std::shared_ptr<Type> elem,
size_t count);
static std::shared_ptr<Type> GetFunctionType(
std::shared_ptr<Type> ret, std::vector<std::shared_ptr<Type>> params,
bool is_vararg = false);
Kind GetKind() const; Kind GetKind() const;
bool IsVoid() const; bool IsVoid() const;
bool IsInt1() const;
bool IsInt32() const; bool IsInt32() const;
bool IsFloat() const; bool IsPtrInt32() const;
bool IsPointer() const;
bool IsArray() const;
bool IsFunction() const;
bool IsLabel() const;
const std::shared_ptr<Type>& GetElementType() const;
size_t GetArraySize() const;
const std::shared_ptr<Type>& GetReturnType() const;
const std::vector<std::shared_ptr<Type>>& GetParamTypes() const;
bool IsVarArg() const;
bool Equals(const Type& other) const;
private: private:
Kind kind_; Kind kind_;
std::shared_ptr<Type> elem_type_;
size_t array_size_ = 0;
std::shared_ptr<Type> ret_type_;
std::vector<std::shared_ptr<Type>> param_types_;
bool is_vararg_ = false;
}; };
class Value { class Value {
@ -134,12 +118,7 @@ class Value {
const std::string& GetName() const; const std::string& GetName() const;
void SetName(std::string n); void SetName(std::string n);
bool IsVoid() const; bool IsVoid() const;
bool IsInt1() const;
bool IsInt32() const; bool IsInt32() const;
bool IsFloat() const;
bool IsPointer() const;
bool IsArray() const;
bool IsFunctionType() const;
bool IsPtrInt32() const; bool IsPtrInt32() const;
bool IsConstant() const; bool IsConstant() const;
bool IsInstruction() const; bool IsInstruction() const;
@ -172,53 +151,8 @@ class ConstantInt : public ConstantValue {
int value_{}; int value_{};
}; };
class ConstantFloat : public ConstantValue {
public:
ConstantFloat(std::shared_ptr<Type> ty, float v);
float GetValue() const { return value_; }
private:
float value_{};
};
class ConstantArray : public ConstantValue {
public:
ConstantArray(std::shared_ptr<Type> ty, std::vector<ConstantValue*> elements);
const std::vector<ConstantValue*>& GetElements() const { return elements_; }
private:
std::vector<ConstantValue*> elements_;
};
// 后续还需要扩展更多指令类型。 // 后续还需要扩展更多指令类型。
enum class Opcode { enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret };
Add,
Sub,
Mul,
SDiv,
SRem,
FAdd,
FSub,
FMul,
FDiv,
Alloca,
Load,
Store,
Ret,
Br,
CondBr,
ICmp,
FCmp,
Call,
Phi,
Gep,
SIToFP,
FPToSI,
ZExt
};
enum class ICmpPredicate { Eq, Ne, Slt, Sle, Sgt, Sge };
enum class FCmpPredicate { Oeq, One, Olt, Ole, Ogt, Oge };
// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。 // User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。
// 当前实现中只有 Instruction 继承自 User。 // 当前实现中只有 Instruction 继承自 User。
@ -228,13 +162,10 @@ class User : public Value {
size_t GetNumOperands() const; size_t GetNumOperands() const;
Value* GetOperand(size_t index) const; Value* GetOperand(size_t index) const;
void SetOperand(size_t index, Value* value); void SetOperand(size_t index, Value* value);
void RemoveOperand(size_t index);
protected: protected:
// 统一的 operand 入口。 // 统一的 operand 入口。
void AddOperand(Value* value); void AddOperand(Value* value);
virtual void OnOperandChanged(size_t index, Value* value);
virtual void OnOperandRemoving(size_t index);
private: private:
std::vector<Value*> operands_; std::vector<Value*> operands_;
@ -247,20 +178,6 @@ class GlobalValue : public User {
GlobalValue(std::shared_ptr<Type> ty, std::string name); GlobalValue(std::shared_ptr<Type> ty, std::string name);
}; };
class GlobalVariable : public GlobalValue {
public:
GlobalVariable(std::shared_ptr<Type> value_ty, std::string name,
ConstantValue* init, bool is_const);
const std::shared_ptr<Type>& GetValueType() const;
ConstantValue* GetInitializer() const;
bool IsConst() const;
private:
std::shared_ptr<Type> value_type_;
ConstantValue* initializer_ = nullptr;
bool is_const_ = false;
};
class Instruction : public User { class Instruction : public User {
public: public:
Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name = ""); Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name = "");
@ -279,67 +196,18 @@ class BinaryInst : public Instruction {
BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, Value* rhs, BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, Value* rhs,
std::string name); std::string name);
Value* GetLhs() const; Value* GetLhs() const;
Value* GetRhs() const; Value* GetRhs() const;
};
class ICmpInst : public Instruction {
public:
ICmpInst(ICmpPredicate pred, Value* lhs, Value* rhs, std::string name);
ICmpPredicate GetPredicate() const { return pred_; }
Value* GetLhs() const;
Value* GetRhs() const;
private:
ICmpPredicate pred_;
};
class FCmpInst : public Instruction {
public:
FCmpInst(FCmpPredicate pred, Value* lhs, Value* rhs, std::string name);
FCmpPredicate GetPredicate() const { return pred_; }
Value* GetLhs() const;
Value* GetRhs() const;
private:
FCmpPredicate pred_;
};
class CastInst : public Instruction {
public:
CastInst(Opcode op, std::shared_ptr<Type> dst_ty, Value* src,
std::string name);
Value* GetValue() const;
};
class BranchInst : public Instruction {
public:
explicit BranchInst(BasicBlock* dest);
BasicBlock* GetDest() const;
};
class CondBrInst : public Instruction {
public:
CondBrInst(Value* cond, BasicBlock* true_dest, BasicBlock* false_dest);
Value* GetCond() const;
BasicBlock* GetTrueDest() const;
BasicBlock* GetFalseDest() const;
}; };
class ReturnInst : public Instruction { class ReturnInst : public Instruction {
public: public:
explicit ReturnInst(std::shared_ptr<Type> void_ty);
ReturnInst(std::shared_ptr<Type> void_ty, Value* val); ReturnInst(std::shared_ptr<Type> void_ty, Value* val);
bool HasReturnValue() const;
Value* GetValue() const; Value* GetValue() const;
}; };
class AllocaInst : public Instruction { class AllocaInst : public Instruction {
public: public:
AllocaInst(std::shared_ptr<Type> allocated_ty, std::string name); AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name);
const std::shared_ptr<Type>& GetAllocatedType() const;
private:
std::shared_ptr<Type> allocated_type_;
}; };
class LoadInst : public Instruction { class LoadInst : public Instruction {
@ -355,48 +223,8 @@ class StoreInst : public Instruction {
Value* GetPtr() const; Value* GetPtr() const;
}; };
class CallInst : public Instruction { // BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。
public: // 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。
CallInst(std::shared_ptr<Type> ret_ty, Value* callee,
std::vector<Value*> args, std::string name);
Value* GetCallee() const;
const std::vector<Value*>& GetArgs() const { return args_; }
private:
void OnOperandChanged(size_t index, Value* value) override;
void OnOperandRemoving(size_t index) override;
std::vector<Value*> args_;
};
class PhiInst : public Instruction {
public:
PhiInst(std::shared_ptr<Type> ty, std::string name);
void AddIncoming(Value* value, BasicBlock* block);
void RemoveIncomingFrom(BasicBlock* block);
const std::vector<Value*>& GetIncomingValues() const;
const std::vector<BasicBlock*>& GetIncomingBlocks() const;
private:
void OnOperandChanged(size_t index, Value* value) override;
void OnOperandRemoving(size_t index) override;
std::vector<Value*> incoming_values_;
std::vector<BasicBlock*> incoming_blocks_;
};
class GepInst : public Instruction {
public:
GepInst(std::shared_ptr<Type> result_ptr_ty, Value* base_ptr,
std::vector<Value*> indices, std::string name);
Value* GetBasePtr() const;
const std::vector<Value*>& GetIndices() const { return indices_; }
private:
void OnOperandChanged(size_t index, Value* value) override;
void OnOperandRemoving(size_t index) override;
std::vector<Value*> indices_;
};
// BasicBlock 已纳入 Value 体系,使用 label type。
class BasicBlock : public Value { class BasicBlock : public Value {
public: public:
explicit BasicBlock(std::string name); explicit BasicBlock(std::string name);
@ -404,17 +232,8 @@ class BasicBlock : public Value {
void SetParent(Function* parent); void SetParent(Function* parent);
bool HasTerminator() const; bool HasTerminator() const;
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const; const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const;
std::vector<std::unique_ptr<Instruction>>& GetMutableInstructions();
const std::vector<BasicBlock*>& GetPredecessors() const; const std::vector<BasicBlock*>& GetPredecessors() const;
const std::vector<BasicBlock*>& GetSuccessors() const; const std::vector<BasicBlock*>& GetSuccessors() const;
void AddPredecessor(BasicBlock* pred);
void AddSuccessor(BasicBlock* succ);
void ClearPredecessors();
void ClearSuccessors();
void RemovePredecessor(BasicBlock* pred);
void RemoveSuccessor(BasicBlock* succ);
void EraseInstruction(Instruction* inst);
void ReplaceTerminator(std::unique_ptr<Instruction> inst);
template <typename T, typename... Args> template <typename T, typename... Args>
T* Append(Args&&... args) { T* Append(Args&&... args) {
if (HasTerminator()) { if (HasTerminator()) {
@ -425,16 +244,6 @@ class BasicBlock : public Value {
auto* ptr = inst.get(); auto* ptr = inst.get();
ptr->SetParent(this); ptr->SetParent(this);
instructions_.push_back(std::move(inst)); instructions_.push_back(std::move(inst));
LinkSuccessorsIfNeeded(ptr);
return ptr;
}
template <typename T, typename... Args>
T* Prepend(Args&&... args) {
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.insert(instructions_.begin(), std::move(inst));
return ptr; return ptr;
} }
@ -443,7 +252,6 @@ class BasicBlock : public Value {
std::vector<std::unique_ptr<Instruction>> instructions_; std::vector<std::unique_ptr<Instruction>> instructions_;
std::vector<BasicBlock*> predecessors_; std::vector<BasicBlock*> predecessors_;
std::vector<BasicBlock*> successors_; std::vector<BasicBlock*> successors_;
void LinkSuccessorsIfNeeded(Instruction* inst);
}; };
// Function 当前也采用了最小实现。 // Function 当前也采用了最小实现。
@ -454,35 +262,16 @@ class BasicBlock : public Value {
// 形参和调用,通常需要引入专门的函数类型表示。 // 形参和调用,通常需要引入专门的函数类型表示。
class Function : public Value { class Function : public Value {
public: public:
Function(std::string name, std::shared_ptr<Type> func_type, // 当前构造函数接收的也是返回类型,而不是完整函数类型。
bool is_declaration = false); Function(std::string name, std::shared_ptr<Type> ret_type);
BasicBlock* CreateBlock(const std::string& name); BasicBlock* CreateBlock(const std::string& name);
BasicBlock* GetEntry(); BasicBlock* GetEntry();
const BasicBlock* GetEntry() const; const BasicBlock* GetEntry() const;
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const; const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const;
std::vector<std::unique_ptr<BasicBlock>>& GetMutableBlocks();
const std::vector<std::unique_ptr<Argument>>& GetArguments() const;
size_t GetNumArgs() const;
Argument* GetArg(size_t index);
std::shared_ptr<Type> GetFunctionType() const;
std::shared_ptr<Type> GetReturnType() const;
bool IsDeclaration() const;
private: private:
BasicBlock* entry_ = nullptr; BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<BasicBlock>> blocks_; std::vector<std::unique_ptr<BasicBlock>> blocks_;
std::vector<std::unique_ptr<Argument>> args_;
std::unordered_map<std::string, size_t> block_name_counts_;
bool is_declaration_ = false;
};
class Argument : public Value {
public:
Argument(std::shared_ptr<Type> ty, std::string name, size_t index);
size_t GetIndex() const { return index_; }
private:
size_t index_ = 0;
}; };
class Module { class Module {
@ -493,20 +282,11 @@ class Module {
// 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。 // 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。
Function* CreateFunction(const std::string& name, Function* CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type); std::shared_ptr<Type> ret_type);
Function* CreateFunctionWithType(const std::string& name,
std::shared_ptr<Type> func_type);
Function* CreateFunctionDecl(const std::string& name,
std::shared_ptr<Type> func_type);
GlobalVariable* CreateGlobalVariable(const std::string& name,
std::shared_ptr<Type> value_type,
ConstantValue* init, bool is_const);
const std::vector<std::unique_ptr<Function>>& GetFunctions() const; const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
const std::vector<std::unique_ptr<GlobalVariable>>& GetGlobals() const;
private: private:
Context context_; Context context_;
std::vector<std::unique_ptr<Function>> functions_; std::vector<std::unique_ptr<Function>> functions_;
std::vector<std::unique_ptr<GlobalVariable>> globals_;
}; };
class IRBuilder { class IRBuilder {
@ -517,44 +297,13 @@ class IRBuilder {
// 构造常量、二元运算、返回指令的最小集合。 // 构造常量、二元运算、返回指令的最小集合。
ConstantInt* CreateConstInt(int v); ConstantInt* CreateConstInt(int v);
ConstantInt* CreateConstBool(bool v);
ConstantFloat* CreateConstFloat(float v);
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs, BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name); const std::string& name);
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name); BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSDiv(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSRem(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFAdd(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFSub(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFMul(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFDiv(Value* lhs, Value* rhs, const std::string& name);
ICmpInst* CreateICmp(ICmpPredicate pred, Value* lhs, Value* rhs,
const std::string& name);
FCmpInst* CreateFCmp(FCmpPredicate pred, Value* lhs, Value* rhs,
const std::string& name);
CastInst* CreateSIToFP(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name);
CastInst* CreateFPToSI(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name);
CastInst* CreateZExt(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name);
AllocaInst* CreateAlloca(std::shared_ptr<Type> ty,
const std::string& name);
AllocaInst* CreateAllocaI32(const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr); StoreInst* CreateStore(Value* val, Value* ptr);
GepInst* CreateGep(Value* base_ptr, std::vector<Value*> indices,
const std::string& name);
CallInst* CreateCall(Value* callee, std::vector<Value*> args,
const std::string& name);
PhiInst* CreatePhi(std::shared_ptr<Type> ty, const std::string& name);
BranchInst* CreateBr(BasicBlock* dest);
CondBrInst* CreateCondBr(Value* cond, BasicBlock* true_dest,
BasicBlock* false_dest);
ReturnInst* CreateRet(Value* v); ReturnInst* CreateRet(Value* v);
ReturnInst* CreateRetVoid();
private: private:
Context& ctx_; Context& ctx_;
@ -566,12 +315,4 @@ class IRPrinter {
void Print(const Module& module, std::ostream& os); void Print(const Module& module, std::ostream& os);
}; };
bool RunMem2Reg(Module& module);
bool RunConstFold(Module& module);
bool RunConstProp(Module& module);
bool RunCSE(Module& module);
bool RunDCE(Module& module);
bool RunCFGSimplify(Module& module);
bool RunScalarOptimizationPipeline(Module& module);
} // namespace ir } // namespace ir

@ -1,114 +1,58 @@
// 将语法树翻译为 IR。
// 实现拆分在 IRGenFunc/IRGenStmt/IRGenExp/IRGenDecl。
#pragma once #pragma once
#include <any> #include <any>
#include <memory> #include <memory>
#include <string>
#include <unordered_map> #include <unordered_map>
#include "SysYBaseVisitor.h" #include "SysYBaseVisitor.h"
#include "SysYParser.h"
#include "ir/IR.h" #include "ir/IR.h"
#include "sem/Sema.h" #include "sem/Sema.h"
namespace ir {
class Module;
class Function;
class IRBuilder;
class Value;
}
class IRGenImpl final : public SysYBaseVisitor { class IRGenImpl final : public SysYBaseVisitor {
public: public:
IRGenImpl(ir::Module& module, const SemanticContext& sema); IRGenImpl(ir::Module& module, const SemanticContext& sema);
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override;
std::any visitBlock(SysYParser::BlockContext* ctx) override; std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override;
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override;
std::any visitDecl(SysYParser::DeclContext* ctx) override; std::any visitDecl(SysYParser::DeclContext* ctx) override;
std::any visitConstDef(SysYParser::ConstDefContext* ctx) override;
std::any visitVarDef(SysYParser::VarDefContext* ctx) override;
std::any visitStmt(SysYParser::StmtContext* ctx) override; std::any visitStmt(SysYParser::StmtContext* ctx) override;
std::any visitVarDef(SysYParser::VarDefContext* ctx) override;
std::any visitExp(SysYParser::ExpContext* ctx) override; std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override;
std::any visitAddExp(SysYParser::AddExpContext* ctx) override; std::any visitParenExp(SysYParser::ParenExpContext* ctx) override;
std::any visitMulExp(SysYParser::MulExpContext* ctx) override; // 新增 std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override;
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override; // 新增 std::any visitVarExp(SysYParser::VarExpContext* ctx) override;
std::any visitRelExp(SysYParser::RelExpContext* ctx) override; std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override;
std::any visitEqExp(SysYParser::EqExpContext* ctx) override;
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override;
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override;
std::any visitCond(SysYParser::CondContext* ctx) override;
std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override;
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override;
std::any visitNumber(SysYParser::NumberContext* ctx) override;
std::any visitLVal(SysYParser::LValContext* ctx) override;
std::any visitConstExp(SysYParser::ConstExpContext* ctx) override;
std::any visitConstInitVal(SysYParser::ConstInitValContext* ctx) override;
std::any visitInitVal(SysYParser::InitValContext* ctx) override;
private: private:
ir::Value* EvalExp(SysYParser::ExpContext* ctx); enum class BlockFlow {
ir::Value* EvalCondValue(SysYParser::CondContext* ctx); Continue,
void EmitCondBr(SysYParser::CondContext* ctx, ir::BasicBlock* true_bb, Terminated,
ir::BasicBlock* false_bb); };
void EmitLOrCond(SysYParser::LOrExpContext* ctx, ir::BasicBlock* true_bb,
ir::BasicBlock* false_bb);
void EmitLAndCond(SysYParser::LAndExpContext* ctx, ir::BasicBlock* true_bb,
ir::BasicBlock* false_bb);
ir::Value* EmitRelEq(SysYParser::RelExpContext* ctx);
ir::Value* EmitEq(SysYParser::EqExpContext* ctx);
ir::Value* CastToFloat(ir::Value* v);
ir::Value* CastToInt(ir::Value* v);
ir::Value* MakeBool(ir::Value* v);
ir::Value* GetLValAddress(SysYParser::LValContext* ctx);
ir::Value* LoadIfNeeded(ir::Value* addr_or_val, const TypeDesc& ty,
bool as_rvalue);
std::shared_ptr<ir::Type> ToIRType(const TypeDesc& ty);
std::shared_ptr<ir::Type> ToIRParamType(const TypeDesc& ty);
ir::Value* DefaultValue(const TypeDesc& ty);
ir::AllocaInst* CreateEntryAlloca(std::shared_ptr<ir::Type> ty,
const std::string& name);
void InitArray(ir::Value* base_ptr, const TypeDesc& ty,
SysYParser::InitValContext* init);
void InitConstArray(ir::Value* base_ptr, const TypeDesc& ty,
SysYParser::ConstInitValContext* init);
size_t FillArrayValues(const TypeDesc& ty, SysYParser::InitValContext* init,
std::vector<ir::Value*>& values, size_t base,
size_t idx, size_t dim);
size_t FillConstArrayValues(const TypeDesc& ty,
SysYParser::ConstInitValContext* init,
std::vector<ir::Value*>& values, size_t base,
size_t idx, size_t dim);
size_t ArrayStride(const TypeDesc& ty, size_t dim) const;
size_t ArrayTotalSize(const TypeDesc& ty) const;
void PushLoop(ir::BasicBlock* break_bb, ir::BasicBlock* cont_bb);
void PopLoop();
ir::BasicBlock* CurrentBreak() const;
ir::BasicBlock* CurrentContinue() const;
ir::ConstantValue* EvalConstScalar(SysYParser::ExpContext* ctx);
ir::ConstantValue* EvalConstScalar(SysYParser::ConstExpContext* ctx);
ir::ConstantValue* EvalConstAdd(SysYParser::AddExpContext* ctx);
ir::ConstantValue* EvalConstMul(SysYParser::MulExpContext* ctx);
ir::ConstantValue* EvalConstUnary(SysYParser::UnaryExpContext* ctx);
ir::ConstantValue* EvalConstPrimary(SysYParser::PrimaryExpContext* ctx);
ir::ConstantValue* EvalConstNumber(SysYParser::NumberContext* ctx);
ir::ConstantValue* EvalConstLVal(SysYParser::LValContext* ctx);
size_t InitGlobalArray(const TypeDesc& ty, SysYParser::InitValContext* init,
std::vector<ir::ConstantValue*>& values, size_t base,
size_t idx, size_t dim);
size_t InitGlobalConstArray(const TypeDesc& ty,
SysYParser::ConstInitValContext* init,
std::vector<ir::ConstantValue*>& values,
size_t base, size_t idx, size_t dim);
enum class BlockFlow { Continue, Terminated };
BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item); BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item);
ir::Value* EvalExpr(SysYParser::ExpContext& expr);
ir::Module& module_; ir::Module& module_;
const SemanticContext& sema_; const SemanticContext& sema_;
ir::Function* func_; ir::Function* func_;
ir::IRBuilder builder_; ir::IRBuilder builder_;
std::unordered_map<const SysYParser::VarDefContext*, ir::Value*> var_storage_; // 名称绑定由 Sema 负责IRGen 只维护“声明 -> 存储槽位”的代码生成状态。
std::unordered_map<const SysYParser::ConstDefContext*, ir::Value*> const_storage_; std::unordered_map<SysYParser::VarDefContext*, ir::Value*> storage_map_;
std::unordered_map<const SysYParser::FuncFParamContext*, ir::Value*> param_storage_;
std::unordered_map<const SysYParser::FuncDefContext*, ir::Function*> func_map_;
std::unordered_map<const SysYParser::VarDefContext*, ir::GlobalVariable*>
global_var_storage_;
std::unordered_map<const SysYParser::ConstDefContext*, ir::GlobalVariable*>
global_const_storage_;
std::vector<std::pair<ir::BasicBlock*, ir::BasicBlock*>> loop_stack_;
}; };
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree, std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemanticContext& sema); const SemanticContext& sema);

@ -1,10 +1,6 @@
// Lab5 后端 MIR 表示:
// - 虚拟寄存器 + 物理寄存器,两类寄存器(GPR/FPR)
// - 多函数、多基本块、全局变量、栈对象
// - 指令携带显式 def/use约定操作数前 num_defs 个为定值)
#pragma once #pragma once
#include <cstdint> #include <initializer_list>
#include <iosfwd> #include <iosfwd>
#include <memory> #include <memory>
#include <string> #include <string>
@ -20,273 +16,104 @@ class MIRContext {
public: public:
MIRContext() = default; MIRContext() = default;
}; };
MIRContext& DefaultContext();
enum class RegClass { GPR, FPR }; MIRContext& DefaultContext();
// 物理寄存器编号(按类内编号): enum class PhysReg { W0, W8, W9, X29, X30, SP };
// GPR: 0..30 = x0..x30, 31 = sp, 32 = xzr
// FPR: 0..31 = s0..s31
namespace preg {
constexpr int kSP = 31;
constexpr int kXZR = 32;
constexpr int kFP = 29; // x29
constexpr int kLR = 30; // x30
constexpr int kIP0 = 16; // x16 scratch
constexpr int kIP1 = 17; // x17 scratch
} // namespace preg
enum class Cond { AL, EQ, NE, LT, LE, GT, GE, MI, LS, HI, HS }; const char* PhysRegName(PhysReg reg);
enum class Opcode { enum class Opcode {
Mov, // dst<-src (reg copy) Prologue,
MovImm, // dst<-imm (materialize 32/64-bit) Epilogue,
Sxtw, // dst(64) = sign-extend src(32) MovImm,
Add, // dst = a + b LoadStack,
Sub, // dst = a - b StoreStack,
Mul, // dst = a * b AddRR,
SDiv, // dst = a / b
MSub, // dst = a - b*c
AddImm, // dst = a + imm
SubImm, // dst = a - imm
LslImm, // dst = a << imm
Cmp, // a ? b (sets flags)
CmpImm, // a ? imm
CSet, // dst = cond
FAdd,
FSub,
FMul,
FDiv,
FCmp,
FMov, // fpr<-fpr
FMovImm, // fpr <- 32-bit float bits (via scratch gpr)
SCvtF, // fpr = (float)gpr
FCvtZS, // gpr = (int)fpr
Ldr, // dst <- [base, #imm]
Str, // [base, #imm] <- src (def=0, ops: src,base,imm)
LdrStack, // dst <- [frame]
StrStack, // [frame] <- src (def=0, ops: src, frame)
AddrFrame, // dst = addr of frame slot
AddrGlobal, // dst = addr of global symbol
B, // branch label
BCond, // branch cond label (uses flags)
Bl, // call symbol (def=0; ops list arg pregs as uses)
Ret, Ret,
}; };
class Operand { class Operand {
public: public:
enum class Kind { None, VReg, PReg, Imm, Frame, Global, Label }; enum class Kind { Reg, Imm, FrameIndex };
static Operand VReg(int id, RegClass cls, int bytes) { static Operand Reg(PhysReg reg);
Operand o; static Operand Imm(int value);
o.kind_ = Kind::VReg; static Operand FrameIndex(int index);
o.id_ = id;
o.cls_ = cls;
o.bytes_ = bytes;
return o;
}
static Operand PReg(int id, RegClass cls, int bytes) {
Operand o;
o.kind_ = Kind::PReg;
o.id_ = id;
o.cls_ = cls;
o.bytes_ = bytes;
return o;
}
static Operand Imm(long long v) {
Operand o;
o.kind_ = Kind::Imm;
o.imm_ = v;
return o;
}
static Operand Frame(int idx) {
Operand o;
o.kind_ = Kind::Frame;
o.id_ = idx;
return o;
}
static Operand Global(std::string name) {
Operand o;
o.kind_ = Kind::Global;
o.sym_ = std::move(name);
return o;
}
static Operand Label(std::string name) {
Operand o;
o.kind_ = Kind::Label;
o.sym_ = std::move(name);
return o;
}
Kind GetKind() const { return kind_; } Kind GetKind() const { return kind_; }
bool IsReg() const { return kind_ == Kind::VReg || kind_ == Kind::PReg; } PhysReg GetReg() const { return reg_; }
bool IsVReg() const { return kind_ == Kind::VReg; } int GetImm() const { return imm_; }
bool IsPReg() const { return kind_ == Kind::PReg; } int GetFrameIndex() const { return imm_; }
int GetId() const { return id_; }
RegClass GetClass() const { return cls_; }
int GetBytes() const { return bytes_; }
long long GetImm() const { return imm_; }
int GetFrame() const { return id_; }
const std::string& GetSym() const { return sym_; }
void SetPReg(int id) {
kind_ = Kind::PReg;
id_ = id;
}
void SetVReg(int id) {
kind_ = Kind::VReg;
id_ = id;
}
void SetBytes(int b) { bytes_ = b; }
private: private:
Kind kind_ = Kind::None; Operand(Kind kind, PhysReg reg, int imm);
int id_ = 0;
RegClass cls_ = RegClass::GPR;
int bytes_ = 4;
long long imm_ = 0;
std::string sym_;
};
struct MachineInstr { Kind kind_;
Opcode op; PhysReg reg_;
std::vector<Operand> ops; int imm_;
int num_defs = 0;
Cond cond = Cond::AL;
MachineInstr(Opcode o, std::vector<Operand> operands, int defs,
Cond c = Cond::AL)
: op(o), ops(std::move(operands)), num_defs(defs), cond(c) {}
}; };
class MachineBasicBlock { class MachineInstr {
public: public:
explicit MachineBasicBlock(std::string name) : name_(std::move(name)) {} MachineInstr(Opcode opcode, std::vector<Operand> operands = {});
const std::string& GetName() const { return name_; }
std::vector<MachineInstr>& Instrs() { return instrs_; }
const std::vector<MachineInstr>& Instrs() const { return instrs_; }
std::vector<MachineBasicBlock*>& Succs() { return succs_; }
const std::vector<MachineBasicBlock*>& Succs() const { return succs_; }
void Add(MachineInstr mi) { instrs_.push_back(std::move(mi)); } Opcode GetOpcode() const { return opcode_; }
const std::vector<Operand>& GetOperands() const { return operands_; }
private: private:
std::string name_; Opcode opcode_;
std::vector<MachineInstr> instrs_; std::vector<Operand> operands_;
std::vector<MachineBasicBlock*> succs_;
}; };
struct VRegInfo { struct FrameSlot {
RegClass cls = RegClass::GPR;
int bytes = 4;
};
struct StackObject {
int index = 0; int index = 0;
int size = 4; int size = 4;
int align = 4; int offset = 0;
int offset = 0; // 相对 x29负数
}; };
class MachineFunction { class MachineBasicBlock {
public: public:
explicit MachineFunction(std::string name) : name_(std::move(name)) {} explicit MachineBasicBlock(std::string name);
const std::string& GetName() const { return name_; } const std::string& GetName() const { return name_; }
std::vector<MachineInstr>& GetInstructions() { return instructions_; }
const std::vector<MachineInstr>& GetInstructions() const { return instructions_; }
MachineBasicBlock* CreateBlock(const std::string& name) { MachineInstr& Append(Opcode opcode,
blocks_.push_back(std::make_unique<MachineBasicBlock>(name)); std::initializer_list<Operand> operands = {});
return blocks_.back().get();
}
const std::vector<std::unique_ptr<MachineBasicBlock>>& Blocks() const {
return blocks_;
}
std::vector<std::unique_ptr<MachineBasicBlock>>& Blocks() { return blocks_; }
int NewVReg(RegClass cls, int bytes) { private:
int id = static_cast<int>(vregs_.size()); std::string name_;
vregs_.push_back(VRegInfo{cls, bytes}); std::vector<MachineInstr> instructions_;
return id; };
}
Operand NewVRegOp(RegClass cls, int bytes) {
return Operand::VReg(NewVReg(cls, bytes), cls, bytes);
}
int NumVRegs() const { return static_cast<int>(vregs_.size()); }
const VRegInfo& VReg(int id) const { return vregs_[id]; }
int CreateStackObject(int size, int align) { class MachineFunction {
StackObject obj; public:
obj.index = static_cast<int>(stack_.size()); explicit MachineFunction(std::string name);
obj.size = size;
obj.align = align;
stack_.push_back(obj);
return obj.index;
}
std::vector<StackObject>& StackObjects() { return stack_; }
const std::vector<StackObject>& StackObjects() const { return stack_; }
StackObject& Stack(int idx) { return stack_[idx]; }
int GetFrameSize() const { return frame_size_; } const std::string& GetName() const { return name_; }
void SetFrameSize(int s) { frame_size_ = s; } MachineBasicBlock& GetEntry() { return entry_; }
const MachineBasicBlock& GetEntry() const { return entry_; }
// 寄存器分配产物:本函数用到、需要保存恢复的 callee-saved 物理寄存器 int CreateFrameIndex(int size = 4);
std::vector<int>& CalleeSavedGPR() { return callee_gpr_; } FrameSlot& GetFrameSlot(int index);
std::vector<int>& CalleeSavedFPR() { return callee_fpr_; } const FrameSlot& GetFrameSlot(int index) const;
const std::vector<int>& CalleeSavedGPR() const { return callee_gpr_; } const std::vector<FrameSlot>& GetFrameSlots() const { return frame_slots_; }
const std::vector<int>& CalleeSavedFPR() const { return callee_fpr_; }
int NumIntArgs() const { return num_int_args_; } int GetFrameSize() const { return frame_size_; }
int NumFloatArgs() const { return num_float_args_; } void SetFrameSize(int size) { frame_size_ = size; }
void SetArgCounts(int i, int f) {
num_int_args_ = i;
num_float_args_ = f;
}
private: private:
std::string name_; std::string name_;
std::vector<std::unique_ptr<MachineBasicBlock>> blocks_; MachineBasicBlock entry_;
std::vector<VRegInfo> vregs_; std::vector<FrameSlot> frame_slots_;
std::vector<StackObject> stack_;
std::vector<int> callee_gpr_;
std::vector<int> callee_fpr_;
int frame_size_ = 0; int frame_size_ = 0;
int num_int_args_ = 0;
int num_float_args_ = 0;
};
struct MachineGlobal {
std::string name;
int size = 0; // 总字节数
int align = 4;
bool is_const = false;
bool zero_init = true;
std::vector<unsigned> words; // 非零初始化:按 4 字节小端存放的原始位
}; };
class MachineModule { std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
public:
std::vector<std::unique_ptr<MachineFunction>>& Functions() { return funcs_; }
const std::vector<std::unique_ptr<MachineFunction>>& Functions() const {
return funcs_;
}
std::vector<MachineGlobal>& Globals() { return globals_; }
const std::vector<MachineGlobal>& Globals() const { return globals_; }
private:
std::vector<std::unique_ptr<MachineFunction>> funcs_;
std::vector<MachineGlobal> globals_;
};
const char* GPRName(int id, int bytes);
const char* FPRName(int id, int bytes);
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineFunction& function); void RunRegAlloc(MachineFunction& function);
void RunFrameLowering(MachineFunction& function); void RunFrameLowering(MachineFunction& function);
void RunPeephole(MachineFunction& function); void PrintAsm(const MachineFunction& function, std::ostream& os);
void RunBackendPipeline(MachineModule& module);
void PrintAsm(const MachineModule& module, std::ostream& os);
void PrintAArch64AsmFromMIR(const ir::Module& module, std::ostream& os);
void PrintAArch64AsmFromIR(const ir::Module& module, std::ostream& os);
} // namespace mir } // namespace mir

@ -1,91 +1,30 @@
// 基于语法树的语义检查与名称绑定Lab2 扩展) // 基于语法树的语义检查与名称绑定
#pragma once #pragma once
#include <optional>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "SysYParser.h" #include "SysYParser.h"
#include "sem/SymbolTable.h"
struct FuncTypeDesc {
TypeDesc ret;
std::vector<TypeDesc> params;
};
struct BoundDecl {
enum class Kind { Var, Const, Param } kind = Kind::Var;
SysYParser::VarDefContext* var_decl = nullptr;
SysYParser::ConstDefContext* const_decl = nullptr;
SysYParser::FuncFParamContext* param_decl = nullptr;
};
class SemanticContext { class SemanticContext {
public: public:
void BindVarUse(SysYParser::LValContext* use, BoundDecl decl) { void BindVarUse(SysYParser::VarContext* use,
SysYParser::VarDefContext* decl) {
var_uses_[use] = decl; var_uses_[use] = decl;
} }
BoundDecl ResolveVarUse(const SysYParser::LValContext* use) const { SysYParser::VarDefContext* ResolveVarUse(
const SysYParser::VarContext* use) const {
auto it = var_uses_.find(use); auto it = var_uses_.find(use);
return it == var_uses_.end() ? BoundDecl{} : it->second; return it == var_uses_.end() ? nullptr : it->second;
}
void RegisterVarDecl(SysYParser::VarDefContext* decl, TypeDesc ty) {
var_types_[decl] = std::move(ty);
}
void RegisterConstDecl(SysYParser::ConstDefContext* decl, TypeDesc ty) {
const_types_[decl] = std::move(ty);
}
void RegisterParam(SysYParser::FuncFParamContext* decl, TypeDesc ty) {
param_types_[decl] = std::move(ty);
}
void RegisterFunc(SysYParser::FuncDefContext* decl, FuncTypeDesc ty) {
func_types_[decl] = std::move(ty);
}
const TypeDesc* GetVarType(const SysYParser::VarDefContext* decl) const {
auto it = var_types_.find(decl);
return it == var_types_.end() ? nullptr : &it->second;
}
const TypeDesc* GetConstType(const SysYParser::ConstDefContext* decl) const {
auto it = const_types_.find(decl);
return it == const_types_.end() ? nullptr : &it->second;
}
const TypeDesc* GetParamType(const SysYParser::FuncFParamContext* decl) const {
auto it = param_types_.find(decl);
return it == param_types_.end() ? nullptr : &it->second;
}
const FuncTypeDesc* GetFuncType(const SysYParser::FuncDefContext* decl) const {
auto it = func_types_.find(decl);
return it == func_types_.end() ? nullptr : &it->second;
}
void BindFuncCall(SysYParser::UnaryExpContext* call,
SysYParser::FuncDefContext* decl) {
func_calls_[call] = decl;
}
SysYParser::FuncDefContext* ResolveFuncCall(
const SysYParser::UnaryExpContext* call) const {
auto it = func_calls_.find(call);
return it == func_calls_.end() ? nullptr : it->second;
} }
private: private:
std::unordered_map<const SysYParser::LValContext*, BoundDecl> var_uses_; std::unordered_map<const SysYParser::VarContext*,
std::unordered_map<const SysYParser::VarDefContext*, TypeDesc> var_types_; SysYParser::VarDefContext*>
std::unordered_map<const SysYParser::ConstDefContext*, TypeDesc> const_types_; var_uses_;
std::unordered_map<const SysYParser::FuncFParamContext*, TypeDesc> param_types_;
std::unordered_map<const SysYParser::FuncDefContext*, FuncTypeDesc> func_types_;
std::unordered_map<const SysYParser::UnaryExpContext*, SysYParser::FuncDefContext*>
func_calls_;
}; };
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); // 目前仅检查:
// - 变量先声明后使用
// - 局部变量不允许重复定义
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);

@ -1,42 +1,17 @@
// 符号表:记录局部变量/常量/参数定义。 // 极简符号表:记录局部变量定义
#pragma once #pragma once
#include <optional>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "SysYParser.h" #include "SysYParser.h"
enum class BaseTypeKind { Int, Float, Void };
struct TypeDesc {
BaseTypeKind base = BaseTypeKind::Int;
std::vector<int> dims; // 为空表示标量;数组维度允许首维为 -1 表示形参不定长
bool is_const = false;
};
enum class SymbolKind { Var, Const, Param };
struct SymbolEntry {
SymbolKind kind = SymbolKind::Var;
SysYParser::VarDefContext* var_decl = nullptr;
SysYParser::ConstDefContext* const_decl = nullptr;
SysYParser::FuncFParamContext* param_decl = nullptr;
TypeDesc type; // 记录类型信息
bool is_const = false;
std::optional<int> const_value;
};
class SymbolTable { class SymbolTable {
public: public:
void EnterScope(); void Add(const std::string& name, SysYParser::VarDefContext* decl);
void ExitScope(); bool Contains(const std::string& name) const;
SysYParser::VarDefContext* Lookup(const std::string& name) const;
bool ContainsInCurrentScope(const std::string& name) const;
void Add(const std::string& name, const SymbolEntry& entry);
const SymbolEntry* Lookup(const std::string& name) const;
private: private:
std::vector<std::unordered_map<std::string, SymbolEntry>> scopes_; std::unordered_map<std::string, SysYParser::VarDefContext*> table_;
}; };

@ -8,7 +8,6 @@ struct CLIOptions {
bool emit_parse_tree = false; bool emit_parse_tree = false;
bool emit_ir = true; bool emit_ir = true;
bool emit_asm = false; bool emit_asm = false;
bool optimize_ir = true;
bool show_help = false; bool show_help = false;
}; };

@ -1,19 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
# Reconfigure with IR pipeline enabled, build, then run Lab2 test script.
RESULT_FILE="test/test_result/run_lab2_result.log"
mkdir -p "$(dirname \"$RESULT_FILE\")"
: > "$RESULT_FILE"
{
echo "[run_lab2] start: $(date '+%Y-%m-%d %H:%M:%S')"
echo "[run_lab2] logging to: $RESULT_FILE"
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF
cmake --build build -j "$(nproc)"
CASE_DIR=test/test_case bash scripts/test_lab2.sh
echo "[run_lab2] end: $(date '+%Y-%m-%d %H:%M:%S')"
} 2>&1 | tee "$RESULT_FILE"

@ -1,124 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
# Lab2 quick/full verification helper.
# Usage:
# bash scripts/test_lab2.sh
# Optional env vars:
# COMPILER=./build/bin/compiler
# CASE_DIR=test/test_case
# OUT_DIR=test/test_result/lab2_ir
# LOG_FILE=test/test_result/lab2_test.log
COMPILER="${COMPILER:-./build/bin/compiler}"
CASE_DIR="${CASE_DIR:-test/test_case}"
OUT_DIR="${OUT_DIR:-test/test_result/lab2_ir}"
LOG_FILE="${LOG_FILE:-test/test_result/lab2_test.log}"
VERIFY_SCRIPT="./scripts/verify_ir.sh"
if [[ ! -x "$COMPILER" ]]; then
echo "compiler not found or not executable: $COMPILER" >&2
echo "build first:" >&2
echo " cmake -S . -B build -DCMAKE_BUILD_TYPE=Release" >&2
echo " cmake --build build -j \"\$(nproc)\"" >&2
exit 1
fi
if [[ ! -x "$VERIFY_SCRIPT" ]]; then
echo "verify script not found or not executable: $VERIFY_SCRIPT" >&2
exit 1
fi
if [[ ! -d "$CASE_DIR" ]]; then
echo "case dir not found: $CASE_DIR" >&2
exit 1
fi
mkdir -p "$OUT_DIR"
# Preflight: ensure compiler supports IR emission (not parse-only build).
probe_input="$CASE_DIR/simple_add.sy"
probe_err="$OUT_DIR/.lab2_probe.err"
if [[ -f "$probe_input" ]]; then
set +e
"$COMPILER" --emit-ir "$probe_input" > /dev/null 2> "$probe_err"
probe_rc=$?
set -e
if [[ $probe_rc -ne 0 ]] && grep -Eiq "parse-only|IR/汇编输出已禁用" "$probe_err"; then
echo "detected parse-only compiler build, cannot run Lab2 IR tests." >&2
echo "rebuild with IR enabled:" >&2
echo " cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF" >&2
echo " cmake --build build -j \"\$(nproc)\"" >&2
rm -f "$probe_err"
exit 2
fi
rm -f "$probe_err"
fi
mkdir -p "$(dirname "$LOG_FILE")"
: > "$LOG_FILE"
echo "[Lab2] start test" | tee -a "$LOG_FILE"
echo "compiler : $COMPILER" | tee -a "$LOG_FILE"
echo "cases : $CASE_DIR" | tee -a "$LOG_FILE"
echo "out dir : $OUT_DIR" | tee -a "$LOG_FILE"
echo "[Step 1] single sample check: simple_add.sy" | tee -a "$LOG_FILE"
sample_input="$(find "$CASE_DIR" -type f -name "simple_add.sy" -print -quit)"
if [[ -z "$sample_input" ]]; then
sample_input="$(find "$CASE_DIR" -type f -name "*.sy" | sort | head -n 1)"
fi
if [[ -z "$sample_input" ]]; then
echo "single sample: FAIL (no .sy case found under $CASE_DIR)" | tee -a "$LOG_FILE"
echo "stop here. see log: $LOG_FILE" >&2
exit 1
fi
if "$VERIFY_SCRIPT" "$sample_input" "$OUT_DIR" --run >> "$LOG_FILE" 2>&1; then
echo "single sample: PASS" | tee -a "$LOG_FILE"
else
echo "single sample: FAIL" | tee -a "$LOG_FILE"
echo "stop here. see log: $LOG_FILE" >&2
exit 1
fi
echo "[Step 2] full Lab2 regression" | tee -a "$LOG_FILE"
pass=0
fail=0
total=0
failed_list=()
while IFS= read -r -d '' sy; do
total=$((total + 1))
name="$(basename "$sy")"
echo "[$total] $name" | tee -a "$LOG_FILE"
if "$VERIFY_SCRIPT" "$sy" "$OUT_DIR" --run >> "$LOG_FILE" 2>&1; then
pass=$((pass + 1))
echo " PASS" | tee -a "$LOG_FILE"
else
fail=$((fail + 1))
failed_list+=("$sy")
echo " FAIL" | tee -a "$LOG_FILE"
fi
done < <(find "$CASE_DIR" -type f -name "*.sy" -print0 | sort -z)
echo "" | tee -a "$LOG_FILE"
echo "[Summary]" | tee -a "$LOG_FILE"
echo "total: $total" | tee -a "$LOG_FILE"
echo "pass : $pass" | tee -a "$LOG_FILE"
echo "fail : $fail" | tee -a "$LOG_FILE"
if [[ $fail -gt 0 ]]; then
echo "failed cases:" | tee -a "$LOG_FILE"
for f in "${failed_list[@]}"; do
echo " - $f" | tee -a "$LOG_FILE"
done
echo "Lab2 target is not fully met yet." | tee -a "$LOG_FILE"
echo "see details in $LOG_FILE"
exit 1
fi
echo "All Lab2 cases passed. Lab2 target regression is met." | tee -a "$LOG_FILE"
echo "see details in $LOG_FILE"

@ -1,119 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
# Lab3 full backend regression helper.
# Usage:
# bash scripts/test_lab3.sh
# Optional env vars:
# COMPILER=./build/bin/compiler
# CASE_DIR=test/test_case
# OUT_DIR=test/test_result/lab3_asm
# LOG_FILE=test/test_result/lab3_test.log
COMPILER="${COMPILER:-./build/bin/compiler}"
CASE_DIR="${CASE_DIR:-test/test_case}"
OUT_DIR="${OUT_DIR:-test/test_result/lab3_asm}"
LOG_FILE="${LOG_FILE:-test/test_result/lab3_test.log}"
VERIFY_SCRIPT="./scripts/verify_asm.sh"
if [[ ! -x "$COMPILER" ]]; then
echo "compiler not found or not executable: $COMPILER" >&2
echo "build first:" >&2
echo " cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF" >&2
echo " cmake --build build -j \"\$(nproc)\"" >&2
exit 1
fi
if [[ ! -x "$VERIFY_SCRIPT" ]]; then
echo "verify script not found or not executable: $VERIFY_SCRIPT" >&2
exit 1
fi
if [[ ! -d "$CASE_DIR" ]]; then
echo "case dir not found: $CASE_DIR" >&2
exit 1
fi
mkdir -p "$OUT_DIR"
probe_input="$CASE_DIR/functional/simple_add.sy"
probe_err="$OUT_DIR/.lab3_probe.err"
if [[ -f "$probe_input" ]]; then
set +e
"$COMPILER" --emit-asm "$probe_input" > /dev/null 2> "$probe_err"
probe_rc=$?
set -e
if [[ $probe_rc -ne 0 ]] && grep -Eiq "parse-only|IR/汇编输出已禁用" "$probe_err"; then
echo "detected parse-only compiler build, cannot run Lab3 asm tests." >&2
echo "rebuild with MIR/ASM enabled:" >&2
echo " cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF" >&2
echo " cmake --build build -j \"\$(nproc)\"" >&2
rm -f "$probe_err"
exit 2
fi
rm -f "$probe_err"
fi
mkdir -p "$(dirname "$LOG_FILE")"
: > "$LOG_FILE"
echo "[Lab3] start test" | tee -a "$LOG_FILE"
echo "compiler : $COMPILER" | tee -a "$LOG_FILE"
echo "cases : $CASE_DIR" | tee -a "$LOG_FILE"
echo "out dir : $OUT_DIR" | tee -a "$LOG_FILE"
echo "[Step 1] single sample check: simple_add.sy" | tee -a "$LOG_FILE"
if [[ ! -f "$probe_input" ]]; then
echo "single sample: FAIL (missing $probe_input)" | tee -a "$LOG_FILE"
exit 1
fi
if "$VERIFY_SCRIPT" "$probe_input" "$OUT_DIR" --run >> "$LOG_FILE" 2>&1; then
echo "single sample: PASS" | tee -a "$LOG_FILE"
else
echo "single sample: FAIL" | tee -a "$LOG_FILE"
echo "stop here. see log: $LOG_FILE" >&2
exit 1
fi
echo "[Step 2] full Lab3 asm regression" | tee -a "$LOG_FILE"
pass=0
fail=0
total=0
failed_list=()
while IFS= read -r -d '' sy; do
total=$((total + 1))
name="${sy#$CASE_DIR/}"
echo "[$total] $name" | tee -a "$LOG_FILE"
if "$VERIFY_SCRIPT" "$sy" "$OUT_DIR" --run >> "$LOG_FILE" 2>&1; then
pass=$((pass + 1))
echo " PASS" | tee -a "$LOG_FILE"
else
fail=$((fail + 1))
failed_list+=("$sy")
echo " FAIL" | tee -a "$LOG_FILE"
fi
done < <(find "$CASE_DIR" -type f -name "*.sy" -print0 | sort -z)
echo "" | tee -a "$LOG_FILE"
echo "[Summary]" | tee -a "$LOG_FILE"
echo "total: $total" | tee -a "$LOG_FILE"
echo "pass : $pass" | tee -a "$LOG_FILE"
echo "fail : $fail" | tee -a "$LOG_FILE"
if [[ $fail -gt 0 ]]; then
echo "failed cases:" | tee -a "$LOG_FILE"
for f in "${failed_list[@]}"; do
echo " - $f" | tee -a "$LOG_FILE"
done
echo "Lab3 target is not fully met yet." | tee -a "$LOG_FILE"
echo "see details in $LOG_FILE"
exit 1
fi
echo "All Lab3 cases passed. Lab3 target regression is met." | tee -a "$LOG_FILE"
echo "see details in $LOG_FILE"

@ -41,25 +41,18 @@ if ! command -v aarch64-linux-gnu-gcc >/dev/null 2>&1; then
exit 1 exit 1
fi fi
if ! command -v clang >/dev/null 2>&1; then
echo "未找到 clang无法由 IR 生成 AArch64 汇编。" >&2
exit 1
fi
mkdir -p "$out_dir" mkdir -p "$out_dir"
base=$(basename "$input") base=$(basename "$input")
stem=${base%.sy} stem=${base%.sy}
asm_file="$out_dir/$stem.s" asm_file="$out_dir/$stem.s"
exe="$out_dir/$stem" exe="$out_dir/$stem"
runtime_obj="$out_dir/sylib.aarch64.o"
stdin_file="$input_dir/$stem.in" stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out" expected_file="$input_dir/$stem.out"
"$compiler" --emit-asm "$input" > "$asm_file" "$compiler" --emit-asm "$input" > "$asm_file"
echo "汇编已生成: $asm_file" echo "汇编已生成: $asm_file"
aarch64-linux-gnu-gcc -O2 -Wno-unused-result -c sylib/sylib.c -o "$runtime_obj" aarch64-linux-gnu-gcc "$asm_file" -o "$exe"
aarch64-linux-gnu-gcc "$asm_file" "$runtime_obj" -o "$exe"
echo "可执行文件已生成: $exe" echo "可执行文件已生成: $exe"
if [[ "$run_exec" == true ]]; then if [[ "$run_exec" == true ]]; then
@ -70,8 +63,6 @@ if [[ "$run_exec" == true ]]; then
stdout_file="$out_dir/$stem.stdout" stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out" actual_file="$out_dir/$stem.actual.out"
actual_norm="$out_dir/$stem.actual.norm"
expected_norm="$out_dir/$stem.expected.norm"
echo "运行 $exe ..." echo "运行 $exe ..."
set +e set +e
if [[ -f "$stdin_file" ]]; then if [[ -f "$stdin_file" ]]; then
@ -92,9 +83,7 @@ if [[ "$run_exec" == true ]]; then
} > "$actual_file" } > "$actual_file"
if [[ -f "$expected_file" ]]; then if [[ -f "$expected_file" ]]; then
perl -0pe 's/\r\n/\n/g; s/\r/\n/g; s/\n?\z//' "$expected_file" > "$expected_norm" if diff -u "$expected_file" "$actual_file"; then
perl -0pe 's/\r\n/\n/g; s/\r/\n/g; s/\n?\z//' "$actual_file" > "$actual_norm"
if diff -u "$expected_norm" "$actual_norm"; then
echo "输出匹配: $expected_file" echo "输出匹配: $expected_file"
else else
echo "输出不匹配: $expected_file" >&2 echo "输出不匹配: $expected_file" >&2

@ -47,28 +47,20 @@ expected_file="$input_dir/$stem.out"
echo "IR 已生成: $out_file" echo "IR 已生成: $out_file"
if [[ "$run_exec" == true ]]; then if [[ "$run_exec" == true ]]; then
if ! command -v llc >/dev/null 2>&1; then
echo "未找到 llc无法运行 IR。请安装 LLVM。" >&2
exit 1
fi
if ! command -v clang >/dev/null 2>&1; then if ! command -v clang >/dev/null 2>&1; then
echo "未找到 clang无法编译可执行文件。请安装 LLVM/Clang。" >&2 echo "未找到 clang无法链接可执行文件。请安装 LLVM/Clang。" >&2
exit 1 exit 1
fi fi
obj="$out_dir/$stem.o"
exe="$out_dir/$stem" exe="$out_dir/$stem"
stdout_file="$out_dir/$stem.stdout" stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out" actual_file="$out_dir/$stem.actual.out"
actual_norm="$out_dir/$stem.actual.norm" llc -filetype=obj "$out_file" -o "$obj"
expected_norm="$out_dir/$stem.expected.norm" clang "$obj" -o "$exe"
# 直接让 clang 优化 .ll可显著降低大规模 Lab2 performance 用例运行时间。
# -fwrapv 保持有符号整数按补码回绕,避免与当前 IR 的测试语义偏离。
if ! clang -O2 -fwrapv -Wno-override-module "$out_file" sylib/sylib.c -o "$exe"; then
if ! command -v llc >/dev/null 2>&1; then
echo "未找到 llc且 clang 直接编译 IR 失败,无法运行 IR。" >&2
exit 1
fi
obj="$out_dir/$stem.o"
llc -filetype=obj "$out_file" -o "$obj"
clang "$obj" sylib/sylib.c -o "$exe"
fi
echo "运行 $exe ..." echo "运行 $exe ..."
set +e set +e
if [[ -f "$stdin_file" ]]; then if [[ -f "$stdin_file" ]]; then
@ -89,9 +81,7 @@ if [[ "$run_exec" == true ]]; then
} > "$actual_file" } > "$actual_file"
if [[ -f "$expected_file" ]]; then if [[ -f "$expected_file" ]]; then
perl -0pe 's/\r\n/\n/g; s/\r/\n/g; s/\n?\z//' "$expected_file" > "$expected_norm" if diff -u "$expected_file" "$actual_file"; then
perl -0pe 's/\r\n/\n/g; s/\r/\n/g; s/\n?\z//' "$actual_file" > "$actual_norm"
if diff -u "$expected_norm" "$actual_norm"; then
echo "输出匹配: $expected_file" echo "输出匹配: $expected_file"
else else
echo "输出不匹配: $expected_file" >&2 echo "输出不匹配: $expected_file" >&2

@ -1,145 +1,67 @@
// SysY 子集语法:支持形如
// int main() { int a = 1; int b = 2; return a + b; }
// 的最小返回表达式编译。
// 后续需要自行添加
grammar SysY; grammar SysY;
/*===-------------------------------------------===*/ /*===-------------------------------------------===*/
/* Lexer rules */ /* Lexer rules */
/*===-------------------------------------------===*/ /*===-------------------------------------------===*/
CONST: 'const';
INT: 'int'; INT: 'int';
FLOAT: 'float';
VOID: 'void';
IF: 'if';
ELSE: 'else';
WHILE: 'while';
BREAK: 'break';
CONTINUE: 'continue';
RETURN: 'return'; RETURN: 'return';
ASSIGN: '='; ASSIGN: '=';
EQ: '==';
NE: '!=';
LT: '<';
GT: '>';
LE: '<=';
GE: '>=';
ADD: '+'; ADD: '+';
SUB: '-';
MUL: '*';
DIV: '/';
MOD: '%';
NOT: '!';
LAND: '&&';
LOR: '||';
LPAREN: '('; LPAREN: '(';
RPAREN: ')'; RPAREN: ')';
LBRACE: '{'; LBRACE: '{';
RBRACE: '}'; RBRACE: '}';
LBRACK: '[';
RBRACK: ']';
COMMA: ',';
SEMICOLON: ';'; SEMICOLON: ';';
FLOAT_CONST
: DEC_FLOAT_CONST
| HEX_FLOAT_CONST
;
INT_CONST
: HEX_PREFIX HEX_DIGIT+
| '0' [0-7]+
| '0'
| [1-9] DIGIT*
;
ID: [a-zA-Z_][a-zA-Z_0-9]*; ID: [a-zA-Z_][a-zA-Z_0-9]*;
ILITERAL: [0-9]+;
WS: [ \t\r\n] -> skip; WS: [ \t\r\n] -> skip;
LINECOMMENT: '//' ~[\r\n]* -> skip; LINECOMMENT: '//' ~[\r\n]* -> skip;
BLOCKCOMMENT: '/*' .*? '*/' -> skip; BLOCKCOMMENT: '/*' .*? '*/' -> skip;
fragment DEC_FLOAT_CONST
: DIGIT+ '.' DIGIT* EXP_PART?
| '.' DIGIT+ EXP_PART?
| DIGIT+ EXP_PART
;
fragment HEX_FLOAT_CONST
: HEX_PREFIX HEX_DIGIT+ '.' HEX_DIGIT* BIN_EXP_PART
| HEX_PREFIX '.' HEX_DIGIT+ BIN_EXP_PART
| HEX_PREFIX HEX_DIGIT+ BIN_EXP_PART
;
fragment EXP_PART: [eE] [+-]? DIGIT+;
fragment BIN_EXP_PART: [pP] [+-]? DIGIT+;
fragment HEX_PREFIX: '0' [xX];
fragment HEX_DIGIT: [0-9a-fA-F];
fragment DIGIT: [0-9];
/*===-------------------------------------------===*/ /*===-------------------------------------------===*/
/* Syntax rules */ /* Syntax rules */
/*===-------------------------------------------===*/ /*===-------------------------------------------===*/
compUnit compUnit
: (decl | funcDef)+ EOF : funcDef EOF
; ;
decl decl
: constDecl : btype varDef SEMICOLON
| varDecl
; ;
constDecl btype
: CONST bType constDef (COMMA constDef)* SEMICOLON
;
varDecl
: bType varDef (COMMA varDef)* SEMICOLON
;
bType
: INT : INT
| FLOAT
;
constDef
: ID (LBRACK constExp RBRACK)* ASSIGN constInitVal
;
constInitVal
: constExp
| LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE
; ;
varDef varDef
: ID (LBRACK constExp RBRACK)* (ASSIGN initVal)? : lValue (ASSIGN initValue)?
; ;
initVal initValue
: exp : exp
| LBRACE (initVal (COMMA initVal)*)? RBRACE
; ;
funcDef funcDef
: funcType ID LPAREN funcFParams? RPAREN block : funcType ID LPAREN RPAREN blockStmt
; ;
funcType funcType
: VOID : INT
| INT
| FLOAT
;
funcFParams
: funcFParam (COMMA funcFParam)*
;
funcFParam
: bType ID (LBRACK RBRACK (LBRACK exp RBRACK)*)?
; ;
block blockStmt
: LBRACE blockItem* RBRACE : LBRACE blockItem* RBRACE
; ;
@ -149,80 +71,28 @@ blockItem
; ;
stmt stmt
: lVal ASSIGN exp SEMICOLON : returnStmt
| exp? SEMICOLON
| block
| IF LPAREN cond RPAREN stmt (ELSE stmt)?
| WHILE LPAREN cond RPAREN stmt
| BREAK SEMICOLON
| CONTINUE SEMICOLON
| RETURN exp? SEMICOLON
; ;
exp returnStmt
: addExp : RETURN exp SEMICOLON
; ;
cond exp
: lOrExp : LPAREN exp RPAREN # parenExp
| var # varExp
| number # numberExp
| exp ADD exp # additiveExp
; ;
lVal var
: ID (LBRACK exp RBRACK)* : ID
; ;
primaryExp lValue
: LPAREN exp RPAREN : ID
| lVal
| number
; ;
number number
: INT_CONST : ILITERAL
| FLOAT_CONST
;
unaryExp
: primaryExp
| ID LPAREN funcRParams? RPAREN
| unaryOp unaryExp
;
unaryOp
: ADD
| SUB
| NOT
;
funcRParams
: exp (COMMA exp)*
;
mulExp
: unaryExp ((MUL | DIV | MOD) unaryExp)*
;
addExp
: mulExp ((ADD | SUB) mulExp)*
;
relExp
: addExp ((LT | GT | LE | GE) addExp)*
;
eqExp
: relExp ((EQ | NE) relExp)*
;
lAndExp
: eqExp (LAND eqExp)*
;
lOrExp
: lAndExp (LOR lAndExp)*
; ;
constExp
: addExp
;

@ -3,44 +3,15 @@ add_library(frontend STATIC
SyntaxTreePrinter.cpp SyntaxTreePrinter.cpp
) )
set(ANTLR4_GRAMMAR "${PROJECT_SOURCE_DIR}/src/antlr4/SysY.g4")
set(ANTLR4_GENERATED_FILES
"${ANTLR4_GENERATED_DIR}/SysYLexer.cpp"
"${ANTLR4_GENERATED_DIR}/SysYLexer.h"
"${ANTLR4_GENERATED_DIR}/SysYLexer.interp"
"${ANTLR4_GENERATED_DIR}/SysYLexer.tokens"
"${ANTLR4_GENERATED_DIR}/SysYParser.cpp"
"${ANTLR4_GENERATED_DIR}/SysYParser.h"
"${ANTLR4_GENERATED_DIR}/SysY.interp"
"${ANTLR4_GENERATED_DIR}/SysY.tokens"
"${ANTLR4_GENERATED_DIR}/SysYBaseVisitor.h"
"${ANTLR4_GENERATED_DIR}/SysYVisitor.h"
)
add_custom_command(
OUTPUT ${ANTLR4_GENERATED_FILES}
COMMAND ${CMAKE_COMMAND} -E make_directory "${ANTLR4_GENERATED_DIR}"
COMMAND ${Java_JAVA_EXECUTABLE} -jar "${ANTLR4_JAR}"
-Dlanguage=Cpp
-visitor
-no-listener
-Xexact-output-dir
-o "${ANTLR4_GENERATED_DIR}"
"${ANTLR4_GRAMMAR}"
DEPENDS "${ANTLR4_GRAMMAR}" "${ANTLR4_JAR}"
COMMENT "Generating ANTLR4 parser sources from SysY.g4"
VERBATIM
)
add_custom_target(antlr4_generated DEPENDS ${ANTLR4_GENERATED_FILES})
add_dependencies(frontend antlr4_generated)
target_sources(frontend PRIVATE
"${ANTLR4_GENERATED_DIR}/SysYLexer.cpp"
"${ANTLR4_GENERATED_DIR}/SysYParser.cpp"
)
target_link_libraries(frontend PUBLIC target_link_libraries(frontend PUBLIC
build_options build_options
${ANTLR4_RUNTIME_TARGET} ${ANTLR4_RUNTIME_TARGET}
) )
# Lexer/Parser
file(GLOB_RECURSE ANTLR4_GENERATED_SOURCES CONFIGURE_DEPENDS
"${ANTLR4_GENERATED_DIR}/*.cpp"
)
if(ANTLR4_GENERATED_SOURCES)
target_sources(frontend PRIVATE ${ANTLR4_GENERATED_SOURCES})
endif()

@ -9,14 +9,13 @@
#include "ir/IR.h" #include "ir/IR.h"
#include <algorithm>
#include <utility> #include <utility>
namespace ir { namespace ir {
// BasicBlock 使用 label type // 当前 BasicBlock 还没有专门的 label type因此先用 void 作为占位类型
BasicBlock::BasicBlock(std::string name) BasicBlock::BasicBlock(std::string name)
: Value(Type::GetLabelType(), std::move(name)) {} : Value(Type::GetVoidType(), std::move(name)) {}
Function* BasicBlock::GetParent() const { return parent_; } Function* BasicBlock::GetParent() const { return parent_; }
@ -33,10 +32,6 @@ const std::vector<std::unique_ptr<Instruction>>& BasicBlock::GetInstructions()
return instructions_; return instructions_;
} }
std::vector<std::unique_ptr<Instruction>>& BasicBlock::GetMutableInstructions() {
return instructions_;
}
// 前驱/后继接口先保留给后续 CFG 扩展使用。 // 前驱/后继接口先保留给后续 CFG 扩展使用。
// 当前最小 IR 中还没有 branch 指令,因此这些列表通常为空。 // 当前最小 IR 中还没有 branch 指令,因此这些列表通常为空。
const std::vector<BasicBlock*>& BasicBlock::GetPredecessors() const { const std::vector<BasicBlock*>& BasicBlock::GetPredecessors() const {
@ -47,83 +42,4 @@ const std::vector<BasicBlock*>& BasicBlock::GetSuccessors() const {
return successors_; return successors_;
} }
void BasicBlock::AddPredecessor(BasicBlock* pred) {
if (!pred) return;
for (auto* p : predecessors_) {
if (p == pred) return;
}
predecessors_.push_back(pred);
}
void BasicBlock::AddSuccessor(BasicBlock* succ) {
if (!succ) return;
for (auto* s : successors_) {
if (s == succ) return;
}
successors_.push_back(succ);
}
void BasicBlock::ClearPredecessors() { predecessors_.clear(); }
void BasicBlock::ClearSuccessors() { successors_.clear(); }
void BasicBlock::RemovePredecessor(BasicBlock* pred) {
predecessors_.erase(std::remove(predecessors_.begin(), predecessors_.end(), pred),
predecessors_.end());
}
void BasicBlock::RemoveSuccessor(BasicBlock* succ) {
successors_.erase(std::remove(successors_.begin(), successors_.end(), succ),
successors_.end());
}
void BasicBlock::EraseInstruction(Instruction* inst) {
if (!inst) return;
auto it = std::find_if(instructions_.begin(), instructions_.end(),
[&](const auto& ptr) { return ptr.get() == inst; });
if (it == instructions_.end()) return;
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
if (auto* operand = inst->GetOperand(i)) {
operand->RemoveUse(inst, i);
}
}
inst->SetParent(nullptr);
instructions_.erase(it);
}
void BasicBlock::ReplaceTerminator(std::unique_ptr<Instruction> inst) {
if (!inst || !inst->IsTerminator()) return;
if (!instructions_.empty() && instructions_.back()->IsTerminator()) {
auto* old = instructions_.back().get();
for (size_t i = 0; i < old->GetNumOperands(); ++i) {
if (auto* operand = old->GetOperand(i)) {
operand->RemoveUse(old, i);
}
}
old->SetParent(nullptr);
instructions_.pop_back();
}
inst->SetParent(this);
instructions_.push_back(std::move(inst));
LinkSuccessorsIfNeeded(instructions_.back().get());
}
void BasicBlock::LinkSuccessorsIfNeeded(Instruction* inst) {
if (!inst) return;
if (auto* br = dynamic_cast<BranchInst*>(inst)) {
auto* dest = br->GetDest();
AddSuccessor(dest);
dest->AddPredecessor(this);
return;
}
if (auto* cbr = dynamic_cast<CondBrInst*>(inst)) {
auto* t = cbr->GetTrueDest();
auto* f = cbr->GetFalseDest();
AddSuccessor(t);
AddSuccessor(f);
t->AddPredecessor(this);
f->AddPredecessor(this);
}
}
} // namespace ir } // namespace ir

@ -1,7 +1,6 @@
// 管理基础类型、整型常量池和临时名生成。 // 管理基础类型、整型常量池和临时名生成。
#include "ir/IR.h" #include "ir/IR.h"
#include <cstring>
#include <sstream> #include <sstream>
namespace ir { namespace ir {
@ -16,43 +15,9 @@ ConstantInt* Context::GetConstInt(int v) {
return inserted->second.get(); return inserted->second.get();
} }
ConstantInt* Context::GetConstBool(bool v) {
int iv = v ? 1 : 0;
auto it = const_bools_.find(iv);
if (it != const_bools_.end()) return it->second.get();
auto inserted = const_bools_.emplace(
iv, std::make_unique<ConstantInt>(Type::GetInt1Type(), iv)).first;
return inserted->second.get();
}
static uint32_t FloatToBits(float v) {
uint32_t bits = 0;
std::memcpy(&bits, &v, sizeof(float));
return bits;
}
ConstantFloat* Context::GetConstFloat(float v) {
uint32_t bits = FloatToBits(v);
auto it = const_floats_.find(bits);
if (it != const_floats_.end()) return it->second.get();
auto inserted = const_floats_.emplace(
bits, std::make_unique<ConstantFloat>(Type::GetFloatType(), v)).first;
return inserted->second.get();
}
ConstantArray* Context::CreateConstArray(std::shared_ptr<Type> array_ty,
std::vector<ConstantValue*> elements) {
if (!array_ty || !array_ty->IsArray()) {
throw std::runtime_error("CreateConstArray 需要 array type");
}
const_arrays_.push_back(
std::make_unique<ConstantArray>(std::move(array_ty), std::move(elements)));
return const_arrays_.back().get();
}
std::string Context::NextTemp() { std::string Context::NextTemp() {
std::ostringstream oss; std::ostringstream oss;
oss << "%t" << ++temp_index_; oss << "%" << ++temp_index_;
return oss.str(); return oss.str();
} }

@ -5,32 +5,13 @@
namespace ir { namespace ir {
Function::Function(std::string name, std::shared_ptr<Type> func_type, Function::Function(std::string name, std::shared_ptr<Type> ret_type)
bool is_declaration) : Value(std::move(ret_type), std::move(name)) {
: Value(std::move(func_type), std::move(name)), entry_ = CreateBlock("entry");
is_declaration_(is_declaration) {
if (!type_ || !type_->IsFunction()) {
throw std::runtime_error("Function 需要 function type");
}
const auto& params = type_->GetParamTypes();
args_.reserve(params.size());
for (size_t i = 0; i < params.size(); ++i) {
args_.push_back(std::make_unique<Argument>(params[i], "%arg" + std::to_string(i), i));
}
if (!is_declaration_) {
entry_ = CreateBlock("entry");
}
} }
BasicBlock* Function::CreateBlock(const std::string& name) { BasicBlock* Function::CreateBlock(const std::string& name) {
std::string base = name.empty() ? "bb" : name; auto block = std::make_unique<BasicBlock>(name);
auto& count = block_name_counts_[base];
std::string final_name = base;
if (count > 0) {
final_name = base + "." + std::to_string(count);
}
++count;
auto block = std::make_unique<BasicBlock>(final_name);
auto* ptr = block.get(); auto* ptr = block.get();
ptr->SetParent(this); ptr->SetParent(this);
blocks_.push_back(std::move(block)); blocks_.push_back(std::move(block));
@ -48,35 +29,4 @@ const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const {
return blocks_; return blocks_;
} }
std::vector<std::unique_ptr<BasicBlock>>& Function::GetMutableBlocks() {
return blocks_;
}
const std::vector<std::unique_ptr<Argument>>& Function::GetArguments() const {
return args_;
}
size_t Function::GetNumArgs() const { return args_.size(); }
Argument* Function::GetArg(size_t index) {
if (index >= args_.size()) {
throw std::out_of_range("Function arg index out of range");
}
return args_[index].get();
}
std::shared_ptr<Type> Function::GetFunctionType() const { return type_; }
std::shared_ptr<Type> Function::GetReturnType() const {
if (!type_ || !type_->IsFunction()) {
throw std::runtime_error("Function type 缺失");
}
return type_->GetReturnType();
}
bool Function::IsDeclaration() const { return is_declaration_; }
Argument::Argument(std::shared_ptr<Type> ty, std::string name, size_t index)
: Value(std::move(ty), std::move(name)), index_(index) {}
} // namespace ir } // namespace ir

@ -8,23 +8,4 @@ namespace ir {
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name) GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)) {} : User(std::move(ty), std::move(name)) {}
GlobalVariable::GlobalVariable(std::shared_ptr<Type> value_ty, std::string name,
ConstantValue* init, bool is_const)
: GlobalValue(Type::GetPointerType(value_ty), std::move(name)),
value_type_(std::move(value_ty)),
initializer_(init),
is_const_(is_const) {
if (!value_type_) {
throw std::runtime_error("GlobalVariable 缺少 value type");
}
}
const std::shared_ptr<Type>& GlobalVariable::GetValueType() const {
return value_type_;
}
ConstantValue* GlobalVariable::GetInitializer() const { return initializer_; }
bool GlobalVariable::IsConst() const { return is_const_; }
} // namespace ir } // namespace ir

@ -9,42 +9,6 @@
#include "utils/Log.h" #include "utils/Log.h"
namespace ir { namespace ir {
static std::string TypeToString(const Type& ty) {
switch (ty.GetKind()) {
case Type::Kind::Void:
return "void";
case Type::Kind::Int1:
return "i1";
case Type::Kind::Int32:
return "i32";
case Type::Kind::Float:
return "float";
case Type::Kind::Label:
return "label";
case Type::Kind::Pointer:
return TypeToString(*ty.GetElementType()) + "*";
case Type::Kind::Array: {
return "[" + std::to_string(ty.GetArraySize()) + " x " +
TypeToString(*ty.GetElementType()) + "]";
}
case Type::Kind::Function: {
std::string out = TypeToString(*ty.GetReturnType()) + " (";
const auto& params = ty.GetParamTypes();
for (size_t i = 0; i < params.size(); ++i) {
if (i > 0) out += ", ";
out += TypeToString(*params[i]);
}
if (ty.IsVarArg()) {
if (!params.empty()) out += ", ";
out += "...";
}
out += ")";
return out;
}
}
return "?";
}
IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb) IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb)
: ctx_(ctx), insert_block_(bb) {} : ctx_(ctx), insert_block_(bb) {}
@ -78,107 +42,11 @@ BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs,
return CreateBinary(Opcode::Add, lhs, rhs, name); return CreateBinary(Opcode::Add, lhs, rhs, name);
} }
BinaryInst* IRBuilder::CreateSub(Value* lhs, Value* rhs, AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
const std::string& name) {
return CreateBinary(Opcode::Sub, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateMul(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::Mul, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateSDiv(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::SDiv, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateSRem(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::SRem, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFAdd(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::FAdd, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFSub(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::FSub, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFMul(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::FMul, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFDiv(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::FDiv, lhs, rhs, name);
}
ICmpInst* IRBuilder::CreateICmp(ICmpPredicate pred, Value* lhs, Value* rhs,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<ICmpInst>(pred, lhs, rhs, name);
}
FCmpInst* IRBuilder::CreateFCmp(FCmpPredicate pred, Value* lhs, Value* rhs,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<FCmpInst>(pred, lhs, rhs, name);
}
CastInst* IRBuilder::CreateSIToFP(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CastInst>(Opcode::SIToFP, std::move(dst_ty), src,
name);
}
CastInst* IRBuilder::CreateFPToSI(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CastInst>(Opcode::FPToSI, std::move(dst_ty), src,
name);
}
CastInst* IRBuilder::CreateZExt(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CastInst>(Opcode::ZExt, std::move(dst_ty), src,
name);
}
ConstantInt* IRBuilder::CreateConstBool(bool v) {
return ctx_.GetConstBool(v);
}
ConstantFloat* IRBuilder::CreateConstFloat(float v) {
return ctx_.GetConstFloat(v);
}
AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr<Type> ty,
const std::string& name) {
if (!insert_block_) { if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
} }
return insert_block_->Append<AllocaInst>(std::move(ty), name); return insert_block_->Append<AllocaInst>(Type::GetPtrInt32Type(), name);
}
AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
return CreateAlloca(Type::GetInt32Type(), name);
} }
LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
@ -189,11 +57,7 @@ LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
throw std::runtime_error( throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr")); FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr"));
} }
if (!ptr->GetType() || !ptr->GetType()->IsPointer()) { return insert_block_->Append<LoadInst>(Type::GetInt32Type(), ptr, name);
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateLoad ptr 不是指针"));
}
auto val_ty = ptr->GetType()->GetElementType();
return insert_block_->Append<LoadInst>(val_ty, ptr, name);
} }
StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
@ -211,100 +75,6 @@ StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
return insert_block_->Append<StoreInst>(Type::GetVoidType(), val, ptr); return insert_block_->Append<StoreInst>(Type::GetVoidType(), val, ptr);
} }
static std::shared_ptr<Type> ResolveGepResultType(const std::shared_ptr<Type>& base_ptr_ty,
size_t index_count) {
if (!base_ptr_ty || !base_ptr_ty->IsPointer()) {
throw std::runtime_error("GEP base type 必须是指针");
}
auto cur = base_ptr_ty->GetElementType();
for (size_t i = 0; i < index_count; ++i) {
// LLVM GEP 的第一个索引只是在当前 pointee 对象上做寻址,
// 不会立刻深入到数组元素类型;真正进入聚合类型从第二个索引开始。
if (i == 0) {
continue;
}
if (cur->IsArray()) {
cur = cur->GetElementType();
continue;
}
if (cur->IsPointer()) {
cur = cur->GetElementType();
continue;
}
}
return Type::GetPointerType(cur);
}
GepInst* IRBuilder::CreateGep(Value* base_ptr, std::vector<Value*> indices,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!base_ptr || !base_ptr->GetType() || !base_ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateGep base_ptr 非指针"));
}
auto result_ty = ResolveGepResultType(base_ptr->GetType(), indices.size());
return insert_block_->Append<GepInst>(result_ty, base_ptr, std::move(indices),
name);
}
CallInst* IRBuilder::CreateCall(Value* callee, std::vector<Value*> args,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!callee || !callee->GetType()) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCall 缺少 callee"));
}
std::shared_ptr<Type> func_ty;
if (callee->GetType()->IsFunction()) {
func_ty = callee->GetType();
} else if (callee->GetType()->IsPointer() &&
callee->GetType()->GetElementType()->IsFunction()) {
func_ty = callee->GetType()->GetElementType();
} else {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCall callee 非函数"));
}
const auto& params = func_ty->GetParamTypes();
if (!func_ty->IsVarArg() && params.size() != args.size()) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCall 参数数量不匹配"));
}
for (size_t i = 0; i < params.size() && i < args.size(); ++i) {
if (!args[i] || !args[i]->GetType() ||
!args[i]->GetType()->Equals(*params[i])) {
std::string msg = "IRBuilder::CreateCall 参数类型不匹配: arg" +
std::to_string(i) + " got " +
TypeToString(*args[i]->GetType()) + ", expect " +
TypeToString(*params[i]);
throw std::runtime_error(FormatError("ir", msg));
}
}
auto ret_ty = func_ty->GetReturnType();
return insert_block_->Append<CallInst>(ret_ty, callee, std::move(args), name);
}
PhiInst* IRBuilder::CreatePhi(std::shared_ptr<Type> ty, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<PhiInst>(std::move(ty), name);
}
BranchInst* IRBuilder::CreateBr(BasicBlock* dest) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<BranchInst>(dest);
}
CondBrInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* true_dest,
BasicBlock* false_dest) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CondBrInst>(cond, true_dest, false_dest);
}
ReturnInst* IRBuilder::CreateRet(Value* v) { ReturnInst* IRBuilder::CreateRet(Value* v) {
if (!insert_block_) { if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
@ -316,11 +86,4 @@ ReturnInst* IRBuilder::CreateRet(Value* v) {
return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v); return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v);
} }
ReturnInst* IRBuilder::CreateRetVoid() {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<ReturnInst>(Type::GetVoidType());
}
} // namespace ir } // namespace ir

@ -5,11 +5,6 @@
#include "ir/IR.h" #include "ir/IR.h"
#include <ostream> #include <ostream>
#include <cstdint>
#include <cstring>
#include <iomanip>
#include <limits>
#include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
@ -17,41 +12,14 @@
namespace ir { namespace ir {
static std::string TypeToString(const Type& ty) { static const char* TypeToString(const Type& ty) {
switch (ty.GetKind()) { switch (ty.GetKind()) {
case Type::Kind::Void: case Type::Kind::Void:
return "void"; return "void";
case Type::Kind::Int1:
return "i1";
case Type::Kind::Int32: case Type::Kind::Int32:
return "i32"; return "i32";
case Type::Kind::Float: case Type::Kind::PtrInt32:
return "float"; return "i32*";
case Type::Kind::Label:
return "label";
case Type::Kind::Pointer:
return TypeToString(*ty.GetElementType()) + "*";
case Type::Kind::Array: {
std::ostringstream oss;
oss << "[" << ty.GetArraySize() << " x "
<< TypeToString(*ty.GetElementType()) << "]";
return oss.str();
}
case Type::Kind::Function: {
std::ostringstream oss;
oss << TypeToString(*ty.GetReturnType()) << " (";
const auto& params = ty.GetParamTypes();
for (size_t i = 0; i < params.size(); ++i) {
if (i > 0) oss << ", ";
oss << TypeToString(*params[i]);
}
if (ty.IsVarArg()) {
if (!params.empty()) oss << ", ";
oss << "...";
}
oss << ")";
return oss.str();
}
} }
throw std::runtime_error(FormatError("ir", "未知类型")); throw std::runtime_error(FormatError("ir", "未知类型"));
} }
@ -64,18 +32,6 @@ static const char* OpcodeToString(Opcode op) {
return "sub"; return "sub";
case Opcode::Mul: case Opcode::Mul:
return "mul"; return "mul";
case Opcode::SDiv:
return "sdiv";
case Opcode::SRem:
return "srem";
case Opcode::FAdd:
return "fadd";
case Opcode::FSub:
return "fsub";
case Opcode::FMul:
return "fmul";
case Opcode::FDiv:
return "fdiv";
case Opcode::Alloca: case Opcode::Alloca:
return "alloca"; return "alloca";
case Opcode::Load: case Opcode::Load:
@ -84,182 +40,21 @@ static const char* OpcodeToString(Opcode op) {
return "store"; return "store";
case Opcode::Ret: case Opcode::Ret:
return "ret"; return "ret";
case Opcode::Br:
return "br";
case Opcode::CondBr:
return "br";
case Opcode::ICmp:
return "icmp";
case Opcode::FCmp:
return "fcmp";
case Opcode::Call:
return "call";
case Opcode::Phi:
return "phi";
case Opcode::Gep:
return "getelementptr";
case Opcode::SIToFP:
return "sitofp";
case Opcode::FPToSI:
return "fptosi";
case Opcode::ZExt:
return "zext";
} }
return "?"; return "?";
} }
static std::string FloatToString(float v) {
std::uint32_t bits = 0;
static_assert(sizeof(bits) == sizeof(v), "float size mismatch");
std::memcpy(&bits, &v, sizeof(bits));
std::ostringstream oss;
oss << "bitcast (i32 " << std::dec << static_cast<std::uint64_t>(bits)
<< " to float)";
return oss.str();
}
static bool IsZeroInitializer(const ConstantValue* c) {
if (auto* ci = dynamic_cast<const ConstantInt*>(c)) {
return ci->GetValue() == 0;
}
if (auto* cf = dynamic_cast<const ConstantFloat*>(c)) {
return cf->GetValue() == 0.0f;
}
if (auto* ca = dynamic_cast<const ConstantArray*>(c)) {
for (auto* elem : ca->GetElements()) {
if (!elem || !IsZeroInitializer(elem)) {
return false;
}
}
return true;
}
return false;
}
static std::string ConstantToString(const ConstantValue* c) {
if (auto* ci = dynamic_cast<const ConstantInt*>(c)) {
return std::to_string(ci->GetValue());
}
if (auto* cf = dynamic_cast<const ConstantFloat*>(c)) {
return FloatToString(cf->GetValue());
}
if (auto* ca = dynamic_cast<const ConstantArray*>(c)) {
if (IsZeroInitializer(ca)) {
return "zeroinitializer";
}
std::ostringstream oss;
oss << "[";
const auto& elems = ca->GetElements();
for (size_t i = 0; i < elems.size(); ++i) {
if (i > 0) oss << ", ";
oss << TypeToString(*elems[i]->GetType()) << " "
<< ConstantToString(elems[i]);
}
oss << "]";
return oss.str();
}
return "<const>";
}
static std::string ValueToString(const Value* v) { static std::string ValueToString(const Value* v) {
if (auto* c = dynamic_cast<const ConstantValue*>(v)) { if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
return ConstantToString(c); return std::to_string(ci->GetValue());
}
if (auto* func = dynamic_cast<const Function*>(v)) {
const auto& name = func->GetName();
if (!name.empty() && name[0] == '@') return name;
return "@" + name;
}
if (auto* gv = dynamic_cast<const GlobalValue*>(v)) {
const auto& name = gv->GetName();
if (!name.empty() && name[0] == '@') return name;
return "@" + name;
} }
return v ? v->GetName() : "<null>"; return v ? v->GetName() : "<null>";
} }
static std::string LabelToString(const BasicBlock* bb) {
if (!bb) return "%<null>";
const auto& name = bb->GetName();
if (!name.empty() && name[0] == '%') return name;
return "%" + name;
}
static const char* ICmpPredToString(ICmpPredicate pred) {
switch (pred) {
case ICmpPredicate::Eq:
return "eq";
case ICmpPredicate::Ne:
return "ne";
case ICmpPredicate::Slt:
return "slt";
case ICmpPredicate::Sle:
return "sle";
case ICmpPredicate::Sgt:
return "sgt";
case ICmpPredicate::Sge:
return "sge";
}
return "?";
}
static const char* FCmpPredToString(FCmpPredicate pred) {
switch (pred) {
case FCmpPredicate::Oeq:
return "oeq";
case FCmpPredicate::One:
return "one";
case FCmpPredicate::Olt:
return "olt";
case FCmpPredicate::Ole:
return "ole";
case FCmpPredicate::Ogt:
return "ogt";
case FCmpPredicate::Oge:
return "oge";
}
return "?";
}
void IRPrinter::Print(const Module& module, std::ostream& os) { void IRPrinter::Print(const Module& module, std::ostream& os) {
for (const auto& g : module.GetGlobals()) {
if (!g) continue;
os << "@" << g->GetName() << " = "
<< (g->IsConst() ? "constant " : "global ")
<< TypeToString(*g->GetValueType()) << " ";
if (auto* init = g->GetInitializer()) {
os << ConstantToString(init);
} else {
if (g->GetValueType()->IsArray()) {
os << "zeroinitializer";
} else if (g->GetValueType()->IsFloat()) {
os << "0.0";
} else {
os << "0";
}
}
os << "\n";
}
for (const auto& func : module.GetFunctions()) { for (const auto& func : module.GetFunctions()) {
if (func->IsDeclaration()) { os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName()
os << "declare " << TypeToString(*func->GetReturnType()) << " @" << "() {\n";
<< func->GetName() << "(";
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) os << ", ";
os << TypeToString(*args[i]->GetType());
}
os << ")\n";
continue;
}
os << "define " << TypeToString(*func->GetReturnType()) << " @"
<< func->GetName() << "(";
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) os << ", ";
os << TypeToString(*args[i]->GetType()) << " " << args[i]->GetName();
}
os << ") {\n";
for (const auto& bb : func->GetBlocks()) { for (const auto& bb : func->GetBlocks()) {
if (!bb) { if (!bb) {
continue; continue;
@ -270,13 +65,7 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
switch (inst->GetOpcode()) { switch (inst->GetOpcode()) {
case Opcode::Add: case Opcode::Add:
case Opcode::Sub: case Opcode::Sub:
case Opcode::Mul: case Opcode::Mul: {
case Opcode::SDiv:
case Opcode::SRem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv: {
auto* bin = static_cast<const BinaryInst*>(inst); auto* bin = static_cast<const BinaryInst*>(inst);
os << " " << bin->GetName() << " = " os << " " << bin->GetName() << " = "
<< OpcodeToString(bin->GetOpcode()) << " " << OpcodeToString(bin->GetOpcode()) << " "
@ -285,122 +74,27 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
<< ValueToString(bin->GetRhs()) << "\n"; << ValueToString(bin->GetRhs()) << "\n";
break; break;
} }
case Opcode::ICmp: {
auto* cmp = static_cast<const ICmpInst*>(inst);
os << " " << cmp->GetName() << " = icmp "
<< ICmpPredToString(cmp->GetPredicate()) << " "
<< TypeToString(*cmp->GetLhs()->GetType()) << " "
<< ValueToString(cmp->GetLhs()) << ", "
<< ValueToString(cmp->GetRhs()) << "\n";
break;
}
case Opcode::FCmp: {
auto* cmp = static_cast<const FCmpInst*>(inst);
os << " " << cmp->GetName() << " = fcmp "
<< FCmpPredToString(cmp->GetPredicate()) << " "
<< TypeToString(*cmp->GetLhs()->GetType()) << " "
<< ValueToString(cmp->GetLhs()) << ", "
<< ValueToString(cmp->GetRhs()) << "\n";
break;
}
case Opcode::SIToFP:
case Opcode::FPToSI:
case Opcode::ZExt: {
auto* cast = static_cast<const CastInst*>(inst);
os << " " << cast->GetName() << " = "
<< OpcodeToString(cast->GetOpcode()) << " "
<< TypeToString(*cast->GetValue()->GetType()) << " "
<< ValueToString(cast->GetValue()) << " to "
<< TypeToString(*cast->GetType()) << "\n";
break;
}
case Opcode::Alloca: { case Opcode::Alloca: {
auto* alloca = static_cast<const AllocaInst*>(inst); auto* alloca = static_cast<const AllocaInst*>(inst);
os << " " << alloca->GetName() << " = alloca " os << " " << alloca->GetName() << " = alloca i32\n";
<< TypeToString(*alloca->GetAllocatedType()) << "\n";
break; break;
} }
case Opcode::Load: { case Opcode::Load: {
auto* load = static_cast<const LoadInst*>(inst); auto* load = static_cast<const LoadInst*>(inst);
os << " " << load->GetName() << " = load " os << " " << load->GetName() << " = load i32, i32* "
<< TypeToString(*load->GetType()) << ", "
<< TypeToString(*load->GetPtr()->GetType()) << " "
<< ValueToString(load->GetPtr()) << "\n"; << ValueToString(load->GetPtr()) << "\n";
break; break;
} }
case Opcode::Store: { case Opcode::Store: {
auto* store = static_cast<const StoreInst*>(inst); auto* store = static_cast<const StoreInst*>(inst);
os << " store " << TypeToString(*store->GetValue()->GetType()) os << " store i32 " << ValueToString(store->GetValue())
<< " " << ValueToString(store->GetValue()) << ", " << ", i32* " << ValueToString(store->GetPtr()) << "\n";
<< TypeToString(*store->GetPtr()->GetType()) << " "
<< ValueToString(store->GetPtr()) << "\n";
break;
}
case Opcode::Br: {
auto* br = static_cast<const BranchInst*>(inst);
os << " br label " << LabelToString(br->GetDest()) << "\n";
break;
}
case Opcode::CondBr: {
auto* cbr = static_cast<const CondBrInst*>(inst);
os << " br i1 " << ValueToString(cbr->GetCond())
<< ", label " << LabelToString(cbr->GetTrueDest())
<< ", label " << LabelToString(cbr->GetFalseDest()) << "\n";
break;
}
case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst);
const auto& args = call->GetArgs();
if (!call->GetType()->IsVoid()) {
os << " " << call->GetName() << " = ";
} else {
os << " ";
}
os << "call " << TypeToString(*call->GetType()) << " "
<< ValueToString(call->GetCallee()) << "(";
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) os << ", ";
os << TypeToString(*args[i]->GetType()) << " "
<< ValueToString(args[i]);
}
os << ")\n";
break;
}
case Opcode::Phi: {
auto* phi = static_cast<const PhiInst*>(inst);
os << " " << phi->GetName() << " = phi "
<< TypeToString(*phi->GetType()) << " ";
const auto& values = phi->GetIncomingValues();
const auto& blocks = phi->GetIncomingBlocks();
for (size_t i = 0; i < values.size(); ++i) {
if (i > 0) os << ", ";
os << "[ " << ValueToString(values[i]) << ", "
<< LabelToString(blocks[i]) << " ]";
}
os << "\n";
break;
}
case Opcode::Gep: {
auto* gep = static_cast<const GepInst*>(inst);
os << " " << gep->GetName() << " = getelementptr "
<< TypeToString(*gep->GetBasePtr()->GetType()->GetElementType())
<< ", " << TypeToString(*gep->GetBasePtr()->GetType()) << " "
<< ValueToString(gep->GetBasePtr());
const auto& idx = gep->GetIndices();
for (auto* v : idx) {
os << ", i32 " << ValueToString(v);
}
os << "\n";
break; break;
} }
case Opcode::Ret: { case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst); auto* ret = static_cast<const ReturnInst*>(inst);
if (ret->HasReturnValue()) { os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " "
os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " " << ValueToString(ret->GetValue()) << "\n";
<< ValueToString(ret->GetValue()) << "\n";
} else {
os << " ret void\n";
}
break; break;
} }
} }

@ -3,7 +3,6 @@
// - 指令操作数与结果类型管理,支持打印与优化 // - 指令操作数与结果类型管理,支持打印与优化
#include "ir/IR.h" #include "ir/IR.h"
#include <cstddef>
#include <stdexcept> #include <stdexcept>
#include "utils/Log.h" #include "utils/Log.h"
@ -37,7 +36,6 @@ void User::SetOperand(size_t index, Value* value) {
} }
operands_[index] = value; operands_[index] = value;
value->AddUse(this, index); value->AddUse(this, index);
OnOperandChanged(index, value);
} }
void User::AddOperand(Value* value) { void User::AddOperand(Value* value) {
@ -49,56 +47,22 @@ void User::AddOperand(Value* value) {
value->AddUse(this, operand_index); value->AddUse(this, operand_index);
} }
void User::RemoveOperand(size_t index) {
if (index >= operands_.size()) {
throw std::out_of_range("User operand index out of range");
}
OnOperandRemoving(index);
if (auto* old = operands_[index]) {
old->RemoveUse(this, index);
}
for (size_t i = index + 1; i < operands_.size(); ++i) {
if (auto* value = operands_[i]) {
value->RemoveUse(this, i);
value->AddUse(this, i - 1);
}
}
operands_.erase(operands_.begin() + static_cast<std::ptrdiff_t>(index));
}
void User::OnOperandChanged(size_t, Value*) {}
void User::OnOperandRemoving(size_t) {}
Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name) Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)), opcode_(op) {} : User(std::move(ty), std::move(name)), opcode_(op) {}
Opcode Instruction::GetOpcode() const { return opcode_; } Opcode Instruction::GetOpcode() const { return opcode_; }
bool Instruction::IsTerminator() const { bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret; }
return opcode_ == Opcode::Ret || opcode_ == Opcode::Br ||
opcode_ == Opcode::CondBr;
}
BasicBlock* Instruction::GetParent() const { return parent_; } BasicBlock* Instruction::GetParent() const { return parent_; }
void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; } void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; }
static bool IsIntBinaryOp(Opcode op) {
return op == Opcode::Add || op == Opcode::Sub || op == Opcode::Mul ||
op == Opcode::SDiv || op == Opcode::SRem;
}
static bool IsFloatBinaryOp(Opcode op) {
return op == Opcode::FAdd || op == Opcode::FSub || op == Opcode::FMul ||
op == Opcode::FDiv;
}
BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
Value* rhs, std::string name) Value* rhs, std::string name)
: Instruction(op, std::move(ty), std::move(name)) { : Instruction(op, std::move(ty), std::move(name)) {
if (!IsIntBinaryOp(op) && !IsFloatBinaryOp(op)) { if (op != Opcode::Add) {
throw std::runtime_error(FormatError("ir", "BinaryInst 非算术 op")); throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add"));
} }
if (!lhs || !rhs) { if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数")); throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数"));
@ -106,15 +70,12 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
if (!type_ || !lhs->GetType() || !rhs->GetType()) { if (!type_ || !lhs->GetType() || !rhs->GetType()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息")); throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息"));
} }
if (!lhs->GetType()->Equals(*rhs->GetType()) || if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind() ||
!type_->Equals(*lhs->GetType())) { type_->GetKind() != lhs->GetType()->GetKind()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配")); throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配"));
} }
if (IsIntBinaryOp(op) && !type_->IsInt32()) { if (!type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "整数二元只支持 i32")); throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32"));
}
if (IsFloatBinaryOp(op) && !type_->IsFloat()) {
throw std::runtime_error(FormatError("ir", "浮点二元只支持 float"));
} }
AddOperand(lhs); AddOperand(lhs);
AddOperand(rhs); AddOperand(rhs);
@ -124,127 +85,6 @@ Value* BinaryInst::GetLhs() const { return GetOperand(0); }
Value* BinaryInst::GetRhs() const { return GetOperand(1); } Value* BinaryInst::GetRhs() const { return GetOperand(1); }
ICmpInst::ICmpInst(ICmpPredicate pred, Value* lhs, Value* rhs, std::string name)
: Instruction(Opcode::ICmp, Type::GetInt1Type(), std::move(name)),
pred_(pred) {
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "ICmpInst 缺少操作数"));
}
if (!lhs->GetType() || !rhs->GetType() ||
!lhs->GetType()->Equals(*rhs->GetType())) {
throw std::runtime_error(FormatError("ir", "ICmpInst 类型不匹配"));
}
if (!lhs->GetType()->IsInt1() && !lhs->GetType()->IsInt32()) {
throw std::runtime_error(FormatError("ir", "ICmpInst 仅支持整型"));
}
AddOperand(lhs);
AddOperand(rhs);
}
Value* ICmpInst::GetLhs() const { return GetOperand(0); }
Value* ICmpInst::GetRhs() const { return GetOperand(1); }
FCmpInst::FCmpInst(FCmpPredicate pred, Value* lhs, Value* rhs, std::string name)
: Instruction(Opcode::FCmp, Type::GetInt1Type(), std::move(name)),
pred_(pred) {
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "FCmpInst 缺少操作数"));
}
if (!lhs->GetType() || !rhs->GetType() ||
!lhs->GetType()->Equals(*rhs->GetType())) {
throw std::runtime_error(FormatError("ir", "FCmpInst 类型不匹配"));
}
if (!lhs->GetType()->IsFloat()) {
throw std::runtime_error(FormatError("ir", "FCmpInst 仅支持 float"));
}
AddOperand(lhs);
AddOperand(rhs);
}
Value* FCmpInst::GetLhs() const { return GetOperand(0); }
Value* FCmpInst::GetRhs() const { return GetOperand(1); }
CastInst::CastInst(Opcode op, std::shared_ptr<Type> dst_ty, Value* src,
std::string name)
: Instruction(op, std::move(dst_ty), std::move(name)) {
if (op != Opcode::SIToFP && op != Opcode::FPToSI && op != Opcode::ZExt) {
throw std::runtime_error(FormatError("ir", "CastInst 不支持的 op"));
}
if (!src) {
throw std::runtime_error(FormatError("ir", "CastInst 缺少 src"));
}
if (op == Opcode::SIToFP) {
if (!src->GetType()->IsInt32() && !src->GetType()->IsInt1()) {
throw std::runtime_error(FormatError("ir", "SIToFP 仅支持整型"));
}
if (!type_ || !type_->IsFloat()) {
throw std::runtime_error(FormatError("ir", "SIToFP 目标必须是 float"));
}
} else if (op == Opcode::FPToSI) {
if (!src->GetType()->IsFloat()) {
throw std::runtime_error(FormatError("ir", "FPToSI 仅支持 float"));
}
if (!type_ || !type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "FPToSI 目标必须是 i32"));
}
} else {
if (!src->GetType()->IsInt1()) {
throw std::runtime_error(FormatError("ir", "ZExt 仅支持 i1"));
}
if (!type_ || !type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "ZExt 目标必须是 i32"));
}
}
AddOperand(src);
}
Value* CastInst::GetValue() const { return GetOperand(0); }
BranchInst::BranchInst(BasicBlock* dest)
: Instruction(Opcode::Br, Type::GetVoidType(), "") {
if (!dest) {
throw std::runtime_error(FormatError("ir", "BranchInst 缺少目标块"));
}
AddOperand(dest);
}
BasicBlock* BranchInst::GetDest() const {
return static_cast<BasicBlock*>(GetOperand(0));
}
CondBrInst::CondBrInst(Value* cond, BasicBlock* true_dest,
BasicBlock* false_dest)
: Instruction(Opcode::CondBr, Type::GetVoidType(), "") {
if (!cond || !true_dest || !false_dest) {
throw std::runtime_error(FormatError("ir", "CondBrInst 缺少参数"));
}
if (!cond->GetType() || !cond->GetType()->IsInt1()) {
throw std::runtime_error(FormatError("ir", "CondBrInst cond 必须是 i1"));
}
AddOperand(cond);
AddOperand(true_dest);
AddOperand(false_dest);
}
Value* CondBrInst::GetCond() const { return GetOperand(0); }
BasicBlock* CondBrInst::GetTrueDest() const {
return static_cast<BasicBlock*>(GetOperand(1));
}
BasicBlock* CondBrInst::GetFalseDest() const {
return static_cast<BasicBlock*>(GetOperand(2));
}
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty)
: Instruction(Opcode::Ret, std::move(void_ty), "") {
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void"));
}
}
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val) ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
: Instruction(Opcode::Ret, std::move(void_ty), "") { : Instruction(Opcode::Ret, std::move(void_ty), "") {
if (!val) { if (!val) {
@ -256,36 +96,26 @@ ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
AddOperand(val); AddOperand(val);
} }
bool ReturnInst::HasReturnValue() const { return GetNumOperands() > 0; } Value* ReturnInst::GetValue() const { return GetOperand(0); }
Value* ReturnInst::GetValue() const {
if (!HasReturnValue()) return nullptr;
return GetOperand(0);
}
AllocaInst::AllocaInst(std::shared_ptr<Type> allocated_ty, std::string name) AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name)
: Instruction(Opcode::Alloca, Type::GetPointerType(allocated_ty), : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) {
std::move(name)), if (!type_ || !type_->IsPtrInt32()) {
allocated_type_(std::move(allocated_ty)) { throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*"));
if (!allocated_type_) {
throw std::runtime_error(FormatError("ir", "AllocaInst 缺少类型"));
} }
} }
const std::shared_ptr<Type>& AllocaInst::GetAllocatedType() const {
return allocated_type_;
}
LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name) LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name)
: Instruction(Opcode::Load, std::move(val_ty), std::move(name)) { : Instruction(Opcode::Load, std::move(val_ty), std::move(name)) {
if (!ptr) { if (!ptr) {
throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr")); throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr"));
} }
if (!ptr->GetType() || !ptr->GetType()->IsPointer()) { if (!type_ || !type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "LoadInst ptr 不是指针")); throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32"));
} }
if (!type_ || !ptr->GetType()->GetElementType()->Equals(*type_)) { if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
throw std::runtime_error(FormatError("ir", "LoadInst 类型不匹配")); throw std::runtime_error(
FormatError("ir", "LoadInst 当前只支持从 i32* 加载"));
} }
AddOperand(ptr); AddOperand(ptr);
} }
@ -303,11 +133,12 @@ StoreInst::StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr)
if (!type_ || !type_->IsVoid()) { if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void")); throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void"));
} }
if (!ptr->GetType() || !ptr->GetType()->IsPointer()) { if (!val->GetType() || !val->GetType()->IsInt32()) {
throw std::runtime_error(FormatError("ir", "StoreInst ptr 不是指针")); throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32"));
} }
if (!ptr->GetType()->GetElementType()->Equals(*val->GetType())) { if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
throw std::runtime_error(FormatError("ir", "StoreInst 类型不匹配")); throw std::runtime_error(
FormatError("ir", "StoreInst 当前只支持写入 i32*"));
} }
AddOperand(val); AddOperand(val);
AddOperand(ptr); AddOperand(ptr);
@ -317,141 +148,4 @@ Value* StoreInst::GetValue() const { return GetOperand(0); }
Value* StoreInst::GetPtr() const { return GetOperand(1); } Value* StoreInst::GetPtr() const { return GetOperand(1); }
CallInst::CallInst(std::shared_ptr<Type> ret_ty, Value* callee,
std::vector<Value*> args, std::string name)
: Instruction(Opcode::Call, std::move(ret_ty), std::move(name)),
args_(std::move(args)) {
if (!callee) {
throw std::runtime_error(FormatError("ir", "CallInst 缺少 callee"));
}
AddOperand(callee);
for (auto* arg : args_) {
if (!arg) {
throw std::runtime_error(FormatError("ir", "CallInst arg 为空"));
}
AddOperand(arg);
}
}
Value* CallInst::GetCallee() const { return GetOperand(0); }
void CallInst::OnOperandChanged(size_t index, Value* value) {
if (index == 0) return;
size_t arg_index = index - 1;
if (arg_index < args_.size()) {
args_[arg_index] = value;
}
}
void CallInst::OnOperandRemoving(size_t index) {
if (index == 0) return;
size_t arg_index = index - 1;
if (arg_index < args_.size()) {
args_.erase(args_.begin() + static_cast<std::ptrdiff_t>(arg_index));
}
}
PhiInst::PhiInst(std::shared_ptr<Type> ty, std::string name)
: Instruction(Opcode::Phi, std::move(ty), std::move(name)) {}
void PhiInst::AddIncoming(Value* value, BasicBlock* block) {
if (!value || !block) {
throw std::runtime_error(FormatError("ir", "PhiInst incoming 为空"));
}
if (!value->GetType() || !type_ || !value->GetType()->Equals(*type_)) {
throw std::runtime_error(FormatError("ir", "PhiInst 类型不匹配"));
}
incoming_values_.push_back(value);
incoming_blocks_.push_back(block);
AddOperand(value);
AddOperand(block);
}
void PhiInst::RemoveIncomingFrom(BasicBlock* block) {
for (size_t i = 0; i < incoming_blocks_.size();) {
if (incoming_blocks_[i] != block) {
++i;
continue;
}
RemoveOperand(2 * i + 1);
RemoveOperand(2 * i);
}
}
const std::vector<Value*>& PhiInst::GetIncomingValues() const {
return incoming_values_;
}
const std::vector<BasicBlock*>& PhiInst::GetIncomingBlocks() const {
return incoming_blocks_;
}
void PhiInst::OnOperandChanged(size_t index, Value* value) {
size_t incoming_index = index / 2;
if (index % 2 == 0) {
if (incoming_index < incoming_values_.size()) {
incoming_values_[incoming_index] = value;
}
return;
}
if (incoming_index < incoming_blocks_.size()) {
incoming_blocks_[incoming_index] = static_cast<BasicBlock*>(value);
}
}
void PhiInst::OnOperandRemoving(size_t index) {
size_t incoming_index = index / 2;
if (index % 2 == 0) {
if (incoming_index < incoming_values_.size()) {
incoming_values_.erase(incoming_values_.begin() +
static_cast<std::ptrdiff_t>(incoming_index));
}
return;
}
if (incoming_index < incoming_blocks_.size()) {
incoming_blocks_.erase(incoming_blocks_.begin() +
static_cast<std::ptrdiff_t>(incoming_index));
}
}
GepInst::GepInst(std::shared_ptr<Type> result_ptr_ty, Value* base_ptr,
std::vector<Value*> indices, std::string name)
: Instruction(Opcode::Gep, std::move(result_ptr_ty), std::move(name)),
indices_(std::move(indices)) {
if (!base_ptr) {
throw std::runtime_error(FormatError("ir", "GepInst 缺少 base_ptr"));
}
if (!base_ptr->GetType() || !base_ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "GepInst base_ptr 不是指针"));
}
if (!type_ || !type_->IsPointer()) {
throw std::runtime_error(FormatError("ir", "GepInst 结果必须是指针"));
}
AddOperand(base_ptr);
for (auto* idx : indices_) {
if (!idx || !idx->GetType() || !idx->GetType()->IsInt32()) {
throw std::runtime_error(FormatError("ir", "GepInst index 必须是 i32"));
}
AddOperand(idx);
}
}
Value* GepInst::GetBasePtr() const { return GetOperand(0); }
void GepInst::OnOperandChanged(size_t index, Value* value) {
if (index == 0) return;
size_t idx = index - 1;
if (idx < indices_.size()) {
indices_[idx] = value;
}
}
void GepInst::OnOperandRemoving(size_t index) {
if (index == 0) return;
size_t idx = index - 1;
if (idx < indices_.size()) {
indices_.erase(indices_.begin() + static_cast<std::ptrdiff_t>(idx));
}
}
} // namespace ir } // namespace ir

@ -10,39 +10,12 @@ const Context& Module::GetContext() const { return context_; }
Function* Module::CreateFunction(const std::string& name, Function* Module::CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type) { std::shared_ptr<Type> ret_type) {
auto func_ty = Type::GetFunctionType(std::move(ret_type), {}); functions_.push_back(std::make_unique<Function>(name, std::move(ret_type)));
functions_.push_back(std::make_unique<Function>(name, std::move(func_ty)));
return functions_.back().get(); return functions_.back().get();
} }
Function* Module::CreateFunctionWithType(const std::string& name,
std::shared_ptr<Type> func_type) {
functions_.push_back(
std::make_unique<Function>(name, std::move(func_type), false));
return functions_.back().get();
}
Function* Module::CreateFunctionDecl(const std::string& name,
std::shared_ptr<Type> func_type) {
functions_.push_back(
std::make_unique<Function>(name, std::move(func_type), true));
return functions_.back().get();
}
GlobalVariable* Module::CreateGlobalVariable(const std::string& name,
std::shared_ptr<Type> value_type,
ConstantValue* init, bool is_const) {
globals_.push_back(std::make_unique<GlobalVariable>(
std::move(value_type), name, init, is_const));
return globals_.back().get();
}
const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const { const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const {
return functions_; return functions_;
} }
const std::vector<std::unique_ptr<GlobalVariable>>& Module::GetGlobals() const {
return globals_;
}
} // namespace ir } // namespace ir

@ -1,148 +1,31 @@
// 支持 void/i1/i32/float/ptr/array/function/label // 当前仅支持 void、i32 和 i32*
#include "ir/IR.h" #include "ir/IR.h"
namespace ir { namespace ir {
Type::Type(Kind k) : kind_(k) {} Type::Type(Kind k) : kind_(k) {}
Type::Type(Kind k, std::shared_ptr<Type> elem, size_t count)
: kind_(k), elem_type_(std::move(elem)), array_size_(count) {}
Type::Type(Kind k, std::shared_ptr<Type> ret,
std::vector<std::shared_ptr<Type>> params, bool is_vararg)
: kind_(k), ret_type_(std::move(ret)), param_types_(std::move(params)),
is_vararg_(is_vararg) {}
const std::shared_ptr<Type>& Type::GetVoidType() { const std::shared_ptr<Type>& Type::GetVoidType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Void); static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Void);
return type; return type;
} }
const std::shared_ptr<Type>& Type::GetInt1Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int1);
return type;
}
const std::shared_ptr<Type>& Type::GetInt32Type() { const std::shared_ptr<Type>& Type::GetInt32Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int32); static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int32);
return type; return type;
} }
const std::shared_ptr<Type>& Type::GetFloatType() { const std::shared_ptr<Type>& Type::GetPtrInt32Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Float); static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::PtrInt32);
return type;
}
const std::shared_ptr<Type>& Type::GetLabelType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Label);
return type; return type;
} }
std::shared_ptr<Type> Type::GetPointerType(std::shared_ptr<Type> elem) {
if (!elem) {
throw std::runtime_error("PointerType 缺少 element type");
}
return std::make_shared<Type>(Kind::Pointer, std::move(elem), 0);
}
std::shared_ptr<Type> Type::GetArrayType(std::shared_ptr<Type> elem,
size_t count) {
if (!elem) {
throw std::runtime_error("ArrayType 缺少 element type");
}
return std::make_shared<Type>(Kind::Array, std::move(elem), count);
}
std::shared_ptr<Type> Type::GetFunctionType(
std::shared_ptr<Type> ret, std::vector<std::shared_ptr<Type>> params,
bool is_vararg) {
if (!ret) {
throw std::runtime_error("FunctionType 缺少 return type");
}
return std::make_shared<Type>(Kind::Function, std::move(ret),
std::move(params), is_vararg);
}
Type::Kind Type::GetKind() const { return kind_; } Type::Kind Type::GetKind() const { return kind_; }
bool Type::IsVoid() const { return kind_ == Kind::Void; } bool Type::IsVoid() const { return kind_ == Kind::Void; }
bool Type::IsInt1() const { return kind_ == Kind::Int1; }
bool Type::IsInt32() const { return kind_ == Kind::Int32; } bool Type::IsInt32() const { return kind_ == Kind::Int32; }
bool Type::IsFloat() const { return kind_ == Kind::Float; } bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; }
bool Type::IsPointer() const { return kind_ == Kind::Pointer; }
bool Type::IsArray() const { return kind_ == Kind::Array; }
bool Type::IsFunction() const { return kind_ == Kind::Function; }
bool Type::IsLabel() const { return kind_ == Kind::Label; }
const std::shared_ptr<Type>& Type::GetElementType() const {
if (!elem_type_) {
throw std::runtime_error("Type 没有 element type");
}
return elem_type_;
}
size_t Type::GetArraySize() const {
if (!IsArray()) {
throw std::runtime_error("Type 不是 array");
}
return array_size_;
}
const std::shared_ptr<Type>& Type::GetReturnType() const {
if (!IsFunction()) {
throw std::runtime_error("Type 不是 function");
}
return ret_type_;
}
const std::vector<std::shared_ptr<Type>>& Type::GetParamTypes() const {
if (!IsFunction()) {
throw std::runtime_error("Type 不是 function");
}
return param_types_;
}
bool Type::IsVarArg() const {
if (!IsFunction()) {
throw std::runtime_error("Type 不是 function");
}
return is_vararg_;
}
bool Type::Equals(const Type& other) const {
if (kind_ != other.kind_) return false;
switch (kind_) {
case Kind::Pointer:
return elem_type_ && other.elem_type_ &&
elem_type_->Equals(*other.elem_type_);
case Kind::Array:
return array_size_ == other.array_size_ && elem_type_ &&
other.elem_type_ && elem_type_->Equals(*other.elem_type_);
case Kind::Function: {
if (!ret_type_ || !other.ret_type_ ||
!ret_type_->Equals(*other.ret_type_) ||
is_vararg_ != other.is_vararg_ ||
param_types_.size() != other.param_types_.size()) {
return false;
}
for (size_t i = 0; i < param_types_.size(); ++i) {
if (!param_types_[i] || !other.param_types_[i] ||
!param_types_[i]->Equals(*other.param_types_[i])) {
return false;
}
}
return true;
}
default:
return true;
}
}
} // namespace ir } // namespace ir

@ -18,21 +18,9 @@ void Value::SetName(std::string n) { name_ = std::move(n); }
bool Value::IsVoid() const { return type_ && type_->IsVoid(); } bool Value::IsVoid() const { return type_ && type_->IsVoid(); }
bool Value::IsInt1() const { return type_ && type_->IsInt1(); }
bool Value::IsInt32() const { return type_ && type_->IsInt32(); } bool Value::IsInt32() const { return type_ && type_->IsInt32(); }
bool Value::IsFloat() const { return type_ && type_->IsFloat(); } bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); }
bool Value::IsPointer() const { return type_ && type_->IsPointer(); }
bool Value::IsArray() const { return type_ && type_->IsArray(); }
bool Value::IsFunctionType() const { return type_ && type_->IsFunction(); }
bool Value::IsPtrInt32() const {
return type_ && type_->IsPointer() && type_->GetElementType()->IsInt32();
}
bool Value::IsConstant() const { bool Value::IsConstant() const {
return dynamic_cast<const ConstantValue*>(this) != nullptr; return dynamic_cast<const ConstantValue*>(this) != nullptr;
@ -90,25 +78,6 @@ ConstantValue::ConstantValue(std::shared_ptr<Type> ty, std::string name)
: Value(std::move(ty), std::move(name)) {} : Value(std::move(ty), std::move(name)) {}
ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int v) ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int v)
: ConstantValue(std::move(ty), ""), value_(v) { : ConstantValue(std::move(ty), ""), value_(v) {}
if (!type_ || (!type_->IsInt32() && !type_->IsInt1())) {
throw std::runtime_error("ConstantInt 需要 i1/i32 类型");
}
}
ConstantFloat::ConstantFloat(std::shared_ptr<Type> ty, float v)
: ConstantValue(std::move(ty), ""), value_(v) {
if (!type_ || !type_->IsFloat()) {
throw std::runtime_error("ConstantFloat 需要 float 类型");
}
}
ConstantArray::ConstantArray(std::shared_ptr<Type> ty,
std::vector<ConstantValue*> elements)
: ConstantValue(std::move(ty), ""), elements_(std::move(elements)) {
if (!type_ || !type_->IsArray()) {
throw std::runtime_error("ConstantArray 需要 array 类型");
}
}
} // namespace ir } // namespace ir

@ -1,135 +1,4 @@
#include "ir/IR.h" // CFG 简化:
// - 删除不可达块、合并空块、简化分支等
#include <memory> // - 改善 IR 结构,便于后续优化与后端生成
#include <queue>
#include <unordered_set>
namespace ir {
namespace {
Instruction* Terminator(BasicBlock& block) {
auto& insts = block.GetMutableInstructions();
if (insts.empty()) return nullptr;
auto* inst = insts.back().get();
return inst && inst->IsTerminator() ? inst : nullptr;
}
void AddCFGEdge(BasicBlock* from, BasicBlock* to) {
if (!from || !to) return;
from->AddSuccessor(to);
to->AddPredecessor(from);
}
void RebuildCFG(Function& func) {
for (const auto& block : func.GetBlocks()) {
if (!block) continue;
block->ClearPredecessors();
block->ClearSuccessors();
}
for (const auto& block : func.GetBlocks()) {
if (!block) continue;
auto* term = Terminator(*block);
if (auto* br = dynamic_cast<BranchInst*>(term)) {
AddCFGEdge(block.get(), br->GetDest());
} else if (auto* cbr = dynamic_cast<CondBrInst*>(term)) {
AddCFGEdge(block.get(), cbr->GetTrueDest());
AddCFGEdge(block.get(), cbr->GetFalseDest());
}
}
}
bool SimplifyBranches(Function& func) {
bool changed = false;
for (const auto& block : func.GetBlocks()) {
if (!block) continue;
auto* cbr = dynamic_cast<CondBrInst*>(Terminator(*block));
if (!cbr) continue;
BasicBlock* dest = nullptr;
if (cbr->GetTrueDest() == cbr->GetFalseDest()) {
dest = cbr->GetTrueDest();
} else if (auto* cond = dynamic_cast<ConstantInt*>(cbr->GetCond())) {
dest = cond->GetValue() != 0 ? cbr->GetTrueDest() : cbr->GetFalseDest();
}
if (dest) {
block->ReplaceTerminator(std::make_unique<BranchInst>(dest));
changed = true;
}
}
return changed;
}
std::unordered_set<BasicBlock*> ReachableBlocks(Function& func) {
std::unordered_set<BasicBlock*> reachable;
std::queue<BasicBlock*> work;
if (auto* entry = func.GetEntry()) {
reachable.insert(entry);
work.push(entry);
}
while (!work.empty()) {
auto* block = work.front();
work.pop();
for (auto* succ : block->GetSuccessors()) {
if (succ && reachable.insert(succ).second) {
work.push(succ);
}
}
}
return reachable;
}
bool RemoveUnreachable(Function& func) {
auto reachable = ReachableBlocks(func);
bool changed = false;
for (const auto& block : func.GetBlocks()) {
if (!block || reachable.count(block.get()) == 0) continue;
for (const auto& inst : block->GetInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst.get());
if (!phi) continue;
for (const auto& other : func.GetBlocks()) {
if (other && reachable.count(other.get()) == 0) {
phi->RemoveIncomingFrom(other.get());
}
}
}
}
auto& blocks = func.GetMutableBlocks();
for (auto it = blocks.begin(); it != blocks.end();) {
auto* block = it->get();
if (!block || reachable.count(block) != 0) {
++it;
continue;
}
auto& insts = block->GetMutableInstructions();
while (!insts.empty()) {
block->EraseInstruction(insts.back().get());
}
it = blocks.erase(it);
changed = true;
}
return changed;
}
} // namespace
bool RunCFGSimplify(Module& module) {
bool changed = false;
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
RebuildCFG(*func);
bool local_changed = SimplifyBranches(*func);
if (local_changed) {
RebuildCFG(*func);
}
local_changed |= RemoveUnreachable(*func);
if (local_changed) {
RebuildCFG(*func);
}
changed |= local_changed;
}
return changed;
}
} // namespace ir

@ -1,94 +1,4 @@
#include "ir/IR.h" // 公共子表达式消除CSE
// - 识别并复用重复计算的等价表达式
#include <algorithm> // - 典型放置在 ConstFold 之后、DCE 之前
#include <cstdint> // - 当前为 Lab4 的框架占位,具体算法由实验实现
#include <sstream>
#include <string>
#include <unordered_map>
namespace ir {
namespace {
bool IsCommutative(Opcode op) {
return op == Opcode::Add || op == Opcode::Mul || op == Opcode::FAdd ||
op == Opcode::FMul;
}
std::string ValueId(Value* value) {
std::ostringstream oss;
oss << reinterpret_cast<std::uintptr_t>(value);
return oss.str();
}
std::string KeyFor(const Instruction& inst) {
std::ostringstream oss;
oss << static_cast<int>(inst.GetOpcode()) << ":";
if (auto* icmp = dynamic_cast<const ICmpInst*>(&inst)) {
oss << static_cast<int>(icmp->GetPredicate()) << ":";
} else if (auto* fcmp = dynamic_cast<const FCmpInst*>(&inst)) {
oss << static_cast<int>(fcmp->GetPredicate()) << ":";
}
std::vector<std::string> operands;
for (size_t i = 0; i < inst.GetNumOperands(); ++i) {
operands.push_back(ValueId(inst.GetOperand(i)));
}
if (IsCommutative(inst.GetOpcode()) && operands.size() == 2) {
std::sort(operands.begin(), operands.end());
}
for (const auto& operand : operands) {
oss << operand << ",";
}
return oss.str();
}
bool IsCSECandidate(const Instruction& inst) {
switch (inst.GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::SDiv:
case Opcode::SRem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::ICmp:
case Opcode::FCmp:
case Opcode::SIToFP:
case Opcode::FPToSI:
case Opcode::ZExt:
case Opcode::Gep:
return true;
default:
return false;
}
}
} // namespace
bool RunCSE(Module& module) {
bool changed = false;
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
for (const auto& block : func->GetBlocks()) {
if (!block) continue;
std::unordered_map<std::string, Instruction*> available;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!inst || !IsCSECandidate(*inst)) continue;
std::string key = KeyFor(*inst);
auto it = available.find(key);
if (it != available.end()) {
inst->ReplaceAllUsesWith(it->second);
changed = true;
} else {
available.emplace(std::move(key), inst);
}
}
}
}
return changed;
}
} // namespace ir

@ -1,239 +1,4 @@
#include "ir/IR.h" // IR 常量折叠:
// - 折叠可判定的常量表达式
#include <cmath> // - 简化常量控制流分支(按实现范围裁剪)
namespace ir {
namespace {
ConstantInt* AsConstInt(Value* value) {
return dynamic_cast<ConstantInt*>(value);
}
ConstantFloat* AsConstFloat(Value* value) {
return dynamic_cast<ConstantFloat*>(value);
}
bool IsZero(Value* value) {
if (auto* i = AsConstInt(value)) return i->GetValue() == 0;
if (auto* f = AsConstFloat(value)) return f->GetValue() == 0.0f;
return false;
}
bool IsOne(Value* value) {
if (auto* i = AsConstInt(value)) return i->GetValue() == 1;
if (auto* f = AsConstFloat(value)) return f->GetValue() == 1.0f;
return false;
}
ConstantValue* FoldBinary(BinaryInst& inst, Context& ctx) {
auto* li = AsConstInt(inst.GetLhs());
auto* ri = AsConstInt(inst.GetRhs());
if (li && ri) {
int lhs = li->GetValue();
int rhs = ri->GetValue();
switch (inst.GetOpcode()) {
case Opcode::Add:
return ctx.GetConstInt(lhs + rhs);
case Opcode::Sub:
return ctx.GetConstInt(lhs - rhs);
case Opcode::Mul:
return ctx.GetConstInt(lhs * rhs);
case Opcode::SDiv:
if (rhs != 0) return ctx.GetConstInt(lhs / rhs);
break;
case Opcode::SRem:
if (rhs != 0) return ctx.GetConstInt(lhs % rhs);
break;
default:
break;
}
}
auto* lf = AsConstFloat(inst.GetLhs());
auto* rf = AsConstFloat(inst.GetRhs());
if (lf && rf) {
float lhs = lf->GetValue();
float rhs = rf->GetValue();
switch (inst.GetOpcode()) {
case Opcode::FAdd:
return ctx.GetConstFloat(lhs + rhs);
case Opcode::FSub:
return ctx.GetConstFloat(lhs - rhs);
case Opcode::FMul:
return ctx.GetConstFloat(lhs * rhs);
case Opcode::FDiv:
return ctx.GetConstFloat(lhs / rhs);
default:
break;
}
}
return nullptr;
}
Value* SimplifyBinary(BinaryInst& inst) {
auto op = inst.GetOpcode();
auto* lhs = inst.GetLhs();
auto* rhs = inst.GetRhs();
switch (op) {
case Opcode::Add:
case Opcode::FAdd:
if (IsZero(rhs)) return lhs;
if (IsZero(lhs)) return rhs;
break;
case Opcode::Sub:
case Opcode::FSub:
if (IsZero(rhs)) return lhs;
break;
case Opcode::Mul:
if (IsOne(rhs)) return lhs;
if (IsOne(lhs)) return rhs;
if (IsZero(rhs)) return rhs;
if (IsZero(lhs)) return lhs;
break;
case Opcode::FMul:
if (IsOne(rhs)) return lhs;
if (IsOne(lhs)) return rhs;
break;
case Opcode::SDiv:
case Opcode::FDiv:
if (IsOne(rhs)) return lhs;
break;
case Opcode::SRem:
if (IsOne(rhs)) return rhs;
break;
default:
break;
}
return nullptr;
}
ConstantInt* FoldICmp(ICmpInst& inst, Context& ctx) {
auto* lhs = AsConstInt(inst.GetLhs());
auto* rhs = AsConstInt(inst.GetRhs());
if (!lhs || !rhs) return nullptr;
int l = lhs->GetValue();
int r = rhs->GetValue();
bool result = false;
switch (inst.GetPredicate()) {
case ICmpPredicate::Eq:
result = l == r;
break;
case ICmpPredicate::Ne:
result = l != r;
break;
case ICmpPredicate::Slt:
result = l < r;
break;
case ICmpPredicate::Sle:
result = l <= r;
break;
case ICmpPredicate::Sgt:
result = l > r;
break;
case ICmpPredicate::Sge:
result = l >= r;
break;
}
return ctx.GetConstBool(result);
}
ConstantInt* FoldFCmp(FCmpInst& inst, Context& ctx) {
auto* lhs = AsConstFloat(inst.GetLhs());
auto* rhs = AsConstFloat(inst.GetRhs());
if (!lhs || !rhs) return nullptr;
float l = lhs->GetValue();
float r = rhs->GetValue();
bool ordered = !std::isnan(l) && !std::isnan(r);
bool result = false;
switch (inst.GetPredicate()) {
case FCmpPredicate::Oeq:
result = ordered && l == r;
break;
case FCmpPredicate::One:
result = ordered && l != r;
break;
case FCmpPredicate::Olt:
result = ordered && l < r;
break;
case FCmpPredicate::Ole:
result = ordered && l <= r;
break;
case FCmpPredicate::Ogt:
result = ordered && l > r;
break;
case FCmpPredicate::Oge:
result = ordered && l >= r;
break;
}
return ctx.GetConstBool(result);
}
ConstantValue* FoldCast(CastInst& inst, Context& ctx) {
switch (inst.GetOpcode()) {
case Opcode::SIToFP:
if (auto* c = AsConstInt(inst.GetValue())) {
return ctx.GetConstFloat(static_cast<float>(c->GetValue()));
}
break;
case Opcode::FPToSI:
if (auto* c = AsConstFloat(inst.GetValue())) {
return ctx.GetConstInt(static_cast<int>(c->GetValue()));
}
break;
case Opcode::ZExt:
if (auto* c = AsConstInt(inst.GetValue())) {
return ctx.GetConstInt(c->GetValue() != 0 ? 1 : 0);
}
break;
default:
break;
}
return nullptr;
}
Value* SimplifyPhi(PhiInst& phi) {
const auto& values = phi.GetIncomingValues();
if (values.empty()) return nullptr;
auto* first = values.front();
for (auto* value : values) {
if (value != first) return nullptr;
}
return first;
}
} // namespace
bool RunConstFold(Module& module) {
bool changed = false;
auto& ctx = module.GetContext();
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
for (const auto& block : func->GetBlocks()) {
if (!block) continue;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
Value* replacement = nullptr;
if (auto* bin = dynamic_cast<BinaryInst*>(inst)) {
replacement = FoldBinary(*bin, ctx);
if (!replacement) replacement = SimplifyBinary(*bin);
} else if (auto* icmp = dynamic_cast<ICmpInst*>(inst)) {
replacement = FoldICmp(*icmp, ctx);
} else if (auto* fcmp = dynamic_cast<FCmpInst*>(inst)) {
replacement = FoldFCmp(*fcmp, ctx);
} else if (auto* cast = dynamic_cast<CastInst*>(inst)) {
replacement = FoldCast(*cast, ctx);
} else if (auto* phi = dynamic_cast<PhiInst*>(inst)) {
replacement = SimplifyPhi(*phi);
}
if (replacement && replacement != inst) {
inst->ReplaceAllUsesWith(replacement);
changed = true;
}
}
}
}
return changed;
}
} // namespace ir

@ -1,77 +1,5 @@
#include "ir/IR.h" // 常量传播Constant Propagation
// - 沿 use-def 关系传播已知常量
#include <unordered_map> // - 将可替换的 SSA 值改写为常量,暴露更多折叠机会
// - 常与 ConstFold、DCE、CFGSimplify 迭代配合使用
namespace ir {
namespace {
ConstantValue* AsConstant(Value* value) {
return dynamic_cast<ConstantValue*>(value);
}
ConstantValue* ScalarConstInitializer(Value* ptr) {
auto* global = dynamic_cast<GlobalVariable*>(ptr);
if (!global || !global->IsConst()) return nullptr;
auto* init = global->GetInitializer();
if (!init) return nullptr;
if (dynamic_cast<ConstantArray*>(init)) return nullptr;
return init;
}
bool IsScalarStackSlot(Value* ptr) {
auto* alloca = dynamic_cast<AllocaInst*>(ptr);
if (!alloca) return false;
const auto& ty = alloca->GetAllocatedType();
return ty && (ty->IsInt1() || ty->IsInt32() || ty->IsFloat());
}
} // namespace
bool RunConstProp(Module& module) {
bool changed = false;
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
for (const auto& block : func->GetBlocks()) {
if (!block) continue;
std::unordered_map<Value*, ConstantValue*> known_memory;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* load = dynamic_cast<LoadInst*>(inst)) {
auto* ptr = load->GetPtr();
ConstantValue* known = nullptr;
auto it = known_memory.find(ptr);
if (it != known_memory.end()) {
known = it->second;
} else {
known = ScalarConstInitializer(ptr);
}
if (known) {
load->ReplaceAllUsesWith(known);
changed = true;
}
continue;
}
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
auto* ptr = store->GetPtr();
if (IsScalarStackSlot(ptr)) {
if (auto* c = AsConstant(store->GetValue())) {
known_memory[ptr] = c;
} else {
known_memory.erase(ptr);
}
}
continue;
}
if (inst->GetOpcode() == Opcode::Call) {
known_memory.clear();
}
}
}
}
return changed;
}
} // namespace ir

@ -1,55 +1,4 @@
#include "ir/IR.h" // 死代码删除DCE
// - 删除无用指令与无用基本块
namespace ir { // - 通常与 CFG 简化配合使用
namespace {
bool HasSideEffect(const Instruction& inst) {
switch (inst.GetOpcode()) {
case Opcode::Store:
case Opcode::Ret:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Call:
return true;
default:
return false;
}
}
bool IsDead(const Instruction& inst) {
if (inst.IsVoid()) return false;
if (HasSideEffect(inst)) return false;
return inst.GetUses().empty();
}
} // namespace
bool RunDCE(Module& module) {
bool changed = false;
bool local_changed = true;
while (local_changed) {
local_changed = false;
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
for (const auto& block : func->GetBlocks()) {
if (!block) continue;
auto& insts = block->GetMutableInstructions();
for (auto it = insts.begin(); it != insts.end();) {
auto* inst = it->get();
if (inst && IsDead(*inst)) {
block->EraseInstruction(inst);
local_changed = true;
changed = true;
it = insts.begin();
} else {
++it;
}
}
}
}
}
return changed;
}
} // namespace ir

@ -1,109 +1,4 @@
#include "ir/IR.h" // Mem2RegSSA 构造):
// - 将局部变量的 alloca/load/store 提升为 SSA 形式
#include <vector> // - 插入 PHI 并重写使用,依赖支配树等分析
namespace ir {
namespace {
bool IsPromotableType(const AllocaInst& alloca) {
const auto& ty = alloca.GetAllocatedType();
return ty && (ty->IsInt1() || ty->IsInt32() || ty->IsFloat());
}
bool IsDirectUseOf(Value* ptr, Instruction* inst) {
if (auto* load = dynamic_cast<LoadInst*>(inst)) {
return load->GetPtr() == ptr;
}
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
return store->GetPtr() == ptr;
}
return false;
}
BasicBlock* SingleUseBlock(AllocaInst& alloca) {
BasicBlock* use_block = nullptr;
for (const auto& use : alloca.GetUses()) {
auto* inst = dynamic_cast<Instruction*>(use.GetUser());
if (!inst || !IsDirectUseOf(&alloca, inst)) return nullptr;
auto* parent = inst->GetParent();
if (!parent) return nullptr;
if (!use_block) {
use_block = parent;
} else if (use_block != parent) {
return nullptr;
}
}
return use_block;
}
bool CanPromoteInBlock(AllocaInst& alloca, BasicBlock& block) {
bool has_value = false;
bool saw_use = false;
for (const auto& inst_ptr : block.GetInstructions()) {
auto* inst = inst_ptr.get();
if (!inst || !IsDirectUseOf(&alloca, inst)) continue;
saw_use = true;
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
if (store->GetValue() == &alloca) return false;
has_value = true;
} else if (dynamic_cast<LoadInst*>(inst)) {
if (!has_value) return false;
}
}
return saw_use;
}
bool PromoteInBlock(AllocaInst& alloca, BasicBlock& block) {
Value* current = nullptr;
std::vector<Instruction*> erase;
for (const auto& inst_ptr : block.GetInstructions()) {
auto* inst = inst_ptr.get();
if (!inst || !IsDirectUseOf(&alloca, inst)) continue;
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
current = store->GetValue();
erase.push_back(store);
} else if (auto* load = dynamic_cast<LoadInst*>(inst)) {
load->ReplaceAllUsesWith(current);
erase.push_back(load);
}
}
for (auto* inst : erase) {
block.EraseInstruction(inst);
}
if (alloca.GetUses().empty()) {
if (auto* parent = alloca.GetParent()) {
parent->EraseInstruction(&alloca);
}
}
return !erase.empty();
}
} // namespace
bool RunMem2Reg(Module& module) {
bool changed = false;
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
std::vector<AllocaInst*> allocas;
for (const auto& block : func->GetBlocks()) {
if (!block) continue;
for (const auto& inst : block->GetInstructions()) {
if (auto* alloca = dynamic_cast<AllocaInst*>(inst.get())) {
if (IsPromotableType(*alloca)) {
allocas.push_back(alloca);
}
}
}
}
for (auto* alloca : allocas) {
if (!alloca || alloca->GetUses().empty()) continue;
auto* block = SingleUseBlock(*alloca);
if (!block || !CanPromoteInBlock(*alloca, *block)) continue;
changed |= PromoteInBlock(*alloca, *block);
}
}
return changed;
}
} // namespace ir

@ -1,23 +1 @@
#include "ir/IR.h" // IR Pass 管理骨架。
namespace ir {
bool RunScalarOptimizationPipeline(Module& module) {
bool changed = false;
changed |= RunMem2Reg(module);
for (int i = 0; i < 8; ++i) {
bool iter_changed = false;
iter_changed |= RunConstFold(module);
iter_changed |= RunConstProp(module);
iter_changed |= RunCSE(module);
iter_changed |= RunDCE(module);
iter_changed |= RunCFGSimplify(module);
changed |= iter_changed;
if (!iter_changed) break;
}
return changed;
}
} // namespace ir

@ -1,32 +1,46 @@
#include "irgen/IRGen.h" #include "irgen/IRGen.h"
#include <any>
#include <stdexcept> #include <stdexcept>
#include "SysYParser.h" #include "SysYParser.h"
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h" #include "utils/Log.h"
std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) { namespace {
if (!ctx) return BlockFlow::Continue;
BlockFlow flow = BlockFlow::Continue; std::string GetLValueName(SysYParser::LValueContext& lvalue) {
if (!lvalue.ID()) {
throw std::runtime_error(FormatError("irgen", "非法左值"));
}
return lvalue.ID()->getText();
}
} // namespace
std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句块"));
}
for (auto* item : ctx->blockItem()) { for (auto* item : ctx->blockItem()) {
if (item) { if (item) {
if (VisitBlockItemResult(*item) == BlockFlow::Terminated) { if (VisitBlockItemResult(*item) == BlockFlow::Terminated) {
flow = BlockFlow::Terminated; // 当前语法要求 return 为块内最后一条语句;命中后可停止生成。
break; break;
} }
} }
} }
return flow; return {};
} }
IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult(SysYParser::BlockItemContext& item) { IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult(
SysYParser::BlockItemContext& item) {
return std::any_cast<BlockFlow>(item.accept(this)); return std::any_cast<BlockFlow>(item.accept(this));
} }
std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
if (!ctx) return BlockFlow::Continue; if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少块内项"));
}
if (ctx->decl()) { if (ctx->decl()) {
ctx->decl()->accept(this); ctx->decl()->accept(this);
return BlockFlow::Continue; return BlockFlow::Continue;
@ -34,219 +48,60 @@ std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
if (ctx->stmt()) { if (ctx->stmt()) {
return ctx->stmt()->accept(this); return ctx->stmt()->accept(this);
} }
return BlockFlow::Continue; throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明"));
} }
// 变量声明的 IR 生成目前也是最小实现:
// - 先检查声明的基础类型,当前仅支持局部 int
// - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。
//
// 和更完整的版本相比,这里还没有:
// - 一个 Decl 中多个变量定义的顺序处理;
// - const、数组、全局变量等不同声明形态
// - 更丰富的类型系统。
std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
if (!ctx) return {}; if (!ctx) {
if (auto* constDecl = ctx->constDecl()) { throw std::runtime_error(FormatError("irgen", "缺少变量声明"));
for (auto* def : constDecl->constDef()) {
def->accept(this);
}
return {};
} }
if (auto* varDecl = ctx->varDecl()) { if (!ctx->btype() || !ctx->btype()->INT()) {
for (auto* varDef : varDecl->varDef()) { throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明"));
varDef->accept(this); }
} auto* var_def = ctx->varDef();
return {}; if (!var_def) {
throw std::runtime_error(FormatError("irgen", "非法变量声明"));
} }
var_def->accept(this);
return {}; return {};
} }
// 当前仍是教学用的最小版本,因此这里只支持:
// - 局部 int 变量;
// - 标量初始化;
// - 一个 VarDef 对应一个槽位。
std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
if (!ctx) return {}; if (!ctx) {
if (!ctx->ID()) { throw std::runtime_error(FormatError("irgen", "缺少变量定义"));
throw std::runtime_error(FormatError("irgen", "变量声明缺少名称"));
}
if (!func_) {
const TypeDesc* ty = sema_.GetVarType(ctx);
if (!ty) {
throw std::runtime_error(FormatError("irgen", "全局变量类型缺失"));
}
if (global_var_storage_.find(ctx) != global_var_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "重复生成全局变量"));
}
ir::ConstantValue* init = nullptr;
if (ty->dims.empty()) {
if (auto* initVal = ctx->initVal()) {
if (!initVal->exp()) {
throw std::runtime_error(FormatError("irgen", "全局变量初始化非法"));
}
init = EvalConstScalar(initVal->exp());
if (ty->base == BaseTypeKind::Int &&
dynamic_cast<ir::ConstantFloat*>(init)) {
auto* cf = static_cast<ir::ConstantFloat*>(init);
init = module_.GetContext().GetConstInt(static_cast<int>(cf->GetValue()));
} else if (ty->base == BaseTypeKind::Float &&
dynamic_cast<ir::ConstantInt*>(init)) {
auto* ci = static_cast<ir::ConstantInt*>(init);
init = module_.GetContext().GetConstFloat(static_cast<float>(ci->GetValue()));
}
}
} else if (auto* initVal = ctx->initVal()) {
size_t total = ArrayTotalSize(*ty);
std::vector<ir::ConstantValue*> values(
total,
ty->base == BaseTypeKind::Float
? static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(0.0f))
: static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(0)));
InitGlobalArray(*ty, initVal, values, 0, 0, 0);
init = module_.GetContext().CreateConstArray(ToIRType(*ty), values);
}
auto* gv = module_.CreateGlobalVariable(ctx->ID()->getText(),
ToIRType(*ty), init, false);
global_var_storage_[ctx] = gv;
return {};
} }
if (var_storage_.find(ctx) != var_storage_.end()) { if (!ctx->lValue()) {
throw std::runtime_error(FormatError("irgen", "重复生成存储槽位")); throw std::runtime_error(FormatError("irgen", "变量声明缺少名称"));
} }
const TypeDesc* ty = sema_.GetVarType(ctx); GetLValueName(*ctx->lValue());
if (!ty) { if (storage_map_.find(ctx) != storage_map_.end()) {
throw std::runtime_error(FormatError("irgen", "变量类型缺失")); throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位"));
} }
auto* slot = CreateEntryAlloca(ToIRType(*ty), module_.GetContext().NextTemp()); auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
var_storage_[ctx] = slot; storage_map_[ctx] = slot;
if (ty->dims.empty()) { ir::Value* init = nullptr;
ir::Value* init = nullptr; if (auto* init_value = ctx->initValue()) {
if (auto* initVal = ctx->initVal()) { if (!init_value->exp()) {
if (!initVal->exp()) { throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化"));
throw std::runtime_error(FormatError("irgen", "标量初始化非法"));
}
init = EvalExp(initVal->exp());
} else {
init = DefaultValue(*ty);
}
if (ty->base == BaseTypeKind::Float) {
if (init->IsInt1() || init->IsInt32()) {
init = CastToFloat(init->IsInt1() ? CastToInt(init) : init);
}
} else if (ty->base == BaseTypeKind::Int) {
if (init->IsFloat() || init->IsInt1()) {
init = CastToInt(init);
}
} }
builder_.CreateStore(init, slot); init = EvalExpr(*init_value->exp());
} else { } else {
if (!ctx->initVal() && ty->dims.size() == 1 && ty->dims[0] >= 1024) { init = builder_.CreateConstInt(0);
auto* idx_slot = CreateEntryAlloca(ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
builder_.CreateStore(builder_.CreateConstInt(0), idx_slot);
auto* cond_bb = func_->CreateBlock("arr.zero.cond");
auto* body_bb = func_->CreateBlock("arr.zero.body");
auto* end_bb = func_->CreateBlock("arr.zero.end");
builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(cond_bb);
auto* idx = builder_.CreateLoad(idx_slot, module_.GetContext().NextTemp());
auto* bound = builder_.CreateConstInt(ty->dims[0]);
auto* cmp = builder_.CreateICmp(ir::ICmpPredicate::Slt, idx, bound,
module_.GetContext().NextTemp());
builder_.CreateCondBr(cmp, body_bb, end_bb);
builder_.SetInsertPoint(body_bb);
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
indices.push_back(idx);
auto* elem_addr = builder_.CreateGep(slot, std::move(indices),
module_.GetContext().NextTemp());
builder_.CreateStore(DefaultValue(*ty), elem_addr);
auto* next = builder_.CreateAdd(idx, builder_.CreateConstInt(1),
module_.GetContext().NextTemp());
builder_.CreateStore(next, idx_slot);
builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(end_bb);
} else {
InitArray(slot, *ty, ctx->initVal());
}
} }
builder_.CreateStore(init, slot);
return {}; return {};
} }
std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "常量声明缺少名称"));
}
if (!func_) {
const TypeDesc* ty = sema_.GetConstType(ctx);
if (!ty) {
throw std::runtime_error(FormatError("irgen", "全局常量类型缺失"));
}
if (global_const_storage_.find(ctx) != global_const_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "重复生成全局常量"));
}
ir::ConstantValue* init = nullptr;
if (ty->dims.empty()) {
if (auto* initVal = ctx->constInitVal()) {
if (!initVal->constExp()) {
throw std::runtime_error(FormatError("irgen", "全局常量初始化非法"));
}
init = EvalConstScalar(initVal->constExp());
if (ty->base == BaseTypeKind::Int &&
dynamic_cast<ir::ConstantFloat*>(init)) {
auto* cf = static_cast<ir::ConstantFloat*>(init);
init = module_.GetContext().GetConstInt(static_cast<int>(cf->GetValue()));
} else if (ty->base == BaseTypeKind::Float &&
dynamic_cast<ir::ConstantInt*>(init)) {
auto* ci = static_cast<ir::ConstantInt*>(init);
init = module_.GetContext().GetConstFloat(static_cast<float>(ci->GetValue()));
}
}
} else if (auto* initVal = ctx->constInitVal()) {
size_t total = ArrayTotalSize(*ty);
std::vector<ir::ConstantValue*> values(
total,
ty->base == BaseTypeKind::Float
? static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(0.0f))
: static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(0)));
InitGlobalConstArray(*ty, initVal, values, 0, 0, 0);
init = module_.GetContext().CreateConstArray(ToIRType(*ty), values);
}
auto* gv = module_.CreateGlobalVariable(ctx->ID()->getText(),
ToIRType(*ty), init, true);
global_const_storage_[ctx] = gv;
return {};
}
if (const_storage_.find(ctx) != const_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "重复生成常量存储"));
}
const TypeDesc* ty = sema_.GetConstType(ctx);
if (!ty) {
throw std::runtime_error(FormatError("irgen", "常量类型缺失"));
}
auto* slot = CreateEntryAlloca(ToIRType(*ty), module_.GetContext().NextTemp());
const_storage_[ctx] = slot;
if (ty->dims.empty()) {
ir::Value* init = nullptr;
if (auto* initVal = ctx->constInitVal()) {
if (!initVal->constExp()) {
throw std::runtime_error(FormatError("irgen", "常量初始化非法"));
}
init = std::any_cast<ir::Value*>(initVal->constExp()->accept(this));
} else {
init = DefaultValue(*ty);
}
if (ty->base == BaseTypeKind::Float) {
if (init->IsInt1() || init->IsInt32()) {
init = CastToFloat(init->IsInt1() ? CastToInt(init) : init);
}
} else if (ty->base == BaseTypeKind::Int) {
if (init->IsFloat() || init->IsInt1()) {
init = CastToInt(init);
}
}
builder_.CreateStore(init, slot);
} else {
InitConstArray(slot, *ty, ctx->constInitVal());
}
return {};
}

@ -4,11 +4,12 @@
#include "SysYParser.h" #include "SysYParser.h"
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h"
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree, std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemanticContext& sema) { const SemanticContext& sema) {
auto module = std::make_unique<ir::Module>(); // 无参构造 auto module = std::make_unique<ir::Module>();
IRGenImpl visitor(*module, sema); IRGenImpl gen(*module, sema);
tree.accept(&visitor); tree.accept(&gen);
return module; return module;
} }

File diff suppressed because it is too large Load Diff

@ -6,116 +6,82 @@
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h" #include "utils/Log.h"
namespace {
void VerifyFunctionStructure(const ir::Function& func) {
// 当前 IRGen 仍是单入口、顺序生成;这里在生成结束后补一层块终结校验。
for (const auto& bb : func.GetBlocks()) {
if (!bb || !bb->HasTerminator()) {
throw std::runtime_error(
FormatError("irgen", "基本块未正确终结: " +
(bb ? bb->GetName() : std::string("<null>"))));
}
}
}
} // namespace
IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema) IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema)
: module_(module), : module_(module),
sema_(sema), sema_(sema),
func_(nullptr), func_(nullptr),
builder_(module.GetContext(), nullptr) {} builder_(module.GetContext(), nullptr) {}
// 编译单元的 IR 生成当前只实现了最小功能:
// - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容;
// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR
//
// 当前还没有实现:
// - 多个函数定义的遍历与生成;
// - 全局变量、全局常量的 IR 生成。
std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
if (!ctx) return {}; if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少编译单元"));
func_map_.clear();
global_var_storage_.clear();
global_const_storage_.clear();
func_ = nullptr;
for (auto* decl : ctx->decl()) {
if (decl) decl->accept(this);
}
for (auto* funcDef : ctx->funcDef()) {
if (!funcDef || !funcDef->ID()) continue;
const auto* fty = sema_.GetFuncType(funcDef);
if (!fty) {
throw std::runtime_error(FormatError("irgen", "缺少函数类型"));
}
std::vector<std::shared_ptr<ir::Type>> params;
for (const auto& p : fty->params) {
params.push_back(ToIRParamType(p));
}
auto ret = ToIRType(fty->ret);
auto func_ty = ir::Type::GetFunctionType(ret, params);
auto* fn = module_.CreateFunctionWithType(funcDef->ID()->getText(), func_ty);
func_map_[funcDef] = fn;
} }
auto* func = ctx->funcDef();
auto declare_builtin = [&](const std::string& name, if (!func) {
std::shared_ptr<ir::Type> ret, throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
std::vector<std::shared_ptr<ir::Type>> params) {
for (const auto& fn : module_.GetFunctions()) {
if (fn && fn->GetName() == name) return;
}
auto fty = ir::Type::GetFunctionType(ret, params);
module_.CreateFunctionDecl(name, fty);
};
auto i32 = ir::Type::GetInt32Type();
auto f32 = ir::Type::GetFloatType();
declare_builtin("getint", i32, {});
declare_builtin("getch", i32, {});
declare_builtin("getarray", i32, {ir::Type::GetPointerType(i32)});
declare_builtin("putint", ir::Type::GetVoidType(), {i32});
declare_builtin("putch", ir::Type::GetVoidType(), {i32});
declare_builtin("putarray", ir::Type::GetVoidType(),
{i32, ir::Type::GetPointerType(i32)});
declare_builtin("getfloat", f32, {});
declare_builtin("getfarray", i32, {ir::Type::GetPointerType(f32)});
declare_builtin("putfloat", ir::Type::GetVoidType(), {f32});
declare_builtin("putfarray", ir::Type::GetVoidType(),
{i32, ir::Type::GetPointerType(f32)});
declare_builtin("starttime", ir::Type::GetVoidType(), {});
declare_builtin("stoptime", ir::Type::GetVoidType(), {});
for (auto* funcDef : ctx->funcDef()) {
if (funcDef) funcDef->accept(this);
} }
func->accept(this);
return {}; return {};
} }
// 函数 IR 生成当前实现了:
// 1. 获取函数名;
// 2. 检查函数返回类型;
// 3. 在 Module 中创建 Function
// 4. 将 builder 插入点设置到入口基本块;
// 5. 继续生成函数体。
//
// 当前还没有实现:
// - 通用函数返回类型处理;
// - 形参列表遍历与参数类型收集;
// - FunctionType 这样的函数类型对象;
// - Argument/形式参数 IR 对象;
// - 入口块中的参数初始化逻辑。
// ...
// 因此这里目前只支持最小的“无参 int 函数”生成。
std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
if (!ctx || !ctx->block()) { if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
}
if (!ctx->blockStmt()) {
throw std::runtime_error(FormatError("irgen", "函数体为空")); throw std::runtime_error(FormatError("irgen", "函数体为空"));
} }
auto it = func_map_.find(ctx); if (!ctx->ID()) {
if (it == func_map_.end()) { throw std::runtime_error(FormatError("irgen", "缺少函数名"));
throw std::runtime_error(FormatError("irgen", "函数未注册"));
} }
func_ = it->second; if (!ctx->funcType() || !ctx->funcType()->INT()) {
auto* entry = func_->GetEntry(); throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数"));
builder_.SetInsertPoint(entry);
var_storage_.clear();
const_storage_.clear();
param_storage_.clear();
loop_stack_.clear();
const auto* fty = sema_.GetFuncType(ctx);
if (!fty) {
throw std::runtime_error(FormatError("irgen", "缺少函数类型"));
}
if (ctx->funcFParams()) {
auto params = ctx->funcFParams()->funcFParam();
for (size_t i = 0; i < params.size(); ++i) {
auto* param_ctx = params[i];
auto* arg = func_->GetArg(i);
const TypeDesc* pty = sema_.GetParamType(param_ctx);
if (!pty) {
throw std::runtime_error(FormatError("irgen", "缺少参数类型"));
}
auto slot = CreateEntryAlloca(ToIRParamType(*pty),
module_.GetContext().NextTemp());
builder_.CreateStore(arg, slot);
param_storage_[param_ctx] = slot;
}
} }
ctx->block()->accept(this); func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type());
builder_.SetInsertPoint(func_->GetEntry());
storage_map_.clear();
if (!builder_.GetInsertBlock()->HasTerminator()) { ctx->blockStmt()->accept(this);
if (func_->GetReturnType()->IsVoid()) { // 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。
builder_.CreateRetVoid(); VerifyFunctionStructure(*func_);
} else {
TypeDesc ret = fty->ret;
builder_.CreateRet(DefaultValue(ret));
}
}
return {}; return {};
} }

@ -1,132 +1,39 @@
#include "irgen/IRGen.h" #include "irgen/IRGen.h"
#include <any>
#include <stdexcept> #include <stdexcept>
#include "SysYParser.h" #include "SysYParser.h"
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h" #include "utils/Log.h"
// 语句生成当前只实现了最小子集。
// 目前支持:
// - return <exp>;
//
// 还未支持:
// - 赋值语句
// - if / while 等控制流
// - 空语句、块语句嵌套分发之外的更多语句形态
std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
if (!ctx) return {}; if (!ctx) {
if (ctx->lVal() && ctx->ASSIGN()) { throw std::runtime_error(FormatError("irgen", "缺少语句"));
ir::Value* addr = GetLValAddress(ctx->lVal());
ir::Value* val = EvalExp(ctx->exp());
BoundDecl bound = sema_.ResolveVarUse(ctx->lVal());
if ((bound.kind == BoundDecl::Kind::Var && !bound.var_decl) ||
(bound.kind == BoundDecl::Kind::Const && !bound.const_decl) ||
(bound.kind == BoundDecl::Kind::Param && !bound.param_decl)) {
throw std::runtime_error(FormatError(
"irgen", "赋值左值缺少语义绑定: " + ctx->lVal()->getText()));
}
const TypeDesc* ty = nullptr;
if (bound.kind == BoundDecl::Kind::Var && bound.var_decl) {
ty = sema_.GetVarType(bound.var_decl);
} else if (bound.kind == BoundDecl::Kind::Param && bound.param_decl) {
ty = sema_.GetParamType(bound.param_decl);
} else if (bound.kind == BoundDecl::Kind::Const) {
throw std::runtime_error(FormatError("irgen", "不能给常量赋值"));
}
if (!ty) {
throw std::runtime_error(FormatError("irgen", "无法解析赋值类型"));
}
if (ty->base == BaseTypeKind::Float) {
if (val->IsInt1()) {
val = CastToFloat(CastToInt(val));
} else if (val->IsInt32()) {
val = CastToFloat(val);
}
} else if (ty->base == BaseTypeKind::Int) {
if (val->IsFloat()) {
val = CastToInt(val);
} else if (val->IsInt1()) {
val = CastToInt(val);
}
}
builder_.CreateStore(val, addr);
return BlockFlow::Continue;
}
if (ctx->block()) {
return ctx->block()->accept(this);
} }
if (ctx->IF()) { if (ctx->returnStmt()) {
auto* then_bb = func_->CreateBlock("if.then"); return ctx->returnStmt()->accept(this);
auto* else_bb = func_->CreateBlock("if.else");
auto* merge_bb = func_->CreateBlock("if.end");
EmitCondBr(ctx->cond(), then_bb, else_bb);
builder_.SetInsertPoint(then_bb);
auto then_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
if (then_flow != BlockFlow::Terminated) {
builder_.CreateBr(merge_bb);
}
builder_.SetInsertPoint(else_bb);
if (ctx->stmt(1)) {
auto else_flow = std::any_cast<BlockFlow>(ctx->stmt(1)->accept(this));
if (else_flow != BlockFlow::Terminated) {
builder_.CreateBr(merge_bb);
}
} else {
builder_.CreateBr(merge_bb);
}
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
} }
if (ctx->WHILE()) { throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型"));
auto* cond_bb = func_->CreateBlock("while.cond"); }
auto* body_bb = func_->CreateBlock("while.body");
auto* end_bb = func_->CreateBlock("while.end");
builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(cond_bb);
EmitCondBr(ctx->cond(), body_bb, end_bb);
builder_.SetInsertPoint(body_bb); std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) {
PushLoop(end_bb, cond_bb); if (!ctx) {
auto body_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this)); throw std::runtime_error(FormatError("irgen", "缺少 return 语句"));
PopLoop();
if (body_flow != BlockFlow::Terminated) {
builder_.CreateBr(cond_bb);
}
builder_.SetInsertPoint(end_bb);
return BlockFlow::Continue;
}
if (ctx->BREAK()) {
auto* target = CurrentBreak();
if (!target) {
throw std::runtime_error(FormatError("irgen", "break 不在循环内"));
}
builder_.CreateBr(target);
return BlockFlow::Terminated;
}
if (ctx->CONTINUE()) {
auto* target = CurrentContinue();
if (!target) {
throw std::runtime_error(FormatError("irgen", "continue 不在循环内"));
}
builder_.CreateBr(target);
return BlockFlow::Terminated;
}
if (ctx->RETURN()) {
if (!ctx->exp()) {
builder_.CreateRetVoid();
return BlockFlow::Terminated;
}
ir::Value* v = EvalExp(ctx->exp());
auto ret_ty = func_->GetReturnType();
if (ret_ty->IsFloat() && v->IsInt32()) {
v = CastToFloat(v);
} else if (ret_ty->IsInt32() && v->IsFloat()) {
v = CastToInt(v);
}
builder_.CreateRet(v);
return BlockFlow::Terminated;
} }
if (ctx->exp()) { if (!ctx->exp()) {
EvalExp(ctx->exp()); throw std::runtime_error(FormatError("irgen", "return 缺少表达式"));
} }
return BlockFlow::Continue; ir::Value* v = EvalExpr(*ctx->exp());
} builder_.CreateRet(v);
return BlockFlow::Terminated;
}

@ -36,9 +36,6 @@ int main(int argc, char** argv) {
auto sema = RunSema(*comp_unit); auto sema = RunSema(*comp_unit);
auto module = GenerateIR(*comp_unit, sema); auto module = GenerateIR(*comp_unit, sema);
if (opts.optimize_ir) {
ir::RunScalarOptimizationPipeline(*module);
}
if (opts.emit_ir) { if (opts.emit_ir) {
ir::IRPrinter printer; ir::IRPrinter printer;
if (need_blank_line) { if (need_blank_line) {
@ -49,10 +46,13 @@ int main(int argc, char** argv) {
} }
if (opts.emit_asm) { if (opts.emit_asm) {
auto machine_func = mir::LowerToMIR(*module);
mir::RunRegAlloc(*machine_func);
mir::RunFrameLowering(*machine_func);
if (need_blank_line) { if (need_blank_line) {
std::cout << "\n"; std::cout << "\n";
} }
mir::PrintAArch64AsmFromMIR(*module, std::cout); mir::PrintAsm(*machine_func, std::cout);
} }
#else #else
if (opts.emit_ir || opts.emit_asm) { if (opts.emit_ir || opts.emit_asm) {

@ -1,3 +0,0 @@
fn main() {
println!("Hello, world!");
}

@ -1,351 +1,78 @@
// AArch64 汇编发射Lab5
// 输入为寄存器分配 + 栈帧布局后的 MachineModule操作数均为物理寄存器/栈对象)。
#include "mir/MIR.h" #include "mir/MIR.h"
#include <ostream> #include <ostream>
#include <string> #include <stdexcept>
#include <vector>
#include "utils/Log.h"
namespace mir { namespace mir {
namespace { namespace {
int CalleeAreaBytes(const MachineFunction& f) { const FrameSlot& GetFrameSlot(const MachineFunction& function,
return ((int)f.CalleeSavedGPR().size() + (int)f.CalleeSavedFPR().size()) * 8; const Operand& operand) {
} if (operand.GetKind() != Operand::Kind::FrameIndex) {
throw std::runtime_error(FormatError("mir", "期望 FrameIndex 操作数"));
class Printer {
public:
Printer(const MachineModule& m, std::ostream& os) : m_(m), os_(os) {}
void Run();
private:
const MachineModule& m_;
std::ostream& os_;
const MachineFunction* mf_ = nullptr;
void EmitGlobals();
void EmitFunction(const MachineFunction& f);
void EmitProlog(const MachineFunction& f);
void EmitEpilog(const MachineFunction& f);
void EmitInstr(const MachineInstr& mi);
std::string R(const Operand& op) {
if (op.GetClass() == RegClass::FPR)
return FPRName(op.GetId(), op.GetBytes());
return GPRName(op.GetId(), op.GetBytes());
}
// 把任意立即数加载进暂存寄存器x16 / x17
void LoadImm(const char* dst, long long v);
// [x29 + offset] 形式访存offset 过大时用 x16 计算地址。
void MemAccess(const char* mnem, const std::string& reg, int offset);
int FrameOffset(int stack_index) {
return mf_->StackObjects()[stack_index].offset;
}
};
const char* CondStr(Cond c) {
switch (c) {
case Cond::EQ: return "eq";
case Cond::NE: return "ne";
case Cond::LT: return "lt";
case Cond::LE: return "le";
case Cond::GT: return "gt";
case Cond::GE: return "ge";
case Cond::MI: return "mi";
case Cond::LS: return "ls";
case Cond::HI: return "hi";
case Cond::HS: return "hs";
default: return "al";
}
}
void Printer::LoadImm(const char* dst, long long v) {
bool is_w = dst[0] == 'w';
unsigned long long uv = (unsigned long long)v;
int hi = is_w ? 32 : 64; // w 寄存器只填低 32 位
if (is_w) uv &= 0xffffffffULL;
os_ << "\tmov\t" << dst << ", #" << (uv & 0xffff) << "\n";
for (int sh = 16; sh < hi; sh += 16) {
unsigned chunk = (uv >> sh) & 0xffff;
if (chunk)
os_ << "\tmovk\t" << dst << ", #" << chunk << ", lsl #" << sh << "\n";
}
}
void Printer::MemAccess(const char* mnem, const std::string& reg, int offset) {
if (offset >= -256 && offset <= 4095) {
os_ << "\t" << mnem << "\t" << reg << ", [x29, #" << offset << "]\n";
} else {
LoadImm("x16", offset);
os_ << "\tadd\tx16, x29, x16\n";
os_ << "\t" << mnem << "\t" << reg << ", [x16]\n";
} }
return function.GetFrameSlot(operand.GetFrameIndex());
} }
void Printer::Run() { void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
EmitGlobals(); int offset) {
os_ << "\t.text\n"; os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset
for (const auto& f : m_.Functions()) EmitFunction(*f); << "]\n";
}
void Printer::EmitGlobals() {
if (m_.Globals().empty()) return;
for (const auto& g : m_.Globals()) {
if (g.zero_init) {
os_ << "\t.bss\n";
os_ << "\t.align\t" << (g.align == 16 ? 4 : 2) << "\n";
os_ << "\t.globl\t" << g.name << "\n";
os_ << g.name << ":\n";
os_ << "\t.zero\t" << g.size << "\n";
} else {
os_ << "\t.data\n";
os_ << "\t.align\t" << (g.align == 16 ? 4 : 2) << "\n";
os_ << "\t.globl\t" << g.name << "\n";
os_ << g.name << ":\n";
for (unsigned w : g.words) os_ << "\t.word\t" << w << "\n";
}
}
}
void Printer::EmitFunction(const MachineFunction& f) {
mf_ = &f;
os_ << "\t.globl\t" << f.GetName() << "\n";
os_ << "\t.type\t" << f.GetName() << ", %function\n";
os_ << f.GetName() << ":\n";
EmitProlog(f);
for (const auto& bb : f.Blocks()) {
os_ << ".L." << f.GetName() << "." << bb->GetName() << ":\n";
for (const auto& mi : bb->Instrs()) EmitInstr(mi);
}
os_ << "\t.size\t" << f.GetName() << ", .-" << f.GetName() << "\n";
}
void Printer::EmitProlog(const MachineFunction& f) {
int frame = f.GetFrameSize();
// sub sp, sp, #frame ; 保存 fp/lr ; mov x29, sp
if (frame <= 4095) {
os_ << "\tsub\tsp, sp, #" << frame << "\n";
} else {
LoadImm("x16", frame);
os_ << "\tsub\tsp, sp, x16\n";
}
os_ << "\tstp\tx29, x30, [sp]\n";
os_ << "\tmov\tx29, sp\n";
// 保存 callee-saved相对 x29 偏移,从 16 开始)。
int off = 16;
for (int r : f.CalleeSavedGPR()) {
MemAccess("str", GPRName(r, 8), off);
off += 8;
}
for (int r : f.CalleeSavedFPR()) {
MemAccess("str", FPRName(r, 8), off);
off += 8;
}
} }
void Printer::EmitEpilog(const MachineFunction& f) { } // namespace
int frame = f.GetFrameSize();
int off = 16;
for (int r : f.CalleeSavedGPR()) {
MemAccess("ldr", GPRName(r, 8), off);
off += 8;
}
for (int r : f.CalleeSavedFPR()) {
MemAccess("ldr", FPRName(r, 8), off);
off += 8;
}
os_ << "\tldp\tx29, x30, [sp]\n";
if (frame <= 4095) {
os_ << "\tadd\tsp, sp, #" << frame << "\n";
} else {
LoadImm("x16", frame);
os_ << "\tadd\tsp, sp, x16\n";
}
os_ << "\tret\n";
}
void Printer::EmitInstr(const MachineInstr& mi) { void PrintAsm(const MachineFunction& function, std::ostream& os) {
auto label = [&](const Operand& o) { os << ".text\n";
return ".L." + mf_->GetName() + "." + o.GetSym(); os << ".global " << function.GetName() << "\n";
}; os << ".type " << function.GetName() << ", %function\n";
switch (mi.op) { os << function.GetName() << ":\n";
case Opcode::Mov: {
// 同寄存器拷贝可省略peephole 已处理,这里再兜底)。 for (const auto& inst : function.GetEntry().GetInstructions()) {
if (mi.ops[0].IsPReg() && mi.ops[1].IsPReg() && const auto& ops = inst.GetOperands();
mi.ops[0].GetId() == mi.ops[1].GetId()) switch (inst.GetOpcode()) {
case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) {
os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
}
break; break;
os_ << "\tmov\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n"; case Opcode::Epilogue:
break; if (function.GetFrameSize() > 0) {
} os << " add sp, sp, #" << function.GetFrameSize() << "\n";
case Opcode::MovImm: }
LoadImm(R(mi.ops[0]).c_str(), mi.ops[1].GetImm()); os << " ldp x29, x30, [sp], #16\n";
break; break;
case Opcode::Sxtw: case Opcode::MovImm:
os_ << "\tsxtw\t" << GPRName(mi.ops[0].GetId(), 8) << ", " os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #"
<< GPRName(mi.ops[1].GetId(), 4) << "\n"; << ops.at(1).GetImm() << "\n";
break; break;
case Opcode::Add: case Opcode::LoadStack: {
os_ << "\tadd\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", " const auto& slot = GetFrameSlot(function, ops.at(1));
<< R(mi.ops[2]) << "\n"; PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
break;
case Opcode::Sub:
os_ << "\tsub\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << "\n";
break;
case Opcode::Mul:
os_ << "\tmul\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << "\n";
break;
case Opcode::SDiv:
os_ << "\tsdiv\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << "\n";
break;
case Opcode::MSub:
os_ << "\tmsub\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << ", " << R(mi.ops[3]) << "\n";
break;
case Opcode::AddImm: {
long long v = mi.ops[2].GetImm();
if (v >= 0 && v <= 4095) {
os_ << "\tadd\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", #" << v
<< "\n";
} else {
LoadImm("x16", v);
os_ << "\tadd\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", x16\n";
}
break;
}
case Opcode::SubImm:
os_ << "\tsub\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", #"
<< mi.ops[2].GetImm() << "\n";
break;
case Opcode::LslImm:
os_ << "\tlsl\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", #"
<< mi.ops[2].GetImm() << "\n";
break;
case Opcode::Cmp:
os_ << "\tcmp\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n";
break;
case Opcode::CmpImm:
os_ << "\tcmp\t" << R(mi.ops[0]) << ", #" << mi.ops[1].GetImm() << "\n";
break;
case Opcode::CSet:
os_ << "\tcset\t" << R(mi.ops[0]) << ", " << CondStr(mi.cond) << "\n";
break;
case Opcode::FAdd:
os_ << "\tfadd\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << "\n";
break;
case Opcode::FSub:
os_ << "\tfsub\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << "\n";
break;
case Opcode::FMul:
os_ << "\tfmul\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << "\n";
break;
case Opcode::FDiv:
os_ << "\tfdiv\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << "\n";
break;
case Opcode::FCmp:
os_ << "\tfcmp\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n";
break;
case Opcode::FMov: {
if (mi.ops[0].IsPReg() && mi.ops[1].IsPReg() &&
mi.ops[0].GetId() == mi.ops[1].GetId())
break; break;
os_ << "\tfmov\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n";
break;
}
case Opcode::FMovImm:
// 通过 x16 装入 32 位浮点位模式,再 fmov 到 s 寄存器。
LoadImm("w16", (long long)(unsigned)mi.ops[1].GetImm());
os_ << "\tfmov\t" << R(mi.ops[0]) << ", w16\n";
break;
case Opcode::SCvtF:
os_ << "\tscvtf\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n";
break;
case Opcode::FCvtZS:
os_ << "\tfcvtzs\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n";
break;
case Opcode::Ldr: {
long long off = mi.ops[2].GetImm();
if (off >= -256 && off <= 4095)
os_ << "\tldr\t" << R(mi.ops[0]) << ", [" << R(mi.ops[1]) << ", #"
<< off << "]\n";
else {
LoadImm("x16", off);
os_ << "\tldr\t" << R(mi.ops[0]) << ", [" << R(mi.ops[1]) << ", x16]\n";
}
break;
}
case Opcode::Str: {
long long off = mi.ops[2].GetImm();
if (off >= -256 && off <= 4095)
os_ << "\tstr\t" << R(mi.ops[0]) << ", [" << R(mi.ops[1]) << ", #"
<< off << "]\n";
else {
LoadImm("x16", off);
os_ << "\tstr\t" << R(mi.ops[0]) << ", [" << R(mi.ops[1]) << ", x16]\n";
} }
break; case Opcode::StoreStack: {
} const auto& slot = GetFrameSlot(function, ops.at(1));
case Opcode::LdrStack: PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
MemAccess("ldr", R(mi.ops[0]), FrameOffset(mi.ops[1].GetFrame())); break;
break;
case Opcode::StrStack:
MemAccess("str", R(mi.ops[0]), FrameOffset(mi.ops[1].GetFrame()));
break;
case Opcode::AddrFrame: {
int off = FrameOffset(mi.ops[1].GetFrame());
if (off >= 0 && off <= 4095)
os_ << "\tadd\t" << R(mi.ops[0]) << ", x29, #" << off << "\n";
else {
LoadImm("x16", off);
os_ << "\tadd\t" << R(mi.ops[0]) << ", x29, x16\n";
} }
break; case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Ret:
os << " ret\n";
break;
} }
case Opcode::AddrGlobal:
os_ << "\tadrp\t" << R(mi.ops[0]) << ", " << mi.ops[1].GetSym() << "\n";
os_ << "\tadd\t" << R(mi.ops[0]) << ", " << R(mi.ops[0]) << ", :lo12:"
<< mi.ops[1].GetSym() << "\n";
break;
case Opcode::B:
os_ << "\tb\t" << label(mi.ops[0]) << "\n";
break;
case Opcode::BCond:
os_ << "\tb." << CondStr(mi.cond) << "\t" << label(mi.ops[0]) << "\n";
break;
case Opcode::Bl:
os_ << "\tbl\t" << mi.ops[0].GetSym() << "\n";
break;
case Opcode::Ret:
EmitEpilog(*mf_);
break;
} }
}
} // namespace
void PrintAsm(const MachineModule& module, std::ostream& os) {
Printer p(module, os);
p.Run();
}
void RunBackendPipeline(MachineModule& module) {
for (auto& f : module.Functions()) {
RunRegAlloc(*f);
RunFrameLowering(*f);
RunPeephole(*f);
}
}
void PrintAArch64AsmFromMIR(const ir::Module& module, std::ostream& os) { os << ".size " << function.GetName() << ", .-" << function.GetName()
auto mm = LowerToMIR(module); << "\n";
RunBackendPipeline(*mm);
PrintAsm(*mm, os);
} }
} // namespace mir } // namespace mir

@ -8,9 +8,6 @@ add_library(mir_core STATIC
RegAlloc.cpp RegAlloc.cpp
FrameLowering.cpp FrameLowering.cpp
AsmPrinter.cpp AsmPrinter.cpp
LLVMAsmBackend.cpp
passes/PassManager.cpp
passes/Peephole.cpp
) )
target_link_libraries(mir_core PUBLIC target_link_libraries(mir_core PUBLIC
@ -18,7 +15,10 @@ target_link_libraries(mir_core PUBLIC
ir ir
) )
add_subdirectory(passes)
add_library(mir INTERFACE) add_library(mir INTERFACE)
target_link_libraries(mir INTERFACE target_link_libraries(mir INTERFACE
mir_core mir_core
mir_passes
) )

@ -1,35 +1,45 @@
// 栈帧布局Lab5
// 在寄存器分配产出 spill 槽与 callee-saved 使用集合后,确定每个栈对象
// 相对 x29 的偏移与总帧大小。实际的 prologue/epilogue 由 AsmPrinter 按
// 同一套布局公式发射。
//
// 帧布局x29 指向帧底sp == x29
// [x29 + 0] 保存的 x29
// [x29 + 8] 保存的 x30(lr)
// [x29 + 16 ...] callee-saved GPR、callee-saved FPR各 8 字节)
// [其后] 局部/spill 栈对象(按声明顺序,按对齐摆放)
// 总大小对齐到 16 字节。
#include "mir/MIR.h" #include "mir/MIR.h"
#include <stdexcept>
#include <vector>
#include "utils/Log.h"
namespace mir { namespace mir {
namespace {
int CalleeSavedAreaBytes(const MachineFunction& f) { int AlignTo(int value, int align) {
int n = (int)f.CalleeSavedGPR().size() + (int)f.CalleeSavedFPR().size(); return ((value + align - 1) / align) * align;
return n * 8;
} }
} // namespace
void RunFrameLowering(MachineFunction& function) { void RunFrameLowering(MachineFunction& function) {
int base = 16 + CalleeSavedAreaBytes(function); // fp/lr + callee-saved int cursor = 0;
int off = base; for (const auto& slot : function.GetFrameSlots()) {
for (auto& obj : function.StackObjects()) { cursor += slot.size;
int align = obj.align < 4 ? 4 : obj.align; if (-cursor < -256) {
off = (off + align - 1) / align * align; throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
obj.offset = off; // 相对 x29 的正偏移 }
off += obj.size; }
cursor = 0;
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
function.GetFrameSlot(slot.index).offset = -cursor;
}
function.SetFrameSize(AlignTo(cursor, 16));
auto& insts = function.GetEntry().GetInstructions();
std::vector<MachineInstr> lowered;
lowered.emplace_back(Opcode::Prologue);
for (const auto& inst : insts) {
if (inst.GetOpcode() == Opcode::Ret) {
lowered.emplace_back(Opcode::Epilogue);
}
lowered.push_back(inst);
} }
int frame = (off + 15) / 16 * 16; insts = std::move(lowered);
if (frame < 16) frame = 16;
function.SetFrameSize(frame);
} }
} // namespace mir } // namespace mir

@ -1,132 +0,0 @@
#include "mir/MIR.h"
#include <chrono>
#include <cstdio>
#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <ostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include "ir/IR.h"
#include "utils/Log.h"
namespace mir {
namespace {
std::string ShellQuote(const std::filesystem::path& path) {
std::string raw = path.string();
#if defined(_WIN32)
std::string quoted = "\"";
for (char ch : raw) {
if (ch == '"') {
quoted += "\\\"";
} else {
quoted += ch;
}
}
quoted += "\"";
return quoted;
#else
std::string quoted = "'";
for (char ch : raw) {
if (ch == '\'') {
quoted += "'\\''";
} else {
quoted += ch;
}
}
quoted += "'";
return quoted;
#endif
}
std::string ReadTextFile(const std::filesystem::path& path) {
std::ifstream in(path, std::ios::binary);
if (!in) {
throw std::runtime_error(
FormatError("mir", "无法读取临时汇编文件: " + path.string()));
}
std::ostringstream oss;
oss << in.rdbuf();
return oss.str();
}
std::filesystem::path CreateTempDir() {
auto base = std::filesystem::temp_directory_path();
#if defined(_WIN32)
auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count();
for (int i = 0; i < 100; ++i) {
auto candidate =
base / ("nudt_lab3_" + std::to_string(seed) + "_" + std::to_string(i));
std::error_code ec;
if (std::filesystem::create_directory(candidate, ec)) {
return candidate;
}
}
throw std::runtime_error(FormatError("mir", "创建临时目录失败"));
#else
std::string pattern = (base / "nudt_lab3_XXXXXX").string();
std::vector<char> dir_template(pattern.begin(), pattern.end());
dir_template.push_back('\0');
char* created = mkdtemp(dir_template.data());
if (!created) {
throw std::runtime_error(FormatError("mir", "创建临时目录失败"));
}
return std::filesystem::path(created);
#endif
}
} // namespace
void PrintAArch64AsmFromIR(const ir::Module& module, std::ostream& os) {
std::filesystem::path work_dir = CreateTempDir();
const auto ir_file = work_dir / "module.ll";
const auto asm_file = work_dir / "module.s";
const auto err_file = work_dir / "clang.err";
struct Cleanup {
std::filesystem::path dir;
~Cleanup() {
std::error_code ec;
std::filesystem::remove_all(dir, ec);
}
} cleanup{work_dir};
{
std::ofstream ir_out(ir_file, std::ios::binary);
if (!ir_out) {
throw std::runtime_error(
FormatError("mir", "无法写入临时 IR 文件: " + ir_file.string()));
}
ir::IRPrinter printer;
printer.Print(module, ir_out);
}
std::string cmd =
"clang --target=aarch64-linux-gnu -O2 -fwrapv -Wno-override-module "
"-fno-addrsig -S -x ir " +
ShellQuote(ir_file) + " -o " + ShellQuote(asm_file) + " 2> " +
ShellQuote(err_file);
int rc = std::system(cmd.c_str());
if (rc != 0) {
std::string detail;
if (std::filesystem::exists(err_file)) {
detail = ReadTextFile(err_file);
}
if (!detail.empty() && detail.back() == '\n') {
detail.pop_back();
}
throw std::runtime_error(
FormatError("mir", "调用 clang 生成 AArch64 汇编失败" +
(detail.empty() ? std::string() : ": " + detail)));
}
os << ReadTextFile(asm_file);
}
} // namespace mir

@ -1,15 +1,7 @@
// IR -> MIR 指令选择Lab5
// - 为每个 IR 值分配虚拟寄存器GPR / FPR 两类)
// - alloca -> 栈对象gep/global -> 地址计算
// - 完整覆盖算术、比较、分支、调用、访存、类型转换、浮点
#include "mir/MIR.h" #include "mir/MIR.h"
#include <cstring>
#include <functional>
#include <stdexcept> #include <stdexcept>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h" #include "utils/Log.h"
@ -17,509 +9,115 @@
namespace mir { namespace mir {
namespace { namespace {
int TypeSize(const ir::Type& t) { using ValueSlotMap = std::unordered_map<const ir::Value*, int>;
switch (t.GetKind()) {
case ir::Type::Kind::Int1:
case ir::Type::Kind::Int32:
case ir::Type::Kind::Float:
return 4;
case ir::Type::Kind::Pointer:
return 8;
case ir::Type::Kind::Array:
return static_cast<int>(t.GetArraySize()) *
TypeSize(*t.GetElementType());
default:
return 8;
}
}
RegClass ClassOf(const ir::Type& t) {
return t.IsFloat() ? RegClass::FPR : RegClass::GPR;
}
int BytesOf(const ir::Type& t) { return t.IsPointer() ? 8 : 4; }
bool IsPow2(long long v) { return v > 0 && (v & (v - 1)) == 0; }
int Log2(long long v) {
int n = 0;
while ((1LL << n) < v) ++n;
return n;
}
class Lowerer {
public:
Lowerer(const ir::Module& m, MachineModule& out) : ir_(m), out_(out) {}
void Run();
private:
const ir::Module& ir_;
MachineModule& out_;
MachineFunction* mf_ = nullptr;
MachineBasicBlock* mbb_ = nullptr;
std::unordered_map<const ir::Value*, Operand> vmap_;
std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> bmap_;
std::unordered_map<const ir::AllocaInst*, int> allocas_;
int label_id_ = 0;
void Emit(Opcode op, std::vector<Operand> ops, int defs, Cond c = Cond::AL) { void EmitValueToReg(const ir::Value* value, PhysReg target,
mbb_->Add(MachineInstr(op, std::move(ops), defs, c)); const ValueSlotMap& slots, MachineBasicBlock& block) {
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(constant->GetValue())});
return;
} }
Operand NewG(int bytes = 4) { return mf_->NewVRegOp(RegClass::GPR, bytes); }
Operand NewF() { return mf_->NewVRegOp(RegClass::FPR, 4); }
void LowerFunction(const ir::Function& f);
void LowerBlock(const ir::BasicBlock& bb);
void LowerInst(const ir::Instruction& inst);
void LowerInstMem(const ir::Instruction& inst);
Operand GetReg(const ir::Value* v); // 取值(必要时物化常量/地址)
Operand MaterializeInt(int v);
Operand AddressOf(const ir::Value* ptr);
long long GepConst(const ir::GepInst& gep, Operand* out_base);
void LowerGlobals();
Cond ICmpCond(ir::ICmpPredicate p);
Cond FCmpCond(ir::FCmpPredicate p);
};
void Lowerer::Run() { auto it = slots.find(value);
LowerGlobals(); if (it == slots.end()) {
for (const auto& f : ir_.GetFunctions()) { throw std::runtime_error(
if (f->IsDeclaration()) continue; FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
LowerFunction(*f);
} }
}
void Lowerer::LowerGlobals() { block.Append(Opcode::LoadStack,
for (const auto& g : ir_.GetGlobals()) { {Operand::Reg(target), Operand::FrameIndex(it->second)});
MachineGlobal mg;
mg.name = g->GetName();
const ir::Type& vt = *g->GetValueType();
mg.size = TypeSize(vt);
mg.align = vt.IsArray() ? 16 : 4;
mg.is_const = g->IsConst();
ir::ConstantValue* init = g->GetInitializer();
mg.zero_init = true;
int nwords = (mg.size + 3) / 4;
mg.words.assign(nwords, 0u);
// 收集初始化位(递归展开数组)。
std::vector<unsigned> flat;
std::function<void(ir::ConstantValue*)> walk = [&](ir::ConstantValue* c) {
if (!c) return;
if (auto* ci = dynamic_cast<ir::ConstantInt*>(c)) {
flat.push_back(static_cast<unsigned>(ci->GetValue()));
} else if (auto* cf = dynamic_cast<ir::ConstantFloat*>(c)) {
float v = cf->GetValue();
unsigned bits;
std::memcpy(&bits, &v, 4);
flat.push_back(bits);
} else if (auto* ca = dynamic_cast<ir::ConstantArray*>(c)) {
for (auto* e : ca->GetElements()) walk(e);
}
};
walk(init);
for (size_t i = 0; i < flat.size() && i < mg.words.size(); ++i) {
mg.words[i] = flat[i];
if (flat[i] != 0) mg.zero_init = false;
}
out_.Globals().push_back(std::move(mg));
}
} }
void Lowerer::LowerFunction(const ir::Function& f) { void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
out_.Functions().push_back(std::make_unique<MachineFunction>(f.GetName())); ValueSlotMap& slots) {
mf_ = out_.Functions().back().get(); auto& block = function.GetEntry();
vmap_.clear();
bmap_.clear();
allocas_.clear();
for (const auto& bb : f.GetBlocks()) { switch (inst.GetOpcode()) {
bmap_[bb.get()] = mf_->CreateBlock(bb->GetName()); case ir::Opcode::Alloca: {
} slots.emplace(&inst, function.CreateFrameIndex());
// 记录后继,便于活跃性分析。 return;
for (const auto& bb : f.GetBlocks()) { }
MachineBasicBlock* mb = bmap_[bb.get()]; case ir::Opcode::Store: {
for (auto* s : bb->GetSuccessors()) mb->Succs().push_back(bmap_[s]); auto& store = static_cast<const ir::StoreInst&>(inst);
} auto dst = slots.find(store.GetPtr());
if (dst == slots.end()) {
mbb_ = bmap_[f.GetEntry()]; throw std::runtime_error(
FormatError("mir", "暂不支持对非栈变量地址进行写入"));
// 形参:整型走 x0.., 浮点走 s0..,超过 8 个的从栈读取(测试未用,简化)。 }
int ig = 0, fg = 0; EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block);
std::vector<MachineInstr> arg_copies; block.Append(Opcode::StoreStack,
for (size_t i = 0; i < f.GetNumArgs(); ++i) { {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)});
ir::Argument* a = const_cast<ir::Function&>(f).GetArg(i); return;
const ir::Type& at = *a->GetType(); }
if (at.IsFloat()) { case ir::Opcode::Load: {
Operand dst = NewF(); auto& load = static_cast<const ir::LoadInst&>(inst);
arg_copies.push_back(MachineInstr( auto src = slots.find(load.GetPtr());
Opcode::FMov, if (src == slots.end()) {
{dst, Operand::PReg(fg++, RegClass::FPR, 4)}, 1)); throw std::runtime_error(
vmap_[a] = dst; FormatError("mir", "暂不支持对非栈变量地址进行读取"));
} else { }
int bytes = at.IsPointer() ? 8 : 4; int dst_slot = function.CreateFrameIndex();
Operand dst = NewG(bytes); block.Append(Opcode::LoadStack,
arg_copies.push_back(MachineInstr( {Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)});
Opcode::Mov, {dst, Operand::PReg(ig++, RegClass::GPR, bytes)}, 1)); block.Append(Opcode::StoreStack,
vmap_[a] = dst; {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Add: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block);
block.Append(Opcode::Ret);
return;
} }
case ir::Opcode::Sub:
case ir::Opcode::Mul:
throw std::runtime_error(FormatError("mir", "暂不支持该二元运算"));
} }
mf_->SetArgCounts(ig, fg);
for (auto& mi : arg_copies) mbb_->Add(std::move(mi));
for (const auto& bb : f.GetBlocks()) {
mbb_ = bmap_[bb.get()];
LowerBlock(*bb);
}
}
void Lowerer::LowerBlock(const ir::BasicBlock& bb) {
for (const auto& inst : bb.GetInstructions()) {
LowerInst(*inst);
}
}
Operand Lowerer::MaterializeInt(int v) { throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令"));
Operand d = NewG(4);
Emit(Opcode::MovImm, {d, Operand::Imm(v)}, 1);
return d;
} }
Operand Lowerer::GetReg(const ir::Value* v) { } // namespace
auto it = vmap_.find(v);
if (it != vmap_.end()) return it->second;
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(v)) {
return MaterializeInt(ci->GetValue());
}
if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(v)) {
float f = cf->GetValue();
unsigned bits;
std::memcpy(&bits, &f, 4);
Operand d = NewF();
Emit(Opcode::FMovImm, {d, Operand::Imm(bits)}, 1);
return d;
}
// 兜底:未知值视为 0。
return MaterializeInt(0);
}
// 计算 gep 的常量字节偏移;返回偏移并把基址写入 *out_base。
// 若存在变量下标,直接生成地址计算指令并把最终地址放入 *out_base、返回 0。
long long Lowerer::GepConst(const ir::GepInst& gep, Operand* out_base) {
Operand base = AddressOf(gep.GetBasePtr());
// 推断逐层元素类型:基址指针指向的类型。
std::shared_ptr<ir::Type> cur = gep.GetBasePtr()->GetType()->GetElementType();
long long const_off = 0;
Operand addr = base;
bool addr_dirty = false;
const auto& idxs = gep.GetIndices();
for (size_t i = 0; i < idxs.size(); ++i) {
int elem_size = cur ? TypeSize(*cur) : 4;
ir::Value* iv = idxs[i];
if (auto* ci = dynamic_cast<ir::ConstantInt*>(iv)) {
const_off += static_cast<long long>(ci->GetValue()) * elem_size;
} else {
// addr += index * elem_size
Operand idx = GetReg(iv);
Operand idx64 = NewG(8);
Emit(Opcode::Sxtw, {idx64, idx}, 1);
Operand scaled = NewG(8);
if (IsPow2(elem_size)) {
Emit(Opcode::LslImm, {scaled, idx64, Operand::Imm(Log2(elem_size))}, 1);
} else {
Operand sz = NewG(8);
Emit(Opcode::MovImm, {sz, Operand::Imm(elem_size)}, 1);
Emit(Opcode::Mul, {scaled, idx64, sz}, 1);
}
Operand na = NewG(8);
Emit(Opcode::Add, {na, addr, scaled}, 1);
addr = na;
addr_dirty = true;
}
if (cur && cur->IsArray()) cur = cur->GetElementType();
}
if (addr_dirty) {
*out_base = addr;
return const_off;
}
*out_base = base;
return const_off;
}
// 返回某个指针型 IR 值对应的“地址”寄存器。 std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module) {
Operand Lowerer::AddressOf(const ir::Value* ptr) { DefaultContext();
auto it = vmap_.find(ptr);
if (it != vmap_.end()) return it->second;
if (auto* a = dynamic_cast<const ir::AllocaInst*>(ptr)) {
int idx;
auto ai = allocas_.find(a);
if (ai == allocas_.end()) {
idx = mf_->CreateStackObject(TypeSize(*a->GetAllocatedType()),
a->GetAllocatedType()->IsArray() ? 16 : 4);
allocas_[a] = idx;
} else {
idx = ai->second;
}
Operand d = NewG(8);
Emit(Opcode::AddrFrame, {d, Operand::Frame(idx)}, 1);
vmap_[ptr] = d;
return d;
}
// 全局变量
Operand d = NewG(8);
Emit(Opcode::AddrGlobal, {d, Operand::Global(ptr->GetName())}, 1);
return d;
}
Cond Lowerer::ICmpCond(ir::ICmpPredicate p) { if (module.GetFunctions().size() != 1) {
switch (p) { throw std::runtime_error(FormatError("mir", "暂不支持多个函数"));
case ir::ICmpPredicate::Eq: return Cond::EQ;
case ir::ICmpPredicate::Ne: return Cond::NE;
case ir::ICmpPredicate::Slt: return Cond::LT;
case ir::ICmpPredicate::Sle: return Cond::LE;
case ir::ICmpPredicate::Sgt: return Cond::GT;
case ir::ICmpPredicate::Sge: return Cond::GE;
} }
return Cond::EQ;
}
Cond Lowerer::FCmpCond(ir::FCmpPredicate p) { const auto& func = *module.GetFunctions().front();
switch (p) { if (func.GetName() != "main") {
case ir::FCmpPredicate::Oeq: return Cond::EQ; throw std::runtime_error(FormatError("mir", "暂不支持非 main 函数"));
case ir::FCmpPredicate::One: return Cond::NE;
case ir::FCmpPredicate::Olt: return Cond::MI;
case ir::FCmpPredicate::Ole: return Cond::LS;
case ir::FCmpPredicate::Ogt: return Cond::GT;
case ir::FCmpPredicate::Oge: return Cond::GE;
} }
return Cond::EQ;
}
void Lowerer::LowerInst(const ir::Instruction& inst) { auto machine_func = std::make_unique<MachineFunction>(func.GetName());
using ir::Opcode; ValueSlotMap slots;
switch (inst.GetOpcode()) { const auto* entry = func.GetEntry();
case Opcode::Add: if (!entry) {
case Opcode::Sub: throw std::runtime_error(FormatError("mir", "IR 函数缺少入口基本块"));
case Opcode::Mul:
case Opcode::SDiv:
case Opcode::SRem: {
auto& b = static_cast<const ir::BinaryInst&>(inst);
Operand l = GetReg(b.GetLhs());
Operand r = GetReg(b.GetRhs());
Operand d = NewG(4);
mir::Opcode mop = mir::Opcode::Add;
if (inst.GetOpcode() == Opcode::Add) mop = mir::Opcode::Add;
else if (inst.GetOpcode() == Opcode::Sub) mop = mir::Opcode::Sub;
else if (inst.GetOpcode() == Opcode::Mul) mop = mir::Opcode::Mul;
else mop = mir::Opcode::SDiv;
if (inst.GetOpcode() == Opcode::SRem) {
Operand q = NewG(4);
Emit(mir::Opcode::SDiv, {q, l, r}, 1);
Emit(mir::Opcode::MSub, {d, q, r, l}, 1); // d = l - q*r
} else {
Emit(mop, {d, l, r}, 1);
}
vmap_[&inst] = d;
break;
}
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv: {
auto& b = static_cast<const ir::BinaryInst&>(inst);
Operand l = GetReg(b.GetLhs());
Operand r = GetReg(b.GetRhs());
Operand d = NewF();
mir::Opcode mop = mir::Opcode::FAdd;
if (inst.GetOpcode() == Opcode::FSub) mop = mir::Opcode::FSub;
else if (inst.GetOpcode() == Opcode::FMul) mop = mir::Opcode::FMul;
else if (inst.GetOpcode() == Opcode::FDiv) mop = mir::Opcode::FDiv;
Emit(mop, {d, l, r}, 1);
vmap_[&inst] = d;
break;
}
case Opcode::SIToFP: {
auto& c = static_cast<const ir::CastInst&>(inst);
Operand s = GetReg(c.GetValue());
Operand d = NewF();
Emit(mir::Opcode::SCvtF, {d, s}, 1);
vmap_[&inst] = d;
break;
}
case Opcode::FPToSI: {
auto& c = static_cast<const ir::CastInst&>(inst);
Operand s = GetReg(c.GetValue());
Operand d = NewG(4);
Emit(mir::Opcode::FCvtZS, {d, s}, 1);
vmap_[&inst] = d;
break;
}
case Opcode::ZExt: {
auto& c = static_cast<const ir::CastInst&>(inst);
Operand s = GetReg(c.GetValue());
Operand d = NewG(4);
Emit(mir::Opcode::Mov, {d, s}, 1); // i1->i32cset 已产出 0/1
vmap_[&inst] = d;
break;
}
case Opcode::ICmp: {
auto& c = static_cast<const ir::ICmpInst&>(inst);
Operand l = GetReg(c.GetLhs());
Operand r = GetReg(c.GetRhs());
Emit(mir::Opcode::Cmp, {l, r}, 0);
Operand d = NewG(4);
Emit(mir::Opcode::CSet, {d}, 1, ICmpCond(c.GetPredicate()));
vmap_[&inst] = d;
break;
}
case Opcode::FCmp: {
auto& c = static_cast<const ir::FCmpInst&>(inst);
Operand l = GetReg(c.GetLhs());
Operand r = GetReg(c.GetRhs());
Emit(mir::Opcode::FCmp, {l, r}, 0);
Operand d = NewG(4);
Emit(mir::Opcode::CSet, {d}, 1, FCmpCond(c.GetPredicate()));
vmap_[&inst] = d;
break;
}
default:
LowerInstMem(inst);
break;
} }
}
void Lowerer::LowerInstMem(const ir::Instruction& inst) { for (const auto& inst : entry->GetInstructions()) {
using ir::Opcode; LowerInstruction(*inst, *machine_func, slots);
switch (inst.GetOpcode()) {
case Opcode::Alloca: {
auto& a = static_cast<const ir::AllocaInst&>(inst);
int idx = mf_->CreateStackObject(TypeSize(*a.GetAllocatedType()),
a.GetAllocatedType()->IsArray() ? 16 : 4);
allocas_[&a] = idx;
Operand d = NewG(8);
Emit(mir::Opcode::AddrFrame, {d, Operand::Frame(idx)}, 1);
vmap_[&inst] = d;
break;
}
case Opcode::Load: {
auto& ld = static_cast<const ir::LoadInst&>(inst);
Operand base;
long long off = 0;
if (auto* gep = dynamic_cast<const ir::GepInst*>(ld.GetPtr())) {
off = GepConst(*gep, &base);
} else {
base = AddressOf(ld.GetPtr());
}
bool is_f = inst.GetType()->IsFloat();
Operand d = is_f ? NewF() : NewG(BytesOf(*inst.GetType()));
Emit(mir::Opcode::Ldr, {d, base, Operand::Imm(off)}, 1);
vmap_[&inst] = d;
break;
}
case Opcode::Store: {
auto& st = static_cast<const ir::StoreInst&>(inst);
Operand val = GetReg(st.GetValue());
Operand base;
long long off = 0;
if (auto* gep = dynamic_cast<const ir::GepInst*>(st.GetPtr())) {
off = GepConst(*gep, &base);
} else {
base = AddressOf(st.GetPtr());
}
Emit(mir::Opcode::Str, {val, base, Operand::Imm(off)}, 0);
break;
}
case Opcode::Gep: {
auto& gep = static_cast<const ir::GepInst&>(inst);
Operand base;
long long off = GepConst(gep, &base);
Operand d = NewG(8);
if (off == 0) {
Emit(mir::Opcode::Mov, {d, base}, 1);
} else {
Emit(mir::Opcode::AddImm, {d, base, Operand::Imm(off)}, 1);
}
vmap_[&inst] = d;
break;
}
case Opcode::Call: {
auto& call = static_cast<const ir::CallInst&>(inst);
// 先把所有实参算入虚拟寄存器,再连续搬入物理参数寄存器,
// 避免计算后续实参时分配器复用 x0..x7 破坏已就绪的参数。
std::vector<Operand> vals;
for (auto* arg : call.GetArgs()) vals.push_back(GetReg(arg));
int ig = 0, fg = 0;
std::vector<Operand> arg_uses;
for (size_t i = 0; i < call.GetArgs().size(); ++i) {
ir::Value* arg = call.GetArgs()[i];
if (arg->GetType()->IsFloat()) {
Operand p = Operand::PReg(fg++, RegClass::FPR, 4);
Emit(mir::Opcode::FMov, {p, vals[i]}, 1);
arg_uses.push_back(p);
} else {
int bytes = arg->GetType()->IsPointer() ? 8 : 4;
Operand p = Operand::PReg(ig++, RegClass::GPR, bytes);
Emit(mir::Opcode::Mov, {p, vals[i]}, 1);
arg_uses.push_back(p);
}
}
std::vector<Operand> ops;
ops.push_back(Operand::Global(call.GetCallee()->GetName()));
for (auto& u : arg_uses) ops.push_back(u);
Emit(mir::Opcode::Bl, ops, 0);
if (!inst.GetType()->IsVoid()) {
if (inst.GetType()->IsFloat()) {
Operand d = NewF();
Emit(mir::Opcode::FMov, {d, Operand::PReg(0, RegClass::FPR, 4)}, 1);
vmap_[&inst] = d;
} else {
Operand d = NewG(BytesOf(*inst.GetType()));
Emit(mir::Opcode::Mov,
{d, Operand::PReg(0, RegClass::GPR, d.GetBytes())}, 1);
vmap_[&inst] = d;
}
}
break;
}
case Opcode::Br: {
auto& br = static_cast<const ir::BranchInst&>(inst);
Emit(mir::Opcode::B, {Operand::Label(br.GetDest()->GetName())}, 0);
break;
}
case Opcode::CondBr: {
auto& cbr = static_cast<const ir::CondBrInst&>(inst);
Operand c = GetReg(cbr.GetCond());
Emit(mir::Opcode::CmpImm, {c, Operand::Imm(0)}, 0);
Emit(mir::Opcode::BCond, {Operand::Label(cbr.GetTrueDest()->GetName())}, 0,
Cond::NE);
Emit(mir::Opcode::B, {Operand::Label(cbr.GetFalseDest()->GetName())}, 0);
break;
}
case Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
if (ret.HasReturnValue()) {
Operand v = GetReg(ret.GetValue());
if (ret.GetValue()->GetType()->IsFloat()) {
Emit(mir::Opcode::FMov, {Operand::PReg(0, RegClass::FPR, 4), v}, 1);
} else {
Emit(mir::Opcode::Mov,
{Operand::PReg(0, RegClass::GPR, v.GetBytes()), v}, 1);
}
}
Emit(mir::Opcode::Ret, {}, 0);
break;
}
default:
break;
} }
}
} // namespace
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) { return machine_func;
auto out = std::make_unique<MachineModule>();
Lowerer lo(module, *out);
lo.Run();
return out;
} }
} // namespace mir } // namespace mir

@ -1,2 +1,16 @@
// 机器基本块:实现已并入头文件,本文件仅保留 TU 占位。
#include "mir/MIR.h" #include "mir/MIR.h"
#include <utility>
namespace mir {
MachineBasicBlock::MachineBasicBlock(std::string name)
: name_(std::move(name)) {}
MachineInstr& MachineBasicBlock::Append(Opcode opcode,
std::initializer_list<Operand> operands) {
instructions_.emplace_back(opcode, std::vector<Operand>(operands));
return instructions_.back();
}
} // namespace mir

@ -1,2 +1,10 @@
// 机器上下文:实现已并入头文件,本文件仅保留 TU 占位。
#include "mir/MIR.h" #include "mir/MIR.h"
namespace mir {
MIRContext& DefaultContext() {
static MIRContext ctx;
return ctx;
}
} // namespace mir

@ -1,2 +1,33 @@
// 机器函数:实现已并入头文件,本文件仅保留 TU 占位。
#include "mir/MIR.h" #include "mir/MIR.h"
#include <stdexcept>
#include <utility>
#include "utils/Log.h"
namespace mir {
MachineFunction::MachineFunction(std::string name)
: name_(std::move(name)), entry_("entry") {}
int MachineFunction::CreateFrameIndex(int size) {
int index = static_cast<int>(frame_slots_.size());
frame_slots_.push_back(FrameSlot{index, size, 0});
return index;
}
FrameSlot& MachineFunction::GetFrameSlot(int index) {
if (index < 0 || index >= static_cast<int>(frame_slots_.size())) {
throw std::runtime_error(FormatError("mir", "非法 FrameIndex"));
}
return frame_slots_[index];
}
const FrameSlot& MachineFunction::GetFrameSlot(int index) const {
if (index < 0 || index >= static_cast<int>(frame_slots_.size())) {
throw std::runtime_error(FormatError("mir", "非法 FrameIndex"));
}
return frame_slots_[index];
}
} // namespace mir

@ -1,2 +1,23 @@
// 机器指令:实现已并入头文件,本文件仅保留 TU 占位。
#include "mir/MIR.h" #include "mir/MIR.h"
#include <utility>
namespace mir {
Operand::Operand(Kind kind, PhysReg reg, int imm)
: kind_(kind), reg_(reg), imm_(imm) {}
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); }
Operand Operand::Imm(int value) {
return Operand(Kind::Imm, PhysReg::W0, value);
}
Operand Operand::FrameIndex(int index) {
return Operand(Kind::FrameIndex, PhysReg::W0, index);
}
MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands)
: opcode_(opcode), operands_(std::move(operands)) {}
} // namespace mir

@ -1,354 +1,36 @@
// 寄存器分配Lab5线性扫描 + 活跃区间。
//
// 物理寄存器约定:
// GPR: x0-x8 参数/返回(不参与分配)x9-x12 可分配(caller-saved)
// x13-x15 spill 暂存x16-x17 汇编寻址暂存x18 平台保留,
// x19-x28 可分配(callee-saved)x29/x30 fp/lrx31 sp。
// FPR: s0-s7 参数/返回s8-s15 可分配(callee-saved)
// s16-s28 可分配(caller-saved)s29-s31 spill 暂存。
//
// 跨调用活跃的虚拟寄存器只能落在 callee-saved 寄存器或被 spill。
#include "mir/MIR.h" #include "mir/MIR.h"
#include <algorithm> #include <stdexcept>
#include <unordered_map>
#include <unordered_set> #include "utils/Log.h"
#include <vector>
namespace mir { namespace mir {
namespace { namespace {
const std::vector<int>& CallerGPR() { bool IsAllowedReg(PhysReg reg) {
static const std::vector<int> v{9, 10, 11}; switch (reg) {
return v; case PhysReg::W0:
} case PhysReg::W8:
const std::vector<int>& CalleeGPR() { case PhysReg::W9:
static const std::vector<int> v{19, 20, 21, 22, 23, 24, 25, 26, 27, 28}; case PhysReg::X29:
return v; case PhysReg::X30:
} case PhysReg::SP:
const std::vector<int>& CalleeFPR() { return true;
static const std::vector<int> v{8, 9, 10, 11, 12, 13, 14, 15};
return v;
}
const std::vector<int>& CallerFPR() {
static const std::vector<int> v{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27};
return v;
}
// 4 个暂存寄存器:覆盖单条指令最多 3 use + 1 def 同时为 spill 的情形。
const int kGprSpillScratch[4] = {12, 13, 14, 15};
const int kFprSpillScratch[4] = {28, 29, 30, 31};
bool IsCalleeSavedGPR(int id) { return id >= 19 && id <= 28; }
bool IsCalleeSavedFPR(int id) { return id >= 8 && id <= 15; }
bool IsCalleeSavedGPRorFPR(int id, RegClass cls) {
return cls == RegClass::GPR ? IsCalleeSavedGPR(id) : IsCalleeSavedFPR(id);
}
// 提取指令的 def/use 寄存器(仅虚拟寄存器,按操作数顺序)。
void Defs(const MachineInstr& mi, std::vector<int>* out) {
for (int i = 0; i < mi.num_defs && i < (int)mi.ops.size(); ++i)
if (mi.ops[i].IsVReg()) out->push_back(mi.ops[i].GetId());
}
void Uses(const MachineInstr& mi, std::vector<int>* out) {
for (int i = mi.num_defs; i < (int)mi.ops.size(); ++i)
if (mi.ops[i].IsVReg()) out->push_back(mi.ops[i].GetId());
}
struct Interval {
int vreg;
int start;
int end;
RegClass cls;
bool cross_call = false;
int preg = -1; // 分配到的物理寄存器;-1 表示 spill
int spill_slot = -1;
};
// 把所有基本块按顺序展开成全局指令编号,并记录每块的 [lo,hi]。
struct Numbering {
std::vector<MachineInstr*> instrs;
std::vector<MachineBasicBlock*> owner;
std::unordered_map<MachineBasicBlock*, std::pair<int, int>> range;
std::vector<int> call_sites; // Bl 指令的全局编号
};
Numbering NumberInstrs(MachineFunction& f) {
Numbering n;
for (auto& bb : f.Blocks()) {
int lo = (int)n.instrs.size();
for (auto& mi : bb->Instrs()) {
if (mi.op == Opcode::Bl) n.call_sites.push_back((int)n.instrs.size());
n.instrs.push_back(&mi);
n.owner.push_back(bb.get());
}
int hi = (int)n.instrs.size() - 1;
if (hi < lo) hi = lo; // 空块
n.range[bb.get()] = {lo, hi};
}
return n;
}
// 经典反向数据流活跃性分析(按块)。
void ComputeLiveness(
MachineFunction& f,
std::unordered_map<MachineBasicBlock*, std::unordered_set<int>>* live_in,
std::unordered_map<MachineBasicBlock*, std::unordered_set<int>>* live_out) {
std::unordered_map<MachineBasicBlock*, std::unordered_set<int>> use, def;
for (auto& bb : f.Blocks()) {
std::unordered_set<int> u, d;
for (auto& mi : bb->Instrs()) {
std::vector<int> us, ds;
Uses(mi, &us);
Defs(mi, &ds);
for (int r : us)
if (!d.count(r)) u.insert(r);
for (int r : ds) d.insert(r);
}
use[bb.get()] = std::move(u);
def[bb.get()] = std::move(d);
(*live_in)[bb.get()];
(*live_out)[bb.get()];
}
bool changed = true;
while (changed) {
changed = false;
for (auto it = f.Blocks().rbegin(); it != f.Blocks().rend(); ++it) {
MachineBasicBlock* b = it->get();
std::unordered_set<int> out;
for (auto* s : b->Succs())
for (int r : (*live_in)[s]) out.insert(r);
std::unordered_set<int> in = use[b];
for (int r : (*live_out)[b])
if (!def[b].count(r)) in.insert(r);
if (out.size() != (*live_out)[b].size() ||
in.size() != (*live_in)[b].size()) {
changed = true;
}
(*live_out)[b] = std::move(out);
(*live_in)[b] = std::move(in);
}
}
}
// 用块级活跃信息 + 块内精确编号构造每个 vreg 的活跃区间。
std::vector<Interval> BuildIntervals(MachineFunction& f, const Numbering& num) {
std::unordered_map<MachineBasicBlock*, std::unordered_set<int>> live_in,
live_out;
ComputeLiveness(f, &live_in, &live_out);
int nv = f.NumVRegs();
std::vector<int> start(nv, -1), end(nv, -1);
auto extend = [&](int v, int p) {
if (v < 0) return;
if (start[v] == -1 || p < start[v]) start[v] = p;
if (p > end[v]) end[v] = p;
};
for (auto& bb : f.Blocks()) {
auto [lo, hi] = num.range.at(bb.get());
// live-in 在块入口活跃。
for (int v : live_in[bb.get()]) extend(v, lo);
// live-out 在块出口活跃。
for (int v : live_out[bb.get()]) extend(v, hi);
int p = lo;
for (auto& mi : bb->Instrs()) {
std::vector<int> us, ds;
Uses(mi, &us);
Defs(mi, &ds);
for (int v : us) extend(v, p);
for (int v : ds) extend(v, p);
++p;
}
}
std::vector<Interval> ivs;
for (int v = 0; v < nv; ++v) {
if (start[v] == -1) continue;
Interval iv;
iv.vreg = v;
iv.start = start[v];
iv.end = end[v];
iv.cls = f.VReg(v).cls;
for (int cs : num.call_sites)
if (cs > iv.start && cs <= iv.end) { // 调用点严格落在区间内
iv.cross_call = true;
break;
}
ivs.push_back(iv);
} }
std::sort(ivs.begin(), ivs.end(), return false;
[](const Interval& a, const Interval& b) {
return a.start < b.start;
});
return ivs;
}
// 线性扫描分配;对每个区间设定 preg(>=0) 或标记 spill(preg==-1)。
// 返回 true 表示无需额外 spill 重试本实现一次性完成spill 直接落槽)。
void LinearScan(MachineFunction& f, std::vector<Interval>& ivs) {
// 为两类寄存器分别准备“优先 caller-saved再 callee-saved”的池
// 跨调用区间则只用 callee-saved。
auto run = [&](RegClass cls) {
const std::vector<int>& caller =
cls == RegClass::GPR ? CallerGPR() : CallerFPR();
const std::vector<int>& callee =
cls == RegClass::GPR ? CalleeGPR() : CalleeFPR();
// active已分配且尚未结束的区间按 end 升序。
std::vector<Interval*> active;
std::unordered_set<int> free_regs;
for (int r : caller) free_regs.insert(r);
for (int r : callee) free_regs.insert(r);
auto expire = [&](int point) {
std::vector<Interval*> keep;
for (Interval* a : active) {
if (a->end < point) {
free_regs.insert(a->preg);
} else {
keep.push_back(a);
}
}
active = std::move(keep);
};
auto pick = [&](bool need_callee) -> int {
// 跨调用:仅 callee-saved否则优先 caller-saved。
if (need_callee) {
for (int r : callee)
if (free_regs.count(r)) return r;
return -1;
}
for (int r : caller)
if (free_regs.count(r)) return r;
for (int r : callee)
if (free_regs.count(r)) return r;
return -1;
};
for (auto& iv : ivs) {
if (iv.cls != cls) continue;
expire(iv.start);
int r = pick(iv.cross_call);
if (r == -1) {
// spill在当前区间和 active 里结束最晚者之间选择。
Interval* victim = nullptr;
for (Interval* a : active)
if (a->cls == cls &&
(iv.cross_call ? IsCalleeSavedGPRorFPR(a->preg, cls) : true)) {
if (!victim || a->end > victim->end) victim = a;
}
if (victim && victim->end > iv.end) {
iv.preg = victim->preg;
victim->preg = -1;
victim->spill_slot =
f.CreateStackObject(cls == RegClass::GPR ? 8 : 8, 8);
// 从 active 移除 victim加入当前。
active.erase(std::remove(active.begin(), active.end(), victim),
active.end());
active.push_back(&iv);
} else {
iv.preg = -1;
iv.spill_slot = f.CreateStackObject(8, 8);
}
} else {
free_regs.erase(r);
iv.preg = r;
active.push_back(&iv);
}
std::sort(active.begin(), active.end(),
[](Interval* a, Interval* b) { return a->end < b->end; });
}
};
run(RegClass::GPR);
run(RegClass::FPR);
} }
// 把分配结果写回指令spill 的虚拟寄存器用暂存寄存器搬运并落槽。 } // namespace
void Rewrite(MachineFunction& f, const std::vector<Interval>& ivs) {
int nv = f.NumVRegs();
std::vector<int> preg(nv, -1);
std::vector<int> slot(nv, -1);
for (const auto& iv : ivs) {
preg[iv.vreg] = iv.preg;
slot[iv.vreg] = iv.spill_slot;
}
std::unordered_set<int> used_callee_gpr, used_callee_fpr;
for (const auto& iv : ivs) {
if (iv.preg < 0) continue;
if (iv.cls == RegClass::GPR && IsCalleeSavedGPR(iv.preg))
used_callee_gpr.insert(iv.preg);
if (iv.cls == RegClass::FPR && IsCalleeSavedFPR(iv.preg))
used_callee_fpr.insert(iv.preg);
}
for (auto& bb : f.Blocks()) {
std::vector<MachineInstr> out;
for (auto& mi : bb->Instrs()) {
std::vector<MachineInstr> pre, post;
int gscr = 0, fscr = 0;
std::unordered_map<int, int> assigned; // vreg -> scratch preg
auto scratchFor = [&](const Operand& op) -> int {
int v = op.GetId();
auto it = assigned.find(v);
if (it != assigned.end()) return it->second;
int s = op.GetClass() == RegClass::GPR ? kGprSpillScratch[gscr++]
: kFprSpillScratch[fscr++];
assigned[v] = s;
return s;
};
for (int i = 0; i < (int)mi.ops.size(); ++i) { void RunRegAlloc(MachineFunction& function) {
Operand& op = mi.ops[i]; for (const auto& inst : function.GetEntry().GetInstructions()) {
if (!op.IsVReg()) continue; for (const auto& operand : inst.GetOperands()) {
int v = op.GetId(); if (operand.GetKind() == Operand::Kind::Reg &&
if (preg[v] >= 0) { !IsAllowedReg(operand.GetReg())) {
op.SetPReg(preg[v]); throw std::runtime_error(FormatError("mir", "寄存器分配失败"));
continue;
}
// spilled
int s = scratchFor(op);
bool is_def = i < mi.num_defs;
if (is_def) {
MachineInstr st(Opcode::StrStack,
{Operand::PReg(s, op.GetClass(), op.GetBytes()),
Operand::Frame(slot[v])},
0);
post.push_back(st);
} else {
MachineInstr ld(Opcode::LdrStack,
{Operand::PReg(s, op.GetClass(), op.GetBytes()),
Operand::Frame(slot[v])},
1);
pre.push_back(ld);
}
op.SetPReg(s);
} }
for (auto& p : pre) out.push_back(std::move(p));
out.push_back(mi);
for (auto& p : post) out.push_back(std::move(p));
} }
bb->Instrs() = std::move(out);
} }
for (int r : used_callee_gpr) f.CalleeSavedGPR().push_back(r);
for (int r : used_callee_fpr) f.CalleeSavedFPR().push_back(r);
std::sort(f.CalleeSavedGPR().begin(), f.CalleeSavedGPR().end());
std::sort(f.CalleeSavedFPR().begin(), f.CalleeSavedFPR().end());
}
} // namespace
void RunRegAlloc(MachineFunction& function) {
Numbering num = NumberInstrs(function);
std::vector<Interval> ivs = BuildIntervals(function, num);
LinearScan(function, ivs);
Rewrite(function, ivs);
} }
} // namespace mir } // namespace mir

@ -1,45 +1,27 @@
#include "mir/MIR.h" #include "mir/MIR.h"
namespace mir { #include <stdexcept>
MIRContext& DefaultContext() {
static MIRContext ctx;
return ctx;
}
namespace { #include "utils/Log.h"
// 64 位通用寄存器名x0..x30, sp
const char* kX[33] = {
"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10",
"x11", "x12", "x13", "x14", "x15", "x16", "x17", "x18", "x19", "x20", "x21",
"x22", "x23", "x24", "x25", "x26", "x27", "x28", "x29", "x30", "sp", "xzr"};
const char* kW[33] = {
"w0", "w1", "w2", "w3", "w4", "w5", "w6", "w7", "w8", "w9", "w10",
"w11", "w12", "w13", "w14", "w15", "w16", "w17", "w18", "w19", "w20", "w21",
"w22", "w23", "w24", "w25", "w26", "w27", "w28", "w29", "w30", "wsp", "wzr"};
} // namespace
const char* GPRName(int id, int bytes) { namespace mir {
if (id < 0 || id > 32) return "x0";
return bytes == 8 ? kX[id] : kW[id];
}
const char* FPRName(int id, int bytes) { const char* PhysRegName(PhysReg reg) {
static char buf[8][8]; switch (reg) {
static int slot = 0; case PhysReg::W0:
char* b = buf[slot]; return "w0";
slot = (slot + 1) & 7; case PhysReg::W8:
b[0] = (bytes == 8) ? 'd' : 's'; return "w8";
int n = id; case PhysReg::W9:
if (n < 10) { return "w9";
b[1] = static_cast<char>('0' + n); case PhysReg::X29:
b[2] = '\0'; return "x29";
} else { case PhysReg::X30:
b[1] = static_cast<char>('0' + n / 10); return "x30";
b[2] = static_cast<char>('0' + n % 10); case PhysReg::SP:
b[3] = '\0'; return "sp";
} }
return b; throw std::runtime_error(FormatError("mir", "未知物理寄存器"));
} }
} // namespace mir } // namespace mir

@ -1,4 +1,4 @@
// 后端 Pass 管理:当前后端流水线直接在 RunBackendPipeline 中按 // MIR Pass 管理:
// RegAlloc -> FrameLowering -> Peephole 顺序驱动(见 AsmPrinter.cpp // - 组织后端 pass 的运行顺序PreRA/PostRA/PEI 等阶段)
// 本文件保留占位,便于后续扩展更细粒度的 Pass 调度。 // - 统一运行 pass 与调试输出(按需要扩展)
#include "mir/MIR.h"

@ -1,84 +1,4 @@
// 后端局部窥孔优化Lab5 // 窥孔优化Peephole
// 在寄存器分配 + 物理寄存器落地之后运行,针对最终机器指令序列做局部清理: // - 删除冗余 move、合并常见指令模式
// 1. 删除自拷贝 mov xN, xN / fmov sN, sN // - 提升最终汇编质量(按实现范围裁剪)
// 2. 删除恒等运算 add/sub xD, xS, #0必要时退化为 mov
// 3. 删除相邻冗余访存str R,[B,#o] 紧跟 ldr R,[B,#o] 时删除 ldr
// 4. 删除写入同一目标且中间无使用的连续 mov保留最后一条
#include "mir/MIR.h"
#include <vector>
namespace mir {
namespace {
bool SameReg(const Operand& a, const Operand& b) {
return a.IsPReg() && b.IsPReg() && a.GetId() == b.GetId() &&
a.GetClass() == b.GetClass();
}
bool IsSelfMove(const MachineInstr& mi) {
if (mi.op == Opcode::Mov || mi.op == Opcode::FMov)
return mi.ops.size() == 2 && SameReg(mi.ops[0], mi.ops[1]);
return false;
}
bool IsIdentityAddSub(const MachineInstr& mi) {
if ((mi.op == Opcode::AddImm || mi.op == Opcode::SubImm) &&
mi.ops.size() == 3 && mi.ops[2].GetImm() == 0)
return true;
return false;
}
// str R,[B,#o] 之后立即 ldr R,[B,#o]同寄存器、同基址、同偏移ldr 冗余。
bool RedundantLoadAfterStore(const MachineInstr& st, const MachineInstr& ld) {
if (st.op != Opcode::Str || ld.op != Opcode::Ldr) return false;
return SameReg(st.ops[0], ld.ops[0]) && SameReg(st.ops[1], ld.ops[1]) &&
st.ops[2].GetImm() == ld.ops[2].GetImm();
}
bool RedundantStackLoadAfterStore(const MachineInstr& st,
const MachineInstr& ld) {
if (st.op != Opcode::StrStack || ld.op != Opcode::LdrStack) return false;
return SameReg(st.ops[0], ld.ops[0]) &&
st.ops[1].GetFrame() == ld.ops[1].GetFrame();
}
void OptimizeBlock(MachineBasicBlock& bb) {
std::vector<MachineInstr> out;
out.reserve(bb.Instrs().size());
for (auto& mi : bb.Instrs()) {
if (IsSelfMove(mi)) continue;
if (IsIdentityAddSub(mi)) {
// add/sub D, S, #0 -> mov D, S若 D==S 直接删)。
if (SameReg(mi.ops[0], mi.ops[1])) continue;
MachineInstr mv(Opcode::Mov, {mi.ops[0], mi.ops[1]}, 1);
out.push_back(std::move(mv));
continue;
}
if (!out.empty()) {
const MachineInstr& prev = out.back();
if (RedundantLoadAfterStore(prev, mi) ||
RedundantStackLoadAfterStore(prev, mi)) {
continue; // 删除冗余 load
}
}
out.push_back(mi);
}
bb.Instrs() = std::move(out);
}
} // namespace
void RunPeephole(MachineFunction& function) {
// 迭代到不动点,确保删除后新暴露的模式也被处理。
bool again = true;
int guard = 0;
while (again && guard++ < 8) {
size_t before = 0, after = 0;
for (auto& bb : function.Blocks()) before += bb->Instrs().size();
for (auto& bb : function.Blocks()) OptimizeBlock(*bb);
for (auto& bb : function.Blocks()) after += bb->Instrs().size();
again = (after < before);
}
}
} // namespace mir

@ -3,7 +3,6 @@
#include <any> #include <any>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <unordered_set>
#include "SysYBaseVisitor.h" #include "SysYBaseVisitor.h"
#include "sem/SymbolTable.h" #include "sem/SymbolTable.h"
@ -11,500 +10,185 @@
namespace { namespace {
static BaseTypeKind BaseTypeFromBType(SysYParser::BTypeContext* ctx) { std::string GetLValueName(SysYParser::LValueContext& lvalue) {
if (!ctx) { if (!lvalue.ID()) {
throw std::runtime_error(FormatError("sema", "缺少 bType")); throw std::runtime_error(FormatError("sema", "非法左值"));
} }
if (ctx->INT()) return BaseTypeKind::Int; return lvalue.ID()->getText();
if (ctx->FLOAT()) return BaseTypeKind::Float;
throw std::runtime_error(FormatError("sema", "未知基础类型"));
} }
static BaseTypeKind BaseTypeFromFuncType(SysYParser::FuncTypeContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少 funcType"));
}
if (ctx->VOID()) return BaseTypeKind::Void;
if (ctx->INT()) return BaseTypeKind::Int;
if (ctx->FLOAT()) return BaseTypeKind::Float;
throw std::runtime_error(FormatError("sema", "未知函数返回类型"));
}
class ConstEvalVisitor final : public SysYBaseVisitor {
public:
explicit ConstEvalVisitor(const SymbolTable& table) : table_(table) {}
std::any visitConstExp(SysYParser::ConstExpContext* ctx) override {
return visitAddExp(ctx->addExp());
}
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
auto muls = ctx->mulExp();
if (muls.empty()) return 0;
int value = std::any_cast<int>(muls[0]->accept(this));
for (size_t i = 1; i < muls.size(); ++i) {
int rhs = std::any_cast<int>(muls[i]->accept(this));
auto* node = ctx->children.at(2 * i - 1);
auto text = node ? node->getText() : "+";
if (text == "+") {
value += rhs;
} else if (text == "-") {
value -= rhs;
} else {
throw std::runtime_error(FormatError("sema", "非法加法运算符"));
}
}
return value;
}
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
auto unaries = ctx->unaryExp();
if (unaries.empty()) return 0;
int value = std::any_cast<int>(unaries[0]->accept(this));
for (size_t i = 1; i < unaries.size(); ++i) {
int rhs = std::any_cast<int>(unaries[i]->accept(this));
auto* node = ctx->children.at(2 * i - 1);
auto text = node ? node->getText() : "*";
if (text == "*") {
value *= rhs;
} else if (text == "/") {
value /= rhs;
} else if (text == "%") {
value %= rhs;
} else {
throw std::runtime_error(FormatError("sema", "非法乘法运算符"));
}
}
return value;
}
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override {
if (ctx->primaryExp()) return ctx->primaryExp()->accept(this);
if (ctx->unaryOp() && ctx->unaryExp()) {
int val = std::any_cast<int>(ctx->unaryExp()->accept(this));
auto op = ctx->unaryOp()->getText();
if (op == "+") return val;
if (op == "-") return -val;
throw std::runtime_error(FormatError("sema", "constExp 不支持 !"));
}
throw std::runtime_error(FormatError("sema", "constExp 不支持函数调用"));
}
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override {
if (ctx->exp()) return ctx->exp()->accept(this);
if (ctx->lVal()) return ctx->lVal()->accept(this);
if (ctx->number()) return ctx->number()->accept(this);
return 0;
}
std::any visitNumber(SysYParser::NumberContext* ctx) override {
if (ctx->INT_CONST()) {
const std::string text = ctx->getText();
size_t idx = 0;
long long val = std::stoll(text, &idx, 0);
if (idx != text.size()) {
throw std::runtime_error(FormatError("sema", "非法整数常量"));
}
return static_cast<int>(val);
}
if (ctx->FLOAT_CONST()) {
return static_cast<int>(std::stof(ctx->getText()));
}
throw std::runtime_error(FormatError("sema", "constExp 仅支持整数"));
}
std::any visitLVal(SysYParser::LValContext* ctx) override {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "constExp 非法变量"));
}
const auto* entry = table_.Lookup(ctx->ID()->getText());
if (!entry || !entry->is_const || !entry->const_value.has_value()) {
throw std::runtime_error(FormatError("sema", "constExp 使用了非常量"));
}
return entry->const_value.value();
}
private:
const SymbolTable& table_;
};
class SemaVisitor final : public SysYBaseVisitor { class SemaVisitor final : public SysYBaseVisitor {
public: public:
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override {
if (!ctx) { if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少编译单元")); throw std::runtime_error(FormatError("sema", "缺少编译单元"));
} }
for (auto* func : ctx->funcDef()) { auto* func = ctx->funcDef();
if (!func || !func->ID()) continue; if (!func || !func->blockStmt()) {
std::string name = func->ID()->getText(); throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
if (func_table_.find(name) != func_table_.end()) {
throw std::runtime_error(FormatError("sema", "重复定义函数: " + name));
}
func_table_[name] = func;
}
for (auto* decl : ctx->decl()) {
if (decl) decl->accept(this);
}
for (auto* func : ctx->funcDef()) {
if (func) func->accept(this);
} }
if (!func->ID() || func->ID()->getText() != "main") {
if (func_table_.find("main") == func_table_.end()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
} }
func->accept(this);
if (!seen_return_) {
throw std::runtime_error(
FormatError("sema", "main 函数必须包含 return 语句"));
}
return {}; return {};
} }
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override {
if (!ctx || !ctx->block()) { if (!ctx || !ctx->blockStmt()) {
throw std::runtime_error(FormatError("sema", "函数体为空")); throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
if (!ctx->ID()) {
throw std::runtime_error(FormatError("sema", "缺少函数名"));
}
FuncTypeDesc fty;
fty.ret.base = BaseTypeFromFuncType(ctx->funcType());
if (ctx->funcFParams()) {
for (auto* param : ctx->funcFParams()->funcFParam()) {
fty.params.push_back(BuildParamType(param));
}
} }
sema_.RegisterFunc(ctx, fty); if (!ctx->funcType() || !ctx->funcType()->INT()) {
throw std::runtime_error(FormatError("sema", "当前仅支持 int main"));
current_ret_ = fty.ret.base;
seen_return_ = false;
table_.EnterScope();
if (ctx->funcFParams()) {
for (auto* param : ctx->funcFParams()->funcFParam()) {
RegisterParam(param);
}
} }
ctx->block()->accept(this); const auto& items = ctx->blockStmt()->blockItem();
table_.ExitScope(); if (items.empty()) {
throw std::runtime_error(
if (current_ret_ != BaseTypeKind::Void && !seen_return_) { FormatError("sema", "main 函数不能为空,且必须以 return 结束"));
throw std::runtime_error(FormatError("sema", "非 void 函数缺少 return"));
} }
ctx->blockStmt()->accept(this);
return {}; return {};
} }
std::any visitBlock(SysYParser::BlockContext* ctx) override { std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override {
if (!ctx) return {}; if (!ctx) {
table_.EnterScope(); throw std::runtime_error(FormatError("sema", "缺少语句块"));
for (auto* item : ctx->blockItem()) {
if (item) item->accept(this);
} }
table_.ExitScope(); const auto& items = ctx->blockItem();
return {}; for (size_t i = 0; i < items.size(); ++i) {
} auto* item = items[i];
if (!item) {
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { continue;
if (!ctx) return {}; }
if (ctx->decl()) return ctx->decl()->accept(this); if (seen_return_) {
if (ctx->stmt()) return ctx->stmt()->accept(this); throw std::runtime_error(
return {}; FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
} }
current_item_index_ = i;
std::any visitDecl(SysYParser::DeclContext* ctx) override { total_items_ = items.size();
if (!ctx) return {}; item->accept(this);
if (auto* c = ctx->constDecl()) return c->accept(this);
if (auto* v = ctx->varDecl()) return v->accept(this);
return {};
}
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override {
if (!ctx || !ctx->bType()) return {};
BaseTypeKind base = BaseTypeFromBType(ctx->bType());
for (auto* def : ctx->constDef()) {
RegisterConst(def, base);
} }
return {}; return {};
} }
std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override { std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override {
if (!ctx || !ctx->bType()) return {}; if (!ctx) {
BaseTypeKind base = BaseTypeFromBType(ctx->bType()); throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
for (auto* def : ctx->varDef()) {
RegisterVar(def, base);
} }
return {}; if (ctx->decl()) {
} ctx->decl()->accept(this);
std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (!ctx) return {};
if (ctx->lVal() && ctx->ASSIGN()) {
ctx->lVal()->accept(this);
if (ctx->exp()) ctx->exp()->accept(this);
return {}; return {};
} }
if (ctx->block()) return ctx->block()->accept(this); if (ctx->stmt()) {
if (ctx->IF()) { ctx->stmt()->accept(this);
if (ctx->cond()) ctx->cond()->accept(this);
if (ctx->stmt(0)) ctx->stmt(0)->accept(this);
if (ctx->stmt(1)) ctx->stmt(1)->accept(this);
return {}; return {};
} }
if (ctx->WHILE()) { throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
loop_depth_++; }
if (ctx->cond()) ctx->cond()->accept(this);
if (ctx->stmt(0)) ctx->stmt(0)->accept(this); std::any visitDecl(SysYParser::DeclContext* ctx) override {
loop_depth_--; if (!ctx) {
return {}; throw std::runtime_error(FormatError("sema", "非法变量声明"));
} }
if (ctx->BREAK()) { if (!ctx->btype() || !ctx->btype()->INT()) {
if (loop_depth_ == 0) { throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明"));
throw std::runtime_error(FormatError("sema", "break 不在循环内"));
}
return {};
} }
if (ctx->CONTINUE()) { auto* var_def = ctx->varDef();
if (loop_depth_ == 0) { if (!var_def || !var_def->lValue()) {
throw std::runtime_error(FormatError("sema", "continue 不在循环内")); throw std::runtime_error(FormatError("sema", "非法变量声明"));
}
return {};
} }
if (ctx->RETURN()) { const std::string name = GetLValueName(*var_def->lValue());
if (ctx->exp()) ctx->exp()->accept(this); if (table_.Contains(name)) {
if (current_ret_ == BaseTypeKind::Void && ctx->exp()) { throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
throw std::runtime_error(FormatError("sema", "void 函数不能返回值")); }
} if (auto* init = var_def->initValue()) {
if (current_ret_ != BaseTypeKind::Void && !ctx->exp()) { if (!init->exp()) {
throw std::runtime_error(FormatError("sema", "非 void 函数必须返回值")); throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化"));
} }
seen_return_ = true; init->exp()->accept(this);
return {};
} }
if (ctx->exp()) ctx->exp()->accept(this); table_.Add(name, var_def);
return {}; return {};
} }
std::any visitExp(SysYParser::ExpContext* ctx) override { std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (ctx->addExp()) return ctx->addExp()->accept(this); if (!ctx || !ctx->returnStmt()) {
return {}; throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
} }
ctx->returnStmt()->accept(this);
std::any visitCond(SysYParser::CondContext* ctx) override {
if (ctx->lOrExp()) return ctx->lOrExp()->accept(this);
return {};
}
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override {
for (auto* e : ctx->lAndExp()) e->accept(this);
return {};
}
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override {
for (auto* e : ctx->eqExp()) e->accept(this);
return {};
}
std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
for (auto* e : ctx->relExp()) e->accept(this);
return {};
}
std::any visitRelExp(SysYParser::RelExpContext* ctx) override {
for (auto* e : ctx->addExp()) e->accept(this);
return {};
}
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
for (auto* mul : ctx->mulExp()) mul->accept(this);
return {}; return {};
} }
std::any visitMulExp(SysYParser::MulExpContext* ctx) override { std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override {
for (auto* unary : ctx->unaryExp()) unary->accept(this); if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "return 缺少表达式"));
}
ctx->exp()->accept(this);
seen_return_ = true;
if (current_item_index_ + 1 != total_items_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
}
return {}; return {};
} }
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override { std::any visitParenExp(SysYParser::ParenExpContext* ctx) override {
if (ctx->primaryExp()) return ctx->primaryExp()->accept(this); if (!ctx || !ctx->exp()) {
if (ctx->ID() && ctx->LPAREN()) { throw std::runtime_error(FormatError("sema", "非法括号表达式"));
std::string name = ctx->ID()->getText();
auto it = func_table_.find(name);
if (it == func_table_.end()) {
if (builtin_funcs_.find(name) == builtin_funcs_.end()) {
throw std::runtime_error(FormatError("sema", "未定义的函数: " + name));
}
} else {
sema_.BindFuncCall(ctx, it->second);
}
if (ctx->funcRParams()) ctx->funcRParams()->accept(this);
return {};
} }
if (ctx->unaryExp()) return ctx->unaryExp()->accept(this); ctx->exp()->accept(this);
return {}; return {};
} }
std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override { std::any visitVarExp(SysYParser::VarExpContext* ctx) override {
for (auto* e : ctx->exp()) e->accept(this); if (!ctx || !ctx->var()) {
throw std::runtime_error(FormatError("sema", "非法变量表达式"));
}
ctx->var()->accept(this);
return {}; return {};
} }
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override { std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override {
if (ctx->exp()) return ctx->exp()->accept(this); if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
if (ctx->lVal()) return ctx->lVal()->accept(this); throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量"));
if (ctx->number()) return ctx->number()->accept(this); }
return {}; return {};
} }
std::any visitNumber(SysYParser::NumberContext* ctx) override { std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override {
if (!ctx->INT_CONST() && !ctx->FLOAT_CONST()) { if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
throw std::runtime_error(FormatError("sema", "非法常量")); throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式"));
} }
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {}; return {};
} }
std::any visitLVal(SysYParser::LValContext* ctx) override { std::any visitVar(SysYParser::VarContext* ctx) override {
if (!ctx || !ctx->ID()) { if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "非法变量引用")); throw std::runtime_error(FormatError("sema", "非法变量引用"));
} }
std::string name = ctx->ID()->getText(); const std::string name = ctx->ID()->getText();
const SymbolEntry* entry = table_.Lookup(name); auto* decl = table_.Lookup(name);
if (!entry) { if (!decl) {
throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name));
} }
BoundDecl bound; sema_.BindVarUse(ctx, decl);
if (entry->kind == SymbolKind::Var) {
bound.kind = BoundDecl::Kind::Var;
bound.var_decl = entry->var_decl;
} else if (entry->kind == SymbolKind::Const) {
bound.kind = BoundDecl::Kind::Const;
bound.const_decl = entry->const_decl;
} else {
bound.kind = BoundDecl::Kind::Param;
bound.param_decl = entry->param_decl;
}
sema_.BindVarUse(ctx, bound);
for (auto* exp : ctx->exp()) {
if (exp) {
exp->accept(this);
}
}
return {}; return {};
} }
SemanticContext TakeSemanticContext() { return std::move(sema_); } SemanticContext TakeSemanticContext() { return std::move(sema_); }
private:
TypeDesc BuildParamType(SysYParser::FuncFParamContext* ctx) {
if (!ctx || !ctx->bType()) {
throw std::runtime_error(FormatError("sema", "非法参数"));
}
TypeDesc ty;
ty.base = BaseTypeFromBType(ctx->bType());
if (ctx->LBRACK().size() > 0) {
ty.dims.push_back(-1);
for (auto* exp : ctx->exp()) {
ty.dims.push_back(EvalConstExp(exp));
}
}
return ty;
}
void RegisterParam(SysYParser::FuncFParamContext* ctx) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "参数缺少名称"));
}
std::string name = ctx->ID()->getText();
if (table_.ContainsInCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义参数: " + name));
}
TypeDesc ty = BuildParamType(ctx);
SymbolEntry entry;
entry.kind = SymbolKind::Param;
entry.param_decl = ctx;
entry.is_const = false;
entry.type = ty;
table_.Add(name, entry);
sema_.RegisterParam(ctx, ty);
}
void RegisterVar(SysYParser::VarDefContext* ctx, BaseTypeKind base) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "变量声明缺少名称"));
}
std::string name = ctx->ID()->getText();
if (table_.ContainsInCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
}
TypeDesc ty;
ty.base = base;
for (auto* dim : ctx->constExp()) {
ty.dims.push_back(EvalConstExp(dim));
}
SymbolEntry entry;
entry.kind = SymbolKind::Var;
entry.var_decl = ctx;
entry.is_const = false;
entry.type = ty;
table_.Add(name, entry);
sema_.RegisterVarDecl(ctx, ty);
if (auto* init = ctx->initVal()) {
init->accept(this);
}
}
void RegisterConst(SysYParser::ConstDefContext* ctx, BaseTypeKind base) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "常量声明缺少名称"));
}
std::string name = ctx->ID()->getText();
if (table_.ContainsInCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义常量: " + name));
}
TypeDesc ty;
ty.base = base;
ty.is_const = true;
for (auto* dim : ctx->constExp()) {
ty.dims.push_back(EvalConstExp(dim));
}
SymbolEntry entry;
entry.kind = SymbolKind::Const;
entry.const_decl = ctx;
entry.is_const = true;
entry.type = ty;
if (ctx->constInitVal() && ty.dims.empty() && ty.base == BaseTypeKind::Int) {
if (auto* exp = ctx->constInitVal()->constExp()) {
entry.const_value = EvalConstExp(exp);
}
}
table_.Add(name, entry);
sema_.RegisterConstDecl(ctx, ty);
if (auto* init = ctx->constInitVal()) {
init->accept(this);
}
}
int EvalConstExp(SysYParser::ConstExpContext* ctx) {
ConstEvalVisitor visitor(table_);
return std::any_cast<int>(ctx->accept(&visitor));
}
int EvalConstExp(SysYParser::ExpContext* ctx) {
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("sema", "非法常量表达式"));
}
ConstEvalVisitor visitor(table_);
return std::any_cast<int>(ctx->addExp()->accept(&visitor));
}
private: private:
SymbolTable table_; SymbolTable table_;
SemanticContext sema_; SemanticContext sema_;
std::unordered_map<std::string, SysYParser::FuncDefContext*> func_table_;
const std::unordered_set<std::string> builtin_funcs_ = {
"getint", "getch", "getarray", "putint", "putch", "putarray",
"getfloat", "getfarray", "putfloat", "putfarray", "starttime",
"stoptime"};
BaseTypeKind current_ret_ = BaseTypeKind::Void;
bool seen_return_ = false; bool seen_return_ = false;
int loop_depth_ = 0; size_t current_item_index_ = 0;
size_t total_items_ = 0;
}; };
} // namespace } // namespace
@ -513,4 +197,4 @@ SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) {
SemaVisitor visitor; SemaVisitor visitor;
comp_unit.accept(&visitor); comp_unit.accept(&visitor);
return visitor.TakeSemanticContext(); return visitor.TakeSemanticContext();
} }

@ -2,34 +2,16 @@
#include "sem/SymbolTable.h" #include "sem/SymbolTable.h"
void SymbolTable::EnterScope() { scopes_.emplace_back(); } void SymbolTable::Add(const std::string& name,
SysYParser::VarDefContext* decl) {
void SymbolTable::ExitScope() { table_[name] = decl;
if (!scopes_.empty()) {
scopes_.pop_back();
}
}
bool SymbolTable::ContainsInCurrentScope(const std::string& name) const {
if (scopes_.empty()) {
return false;
}
return scopes_.back().find(name) != scopes_.back().end();
} }
void SymbolTable::Add(const std::string& name, const SymbolEntry& entry) { bool SymbolTable::Contains(const std::string& name) const {
if (scopes_.empty()) { return table_.find(name) != table_.end();
EnterScope();
}
scopes_.back()[name] = entry;
} }
const SymbolEntry* SymbolTable::Lookup(const std::string& name) const { SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const {
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { auto it = table_.find(name);
auto found = it->find(name); return it == table_.end() ? nullptr : it->second;
if (found != it->end()) {
return &found->second;
}
}
return nullptr;
} }

@ -15,7 +15,7 @@ CLIOptions ParseCLI(int argc, char** argv) {
if (argc <= 1) { if (argc <= 1) {
throw std::runtime_error(FormatError( throw std::runtime_error(FormatError(
"cli", "cli",
"用法: compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] [--no-opt] <input.sy>")); "用法: compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] <input.sy>"));
} }
for (int i = 1; i < argc; ++i) { for (int i = 1; i < argc; ++i) {
@ -58,11 +58,6 @@ CLIOptions ParseCLI(int argc, char** argv) {
continue; continue;
} }
if (std::strcmp(arg, "--no-opt") == 0) {
opt.optimize_ir = false;
continue;
}
if (arg[0] == '-') { if (arg[0] == '-') {
throw std::runtime_error( throw std::runtime_error(
FormatError("cli", std::string("未知参数: ") + arg + FormatError("cli", std::string("未知参数: ") + arg +

@ -50,14 +50,13 @@ void PrintHelp(std::ostream& os) {
os << "SysY Compiler\n" os << "SysY Compiler\n"
<< "\n" << "\n"
<< "用法:\n" << "用法:\n"
<< " compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] [--no-opt] <input.sy>\n" << " compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] <input.sy>\n"
<< "\n" << "\n"
<< "选项:\n" << "选项:\n"
<< " -h, --help 打印帮助信息并退出\n" << " -h, --help 打印帮助信息并退出\n"
<< " --emit-parse-tree 仅在显式模式下启用语法树输出\n" << " --emit-parse-tree 仅在显式模式下启用语法树输出\n"
<< " --emit-ir 仅在显式模式下启用 IR 输出\n" << " --emit-ir 仅在显式模式下启用 IR 输出\n"
<< " --emit-asm 仅在显式模式下启用 AArch64 汇编输出\n" << " --emit-asm 仅在显式模式下启用 AArch64 汇编输出\n"
<< " --no-opt 关闭 Lab4 IR 标量优化管线\n"
<< "\n" << "\n"
<< "说明:\n" << "说明:\n"
<< " - 默认输出 IR\n" << " - 默认输出 IR\n"

@ -1,71 +1,4 @@
#include <stdio.h> // SysY 运行库实现:
// - 按实验/评测规范提供 I/O 等函数实现
int getint() { // - 与编译器生成的目标代码链接,支撑运行时行为
int v = 0;
if (scanf("%d", &v) != 1) return 0;
return v;
}
int getch() {
int c = getchar();
if (c == '\r') {
int next = getchar();
if (next != '\n' && next != EOF) {
ungetc(next, stdin);
}
return '\n';
}
return c;
}
int getarray(int a[]) {
int n = 0;
if (scanf("%d", &n) != 1) return 0;
for (int i = 0; i < n; ++i) {
scanf("%d", &a[i]);
}
return n;
}
void putint(int x) { printf("%d", x); }
void putch(int x) { putchar(x); }
void putarray(int n, int a[]) {
printf("%d:", n);
for (int i = 0; i < n; ++i) {
printf(" %d", a[i]);
}
printf("\n");
}
float getfloat() {
float v = 0.0f;
if (scanf("%f", &v) != 1) return 0.0f;
return v;
}
int getfarray(float a[]) {
int n = 0;
if (scanf("%d", &n) != 1) return 0;
for (int i = 0; i < n; ++i) {
scanf("%f", &a[i]);
}
return n;
}
void putfloat(float x) { printf("%a", x); }
void putfarray(int n, float a[]) {
printf("%d:", n);
for (int i = 0; i < n; ++i) {
printf(" %a", a[i]);
}
printf("\n");
}
// Performance timing hooks (no-op stubs for correctness testing).
void starttime() {}
void stoptime() {}

@ -1,78 +0,0 @@
#!/bin/bash
# ================================================
# SysY 编译器 Lab1 批量解析测试脚本
# 文件名scripts/test_parse.sh
# 适用环境Arch Linuxbash 原生支持,无需额外安装)
# 功能:
# - 遍历 test/test_case 下所有 .sy 文件functional + performance
# - 执行 --emit-parse-tree 检查是否能成功解析
# - 输出简洁的 PASS/FAIL 结果 + 统计
# - 错误文件会自动打印最后 10 行报错信息(方便调试)
# - 所有结果保存到 test/test_result/parse_test.log
# ================================================
set -u # 遇到未定义变量直接报错
# ================== 配置 ==================
COMPILER="./build/bin/compiler"
TEST_DIR="test/test_case"
LOG_FILE="test/test_result/parse_test.log"
MAX_ERROR_LINES=10
# 检查编译器是否存在
if [[ ! -x "$COMPILER" ]]; then
echo "❌ 错误:找不到编译器 $COMPILER"
echo " 请先执行 Lab1 构建命令:"
echo " cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=ON"
echo " cmake --build build -j \"\$(nproc)\""
exit 1
fi
# 创建日志目录(如果不存在)
mkdir -p "$(dirname "$LOG_FILE")"
> "$LOG_FILE" # 清空日志
echo "开始 Lab1 批量语法树测试..." | tee -a "$LOG_FILE"
echo "测试目录:$TEST_DIR" | tee -a "$LOG_FILE"
echo "编译器:$COMPILER" | tee -a "$LOG_FILE"
echo "========================================" | tee -a "$LOG_FILE"
pass=0
fail=0
total=0
# 遍历所有 .sy 文件(支持子目录)
while IFS= read -r -d '' sy_file; do
((total++))
echo -n "[$total] 测试: $sy_file ... " | tee -a "$LOG_FILE"
# 执行解析(把输出丢到 /dev/null防止刷屏
if "$COMPILER" --emit-parse-tree "$sy_file" > /dev/null 2>&1; then
echo "✅PASS" | tee -a "$LOG_FILE"
((pass++))
else
echo "FAIL" | tee -a "$LOG_FILE"
((fail++))
# 打印错误信息到日志(最后几行)
echo " └── 错误详情(最后 $MAX_ERROR_LINES 行):" >> "$LOG_FILE"
"$COMPILER" --emit-parse-tree "$sy_file" 2>&1 | tail -n "$MAX_ERROR_LINES" >> "$LOG_FILE"
echo "" >> "$LOG_FILE"
fi
done < <(find "$TEST_DIR" -name "*.sy" -print0 | sort -z)
# ================== 总结 ==================
echo "========================================" | tee -a "$LOG_FILE"
echo "测试完成!" | tee -a "$LOG_FILE"
echo "总文件数 : $total" | tee -a "$LOG_FILE"
echo "通过 : $pass" | tee -a "$LOG_FILE"
echo "失败 : $fail" | tee -a "$LOG_FILE"
if [[ $fail -eq 0 ]]; then
echo "恭喜Lab1 语法树构建全部通过!可以进入 Lab2 啦~" | tee -a "$LOG_FILE"
else
echo "$fail 个文件解析失败,请检查 SysY.g4 或报错日志" | tee -a "$LOG_FILE"
echo " 日志文件:$LOG_FILE" | tee -a "$LOG_FILE"
fi
echo "========================================" | tee -a "$LOG_FILE"

@ -11,8 +11,6 @@
#include "atn/ProfilingATNSimulator.h" #include "atn/ProfilingATNSimulator.h"
#include <chrono>
using namespace antlr4; using namespace antlr4;
using namespace antlr4::atn; using namespace antlr4::atn;
using namespace antlr4::dfa; using namespace antlr4::dfa;

Loading…
Cancel
Save