Compare commits

..

18 Commits

@ -0,0 +1,61 @@
{
"permissions": {
"allow": [
"Bash(cd \"\\\\\\\\wsl.localhost\\\\Ubuntu-24.04\\\\home\\\\bnk\\\\nudt-compiler-cpp\")",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'echo IN_WSL; which cmake g++ aarch64-linux-gnu-gcc qemu-aarch64 clang 2>/dev/null; pwd')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cmake -S . -B build -DCMAKE_BUILD_TYPE=Release > /tmp/cm.log 2>&1 && echo \"cfg ok\" || tail -20 /tmp/cm.log; cmake --build build -j \"$\\(nproc\\)\" > /tmp/build.log 2>&1 && echo \"build ok\" || tail -40 /tmp/build.log')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && rm -rf build && cmake -S . -B build -DCMAKE_BUILD_TYPE=Release > /tmp/cm.log 2>&1 && echo \"cfg ok\" || tail -20 /tmp/cm.log; cmake --build build -j \"$\\(nproc\\)\" > /tmp/build.log 2>&1 && echo \"build ok\" || tail -40 /tmp/build.log')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && ./build/bin/compiler --emit-ir test/test_case/functional/simple_add.sy 2>&1')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== simple_add no-opt ===\"; ./build/bin/compiler --emit-ir --no-opt test/test_case/functional/simple_add.sy 2>&1 | head -40')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && for f in 11_add2 13_sub2 29_break 36_op_priority2; do echo \"=== $f \\(opt\\) ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/$f.sy 2>&1 | sed -n \"/define/,/^}/p\"; done')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== 09_func_defn ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/09_func_defn.sy 2>&1 | sed -n \"1,60p\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== 25_scope3 ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/25_scope3.sy 2>&1 | sed -n \"/define/,/^}/p\" | head -80')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== 05_arr_defn4 ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/05_arr_defn4.sy 2>&1 | sed -n \"/define/,/^}/p\" | head -120')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== 22 globals+main head ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/22_matrix_multiply.sy 2>&1 | sed -n \"1,40p\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== gep with var index \\(22\\) ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/22_matrix_multiply.sy 2>&1 | grep -A2 getelementptr | head -20; echo \"=== 95 float head ===\"; ./build/bin/compiler --emit-ir test/test_case/functional/95_float.sy 2>&1 | sed -n \"1,50p\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== CMakeLists \\(mir glob?\\) ===\"; grep -n -i \"mir\\\\|glob\\\\|GLOB\\\\|file\\(\" CMakeLists.txt | head -40')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -n \"src/\\\\|GLOB\\\\|add_executable\\\\|add_library\\\\|SOURCES\\\\|\\\\.cpp\" CMakeLists.txt | head -60')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -n \"compiler_core\\\\|compiler\\\\b\\\\|target_sources\\\\|GLOB.*SRC\\\\|PROJECT_SRC\\\\|set\\(SOURCES\\\\|\\\\.cpp\\\\\"\" CMakeLists.txt | head; echo \"---\"; sed -n \"82,160p\" CMakeLists.txt')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cat src/CMakeLists.txt')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== mir/CMakeLists ===\"; cat src/mir/CMakeLists.txt; echo \"=== ls src/mir ===\"; ls -R src/mir; echo \"=== ls include/mir ===\"; ls include/mir')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && tot=0; for f in test/test_case/functional/*.sy; do ./build/bin/compiler --emit-ir \"$f\" 2>/dev/null; done > /tmp/allir.txt; echo \"=== opcodes used ===\"; grep -oE \"= \\(add|sub|mul|sdiv|srem|fadd|fsub|fmul|fdiv|icmp|fcmp|call|getelementptr|load|alloca|phi|sitofp|fptosi|zext\\) \" /tmp/allir.txt | sort | uniq -c; echo \"=== bare ops ===\"; grep -oE \"^ \\(store|br|ret|call\\) \" /tmp/allir.txt | sort | uniq -c; echo \"=== phi count ===\"; grep -c \"phi\" /tmp/allir.txt; echo \"=== float present ===\"; grep -c \"float\" /tmp/allir.txt')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && ls test/test_case/functional/*.sy | wc -l; echo \"--- verify uses ---\"; grep -nE \"qemu|aarch64|gcc|clang|gcc-|--target\" scripts/verify_asm.sh | head; echo \"--- a sample .out ---\"; ls test/test_case/functional/*.out 2>/dev/null | head -3; echo \"--- runtime lib? ---\"; ls test/ ; find . -name \"*.a\" -path \"*runtime*\" 2>/dev/null | head')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== sylib.h ===\"; cat sylib/sylib.h 2>/dev/null | head -60; echo \"=== test list ===\"; ls test/test_case/functional/*.sy | xargs -n1 basename')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -nE \"int |void |float |#define|starttime|_sysy\" sylib/sylib.c | head -40')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && tail -5 include/mir/MIR.h')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -nE \"GetFunctions|GetGlobals|IsDeclaration|GetBlocks|GetEntry|GetSuccessors|GetNumArgs|GetArg\\\\b|class Argument|GetInstructions|GetOpcode|class Module\" include/ir/IR.h | head -40')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== Opcode enum ===\"; grep -nA40 \"enum class Opcode\" include/ir/IR.h | head -50')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -nE \"GetLhs|GetRhs|class BinaryInst|class ICmpInst|class FCmpInst|GetPredicate|class CastInst|GetValue\\\\b|class LoadInst|GetPtr|class StoreInst|class AllocaInst|GetAllocatedType|class GepInst|GetBasePtr|GetIndices|class CallInst|GetCallee|GetArgs|class BranchInst|GetDest|class CondBrInst|GetCond|GetTrueDest|GetFalseDest|class ReturnInst|HasReturnValue\" include/ir/IR.h')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== Type ===\"; grep -nE \"class Type|enum class Kind|GetKind|GetArraySize|GetElementType|IsFloat|IsPointer|IsArray|IsVoid|Int1|Int32\" include/ir/IR.h | head -30; echo \"=== Constants/Global ===\"; grep -nE \"class ConstantInt|class ConstantFloat|class ConstantArray|GetElements|class ConstantValue|class GlobalVariable|GetValueType|IsConst|GetInitializer|class Function\\\\b|GetName\" include/ir/IR.h | head -40')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -nE \"GetType\\\\\\(\\\\\\)|class Value\\\\b|class GlobalValue\" include/ir/IR.h | head; echo \"=== line 40-60 ===\"; sed -n \"36,60p\" include/ir/IR.h; echo \"=== 120-160 ===\"; sed -n \"118,160p\" include/ir/IR.h')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cmake --build build -j \"$\\(nproc\\)\" 2>&1 | grep -E \"error:|Error|错误\" | head -40; echo \"=== exit ${PIPESTATUS[0]} ===\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cmake --build build -j \"$\\(nproc\\)\" 2>&1 | grep -iE \"undefined|multiple|duplicate|reference to\" | head -30')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cat src/mir/CMakeLists.txt; echo \"=== grep mir in src/CMakeLists ===\"; grep -n mir src/CMakeLists.txt')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cat src/mir/passes/CMakeLists.txt')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cmake -S . -B build -DCMAKE_BUILD_TYPE=Release > /tmp/cm.log 2>&1 && echo cfgok || tail -20 /tmp/cm.log; cmake --build build -j \"$\\(nproc\\)\" 2>&1 | grep -iE \"error|undefined|错误\" | head -30; echo \"=== done ===\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && ./build/bin/compiler --emit-asm test/test_case/functional/simple_add.sy 2>&1 | head -40')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && pass=0; fail=0; fails=\"\"; for f in test/test_case/functional/*.sy; do b=$\\(basename \"$f\" .sy\\); if ./scripts/verify_asm.sh \"$f\" /tmp/asmout --run > /tmp/v.log 2>&1; then pass=$\\(\\(pass+1\\)\\); else fail=$\\(\\(fail+1\\)\\); fails=\"$fails $b\"; fi; done; echo \"PASS=$pass FAIL=$fail\"; echo \"FAILED:$fails\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && for b in 13_sub2 95_float; do echo \"===== $b =====\"; ./scripts/verify_asm.sh test/test_case/functional/$b.sy /tmp/asmout --run 2>&1 | tail -15; done')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && cmake --build build -j \"$\\(nproc\\)\" 2>&1 | grep -iE \"error|错误\" | head; for b in 13_sub2 95_float; do echo \"===== $b =====\"; ./scripts/verify_asm.sh test/test_case/functional/$b.sy /tmp/asmout --run 2>&1 | tail -6; done')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && for b in 15_graph_coloring 22_matrix_multiply; do echo \"===== $b =====\"; ./scripts/verify_asm.sh test/test_case/functional/$b.sy /tmp/asmout --run 2>&1 | tail -8; done')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && grep -nE \"no-opt|no_opt|emit-asm|compiler|OPT|opt\" scripts/verify_asm.sh | head')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && CC=aarch64-linux-gnu-gcc; QEMU=qemu-aarch64; SYS=$\\(ls sysroot 2>/dev/null\\); for b in 22_matrix_multiply 95_float 15_graph_coloring 25_scope3; do f=test/test_case/functional/$b.sy; ./build/bin/compiler --no-opt --emit-asm \"$f\" > /tmp/n.s 2>/tmp/n.err || { echo \"$b COMPILE FAIL\"; cat /tmp/n.err; continue; }; aarch64-linux-gnu-gcc /tmp/n.s sylib/libsysy.a -o /tmp/n.exe 2>/tmp/as.err || { echo \"$b ASM/LINK FAIL\"; head -5 /tmp/as.err; continue; }; in=test/test_case/functional/$b.in; if [ -f \"$in\" ]; then qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/n.exe < \"$in\" > /tmp/n.out 2>&1; else qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/n.exe > /tmp/n.out 2>&1; fi; ec=$?; echo \"$b: exit=$ec out=$\\(tr \"\\\\n\" \"|\" < /tmp/n.out | head -c 80\\)\"; done')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && sed -n \"58,110p\" scripts/verify_asm.sh')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && aarch64-linux-gnu-gcc -O2 -Wno-unused-result -c sylib/sylib.c -o /tmp/sylib.o 2>/dev/null; pass=0; fail=0; fails=\"\"; for f in test/test_case/functional/*.sy; do b=$\\(basename \"$f\" .sy\\); ./build/bin/compiler --no-opt --emit-asm \"$f\" > /tmp/n.s 2>/dev/null || { fails=\"$fails $b\\(cc\\)\"; fail=$\\(\\(fail+1\\)\\); continue; }; aarch64-linux-gnu-gcc /tmp/n.s /tmp/sylib.o -o /tmp/n.exe 2>/tmp/as.err || { fails=\"$fails $b\\(as\\)\"; fail=$\\(\\(fail+1\\)\\); continue; }; in=test/test_case/functional/$b.in; if [ -f \"$in\" ]; then qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/n.exe < \"$in\" > /tmp/n.out 2>&1; else qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/n.exe > /tmp/n.out 2>&1; fi; ec=$?; exp=test/test_case/functional/$b.out; { cat /tmp/n.out; [ -s /tmp/n.out ] && [ $\\(tail -c1 /tmp/n.out|wc -l\\) -eq 0 ] && printf \"\\\\n\"; printf \"%s\\\\n\" \"$ec\"; } > /tmp/n.act; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" \"$exp\" > /tmp/n.exp 2>/dev/null; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" /tmp/n.act > /tmp/n.actn; if diff -q /tmp/n.exp /tmp/n.actn >/dev/null 2>&1; then pass=$\\(\\(pass+1\\)\\); else fails=\"$fails $b\\(diff\\)\"; fail=$\\(\\(fail+1\\)\\); fi; done; echo \"NOOPT PASS=$pass FAIL=$fail\"; echo \"FAILS:$fails\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== check spill present in no-opt 22 ===\"; ./build/bin/compiler --no-opt --emit-asm test/test_case/functional/22_matrix_multiply.sy 2>/dev/null | grep -cE \"x29, #\\(2[0-9]|[3-9][0-9]\\)\"; echo \"\\(stack accesses above = spills/locals\\)\"; echo \"=== self-move check \\(should be 0\\) ===\"; ./build/bin/compiler --emit-asm test/test_case/functional/22_matrix_multiply.sy 2>/dev/null | grep -cE \"\\\\bmov\\\\t\\(w|x\\)\\([0-9]+\\), \\\\1?\\\\2$\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== total instrs across all tests \\(opt\\) ===\"; tot=0; for f in test/test_case/functional/*.sy; do n=$\\(./build/bin/compiler --emit-asm \"$f\" 2>/dev/null | grep -cE \"^\\\\t\\(mov|add|sub|mul|ldr|str|b|bl|cmp|cset|fmov|ret|sxtw|lsl|sdiv|msub|fadd|fsub|fmul|fdiv|fcmp|scvtf|fcvtzs|adrp|stp|ldp\\)\"\\); tot=$\\(\\(tot+n\\)\\); done; echo \"total=$tot\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && pass=0; fail=0; fails=\"\"; for f in test/test_case/functional/*.sy; do b=$\\(basename \"$f\" .sy\\); if ./scripts/verify_asm.sh \"$f\" test/test_result/function/asm --run > /tmp/v.log 2>&1; then pass=$\\(\\(pass+1\\)\\); else fail=$\\(\\(fail+1\\)\\); fails=\"$fails $b\"; fi; done; echo \"OFFICIAL SCRIPT PASS=$pass FAIL=$fail\"; echo \"FAILS:$fails\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && f=test/test_case/functional/22_matrix_multiply.sy; ./build/bin/compiler --emit-asm \"$f\" 2>/dev/null > /tmp/with.s; wc -l < /tmp/with.s | xargs echo \"lines with peephole:\"; grep -c \"\tmov\t\" /tmp/with.s | xargs echo \"mov count:\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && pass=0; fail=0; for f in test/test_case/functional/*.sy; do b=$\\(basename \"$f\" .sy\\); if ./scripts/verify_asm.sh \"$f\" test/test_result/function/asm --run > /tmp/v.log 2>&1; then echo \" PASS $b\"; pass=$\\(\\(pass+1\\)\\); else echo \" FAIL $b\"; fail=$\\(\\(fail+1\\)\\); fi; done; echo \"========================\"; echo \"PASS=$pass FAIL=$fail\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== test_case 目录结构 ===\"; find test/test_case -type d; echo \"=== 各目录 .sy 数量 ===\"; for d in $\\(find test/test_case -type d\\); do n=$\\(ls \"$d\"/*.sy 2>/dev/null | wc -l\\); [ \"$n\" -gt 0 ] && echo \"$n $d\"; done; echo \"=== 总数 ===\"; find test/test_case -name \"*.sy\" | wc -l')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && pass=0; fail=0; fails=\"\"; for f in test/test_case/performance/*.sy; do b=$\\(basename \"$f\" .sy\\); if ./scripts/verify_asm.sh \"$f\" test/test_result/perf/asm --run > /tmp/v.log 2>&1; then echo \" PASS $b\"; pass=$\\(\\(pass+1\\)\\); else echo \" FAIL $b\"; fail=$\\(\\(fail+1\\)\\); fails=\"$fails $b\"; fi; done; echo \"===== PERF PASS=$pass FAIL=$fail =====\"; echo \"FAILS:$fails\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && ls test/test_case/performance/*.sy | xargs -n1 basename')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && aarch64-linux-gnu-gcc -O2 -Wno-unused-result -c sylib/sylib.c -o /tmp/sylib.o 2>/dev/null; for f in test/test_case/performance/*.sy; do b=$\\(basename \"$f\" .sy\\); ./build/bin/compiler --emit-asm \"$f\" > /tmp/p.s 2>/tmp/p.err || { echo \"FAIL\\($b\\) compile\"; continue; }; aarch64-linux-gnu-gcc /tmp/p.s /tmp/sylib.o -o /tmp/p.exe 2>/tmp/p.as || { echo \"FAIL\\($b\\) assemble\"; head -3 /tmp/p.as; continue; }; in=test/test_case/performance/$b.in; if [ -f \"$in\" ]; then timeout 25 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/p.exe < \"$in\" > /tmp/p.out 2>&1; else timeout 25 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/p.exe > /tmp/p.out 2>&1; fi; ec=$?; if [ $ec -eq 124 ]; then echo \"TIMEOUT\\($b\\)\"; continue; fi; exp=test/test_case/performance/$b.out; { cat /tmp/p.out; [ -s /tmp/p.out ] && [ $\\(tail -c1 /tmp/p.out|wc -l\\) -eq 0 ] && printf \"\\\\n\"; printf \"%s\\\\n\" \"$ec\"; } > /tmp/p.act; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" \"$exp\" > /tmp/p.exp 2>/dev/null; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" /tmp/p.act > /tmp/p.actn; if diff -q /tmp/p.exp /tmp/p.actn >/dev/null 2>&1; then echo \"PASS\\($b\\)\"; else echo \"DIFF\\($b\\)\"; fi; done')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== if-combine3.sy ===\"; cat test/test_case/performance/if-combine3.sy; echo \"=== .out ===\"; cat test/test_case/performance/if-combine3.out; echo \"=== .in? ===\"; ls test/test_case/performance/if-combine3.in 2>/dev/null && cat test/test_case/performance/if-combine3.in')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"=== 循环尾部\\(最后30行\\) ===\"; tail -30 test/test_case/performance/if-combine3.sy; echo \"=== .in ===\"; cat test/test_case/performance/if-combine3.in 2>/dev/null || echo \"\\(无\\)\"; echo \"=== .out ===\"; cat test/test_case/performance/if-combine3.out')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && ./build/bin/compiler --emit-asm test/test_case/performance/if-combine3.sy > /tmp/ic.s 2>/dev/null && aarch64-linux-gnu-gcc /tmp/ic.s /tmp/sylib.o -o /tmp/ic.exe 2>/dev/null && echo \"5\" | timeout 20 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/ic.exe 2>/dev/null; echo \"exit=$? \\(小输入 n=5验证逻辑\\)\"')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && echo \"50000000\" | timeout 280 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/ic.exe > /tmp/ic.out 2>&1; ec=$?; echo \"exit=$ec\"; cat /tmp/ic.out')",
"Bash(wsl.exe -d Ubuntu-24.04 -e bash -lc 'cd /home/bnk/nudt-compiler-cpp && for b in 2025-MYO-20 gameoflife-oscillator; do f=test/test_case/performance/$b.sy; ./build/bin/compiler --emit-asm \"$f\" > /tmp/x.s 2>/tmp/x.e || { echo \"$b COMPILE FAIL\"; head -3 /tmp/x.e; continue; }; aarch64-linux-gnu-gcc /tmp/x.s /tmp/sylib.o -o /tmp/x.exe 2>/tmp/x.a || { echo \"$b ASM FAIL\"; head -3 /tmp/x.a; continue; }; in=test/test_case/performance/$b.in; if [ -f \"$in\" ]; then timeout 280 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/x.exe < \"$in\" > /tmp/x.out 2>&1; else timeout 280 qemu-aarch64 -L /usr/aarch64-linux-gnu /tmp/x.exe > /tmp/x.out 2>&1; fi; ec=$?; if [ $ec -eq 124 ]; then echo \"$b STILL TIMEOUT\\(>280s\\)\"; continue; fi; exp=test/test_case/performance/$b.out; { cat /tmp/x.out; [ -s /tmp/x.out ] && [ $\\(tail -c1 /tmp/x.out|wc -l\\) -eq 0 ] && printf \"\\\\n\"; printf \"%s\\\\n\" \"$ec\"; } > /tmp/x.act; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" \"$exp\" > /tmp/x.exp 2>/dev/null; perl -0pe \"s/\\\\r\\\\n/\\\\n/g;s/\\\\r/\\\\n/g;s/\\\\n?\\\\z//\" /tmp/x.act > /tmp/x.actn; if diff -q /tmp/x.exp /tmp/x.actn >/dev/null 2>&1; then echo \"$b PASS\"; else echo \"$b DIFF:\"; diff /tmp/x.exp /tmp/x.actn | head; fi; done')"
]
}
}

5
.gitignore vendored

@ -68,3 +68,8 @@ Thumbs.db
# Project outputs # Project outputs
# ========================= # =========================
test/test_result/ test/test_result/
# Added by cargo
/target

@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.20)
project(compiler LANGUAGES C CXX) project(compiler LANGUAGES C CXX)
find_package(Java REQUIRED COMPONENTS Runtime)
# C++ # C++
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
@ -31,7 +33,7 @@ target_include_directories(build_options INTERFACE
option(COMPILER_ENABLE_WARNINGS "Enable common compiler warnings" ON) option(COMPILER_ENABLE_WARNINGS "Enable common compiler warnings" ON)
if(COMPILER_ENABLE_WARNINGS) if(COMPILER_ENABLE_WARNINGS)
if(MSVC) if(MSVC)
target_compile_options(build_options INTERFACE /W4) target_compile_options(build_options INTERFACE /W4 /utf-8)
else() else()
target_compile_options(build_options INTERFACE -Wall -Wextra -Wpedantic) target_compile_options(build_options INTERFACE -Wall -Wextra -Wpedantic)
endif() endif()
@ -39,12 +41,18 @@ endif()
option(COMPILER_PARSE_ONLY "Build only the frontend parser pipeline" OFF) option(COMPILER_PARSE_ONLY "Build only the frontend parser pipeline" OFF)
set(ANTLR4_JAR "${PROJECT_SOURCE_DIR}/third_party/antlr-4.13.2-complete.jar")
if(NOT EXISTS "${ANTLR4_JAR}")
message(FATAL_ERROR "ANTLR jar not found: ${ANTLR4_JAR}")
endif()
# 使 third_party ANTLR4 C++ runtime # 使 third_party ANTLR4 C++ runtime
# third_party runtime third_party/antlr4-runtime-4.13.2/runtime/src # third_party runtime third_party/antlr4-runtime-4.13.2/runtime/src
set(ANTLR4_RUNTIME_SRC_DIR "${PROJECT_SOURCE_DIR}/third_party/antlr4-runtime-4.13.2/runtime/src") set(ANTLR4_RUNTIME_SRC_DIR "${PROJECT_SOURCE_DIR}/third_party/antlr4-runtime-4.13.2/runtime/src")
add_library(antlr4_runtime STATIC) add_library(antlr4_runtime STATIC)
target_compile_features(antlr4_runtime PUBLIC cxx_std_17) target_compile_features(antlr4_runtime PUBLIC cxx_std_17)
target_compile_definitions(antlr4_runtime PUBLIC ANTLR4CPP_STATIC)
target_include_directories(antlr4_runtime PUBLIC target_include_directories(antlr4_runtime PUBLIC
"${ANTLR4_RUNTIME_SRC_DIR}" "${ANTLR4_RUNTIME_SRC_DIR}"

7
Cargo.lock generated

@ -0,0 +1,7 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "nudt-compiler-cpp"
version = "0.1.0"

@ -0,0 +1,6 @@
[package]
name = "nudt-compiler-cpp"
version = "0.1.0"
edition = "2024"
[dependencies]

@ -109,3 +109,8 @@ cmake --build build -j "$(nproc)"
目标:脚本自动读取同名 `.in`,并将程序输出与退出码和同名 `.out` 比对,确保优化后程序行为与优化前保持一致。 目标:脚本自动读取同名 `.in`,并将程序输出与退出码和同名 `.out` 比对,确保优化后程序行为与优化前保持一致。
完成 Lab4 后,应对 `test/test_case` 下全部测试用例逐个回归;如有需要,也可以自行编写批量测试脚本统一执行。 完成 Lab4 后,应对 `test/test_case` 下全部测试用例逐个回归;如有需要,也可以自行编写批量测试脚本统一执行。
批量测试脚本:
bash test/test_result/lab4_batch/run_all.sh

@ -1,37 +1,15 @@
// 当前只支撑 i32、i32*、void 以及最小的内存/算术指令,演示用。 // 扩展后的 IR 库:
// // - 完整基础类型void/i1/i32/float/ptr/array/function/label
// 当前已经实现: // - 指令算术、比较、分支、调用、phi、gep、类型转换等
// 1. 基础类型系统void / i32 / i32* // - 常量int/float/array
// 2. Value 体系Value / ConstantValue / ConstantInt / Function / BasicBlock / User / GlobalValue / Instruction // - 基本块/函数/模块/IRBuilder 的完整接口
// 3. 最小指令集Add / Alloca / Load / Store / Ret
// 4. BasicBlock / Function / Module 三层组织结构
// 5. IRBuilder便捷创建常量和最小指令
// 6. def-use 关系的轻量实现:
// - Instruction 保存 operand 列表
// - Value 保存 uses
// - 支持 ReplaceAllUsesWith 的简化实现
//
// 当前尚未实现或只做了最小占位:
// 1. 完整类型系统数组、函数类型、label 类型等
// 2. 更完整的指令系统br / condbr / call / phi / gep 等
// 3. 更成熟的 Use 管理(例如 LLVM 风格的双向链式结构)
// 4. 更完整的 IR verifier 和优化基础设施
//
// 当前需要特别说明的两个简化点:
// 1. BasicBlock 虽然已经纳入 Value 体系,但其类型目前仍用 void 作为占位,
// 后续如果补 label type可以再改成更合理的块标签类型。
// 2. ConstantValue 体系目前只实现了 ConstantInt后续可以继续补 ConstantFloat、
// ConstantArray等更完整的常量种类。
//
// 建议的扩展顺序:
// 1. 先补更多指令和类型
// 2. 再补控制流相关 IR
// 3. 最后再考虑把 Value/User/Use 进一步抽象成更完整的框架
#pragma once #pragma once
#include <cstdint>
#include <iosfwd> #include <iosfwd>
#include <memory> #include <memory>
#include <optional>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
@ -45,10 +23,14 @@ class Value;
class User; class User;
class ConstantValue; class ConstantValue;
class ConstantInt; class ConstantInt;
class ConstantFloat;
class ConstantArray;
class GlobalValue; class GlobalValue;
class GlobalVariable;
class Instruction; class Instruction;
class BasicBlock; class BasicBlock;
class Function; class Function;
class Argument;
// Use 表示一个 Value 的一次使用记录。 // Use 表示一个 Value 的一次使用记录。
// 当前实现设计: // 当前实现设计:
@ -83,31 +65,65 @@ class Context {
~Context(); ~Context();
// 去重创建 i32 常量。 // 去重创建 i32 常量。
ConstantInt* GetConstInt(int v); ConstantInt* GetConstInt(int v);
ConstantInt* GetConstBool(bool v);
ConstantFloat* GetConstFloat(float v);
ConstantArray* CreateConstArray(std::shared_ptr<Type> array_ty,
std::vector<ConstantValue*> elements);
std::string NextTemp(); std::string NextTemp();
private: private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_; std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_bools_;
std::unordered_map<uint32_t, std::unique_ptr<ConstantFloat>> const_floats_;
std::vector<std::unique_ptr<ConstantArray>> const_arrays_;
int temp_index_ = -1; int temp_index_ = -1;
}; };
class Type { class Type {
public: public:
enum class Kind { Void, Int32, PtrInt32 }; enum class Kind { Void, Int1, Int32, Float, Pointer, Array, Function, Label };
explicit Type(Kind k); explicit Type(Kind k);
Type(Kind k, std::shared_ptr<Type> elem, size_t count);
Type(Kind k, std::shared_ptr<Type> ret, std::vector<std::shared_ptr<Type>> params,
bool is_vararg);
// 使用静态共享对象获取类型。 // 使用静态共享对象获取类型。
// 同一类型可直接比较返回值是否相等,例如: // 同一类型可直接比较返回值是否相等,例如:
// Type::GetInt32Type() == Type::GetInt32Type() // Type::GetInt32Type() == Type::GetInt32Type()
static const std::shared_ptr<Type>& GetVoidType(); static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt1Type();
static const std::shared_ptr<Type>& GetInt32Type(); static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetPtrInt32Type(); static const std::shared_ptr<Type>& GetFloatType();
static const std::shared_ptr<Type>& GetLabelType();
static std::shared_ptr<Type> GetPointerType(std::shared_ptr<Type> elem);
static std::shared_ptr<Type> GetArrayType(std::shared_ptr<Type> elem,
size_t count);
static std::shared_ptr<Type> GetFunctionType(
std::shared_ptr<Type> ret, std::vector<std::shared_ptr<Type>> params,
bool is_vararg = false);
Kind GetKind() const; Kind GetKind() const;
bool IsVoid() const; bool IsVoid() const;
bool IsInt1() const;
bool IsInt32() const; bool IsInt32() const;
bool IsPtrInt32() const; bool IsFloat() const;
bool IsPointer() const;
bool IsArray() const;
bool IsFunction() const;
bool IsLabel() const;
const std::shared_ptr<Type>& GetElementType() const;
size_t GetArraySize() const;
const std::shared_ptr<Type>& GetReturnType() const;
const std::vector<std::shared_ptr<Type>>& GetParamTypes() const;
bool IsVarArg() const;
bool Equals(const Type& other) const;
private: private:
Kind kind_; Kind kind_;
std::shared_ptr<Type> elem_type_;
size_t array_size_ = 0;
std::shared_ptr<Type> ret_type_;
std::vector<std::shared_ptr<Type>> param_types_;
bool is_vararg_ = false;
}; };
class Value { class Value {
@ -118,7 +134,12 @@ class Value {
const std::string& GetName() const; const std::string& GetName() const;
void SetName(std::string n); void SetName(std::string n);
bool IsVoid() const; bool IsVoid() const;
bool IsInt1() const;
bool IsInt32() const; bool IsInt32() const;
bool IsFloat() const;
bool IsPointer() const;
bool IsArray() const;
bool IsFunctionType() const;
bool IsPtrInt32() const; bool IsPtrInt32() const;
bool IsConstant() const; bool IsConstant() const;
bool IsInstruction() const; bool IsInstruction() const;
@ -151,8 +172,53 @@ class ConstantInt : public ConstantValue {
int value_{}; int value_{};
}; };
class ConstantFloat : public ConstantValue {
public:
ConstantFloat(std::shared_ptr<Type> ty, float v);
float GetValue() const { return value_; }
private:
float value_{};
};
class ConstantArray : public ConstantValue {
public:
ConstantArray(std::shared_ptr<Type> ty, std::vector<ConstantValue*> elements);
const std::vector<ConstantValue*>& GetElements() const { return elements_; }
private:
std::vector<ConstantValue*> elements_;
};
// 后续还需要扩展更多指令类型。 // 后续还需要扩展更多指令类型。
enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret }; enum class Opcode {
Add,
Sub,
Mul,
SDiv,
SRem,
FAdd,
FSub,
FMul,
FDiv,
Alloca,
Load,
Store,
Ret,
Br,
CondBr,
ICmp,
FCmp,
Call,
Phi,
Gep,
SIToFP,
FPToSI,
ZExt
};
enum class ICmpPredicate { Eq, Ne, Slt, Sle, Sgt, Sge };
enum class FCmpPredicate { Oeq, One, Olt, Ole, Ogt, Oge };
// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。 // User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。
// 当前实现中只有 Instruction 继承自 User。 // 当前实现中只有 Instruction 继承自 User。
@ -162,10 +228,13 @@ class User : public Value {
size_t GetNumOperands() const; size_t GetNumOperands() const;
Value* GetOperand(size_t index) const; Value* GetOperand(size_t index) const;
void SetOperand(size_t index, Value* value); void SetOperand(size_t index, Value* value);
void RemoveOperand(size_t index);
protected: protected:
// 统一的 operand 入口。 // 统一的 operand 入口。
void AddOperand(Value* value); void AddOperand(Value* value);
virtual void OnOperandChanged(size_t index, Value* value);
virtual void OnOperandRemoving(size_t index);
private: private:
std::vector<Value*> operands_; std::vector<Value*> operands_;
@ -178,6 +247,20 @@ class GlobalValue : public User {
GlobalValue(std::shared_ptr<Type> ty, std::string name); GlobalValue(std::shared_ptr<Type> ty, std::string name);
}; };
class GlobalVariable : public GlobalValue {
public:
GlobalVariable(std::shared_ptr<Type> value_ty, std::string name,
ConstantValue* init, bool is_const);
const std::shared_ptr<Type>& GetValueType() const;
ConstantValue* GetInitializer() const;
bool IsConst() const;
private:
std::shared_ptr<Type> value_type_;
ConstantValue* initializer_ = nullptr;
bool is_const_ = false;
};
class Instruction : public User { class Instruction : public User {
public: public:
Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name = ""); Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name = "");
@ -196,18 +279,67 @@ class BinaryInst : public Instruction {
BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, Value* rhs, BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, Value* rhs,
std::string name); std::string name);
Value* GetLhs() const; Value* GetLhs() const;
Value* GetRhs() const; Value* GetRhs() const;
};
class ICmpInst : public Instruction {
public:
ICmpInst(ICmpPredicate pred, Value* lhs, Value* rhs, std::string name);
ICmpPredicate GetPredicate() const { return pred_; }
Value* GetLhs() const;
Value* GetRhs() const;
private:
ICmpPredicate pred_;
};
class FCmpInst : public Instruction {
public:
FCmpInst(FCmpPredicate pred, Value* lhs, Value* rhs, std::string name);
FCmpPredicate GetPredicate() const { return pred_; }
Value* GetLhs() const;
Value* GetRhs() const;
private:
FCmpPredicate pred_;
};
class CastInst : public Instruction {
public:
CastInst(Opcode op, std::shared_ptr<Type> dst_ty, Value* src,
std::string name);
Value* GetValue() const;
};
class BranchInst : public Instruction {
public:
explicit BranchInst(BasicBlock* dest);
BasicBlock* GetDest() const;
};
class CondBrInst : public Instruction {
public:
CondBrInst(Value* cond, BasicBlock* true_dest, BasicBlock* false_dest);
Value* GetCond() const;
BasicBlock* GetTrueDest() const;
BasicBlock* GetFalseDest() const;
}; };
class ReturnInst : public Instruction { class ReturnInst : public Instruction {
public: public:
explicit ReturnInst(std::shared_ptr<Type> void_ty);
ReturnInst(std::shared_ptr<Type> void_ty, Value* val); ReturnInst(std::shared_ptr<Type> void_ty, Value* val);
bool HasReturnValue() const;
Value* GetValue() const; Value* GetValue() const;
}; };
class AllocaInst : public Instruction { class AllocaInst : public Instruction {
public: public:
AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name); AllocaInst(std::shared_ptr<Type> allocated_ty, std::string name);
const std::shared_ptr<Type>& GetAllocatedType() const;
private:
std::shared_ptr<Type> allocated_type_;
}; };
class LoadInst : public Instruction { class LoadInst : public Instruction {
@ -223,8 +355,48 @@ class StoreInst : public Instruction {
Value* GetPtr() const; Value* GetPtr() const;
}; };
// BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。 class CallInst : public Instruction {
// 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。 public:
CallInst(std::shared_ptr<Type> ret_ty, Value* callee,
std::vector<Value*> args, std::string name);
Value* GetCallee() const;
const std::vector<Value*>& GetArgs() const { return args_; }
private:
void OnOperandChanged(size_t index, Value* value) override;
void OnOperandRemoving(size_t index) override;
std::vector<Value*> args_;
};
class PhiInst : public Instruction {
public:
PhiInst(std::shared_ptr<Type> ty, std::string name);
void AddIncoming(Value* value, BasicBlock* block);
void RemoveIncomingFrom(BasicBlock* block);
const std::vector<Value*>& GetIncomingValues() const;
const std::vector<BasicBlock*>& GetIncomingBlocks() const;
private:
void OnOperandChanged(size_t index, Value* value) override;
void OnOperandRemoving(size_t index) override;
std::vector<Value*> incoming_values_;
std::vector<BasicBlock*> incoming_blocks_;
};
class GepInst : public Instruction {
public:
GepInst(std::shared_ptr<Type> result_ptr_ty, Value* base_ptr,
std::vector<Value*> indices, std::string name);
Value* GetBasePtr() const;
const std::vector<Value*>& GetIndices() const { return indices_; }
private:
void OnOperandChanged(size_t index, Value* value) override;
void OnOperandRemoving(size_t index) override;
std::vector<Value*> indices_;
};
// BasicBlock 已纳入 Value 体系,使用 label type。
class BasicBlock : public Value { class BasicBlock : public Value {
public: public:
explicit BasicBlock(std::string name); explicit BasicBlock(std::string name);
@ -232,8 +404,17 @@ class BasicBlock : public Value {
void SetParent(Function* parent); void SetParent(Function* parent);
bool HasTerminator() const; bool HasTerminator() const;
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const; const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const;
std::vector<std::unique_ptr<Instruction>>& GetMutableInstructions();
const std::vector<BasicBlock*>& GetPredecessors() const; const std::vector<BasicBlock*>& GetPredecessors() const;
const std::vector<BasicBlock*>& GetSuccessors() const; const std::vector<BasicBlock*>& GetSuccessors() const;
void AddPredecessor(BasicBlock* pred);
void AddSuccessor(BasicBlock* succ);
void ClearPredecessors();
void ClearSuccessors();
void RemovePredecessor(BasicBlock* pred);
void RemoveSuccessor(BasicBlock* succ);
void EraseInstruction(Instruction* inst);
void ReplaceTerminator(std::unique_ptr<Instruction> inst);
template <typename T, typename... Args> template <typename T, typename... Args>
T* Append(Args&&... args) { T* Append(Args&&... args) {
if (HasTerminator()) { if (HasTerminator()) {
@ -244,6 +425,16 @@ class BasicBlock : public Value {
auto* ptr = inst.get(); auto* ptr = inst.get();
ptr->SetParent(this); ptr->SetParent(this);
instructions_.push_back(std::move(inst)); instructions_.push_back(std::move(inst));
LinkSuccessorsIfNeeded(ptr);
return ptr;
}
template <typename T, typename... Args>
T* Prepend(Args&&... args) {
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.insert(instructions_.begin(), std::move(inst));
return ptr; return ptr;
} }
@ -252,6 +443,7 @@ class BasicBlock : public Value {
std::vector<std::unique_ptr<Instruction>> instructions_; std::vector<std::unique_ptr<Instruction>> instructions_;
std::vector<BasicBlock*> predecessors_; std::vector<BasicBlock*> predecessors_;
std::vector<BasicBlock*> successors_; std::vector<BasicBlock*> successors_;
void LinkSuccessorsIfNeeded(Instruction* inst);
}; };
// Function 当前也采用了最小实现。 // Function 当前也采用了最小实现。
@ -262,16 +454,35 @@ class BasicBlock : public Value {
// 形参和调用,通常需要引入专门的函数类型表示。 // 形参和调用,通常需要引入专门的函数类型表示。
class Function : public Value { class Function : public Value {
public: public:
// 当前构造函数接收的也是返回类型,而不是完整函数类型。 Function(std::string name, std::shared_ptr<Type> func_type,
Function(std::string name, std::shared_ptr<Type> ret_type); bool is_declaration = false);
BasicBlock* CreateBlock(const std::string& name); BasicBlock* CreateBlock(const std::string& name);
BasicBlock* GetEntry(); BasicBlock* GetEntry();
const BasicBlock* GetEntry() const; const BasicBlock* GetEntry() const;
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const; const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const;
std::vector<std::unique_ptr<BasicBlock>>& GetMutableBlocks();
const std::vector<std::unique_ptr<Argument>>& GetArguments() const;
size_t GetNumArgs() const;
Argument* GetArg(size_t index);
std::shared_ptr<Type> GetFunctionType() const;
std::shared_ptr<Type> GetReturnType() const;
bool IsDeclaration() const;
private: private:
BasicBlock* entry_ = nullptr; BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<BasicBlock>> blocks_; std::vector<std::unique_ptr<BasicBlock>> blocks_;
std::vector<std::unique_ptr<Argument>> args_;
std::unordered_map<std::string, size_t> block_name_counts_;
bool is_declaration_ = false;
};
class Argument : public Value {
public:
Argument(std::shared_ptr<Type> ty, std::string name, size_t index);
size_t GetIndex() const { return index_; }
private:
size_t index_ = 0;
}; };
class Module { class Module {
@ -282,11 +493,20 @@ class Module {
// 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。 // 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。
Function* CreateFunction(const std::string& name, Function* CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type); std::shared_ptr<Type> ret_type);
Function* CreateFunctionWithType(const std::string& name,
std::shared_ptr<Type> func_type);
Function* CreateFunctionDecl(const std::string& name,
std::shared_ptr<Type> func_type);
GlobalVariable* CreateGlobalVariable(const std::string& name,
std::shared_ptr<Type> value_type,
ConstantValue* init, bool is_const);
const std::vector<std::unique_ptr<Function>>& GetFunctions() const; const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
const std::vector<std::unique_ptr<GlobalVariable>>& GetGlobals() const;
private: private:
Context context_; Context context_;
std::vector<std::unique_ptr<Function>> functions_; std::vector<std::unique_ptr<Function>> functions_;
std::vector<std::unique_ptr<GlobalVariable>> globals_;
}; };
class IRBuilder { class IRBuilder {
@ -297,13 +517,44 @@ class IRBuilder {
// 构造常量、二元运算、返回指令的最小集合。 // 构造常量、二元运算、返回指令的最小集合。
ConstantInt* CreateConstInt(int v); ConstantInt* CreateConstInt(int v);
ConstantInt* CreateConstBool(bool v);
ConstantFloat* CreateConstFloat(float v);
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs, BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name); const std::string& name);
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name); BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSDiv(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSRem(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFAdd(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFSub(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFMul(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFDiv(Value* lhs, Value* rhs, const std::string& name);
ICmpInst* CreateICmp(ICmpPredicate pred, Value* lhs, Value* rhs,
const std::string& name);
FCmpInst* CreateFCmp(FCmpPredicate pred, Value* lhs, Value* rhs,
const std::string& name);
CastInst* CreateSIToFP(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name);
CastInst* CreateFPToSI(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name);
CastInst* CreateZExt(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name);
AllocaInst* CreateAlloca(std::shared_ptr<Type> ty,
const std::string& name);
AllocaInst* CreateAllocaI32(const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr); StoreInst* CreateStore(Value* val, Value* ptr);
GepInst* CreateGep(Value* base_ptr, std::vector<Value*> indices,
const std::string& name);
CallInst* CreateCall(Value* callee, std::vector<Value*> args,
const std::string& name);
PhiInst* CreatePhi(std::shared_ptr<Type> ty, const std::string& name);
BranchInst* CreateBr(BasicBlock* dest);
CondBrInst* CreateCondBr(Value* cond, BasicBlock* true_dest,
BasicBlock* false_dest);
ReturnInst* CreateRet(Value* v); ReturnInst* CreateRet(Value* v);
ReturnInst* CreateRetVoid();
private: private:
Context& ctx_; Context& ctx_;
@ -315,4 +566,12 @@ class IRPrinter {
void Print(const Module& module, std::ostream& os); void Print(const Module& module, std::ostream& os);
}; };
bool RunMem2Reg(Module& module);
bool RunConstFold(Module& module);
bool RunConstProp(Module& module);
bool RunCSE(Module& module);
bool RunDCE(Module& module);
bool RunCFGSimplify(Module& module);
bool RunScalarOptimizationPipeline(Module& module);
} // namespace ir } // namespace ir

@ -1,58 +1,114 @@
// 将语法树翻译为 IR。
// 实现拆分在 IRGenFunc/IRGenStmt/IRGenExp/IRGenDecl。
#pragma once #pragma once
#include <any> #include <any>
#include <memory> #include <memory>
#include <string>
#include <unordered_map> #include <unordered_map>
#include "SysYBaseVisitor.h" #include "SysYBaseVisitor.h"
#include "SysYParser.h"
#include "ir/IR.h" #include "ir/IR.h"
#include "sem/Sema.h" #include "sem/Sema.h"
namespace ir {
class Module;
class Function;
class IRBuilder;
class Value;
}
class IRGenImpl final : public SysYBaseVisitor { class IRGenImpl final : public SysYBaseVisitor {
public: public:
IRGenImpl(ir::Module& module, const SemanticContext& sema); IRGenImpl(ir::Module& module, const SemanticContext& sema);
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override;
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; std::any visitBlock(SysYParser::BlockContext* ctx) override;
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override;
std::any visitDecl(SysYParser::DeclContext* ctx) override; std::any visitDecl(SysYParser::DeclContext* ctx) override;
std::any visitStmt(SysYParser::StmtContext* ctx) override; std::any visitConstDef(SysYParser::ConstDefContext* ctx) override;
std::any visitVarDef(SysYParser::VarDefContext* ctx) override; std::any visitVarDef(SysYParser::VarDefContext* ctx) override;
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; std::any visitStmt(SysYParser::StmtContext* ctx) override;
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override;
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override;
std::any visitVarExp(SysYParser::VarExpContext* ctx) override;
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override;
private: std::any visitExp(SysYParser::ExpContext* ctx) override;
enum class BlockFlow { std::any visitAddExp(SysYParser::AddExpContext* ctx) override;
Continue, std::any visitMulExp(SysYParser::MulExpContext* ctx) override; // 新增
Terminated, std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override; // 新增
}; std::any visitRelExp(SysYParser::RelExpContext* ctx) override;
std::any visitEqExp(SysYParser::EqExpContext* ctx) override;
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override;
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override;
std::any visitCond(SysYParser::CondContext* ctx) override;
std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override;
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override;
std::any visitNumber(SysYParser::NumberContext* ctx) override;
std::any visitLVal(SysYParser::LValContext* ctx) override;
std::any visitConstExp(SysYParser::ConstExpContext* ctx) override;
std::any visitConstInitVal(SysYParser::ConstInitValContext* ctx) override;
std::any visitInitVal(SysYParser::InitValContext* ctx) override;
private:
ir::Value* EvalExp(SysYParser::ExpContext* ctx);
ir::Value* EvalCondValue(SysYParser::CondContext* ctx);
void EmitCondBr(SysYParser::CondContext* ctx, ir::BasicBlock* true_bb,
ir::BasicBlock* false_bb);
void EmitLOrCond(SysYParser::LOrExpContext* ctx, ir::BasicBlock* true_bb,
ir::BasicBlock* false_bb);
void EmitLAndCond(SysYParser::LAndExpContext* ctx, ir::BasicBlock* true_bb,
ir::BasicBlock* false_bb);
ir::Value* EmitRelEq(SysYParser::RelExpContext* ctx);
ir::Value* EmitEq(SysYParser::EqExpContext* ctx);
ir::Value* CastToFloat(ir::Value* v);
ir::Value* CastToInt(ir::Value* v);
ir::Value* MakeBool(ir::Value* v);
ir::Value* GetLValAddress(SysYParser::LValContext* ctx);
ir::Value* LoadIfNeeded(ir::Value* addr_or_val, const TypeDesc& ty,
bool as_rvalue);
std::shared_ptr<ir::Type> ToIRType(const TypeDesc& ty);
std::shared_ptr<ir::Type> ToIRParamType(const TypeDesc& ty);
ir::Value* DefaultValue(const TypeDesc& ty);
ir::AllocaInst* CreateEntryAlloca(std::shared_ptr<ir::Type> ty,
const std::string& name);
void InitArray(ir::Value* base_ptr, const TypeDesc& ty,
SysYParser::InitValContext* init);
void InitConstArray(ir::Value* base_ptr, const TypeDesc& ty,
SysYParser::ConstInitValContext* init);
size_t FillArrayValues(const TypeDesc& ty, SysYParser::InitValContext* init,
std::vector<ir::Value*>& values, size_t base,
size_t idx, size_t dim);
size_t FillConstArrayValues(const TypeDesc& ty,
SysYParser::ConstInitValContext* init,
std::vector<ir::Value*>& values, size_t base,
size_t idx, size_t dim);
size_t ArrayStride(const TypeDesc& ty, size_t dim) const;
size_t ArrayTotalSize(const TypeDesc& ty) const;
void PushLoop(ir::BasicBlock* break_bb, ir::BasicBlock* cont_bb);
void PopLoop();
ir::BasicBlock* CurrentBreak() const;
ir::BasicBlock* CurrentContinue() const;
ir::ConstantValue* EvalConstScalar(SysYParser::ExpContext* ctx);
ir::ConstantValue* EvalConstScalar(SysYParser::ConstExpContext* ctx);
ir::ConstantValue* EvalConstAdd(SysYParser::AddExpContext* ctx);
ir::ConstantValue* EvalConstMul(SysYParser::MulExpContext* ctx);
ir::ConstantValue* EvalConstUnary(SysYParser::UnaryExpContext* ctx);
ir::ConstantValue* EvalConstPrimary(SysYParser::PrimaryExpContext* ctx);
ir::ConstantValue* EvalConstNumber(SysYParser::NumberContext* ctx);
ir::ConstantValue* EvalConstLVal(SysYParser::LValContext* ctx);
size_t InitGlobalArray(const TypeDesc& ty, SysYParser::InitValContext* init,
std::vector<ir::ConstantValue*>& values, size_t base,
size_t idx, size_t dim);
size_t InitGlobalConstArray(const TypeDesc& ty,
SysYParser::ConstInitValContext* init,
std::vector<ir::ConstantValue*>& values,
size_t base, size_t idx, size_t dim);
enum class BlockFlow { Continue, Terminated };
BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item); BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item);
ir::Value* EvalExpr(SysYParser::ExpContext& expr);
ir::Module& module_; ir::Module& module_;
const SemanticContext& sema_; const SemanticContext& sema_;
ir::Function* func_; ir::Function* func_;
ir::IRBuilder builder_; ir::IRBuilder builder_;
// 名称绑定由 Sema 负责IRGen 只维护“声明 -> 存储槽位”的代码生成状态。 std::unordered_map<const SysYParser::VarDefContext*, ir::Value*> var_storage_;
std::unordered_map<SysYParser::VarDefContext*, ir::Value*> storage_map_; std::unordered_map<const SysYParser::ConstDefContext*, ir::Value*> const_storage_;
std::unordered_map<const SysYParser::FuncFParamContext*, ir::Value*> param_storage_;
std::unordered_map<const SysYParser::FuncDefContext*, ir::Function*> func_map_;
std::unordered_map<const SysYParser::VarDefContext*, ir::GlobalVariable*>
global_var_storage_;
std::unordered_map<const SysYParser::ConstDefContext*, ir::GlobalVariable*>
global_const_storage_;
std::vector<std::pair<ir::BasicBlock*, ir::BasicBlock*>> loop_stack_;
}; };
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree, std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemanticContext& sema); const SemanticContext& sema);

@ -1,6 +1,10 @@
// Lab5 后端 MIR 表示:
// - 虚拟寄存器 + 物理寄存器,两类寄存器(GPR/FPR)
// - 多函数、多基本块、全局变量、栈对象
// - 指令携带显式 def/use约定操作数前 num_defs 个为定值)
#pragma once #pragma once
#include <initializer_list> #include <cstdint>
#include <iosfwd> #include <iosfwd>
#include <memory> #include <memory>
#include <string> #include <string>
@ -16,104 +20,273 @@ class MIRContext {
public: public:
MIRContext() = default; MIRContext() = default;
}; };
MIRContext& DefaultContext(); MIRContext& DefaultContext();
enum class PhysReg { W0, W8, W9, X29, X30, SP }; enum class RegClass { GPR, FPR };
// 物理寄存器编号(按类内编号):
// GPR: 0..30 = x0..x30, 31 = sp, 32 = xzr
// FPR: 0..31 = s0..s31
namespace preg {
constexpr int kSP = 31;
constexpr int kXZR = 32;
constexpr int kFP = 29; // x29
constexpr int kLR = 30; // x30
constexpr int kIP0 = 16; // x16 scratch
constexpr int kIP1 = 17; // x17 scratch
} // namespace preg
const char* PhysRegName(PhysReg reg); enum class Cond { AL, EQ, NE, LT, LE, GT, GE, MI, LS, HI, HS };
enum class Opcode { enum class Opcode {
Prologue, Mov, // dst<-src (reg copy)
Epilogue, MovImm, // dst<-imm (materialize 32/64-bit)
MovImm, Sxtw, // dst(64) = sign-extend src(32)
LoadStack, Add, // dst = a + b
StoreStack, Sub, // dst = a - b
AddRR, Mul, // dst = a * b
SDiv, // dst = a / b
MSub, // dst = a - b*c
AddImm, // dst = a + imm
SubImm, // dst = a - imm
LslImm, // dst = a << imm
Cmp, // a ? b (sets flags)
CmpImm, // a ? imm
CSet, // dst = cond
FAdd,
FSub,
FMul,
FDiv,
FCmp,
FMov, // fpr<-fpr
FMovImm, // fpr <- 32-bit float bits (via scratch gpr)
SCvtF, // fpr = (float)gpr
FCvtZS, // gpr = (int)fpr
Ldr, // dst <- [base, #imm]
Str, // [base, #imm] <- src (def=0, ops: src,base,imm)
LdrStack, // dst <- [frame]
StrStack, // [frame] <- src (def=0, ops: src, frame)
AddrFrame, // dst = addr of frame slot
AddrGlobal, // dst = addr of global symbol
B, // branch label
BCond, // branch cond label (uses flags)
Bl, // call symbol (def=0; ops list arg pregs as uses)
Ret, Ret,
}; };
class Operand { class Operand {
public: public:
enum class Kind { Reg, Imm, FrameIndex }; enum class Kind { None, VReg, PReg, Imm, Frame, Global, Label };
static Operand Reg(PhysReg reg); static Operand VReg(int id, RegClass cls, int bytes) {
static Operand Imm(int value); Operand o;
static Operand FrameIndex(int index); o.kind_ = Kind::VReg;
o.id_ = id;
o.cls_ = cls;
o.bytes_ = bytes;
return o;
}
static Operand PReg(int id, RegClass cls, int bytes) {
Operand o;
o.kind_ = Kind::PReg;
o.id_ = id;
o.cls_ = cls;
o.bytes_ = bytes;
return o;
}
static Operand Imm(long long v) {
Operand o;
o.kind_ = Kind::Imm;
o.imm_ = v;
return o;
}
static Operand Frame(int idx) {
Operand o;
o.kind_ = Kind::Frame;
o.id_ = idx;
return o;
}
static Operand Global(std::string name) {
Operand o;
o.kind_ = Kind::Global;
o.sym_ = std::move(name);
return o;
}
static Operand Label(std::string name) {
Operand o;
o.kind_ = Kind::Label;
o.sym_ = std::move(name);
return o;
}
Kind GetKind() const { return kind_; } Kind GetKind() const { return kind_; }
PhysReg GetReg() const { return reg_; } bool IsReg() const { return kind_ == Kind::VReg || kind_ == Kind::PReg; }
int GetImm() const { return imm_; } bool IsVReg() const { return kind_ == Kind::VReg; }
int GetFrameIndex() const { return imm_; } bool IsPReg() const { return kind_ == Kind::PReg; }
int GetId() const { return id_; }
RegClass GetClass() const { return cls_; }
int GetBytes() const { return bytes_; }
long long GetImm() const { return imm_; }
int GetFrame() const { return id_; }
const std::string& GetSym() const { return sym_; }
void SetPReg(int id) {
kind_ = Kind::PReg;
id_ = id;
}
void SetVReg(int id) {
kind_ = Kind::VReg;
id_ = id;
}
void SetBytes(int b) { bytes_ = b; }
private: private:
Operand(Kind kind, PhysReg reg, int imm); Kind kind_ = Kind::None;
int id_ = 0;
RegClass cls_ = RegClass::GPR;
int bytes_ = 4;
long long imm_ = 0;
std::string sym_;
};
Kind kind_; struct MachineInstr {
PhysReg reg_; Opcode op;
int imm_; std::vector<Operand> ops;
int num_defs = 0;
Cond cond = Cond::AL;
MachineInstr(Opcode o, std::vector<Operand> operands, int defs,
Cond c = Cond::AL)
: op(o), ops(std::move(operands)), num_defs(defs), cond(c) {}
}; };
class MachineInstr { class MachineBasicBlock {
public: public:
MachineInstr(Opcode opcode, std::vector<Operand> operands = {}); explicit MachineBasicBlock(std::string name) : name_(std::move(name)) {}
const std::string& GetName() const { return name_; }
std::vector<MachineInstr>& Instrs() { return instrs_; }
const std::vector<MachineInstr>& Instrs() const { return instrs_; }
std::vector<MachineBasicBlock*>& Succs() { return succs_; }
const std::vector<MachineBasicBlock*>& Succs() const { return succs_; }
Opcode GetOpcode() const { return opcode_; } void Add(MachineInstr mi) { instrs_.push_back(std::move(mi)); }
const std::vector<Operand>& GetOperands() const { return operands_; }
private: private:
Opcode opcode_; std::string name_;
std::vector<Operand> operands_; std::vector<MachineInstr> instrs_;
std::vector<MachineBasicBlock*> succs_;
}; };
struct FrameSlot { struct VRegInfo {
RegClass cls = RegClass::GPR;
int bytes = 4;
};
struct StackObject {
int index = 0; int index = 0;
int size = 4; int size = 4;
int offset = 0; int align = 4;
int offset = 0; // 相对 x29负数
}; };
class MachineBasicBlock { class MachineFunction {
public: public:
explicit MachineBasicBlock(std::string name); explicit MachineFunction(std::string name) : name_(std::move(name)) {}
const std::string& GetName() const { return name_; } const std::string& GetName() const { return name_; }
std::vector<MachineInstr>& GetInstructions() { return instructions_; }
const std::vector<MachineInstr>& GetInstructions() const { return instructions_; }
MachineInstr& Append(Opcode opcode, MachineBasicBlock* CreateBlock(const std::string& name) {
std::initializer_list<Operand> operands = {}); blocks_.push_back(std::make_unique<MachineBasicBlock>(name));
return blocks_.back().get();
}
const std::vector<std::unique_ptr<MachineBasicBlock>>& Blocks() const {
return blocks_;
}
std::vector<std::unique_ptr<MachineBasicBlock>>& Blocks() { return blocks_; }
private: int NewVReg(RegClass cls, int bytes) {
std::string name_; int id = static_cast<int>(vregs_.size());
std::vector<MachineInstr> instructions_; vregs_.push_back(VRegInfo{cls, bytes});
}; return id;
}
Operand NewVRegOp(RegClass cls, int bytes) {
return Operand::VReg(NewVReg(cls, bytes), cls, bytes);
}
int NumVRegs() const { return static_cast<int>(vregs_.size()); }
const VRegInfo& VReg(int id) const { return vregs_[id]; }
class MachineFunction { int CreateStackObject(int size, int align) {
public: StackObject obj;
explicit MachineFunction(std::string name); obj.index = static_cast<int>(stack_.size());
obj.size = size;
obj.align = align;
stack_.push_back(obj);
return obj.index;
}
std::vector<StackObject>& StackObjects() { return stack_; }
const std::vector<StackObject>& StackObjects() const { return stack_; }
StackObject& Stack(int idx) { return stack_[idx]; }
const std::string& GetName() const { return name_; } int GetFrameSize() const { return frame_size_; }
MachineBasicBlock& GetEntry() { return entry_; } void SetFrameSize(int s) { frame_size_ = s; }
const MachineBasicBlock& GetEntry() const { return entry_; }
int CreateFrameIndex(int size = 4); // 寄存器分配产物:本函数用到、需要保存恢复的 callee-saved 物理寄存器
FrameSlot& GetFrameSlot(int index); std::vector<int>& CalleeSavedGPR() { return callee_gpr_; }
const FrameSlot& GetFrameSlot(int index) const; std::vector<int>& CalleeSavedFPR() { return callee_fpr_; }
const std::vector<FrameSlot>& GetFrameSlots() const { return frame_slots_; } const std::vector<int>& CalleeSavedGPR() const { return callee_gpr_; }
const std::vector<int>& CalleeSavedFPR() const { return callee_fpr_; }
int GetFrameSize() const { return frame_size_; } int NumIntArgs() const { return num_int_args_; }
void SetFrameSize(int size) { frame_size_ = size; } int NumFloatArgs() const { return num_float_args_; }
void SetArgCounts(int i, int f) {
num_int_args_ = i;
num_float_args_ = f;
}
private: private:
std::string name_; std::string name_;
MachineBasicBlock entry_; std::vector<std::unique_ptr<MachineBasicBlock>> blocks_;
std::vector<FrameSlot> frame_slots_; std::vector<VRegInfo> vregs_;
std::vector<StackObject> stack_;
std::vector<int> callee_gpr_;
std::vector<int> callee_fpr_;
int frame_size_ = 0; int frame_size_ = 0;
int num_int_args_ = 0;
int num_float_args_ = 0;
};
struct MachineGlobal {
std::string name;
int size = 0; // 总字节数
int align = 4;
bool is_const = false;
bool zero_init = true;
std::vector<unsigned> words; // 非零初始化:按 4 字节小端存放的原始位
}; };
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module); class MachineModule {
public:
std::vector<std::unique_ptr<MachineFunction>>& Functions() { return funcs_; }
const std::vector<std::unique_ptr<MachineFunction>>& Functions() const {
return funcs_;
}
std::vector<MachineGlobal>& Globals() { return globals_; }
const std::vector<MachineGlobal>& Globals() const { return globals_; }
private:
std::vector<std::unique_ptr<MachineFunction>> funcs_;
std::vector<MachineGlobal> globals_;
};
const char* GPRName(int id, int bytes);
const char* FPRName(int id, int bytes);
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineFunction& function); void RunRegAlloc(MachineFunction& function);
void RunFrameLowering(MachineFunction& function); void RunFrameLowering(MachineFunction& function);
void PrintAsm(const MachineFunction& function, std::ostream& os); void RunPeephole(MachineFunction& function);
void RunBackendPipeline(MachineModule& module);
void PrintAsm(const MachineModule& module, std::ostream& os);
void PrintAArch64AsmFromMIR(const ir::Module& module, std::ostream& os);
void PrintAArch64AsmFromIR(const ir::Module& module, std::ostream& os);
} // namespace mir } // namespace mir

@ -1,30 +1,91 @@
// 基于语法树的语义检查与名称绑定 // 基于语法树的语义检查与名称绑定Lab2 扩展)
#pragma once #pragma once
#include <optional>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "SysYParser.h" #include "SysYParser.h"
#include "sem/SymbolTable.h"
struct FuncTypeDesc {
TypeDesc ret;
std::vector<TypeDesc> params;
};
struct BoundDecl {
enum class Kind { Var, Const, Param } kind = Kind::Var;
SysYParser::VarDefContext* var_decl = nullptr;
SysYParser::ConstDefContext* const_decl = nullptr;
SysYParser::FuncFParamContext* param_decl = nullptr;
};
class SemanticContext { class SemanticContext {
public: public:
void BindVarUse(SysYParser::VarContext* use, void BindVarUse(SysYParser::LValContext* use, BoundDecl decl) {
SysYParser::VarDefContext* decl) {
var_uses_[use] = decl; var_uses_[use] = decl;
} }
SysYParser::VarDefContext* ResolveVarUse( BoundDecl ResolveVarUse(const SysYParser::LValContext* use) const {
const SysYParser::VarContext* use) const {
auto it = var_uses_.find(use); auto it = var_uses_.find(use);
return it == var_uses_.end() ? nullptr : it->second; return it == var_uses_.end() ? BoundDecl{} : it->second;
}
void RegisterVarDecl(SysYParser::VarDefContext* decl, TypeDesc ty) {
var_types_[decl] = std::move(ty);
}
void RegisterConstDecl(SysYParser::ConstDefContext* decl, TypeDesc ty) {
const_types_[decl] = std::move(ty);
}
void RegisterParam(SysYParser::FuncFParamContext* decl, TypeDesc ty) {
param_types_[decl] = std::move(ty);
}
void RegisterFunc(SysYParser::FuncDefContext* decl, FuncTypeDesc ty) {
func_types_[decl] = std::move(ty);
}
const TypeDesc* GetVarType(const SysYParser::VarDefContext* decl) const {
auto it = var_types_.find(decl);
return it == var_types_.end() ? nullptr : &it->second;
}
const TypeDesc* GetConstType(const SysYParser::ConstDefContext* decl) const {
auto it = const_types_.find(decl);
return it == const_types_.end() ? nullptr : &it->second;
}
const TypeDesc* GetParamType(const SysYParser::FuncFParamContext* decl) const {
auto it = param_types_.find(decl);
return it == param_types_.end() ? nullptr : &it->second;
}
const FuncTypeDesc* GetFuncType(const SysYParser::FuncDefContext* decl) const {
auto it = func_types_.find(decl);
return it == func_types_.end() ? nullptr : &it->second;
}
void BindFuncCall(SysYParser::UnaryExpContext* call,
SysYParser::FuncDefContext* decl) {
func_calls_[call] = decl;
}
SysYParser::FuncDefContext* ResolveFuncCall(
const SysYParser::UnaryExpContext* call) const {
auto it = func_calls_.find(call);
return it == func_calls_.end() ? nullptr : it->second;
} }
private: private:
std::unordered_map<const SysYParser::VarContext*, std::unordered_map<const SysYParser::LValContext*, BoundDecl> var_uses_;
SysYParser::VarDefContext*> std::unordered_map<const SysYParser::VarDefContext*, TypeDesc> var_types_;
var_uses_; std::unordered_map<const SysYParser::ConstDefContext*, TypeDesc> const_types_;
std::unordered_map<const SysYParser::FuncFParamContext*, TypeDesc> param_types_;
std::unordered_map<const SysYParser::FuncDefContext*, FuncTypeDesc> func_types_;
std::unordered_map<const SysYParser::UnaryExpContext*, SysYParser::FuncDefContext*>
func_calls_;
}; };
// 目前仅检查: SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);
// - 变量先声明后使用
// - 局部变量不允许重复定义
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);

@ -1,17 +1,42 @@
// 极简符号表:记录局部变量定义 // 符号表:记录局部变量/常量/参数定义。
#pragma once #pragma once
#include <optional>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "SysYParser.h" #include "SysYParser.h"
enum class BaseTypeKind { Int, Float, Void };
struct TypeDesc {
BaseTypeKind base = BaseTypeKind::Int;
std::vector<int> dims; // 为空表示标量;数组维度允许首维为 -1 表示形参不定长
bool is_const = false;
};
enum class SymbolKind { Var, Const, Param };
struct SymbolEntry {
SymbolKind kind = SymbolKind::Var;
SysYParser::VarDefContext* var_decl = nullptr;
SysYParser::ConstDefContext* const_decl = nullptr;
SysYParser::FuncFParamContext* param_decl = nullptr;
TypeDesc type; // 记录类型信息
bool is_const = false;
std::optional<int> const_value;
};
class SymbolTable { class SymbolTable {
public: public:
void Add(const std::string& name, SysYParser::VarDefContext* decl); void EnterScope();
bool Contains(const std::string& name) const; void ExitScope();
SysYParser::VarDefContext* Lookup(const std::string& name) const;
bool ContainsInCurrentScope(const std::string& name) const;
void Add(const std::string& name, const SymbolEntry& entry);
const SymbolEntry* Lookup(const std::string& name) const;
private: private:
std::unordered_map<std::string, SysYParser::VarDefContext*> table_; std::vector<std::unordered_map<std::string, SymbolEntry>> scopes_;
}; };

@ -8,6 +8,7 @@ struct CLIOptions {
bool emit_parse_tree = false; bool emit_parse_tree = false;
bool emit_ir = true; bool emit_ir = true;
bool emit_asm = false; bool emit_asm = false;
bool optimize_ir = true;
bool show_help = false; bool show_help = false;
}; };

@ -0,0 +1,19 @@
#!/usr/bin/env bash
set -euo pipefail
# Reconfigure with IR pipeline enabled, build, then run Lab2 test script.
RESULT_FILE="test/test_result/run_lab2_result.log"
mkdir -p "$(dirname \"$RESULT_FILE\")"
: > "$RESULT_FILE"
{
echo "[run_lab2] start: $(date '+%Y-%m-%d %H:%M:%S')"
echo "[run_lab2] logging to: $RESULT_FILE"
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF
cmake --build build -j "$(nproc)"
CASE_DIR=test/test_case bash scripts/test_lab2.sh
echo "[run_lab2] end: $(date '+%Y-%m-%d %H:%M:%S')"
} 2>&1 | tee "$RESULT_FILE"

@ -0,0 +1,124 @@
#!/usr/bin/env bash
set -euo pipefail
# Lab2 quick/full verification helper.
# Usage:
# bash scripts/test_lab2.sh
# Optional env vars:
# COMPILER=./build/bin/compiler
# CASE_DIR=test/test_case
# OUT_DIR=test/test_result/lab2_ir
# LOG_FILE=test/test_result/lab2_test.log
COMPILER="${COMPILER:-./build/bin/compiler}"
CASE_DIR="${CASE_DIR:-test/test_case}"
OUT_DIR="${OUT_DIR:-test/test_result/lab2_ir}"
LOG_FILE="${LOG_FILE:-test/test_result/lab2_test.log}"
VERIFY_SCRIPT="./scripts/verify_ir.sh"
if [[ ! -x "$COMPILER" ]]; then
echo "compiler not found or not executable: $COMPILER" >&2
echo "build first:" >&2
echo " cmake -S . -B build -DCMAKE_BUILD_TYPE=Release" >&2
echo " cmake --build build -j \"\$(nproc)\"" >&2
exit 1
fi
if [[ ! -x "$VERIFY_SCRIPT" ]]; then
echo "verify script not found or not executable: $VERIFY_SCRIPT" >&2
exit 1
fi
if [[ ! -d "$CASE_DIR" ]]; then
echo "case dir not found: $CASE_DIR" >&2
exit 1
fi
mkdir -p "$OUT_DIR"
# Preflight: ensure compiler supports IR emission (not parse-only build).
probe_input="$CASE_DIR/simple_add.sy"
probe_err="$OUT_DIR/.lab2_probe.err"
if [[ -f "$probe_input" ]]; then
set +e
"$COMPILER" --emit-ir "$probe_input" > /dev/null 2> "$probe_err"
probe_rc=$?
set -e
if [[ $probe_rc -ne 0 ]] && grep -Eiq "parse-only|IR/汇编输出已禁用" "$probe_err"; then
echo "detected parse-only compiler build, cannot run Lab2 IR tests." >&2
echo "rebuild with IR enabled:" >&2
echo " cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF" >&2
echo " cmake --build build -j \"\$(nproc)\"" >&2
rm -f "$probe_err"
exit 2
fi
rm -f "$probe_err"
fi
mkdir -p "$(dirname "$LOG_FILE")"
: > "$LOG_FILE"
echo "[Lab2] start test" | tee -a "$LOG_FILE"
echo "compiler : $COMPILER" | tee -a "$LOG_FILE"
echo "cases : $CASE_DIR" | tee -a "$LOG_FILE"
echo "out dir : $OUT_DIR" | tee -a "$LOG_FILE"
echo "[Step 1] single sample check: simple_add.sy" | tee -a "$LOG_FILE"
sample_input="$(find "$CASE_DIR" -type f -name "simple_add.sy" -print -quit)"
if [[ -z "$sample_input" ]]; then
sample_input="$(find "$CASE_DIR" -type f -name "*.sy" | sort | head -n 1)"
fi
if [[ -z "$sample_input" ]]; then
echo "single sample: FAIL (no .sy case found under $CASE_DIR)" | tee -a "$LOG_FILE"
echo "stop here. see log: $LOG_FILE" >&2
exit 1
fi
if "$VERIFY_SCRIPT" "$sample_input" "$OUT_DIR" --run >> "$LOG_FILE" 2>&1; then
echo "single sample: PASS" | tee -a "$LOG_FILE"
else
echo "single sample: FAIL" | tee -a "$LOG_FILE"
echo "stop here. see log: $LOG_FILE" >&2
exit 1
fi
echo "[Step 2] full Lab2 regression" | tee -a "$LOG_FILE"
pass=0
fail=0
total=0
failed_list=()
while IFS= read -r -d '' sy; do
total=$((total + 1))
name="$(basename "$sy")"
echo "[$total] $name" | tee -a "$LOG_FILE"
if "$VERIFY_SCRIPT" "$sy" "$OUT_DIR" --run >> "$LOG_FILE" 2>&1; then
pass=$((pass + 1))
echo " PASS" | tee -a "$LOG_FILE"
else
fail=$((fail + 1))
failed_list+=("$sy")
echo " FAIL" | tee -a "$LOG_FILE"
fi
done < <(find "$CASE_DIR" -type f -name "*.sy" -print0 | sort -z)
echo "" | tee -a "$LOG_FILE"
echo "[Summary]" | tee -a "$LOG_FILE"
echo "total: $total" | tee -a "$LOG_FILE"
echo "pass : $pass" | tee -a "$LOG_FILE"
echo "fail : $fail" | tee -a "$LOG_FILE"
if [[ $fail -gt 0 ]]; then
echo "failed cases:" | tee -a "$LOG_FILE"
for f in "${failed_list[@]}"; do
echo " - $f" | tee -a "$LOG_FILE"
done
echo "Lab2 target is not fully met yet." | tee -a "$LOG_FILE"
echo "see details in $LOG_FILE"
exit 1
fi
echo "All Lab2 cases passed. Lab2 target regression is met." | tee -a "$LOG_FILE"
echo "see details in $LOG_FILE"

@ -0,0 +1,119 @@
#!/usr/bin/env bash
set -euo pipefail
# Lab3 full backend regression helper.
# Usage:
# bash scripts/test_lab3.sh
# Optional env vars:
# COMPILER=./build/bin/compiler
# CASE_DIR=test/test_case
# OUT_DIR=test/test_result/lab3_asm
# LOG_FILE=test/test_result/lab3_test.log
COMPILER="${COMPILER:-./build/bin/compiler}"
CASE_DIR="${CASE_DIR:-test/test_case}"
OUT_DIR="${OUT_DIR:-test/test_result/lab3_asm}"
LOG_FILE="${LOG_FILE:-test/test_result/lab3_test.log}"
VERIFY_SCRIPT="./scripts/verify_asm.sh"
if [[ ! -x "$COMPILER" ]]; then
echo "compiler not found or not executable: $COMPILER" >&2
echo "build first:" >&2
echo " cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF" >&2
echo " cmake --build build -j \"\$(nproc)\"" >&2
exit 1
fi
if [[ ! -x "$VERIFY_SCRIPT" ]]; then
echo "verify script not found or not executable: $VERIFY_SCRIPT" >&2
exit 1
fi
if [[ ! -d "$CASE_DIR" ]]; then
echo "case dir not found: $CASE_DIR" >&2
exit 1
fi
mkdir -p "$OUT_DIR"
probe_input="$CASE_DIR/functional/simple_add.sy"
probe_err="$OUT_DIR/.lab3_probe.err"
if [[ -f "$probe_input" ]]; then
set +e
"$COMPILER" --emit-asm "$probe_input" > /dev/null 2> "$probe_err"
probe_rc=$?
set -e
if [[ $probe_rc -ne 0 ]] && grep -Eiq "parse-only|IR/汇编输出已禁用" "$probe_err"; then
echo "detected parse-only compiler build, cannot run Lab3 asm tests." >&2
echo "rebuild with MIR/ASM enabled:" >&2
echo " cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=OFF" >&2
echo " cmake --build build -j \"\$(nproc)\"" >&2
rm -f "$probe_err"
exit 2
fi
rm -f "$probe_err"
fi
mkdir -p "$(dirname "$LOG_FILE")"
: > "$LOG_FILE"
echo "[Lab3] start test" | tee -a "$LOG_FILE"
echo "compiler : $COMPILER" | tee -a "$LOG_FILE"
echo "cases : $CASE_DIR" | tee -a "$LOG_FILE"
echo "out dir : $OUT_DIR" | tee -a "$LOG_FILE"
echo "[Step 1] single sample check: simple_add.sy" | tee -a "$LOG_FILE"
if [[ ! -f "$probe_input" ]]; then
echo "single sample: FAIL (missing $probe_input)" | tee -a "$LOG_FILE"
exit 1
fi
if "$VERIFY_SCRIPT" "$probe_input" "$OUT_DIR" --run >> "$LOG_FILE" 2>&1; then
echo "single sample: PASS" | tee -a "$LOG_FILE"
else
echo "single sample: FAIL" | tee -a "$LOG_FILE"
echo "stop here. see log: $LOG_FILE" >&2
exit 1
fi
echo "[Step 2] full Lab3 asm regression" | tee -a "$LOG_FILE"
pass=0
fail=0
total=0
failed_list=()
while IFS= read -r -d '' sy; do
total=$((total + 1))
name="${sy#$CASE_DIR/}"
echo "[$total] $name" | tee -a "$LOG_FILE"
if "$VERIFY_SCRIPT" "$sy" "$OUT_DIR" --run >> "$LOG_FILE" 2>&1; then
pass=$((pass + 1))
echo " PASS" | tee -a "$LOG_FILE"
else
fail=$((fail + 1))
failed_list+=("$sy")
echo " FAIL" | tee -a "$LOG_FILE"
fi
done < <(find "$CASE_DIR" -type f -name "*.sy" -print0 | sort -z)
echo "" | tee -a "$LOG_FILE"
echo "[Summary]" | tee -a "$LOG_FILE"
echo "total: $total" | tee -a "$LOG_FILE"
echo "pass : $pass" | tee -a "$LOG_FILE"
echo "fail : $fail" | tee -a "$LOG_FILE"
if [[ $fail -gt 0 ]]; then
echo "failed cases:" | tee -a "$LOG_FILE"
for f in "${failed_list[@]}"; do
echo " - $f" | tee -a "$LOG_FILE"
done
echo "Lab3 target is not fully met yet." | tee -a "$LOG_FILE"
echo "see details in $LOG_FILE"
exit 1
fi
echo "All Lab3 cases passed. Lab3 target regression is met." | tee -a "$LOG_FILE"
echo "see details in $LOG_FILE"

@ -41,18 +41,25 @@ if ! command -v aarch64-linux-gnu-gcc >/dev/null 2>&1; then
exit 1 exit 1
fi fi
if ! command -v clang >/dev/null 2>&1; then
echo "未找到 clang无法由 IR 生成 AArch64 汇编。" >&2
exit 1
fi
mkdir -p "$out_dir" mkdir -p "$out_dir"
base=$(basename "$input") base=$(basename "$input")
stem=${base%.sy} stem=${base%.sy}
asm_file="$out_dir/$stem.s" asm_file="$out_dir/$stem.s"
exe="$out_dir/$stem" exe="$out_dir/$stem"
runtime_obj="$out_dir/sylib.aarch64.o"
stdin_file="$input_dir/$stem.in" stdin_file="$input_dir/$stem.in"
expected_file="$input_dir/$stem.out" expected_file="$input_dir/$stem.out"
"$compiler" --emit-asm "$input" > "$asm_file" "$compiler" --emit-asm "$input" > "$asm_file"
echo "汇编已生成: $asm_file" echo "汇编已生成: $asm_file"
aarch64-linux-gnu-gcc "$asm_file" -o "$exe" aarch64-linux-gnu-gcc -O2 -Wno-unused-result -c sylib/sylib.c -o "$runtime_obj"
aarch64-linux-gnu-gcc "$asm_file" "$runtime_obj" -o "$exe"
echo "可执行文件已生成: $exe" echo "可执行文件已生成: $exe"
if [[ "$run_exec" == true ]]; then if [[ "$run_exec" == true ]]; then
@ -63,6 +70,8 @@ if [[ "$run_exec" == true ]]; then
stdout_file="$out_dir/$stem.stdout" stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out" actual_file="$out_dir/$stem.actual.out"
actual_norm="$out_dir/$stem.actual.norm"
expected_norm="$out_dir/$stem.expected.norm"
echo "运行 $exe ..." echo "运行 $exe ..."
set +e set +e
if [[ -f "$stdin_file" ]]; then if [[ -f "$stdin_file" ]]; then
@ -83,7 +92,9 @@ if [[ "$run_exec" == true ]]; then
} > "$actual_file" } > "$actual_file"
if [[ -f "$expected_file" ]]; then if [[ -f "$expected_file" ]]; then
if diff -u "$expected_file" "$actual_file"; then perl -0pe 's/\r\n/\n/g; s/\r/\n/g; s/\n?\z//' "$expected_file" > "$expected_norm"
perl -0pe 's/\r\n/\n/g; s/\r/\n/g; s/\n?\z//' "$actual_file" > "$actual_norm"
if diff -u "$expected_norm" "$actual_norm"; then
echo "输出匹配: $expected_file" echo "输出匹配: $expected_file"
else else
echo "输出不匹配: $expected_file" >&2 echo "输出不匹配: $expected_file" >&2

@ -47,20 +47,28 @@ expected_file="$input_dir/$stem.out"
echo "IR 已生成: $out_file" echo "IR 已生成: $out_file"
if [[ "$run_exec" == true ]]; then if [[ "$run_exec" == true ]]; then
if ! command -v llc >/dev/null 2>&1; then
echo "未找到 llc无法运行 IR。请安装 LLVM。" >&2
exit 1
fi
if ! command -v clang >/dev/null 2>&1; then if ! command -v clang >/dev/null 2>&1; then
echo "未找到 clang无法链接可执行文件。请安装 LLVM/Clang。" >&2 echo "未找到 clang无法编译可执行文件。请安装 LLVM/Clang。" >&2
exit 1 exit 1
fi fi
obj="$out_dir/$stem.o"
exe="$out_dir/$stem" exe="$out_dir/$stem"
stdout_file="$out_dir/$stem.stdout" stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out" actual_file="$out_dir/$stem.actual.out"
llc -filetype=obj "$out_file" -o "$obj" actual_norm="$out_dir/$stem.actual.norm"
clang "$obj" -o "$exe" expected_norm="$out_dir/$stem.expected.norm"
# 直接让 clang 优化 .ll可显著降低大规模 Lab2 performance 用例运行时间。
# -fwrapv 保持有符号整数按补码回绕,避免与当前 IR 的测试语义偏离。
if ! clang -O2 -fwrapv -Wno-override-module "$out_file" sylib/sylib.c -o "$exe"; then
if ! command -v llc >/dev/null 2>&1; then
echo "未找到 llc且 clang 直接编译 IR 失败,无法运行 IR。" >&2
exit 1
fi
obj="$out_dir/$stem.o"
llc -filetype=obj "$out_file" -o "$obj"
clang "$obj" sylib/sylib.c -o "$exe"
fi
echo "运行 $exe ..." echo "运行 $exe ..."
set +e set +e
if [[ -f "$stdin_file" ]]; then if [[ -f "$stdin_file" ]]; then
@ -81,7 +89,9 @@ if [[ "$run_exec" == true ]]; then
} > "$actual_file" } > "$actual_file"
if [[ -f "$expected_file" ]]; then if [[ -f "$expected_file" ]]; then
if diff -u "$expected_file" "$actual_file"; then perl -0pe 's/\r\n/\n/g; s/\r/\n/g; s/\n?\z//' "$expected_file" > "$expected_norm"
perl -0pe 's/\r\n/\n/g; s/\r/\n/g; s/\n?\z//' "$actual_file" > "$actual_norm"
if diff -u "$expected_norm" "$actual_norm"; then
echo "输出匹配: $expected_file" echo "输出匹配: $expected_file"
else else
echo "输出不匹配: $expected_file" >&2 echo "输出不匹配: $expected_file" >&2

@ -1,67 +1,145 @@
// SysY 子集语法:支持形如
// int main() { int a = 1; int b = 2; return a + b; }
// 的最小返回表达式编译。
// 后续需要自行添加
grammar SysY; grammar SysY;
/*===-------------------------------------------===*/ /*===-------------------------------------------===*/
/* Lexer rules */ /* Lexer rules */
/*===-------------------------------------------===*/ /*===-------------------------------------------===*/
CONST: 'const';
INT: 'int'; INT: 'int';
FLOAT: 'float';
VOID: 'void';
IF: 'if';
ELSE: 'else';
WHILE: 'while';
BREAK: 'break';
CONTINUE: 'continue';
RETURN: 'return'; RETURN: 'return';
ASSIGN: '='; ASSIGN: '=';
EQ: '==';
NE: '!=';
LT: '<';
GT: '>';
LE: '<=';
GE: '>=';
ADD: '+'; ADD: '+';
SUB: '-';
MUL: '*';
DIV: '/';
MOD: '%';
NOT: '!';
LAND: '&&';
LOR: '||';
LPAREN: '('; LPAREN: '(';
RPAREN: ')'; RPAREN: ')';
LBRACE: '{'; LBRACE: '{';
RBRACE: '}'; RBRACE: '}';
LBRACK: '[';
RBRACK: ']';
COMMA: ',';
SEMICOLON: ';'; SEMICOLON: ';';
FLOAT_CONST
: DEC_FLOAT_CONST
| HEX_FLOAT_CONST
;
INT_CONST
: HEX_PREFIX HEX_DIGIT+
| '0' [0-7]+
| '0'
| [1-9] DIGIT*
;
ID: [a-zA-Z_][a-zA-Z_0-9]*; ID: [a-zA-Z_][a-zA-Z_0-9]*;
ILITERAL: [0-9]+;
WS: [ \t\r\n] -> skip; WS: [ \t\r\n] -> skip;
LINECOMMENT: '//' ~[\r\n]* -> skip; LINECOMMENT: '//' ~[\r\n]* -> skip;
BLOCKCOMMENT: '/*' .*? '*/' -> skip; BLOCKCOMMENT: '/*' .*? '*/' -> skip;
fragment DEC_FLOAT_CONST
: DIGIT+ '.' DIGIT* EXP_PART?
| '.' DIGIT+ EXP_PART?
| DIGIT+ EXP_PART
;
fragment HEX_FLOAT_CONST
: HEX_PREFIX HEX_DIGIT+ '.' HEX_DIGIT* BIN_EXP_PART
| HEX_PREFIX '.' HEX_DIGIT+ BIN_EXP_PART
| HEX_PREFIX HEX_DIGIT+ BIN_EXP_PART
;
fragment EXP_PART: [eE] [+-]? DIGIT+;
fragment BIN_EXP_PART: [pP] [+-]? DIGIT+;
fragment HEX_PREFIX: '0' [xX];
fragment HEX_DIGIT: [0-9a-fA-F];
fragment DIGIT: [0-9];
/*===-------------------------------------------===*/ /*===-------------------------------------------===*/
/* Syntax rules */ /* Syntax rules */
/*===-------------------------------------------===*/ /*===-------------------------------------------===*/
compUnit compUnit
: funcDef EOF : (decl | funcDef)+ EOF
; ;
decl decl
: btype varDef SEMICOLON : constDecl
| varDecl
; ;
btype constDecl
: CONST bType constDef (COMMA constDef)* SEMICOLON
;
varDecl
: bType varDef (COMMA varDef)* SEMICOLON
;
bType
: INT : INT
| FLOAT
;
constDef
: ID (LBRACK constExp RBRACK)* ASSIGN constInitVal
;
constInitVal
: constExp
| LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE
; ;
varDef varDef
: lValue (ASSIGN initValue)? : ID (LBRACK constExp RBRACK)* (ASSIGN initVal)?
; ;
initValue initVal
: exp : exp
| LBRACE (initVal (COMMA initVal)*)? RBRACE
; ;
funcDef funcDef
: funcType ID LPAREN RPAREN blockStmt : funcType ID LPAREN funcFParams? RPAREN block
; ;
funcType funcType
: INT : VOID
| INT
| FLOAT
;
funcFParams
: funcFParam (COMMA funcFParam)*
;
funcFParam
: bType ID (LBRACK RBRACK (LBRACK exp RBRACK)*)?
; ;
blockStmt block
: LBRACE blockItem* RBRACE : LBRACE blockItem* RBRACE
; ;
@ -71,28 +149,80 @@ blockItem
; ;
stmt stmt
: returnStmt : lVal ASSIGN exp SEMICOLON
| exp? SEMICOLON
| block
| IF LPAREN cond RPAREN stmt (ELSE stmt)?
| WHILE LPAREN cond RPAREN stmt
| BREAK SEMICOLON
| CONTINUE SEMICOLON
| RETURN exp? SEMICOLON
; ;
returnStmt exp
: RETURN exp SEMICOLON : addExp
; ;
exp cond
: LPAREN exp RPAREN # parenExp : lOrExp
| var # varExp
| number # numberExp
| exp ADD exp # additiveExp
; ;
var lVal
: ID : ID (LBRACK exp RBRACK)*
; ;
lValue primaryExp
: ID : LPAREN exp RPAREN
| lVal
| number
; ;
number number
: ILITERAL : INT_CONST
| FLOAT_CONST
;
unaryExp
: primaryExp
| ID LPAREN funcRParams? RPAREN
| unaryOp unaryExp
;
unaryOp
: ADD
| SUB
| NOT
;
funcRParams
: exp (COMMA exp)*
;
mulExp
: unaryExp ((MUL | DIV | MOD) unaryExp)*
;
addExp
: mulExp ((ADD | SUB) mulExp)*
;
relExp
: addExp ((LT | GT | LE | GE) addExp)*
;
eqExp
: relExp ((EQ | NE) relExp)*
;
lAndExp
: eqExp (LAND eqExp)*
;
lOrExp
: lAndExp (LOR lAndExp)*
; ;
constExp
: addExp
;

@ -3,15 +3,44 @@ add_library(frontend STATIC
SyntaxTreePrinter.cpp SyntaxTreePrinter.cpp
) )
set(ANTLR4_GRAMMAR "${PROJECT_SOURCE_DIR}/src/antlr4/SysY.g4")
set(ANTLR4_GENERATED_FILES
"${ANTLR4_GENERATED_DIR}/SysYLexer.cpp"
"${ANTLR4_GENERATED_DIR}/SysYLexer.h"
"${ANTLR4_GENERATED_DIR}/SysYLexer.interp"
"${ANTLR4_GENERATED_DIR}/SysYLexer.tokens"
"${ANTLR4_GENERATED_DIR}/SysYParser.cpp"
"${ANTLR4_GENERATED_DIR}/SysYParser.h"
"${ANTLR4_GENERATED_DIR}/SysY.interp"
"${ANTLR4_GENERATED_DIR}/SysY.tokens"
"${ANTLR4_GENERATED_DIR}/SysYBaseVisitor.h"
"${ANTLR4_GENERATED_DIR}/SysYVisitor.h"
)
add_custom_command(
OUTPUT ${ANTLR4_GENERATED_FILES}
COMMAND ${CMAKE_COMMAND} -E make_directory "${ANTLR4_GENERATED_DIR}"
COMMAND ${Java_JAVA_EXECUTABLE} -jar "${ANTLR4_JAR}"
-Dlanguage=Cpp
-visitor
-no-listener
-Xexact-output-dir
-o "${ANTLR4_GENERATED_DIR}"
"${ANTLR4_GRAMMAR}"
DEPENDS "${ANTLR4_GRAMMAR}" "${ANTLR4_JAR}"
COMMENT "Generating ANTLR4 parser sources from SysY.g4"
VERBATIM
)
add_custom_target(antlr4_generated DEPENDS ${ANTLR4_GENERATED_FILES})
add_dependencies(frontend antlr4_generated)
target_sources(frontend PRIVATE
"${ANTLR4_GENERATED_DIR}/SysYLexer.cpp"
"${ANTLR4_GENERATED_DIR}/SysYParser.cpp"
)
target_link_libraries(frontend PUBLIC target_link_libraries(frontend PUBLIC
build_options build_options
${ANTLR4_RUNTIME_TARGET} ${ANTLR4_RUNTIME_TARGET}
) )
# Lexer/Parser
file(GLOB_RECURSE ANTLR4_GENERATED_SOURCES CONFIGURE_DEPENDS
"${ANTLR4_GENERATED_DIR}/*.cpp"
)
if(ANTLR4_GENERATED_SOURCES)
target_sources(frontend PRIVATE ${ANTLR4_GENERATED_SOURCES})
endif()

@ -9,13 +9,14 @@
#include "ir/IR.h" #include "ir/IR.h"
#include <algorithm>
#include <utility> #include <utility>
namespace ir { namespace ir {
// 当前 BasicBlock 还没有专门的 label type因此先用 void 作为占位类型 // BasicBlock 使用 label type
BasicBlock::BasicBlock(std::string name) BasicBlock::BasicBlock(std::string name)
: Value(Type::GetVoidType(), std::move(name)) {} : Value(Type::GetLabelType(), std::move(name)) {}
Function* BasicBlock::GetParent() const { return parent_; } Function* BasicBlock::GetParent() const { return parent_; }
@ -32,6 +33,10 @@ const std::vector<std::unique_ptr<Instruction>>& BasicBlock::GetInstructions()
return instructions_; return instructions_;
} }
std::vector<std::unique_ptr<Instruction>>& BasicBlock::GetMutableInstructions() {
return instructions_;
}
// 前驱/后继接口先保留给后续 CFG 扩展使用。 // 前驱/后继接口先保留给后续 CFG 扩展使用。
// 当前最小 IR 中还没有 branch 指令,因此这些列表通常为空。 // 当前最小 IR 中还没有 branch 指令,因此这些列表通常为空。
const std::vector<BasicBlock*>& BasicBlock::GetPredecessors() const { const std::vector<BasicBlock*>& BasicBlock::GetPredecessors() const {
@ -42,4 +47,83 @@ const std::vector<BasicBlock*>& BasicBlock::GetSuccessors() const {
return successors_; return successors_;
} }
void BasicBlock::AddPredecessor(BasicBlock* pred) {
if (!pred) return;
for (auto* p : predecessors_) {
if (p == pred) return;
}
predecessors_.push_back(pred);
}
void BasicBlock::AddSuccessor(BasicBlock* succ) {
if (!succ) return;
for (auto* s : successors_) {
if (s == succ) return;
}
successors_.push_back(succ);
}
void BasicBlock::ClearPredecessors() { predecessors_.clear(); }
void BasicBlock::ClearSuccessors() { successors_.clear(); }
void BasicBlock::RemovePredecessor(BasicBlock* pred) {
predecessors_.erase(std::remove(predecessors_.begin(), predecessors_.end(), pred),
predecessors_.end());
}
void BasicBlock::RemoveSuccessor(BasicBlock* succ) {
successors_.erase(std::remove(successors_.begin(), successors_.end(), succ),
successors_.end());
}
void BasicBlock::EraseInstruction(Instruction* inst) {
if (!inst) return;
auto it = std::find_if(instructions_.begin(), instructions_.end(),
[&](const auto& ptr) { return ptr.get() == inst; });
if (it == instructions_.end()) return;
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
if (auto* operand = inst->GetOperand(i)) {
operand->RemoveUse(inst, i);
}
}
inst->SetParent(nullptr);
instructions_.erase(it);
}
void BasicBlock::ReplaceTerminator(std::unique_ptr<Instruction> inst) {
if (!inst || !inst->IsTerminator()) return;
if (!instructions_.empty() && instructions_.back()->IsTerminator()) {
auto* old = instructions_.back().get();
for (size_t i = 0; i < old->GetNumOperands(); ++i) {
if (auto* operand = old->GetOperand(i)) {
operand->RemoveUse(old, i);
}
}
old->SetParent(nullptr);
instructions_.pop_back();
}
inst->SetParent(this);
instructions_.push_back(std::move(inst));
LinkSuccessorsIfNeeded(instructions_.back().get());
}
void BasicBlock::LinkSuccessorsIfNeeded(Instruction* inst) {
if (!inst) return;
if (auto* br = dynamic_cast<BranchInst*>(inst)) {
auto* dest = br->GetDest();
AddSuccessor(dest);
dest->AddPredecessor(this);
return;
}
if (auto* cbr = dynamic_cast<CondBrInst*>(inst)) {
auto* t = cbr->GetTrueDest();
auto* f = cbr->GetFalseDest();
AddSuccessor(t);
AddSuccessor(f);
t->AddPredecessor(this);
f->AddPredecessor(this);
}
}
} // namespace ir } // namespace ir

@ -1,6 +1,7 @@
// 管理基础类型、整型常量池和临时名生成。 // 管理基础类型、整型常量池和临时名生成。
#include "ir/IR.h" #include "ir/IR.h"
#include <cstring>
#include <sstream> #include <sstream>
namespace ir { namespace ir {
@ -15,9 +16,43 @@ ConstantInt* Context::GetConstInt(int v) {
return inserted->second.get(); return inserted->second.get();
} }
ConstantInt* Context::GetConstBool(bool v) {
int iv = v ? 1 : 0;
auto it = const_bools_.find(iv);
if (it != const_bools_.end()) return it->second.get();
auto inserted = const_bools_.emplace(
iv, std::make_unique<ConstantInt>(Type::GetInt1Type(), iv)).first;
return inserted->second.get();
}
static uint32_t FloatToBits(float v) {
uint32_t bits = 0;
std::memcpy(&bits, &v, sizeof(float));
return bits;
}
ConstantFloat* Context::GetConstFloat(float v) {
uint32_t bits = FloatToBits(v);
auto it = const_floats_.find(bits);
if (it != const_floats_.end()) return it->second.get();
auto inserted = const_floats_.emplace(
bits, std::make_unique<ConstantFloat>(Type::GetFloatType(), v)).first;
return inserted->second.get();
}
ConstantArray* Context::CreateConstArray(std::shared_ptr<Type> array_ty,
std::vector<ConstantValue*> elements) {
if (!array_ty || !array_ty->IsArray()) {
throw std::runtime_error("CreateConstArray 需要 array type");
}
const_arrays_.push_back(
std::make_unique<ConstantArray>(std::move(array_ty), std::move(elements)));
return const_arrays_.back().get();
}
std::string Context::NextTemp() { std::string Context::NextTemp() {
std::ostringstream oss; std::ostringstream oss;
oss << "%" << ++temp_index_; oss << "%t" << ++temp_index_;
return oss.str(); return oss.str();
} }

@ -5,13 +5,32 @@
namespace ir { namespace ir {
Function::Function(std::string name, std::shared_ptr<Type> ret_type) Function::Function(std::string name, std::shared_ptr<Type> func_type,
: Value(std::move(ret_type), std::move(name)) { bool is_declaration)
entry_ = CreateBlock("entry"); : Value(std::move(func_type), std::move(name)),
is_declaration_(is_declaration) {
if (!type_ || !type_->IsFunction()) {
throw std::runtime_error("Function 需要 function type");
}
const auto& params = type_->GetParamTypes();
args_.reserve(params.size());
for (size_t i = 0; i < params.size(); ++i) {
args_.push_back(std::make_unique<Argument>(params[i], "%arg" + std::to_string(i), i));
}
if (!is_declaration_) {
entry_ = CreateBlock("entry");
}
} }
BasicBlock* Function::CreateBlock(const std::string& name) { BasicBlock* Function::CreateBlock(const std::string& name) {
auto block = std::make_unique<BasicBlock>(name); std::string base = name.empty() ? "bb" : name;
auto& count = block_name_counts_[base];
std::string final_name = base;
if (count > 0) {
final_name = base + "." + std::to_string(count);
}
++count;
auto block = std::make_unique<BasicBlock>(final_name);
auto* ptr = block.get(); auto* ptr = block.get();
ptr->SetParent(this); ptr->SetParent(this);
blocks_.push_back(std::move(block)); blocks_.push_back(std::move(block));
@ -29,4 +48,35 @@ const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const {
return blocks_; return blocks_;
} }
std::vector<std::unique_ptr<BasicBlock>>& Function::GetMutableBlocks() {
return blocks_;
}
const std::vector<std::unique_ptr<Argument>>& Function::GetArguments() const {
return args_;
}
size_t Function::GetNumArgs() const { return args_.size(); }
Argument* Function::GetArg(size_t index) {
if (index >= args_.size()) {
throw std::out_of_range("Function arg index out of range");
}
return args_[index].get();
}
std::shared_ptr<Type> Function::GetFunctionType() const { return type_; }
std::shared_ptr<Type> Function::GetReturnType() const {
if (!type_ || !type_->IsFunction()) {
throw std::runtime_error("Function type 缺失");
}
return type_->GetReturnType();
}
bool Function::IsDeclaration() const { return is_declaration_; }
Argument::Argument(std::shared_ptr<Type> ty, std::string name, size_t index)
: Value(std::move(ty), std::move(name)), index_(index) {}
} // namespace ir } // namespace ir

@ -8,4 +8,23 @@ namespace ir {
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name) GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)) {} : User(std::move(ty), std::move(name)) {}
GlobalVariable::GlobalVariable(std::shared_ptr<Type> value_ty, std::string name,
ConstantValue* init, bool is_const)
: GlobalValue(Type::GetPointerType(value_ty), std::move(name)),
value_type_(std::move(value_ty)),
initializer_(init),
is_const_(is_const) {
if (!value_type_) {
throw std::runtime_error("GlobalVariable 缺少 value type");
}
}
const std::shared_ptr<Type>& GlobalVariable::GetValueType() const {
return value_type_;
}
ConstantValue* GlobalVariable::GetInitializer() const { return initializer_; }
bool GlobalVariable::IsConst() const { return is_const_; }
} // namespace ir } // namespace ir

@ -9,6 +9,42 @@
#include "utils/Log.h" #include "utils/Log.h"
namespace ir { namespace ir {
static std::string TypeToString(const Type& ty) {
switch (ty.GetKind()) {
case Type::Kind::Void:
return "void";
case Type::Kind::Int1:
return "i1";
case Type::Kind::Int32:
return "i32";
case Type::Kind::Float:
return "float";
case Type::Kind::Label:
return "label";
case Type::Kind::Pointer:
return TypeToString(*ty.GetElementType()) + "*";
case Type::Kind::Array: {
return "[" + std::to_string(ty.GetArraySize()) + " x " +
TypeToString(*ty.GetElementType()) + "]";
}
case Type::Kind::Function: {
std::string out = TypeToString(*ty.GetReturnType()) + " (";
const auto& params = ty.GetParamTypes();
for (size_t i = 0; i < params.size(); ++i) {
if (i > 0) out += ", ";
out += TypeToString(*params[i]);
}
if (ty.IsVarArg()) {
if (!params.empty()) out += ", ";
out += "...";
}
out += ")";
return out;
}
}
return "?";
}
IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb) IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb)
: ctx_(ctx), insert_block_(bb) {} : ctx_(ctx), insert_block_(bb) {}
@ -42,11 +78,107 @@ BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs,
return CreateBinary(Opcode::Add, lhs, rhs, name); return CreateBinary(Opcode::Add, lhs, rhs, name);
} }
AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { BinaryInst* IRBuilder::CreateSub(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::Sub, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateMul(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::Mul, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateSDiv(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::SDiv, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateSRem(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::SRem, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFAdd(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::FAdd, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFSub(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::FSub, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFMul(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::FMul, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFDiv(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::FDiv, lhs, rhs, name);
}
ICmpInst* IRBuilder::CreateICmp(ICmpPredicate pred, Value* lhs, Value* rhs,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<ICmpInst>(pred, lhs, rhs, name);
}
FCmpInst* IRBuilder::CreateFCmp(FCmpPredicate pred, Value* lhs, Value* rhs,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<FCmpInst>(pred, lhs, rhs, name);
}
CastInst* IRBuilder::CreateSIToFP(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CastInst>(Opcode::SIToFP, std::move(dst_ty), src,
name);
}
CastInst* IRBuilder::CreateFPToSI(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CastInst>(Opcode::FPToSI, std::move(dst_ty), src,
name);
}
CastInst* IRBuilder::CreateZExt(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name) {
if (!insert_block_) { if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
} }
return insert_block_->Append<AllocaInst>(Type::GetPtrInt32Type(), name); return insert_block_->Append<CastInst>(Opcode::ZExt, std::move(dst_ty), src,
name);
}
ConstantInt* IRBuilder::CreateConstBool(bool v) {
return ctx_.GetConstBool(v);
}
ConstantFloat* IRBuilder::CreateConstFloat(float v) {
return ctx_.GetConstFloat(v);
}
AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr<Type> ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<AllocaInst>(std::move(ty), name);
}
AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
return CreateAlloca(Type::GetInt32Type(), name);
} }
LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
@ -57,7 +189,11 @@ LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
throw std::runtime_error( throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr")); FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr"));
} }
return insert_block_->Append<LoadInst>(Type::GetInt32Type(), ptr, name); if (!ptr->GetType() || !ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateLoad ptr 不是指针"));
}
auto val_ty = ptr->GetType()->GetElementType();
return insert_block_->Append<LoadInst>(val_ty, ptr, name);
} }
StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
@ -75,6 +211,100 @@ StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
return insert_block_->Append<StoreInst>(Type::GetVoidType(), val, ptr); return insert_block_->Append<StoreInst>(Type::GetVoidType(), val, ptr);
} }
static std::shared_ptr<Type> ResolveGepResultType(const std::shared_ptr<Type>& base_ptr_ty,
size_t index_count) {
if (!base_ptr_ty || !base_ptr_ty->IsPointer()) {
throw std::runtime_error("GEP base type 必须是指针");
}
auto cur = base_ptr_ty->GetElementType();
for (size_t i = 0; i < index_count; ++i) {
// LLVM GEP 的第一个索引只是在当前 pointee 对象上做寻址,
// 不会立刻深入到数组元素类型;真正进入聚合类型从第二个索引开始。
if (i == 0) {
continue;
}
if (cur->IsArray()) {
cur = cur->GetElementType();
continue;
}
if (cur->IsPointer()) {
cur = cur->GetElementType();
continue;
}
}
return Type::GetPointerType(cur);
}
GepInst* IRBuilder::CreateGep(Value* base_ptr, std::vector<Value*> indices,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!base_ptr || !base_ptr->GetType() || !base_ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateGep base_ptr 非指针"));
}
auto result_ty = ResolveGepResultType(base_ptr->GetType(), indices.size());
return insert_block_->Append<GepInst>(result_ty, base_ptr, std::move(indices),
name);
}
CallInst* IRBuilder::CreateCall(Value* callee, std::vector<Value*> args,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
if (!callee || !callee->GetType()) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCall 缺少 callee"));
}
std::shared_ptr<Type> func_ty;
if (callee->GetType()->IsFunction()) {
func_ty = callee->GetType();
} else if (callee->GetType()->IsPointer() &&
callee->GetType()->GetElementType()->IsFunction()) {
func_ty = callee->GetType()->GetElementType();
} else {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCall callee 非函数"));
}
const auto& params = func_ty->GetParamTypes();
if (!func_ty->IsVarArg() && params.size() != args.size()) {
throw std::runtime_error(FormatError("ir", "IRBuilder::CreateCall 参数数量不匹配"));
}
for (size_t i = 0; i < params.size() && i < args.size(); ++i) {
if (!args[i] || !args[i]->GetType() ||
!args[i]->GetType()->Equals(*params[i])) {
std::string msg = "IRBuilder::CreateCall 参数类型不匹配: arg" +
std::to_string(i) + " got " +
TypeToString(*args[i]->GetType()) + ", expect " +
TypeToString(*params[i]);
throw std::runtime_error(FormatError("ir", msg));
}
}
auto ret_ty = func_ty->GetReturnType();
return insert_block_->Append<CallInst>(ret_ty, callee, std::move(args), name);
}
PhiInst* IRBuilder::CreatePhi(std::shared_ptr<Type> ty, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<PhiInst>(std::move(ty), name);
}
BranchInst* IRBuilder::CreateBr(BasicBlock* dest) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<BranchInst>(dest);
}
CondBrInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* true_dest,
BasicBlock* false_dest) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CondBrInst>(cond, true_dest, false_dest);
}
ReturnInst* IRBuilder::CreateRet(Value* v) { ReturnInst* IRBuilder::CreateRet(Value* v) {
if (!insert_block_) { if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
@ -86,4 +316,11 @@ ReturnInst* IRBuilder::CreateRet(Value* v) {
return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v); return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v);
} }
ReturnInst* IRBuilder::CreateRetVoid() {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<ReturnInst>(Type::GetVoidType());
}
} // namespace ir } // namespace ir

@ -5,6 +5,11 @@
#include "ir/IR.h" #include "ir/IR.h"
#include <ostream> #include <ostream>
#include <cstdint>
#include <cstring>
#include <iomanip>
#include <limits>
#include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
@ -12,14 +17,41 @@
namespace ir { namespace ir {
static const char* TypeToString(const Type& ty) { static std::string TypeToString(const Type& ty) {
switch (ty.GetKind()) { switch (ty.GetKind()) {
case Type::Kind::Void: case Type::Kind::Void:
return "void"; return "void";
case Type::Kind::Int1:
return "i1";
case Type::Kind::Int32: case Type::Kind::Int32:
return "i32"; return "i32";
case Type::Kind::PtrInt32: case Type::Kind::Float:
return "i32*"; return "float";
case Type::Kind::Label:
return "label";
case Type::Kind::Pointer:
return TypeToString(*ty.GetElementType()) + "*";
case Type::Kind::Array: {
std::ostringstream oss;
oss << "[" << ty.GetArraySize() << " x "
<< TypeToString(*ty.GetElementType()) << "]";
return oss.str();
}
case Type::Kind::Function: {
std::ostringstream oss;
oss << TypeToString(*ty.GetReturnType()) << " (";
const auto& params = ty.GetParamTypes();
for (size_t i = 0; i < params.size(); ++i) {
if (i > 0) oss << ", ";
oss << TypeToString(*params[i]);
}
if (ty.IsVarArg()) {
if (!params.empty()) oss << ", ";
oss << "...";
}
oss << ")";
return oss.str();
}
} }
throw std::runtime_error(FormatError("ir", "未知类型")); throw std::runtime_error(FormatError("ir", "未知类型"));
} }
@ -32,6 +64,18 @@ static const char* OpcodeToString(Opcode op) {
return "sub"; return "sub";
case Opcode::Mul: case Opcode::Mul:
return "mul"; return "mul";
case Opcode::SDiv:
return "sdiv";
case Opcode::SRem:
return "srem";
case Opcode::FAdd:
return "fadd";
case Opcode::FSub:
return "fsub";
case Opcode::FMul:
return "fmul";
case Opcode::FDiv:
return "fdiv";
case Opcode::Alloca: case Opcode::Alloca:
return "alloca"; return "alloca";
case Opcode::Load: case Opcode::Load:
@ -40,21 +84,182 @@ static const char* OpcodeToString(Opcode op) {
return "store"; return "store";
case Opcode::Ret: case Opcode::Ret:
return "ret"; return "ret";
case Opcode::Br:
return "br";
case Opcode::CondBr:
return "br";
case Opcode::ICmp:
return "icmp";
case Opcode::FCmp:
return "fcmp";
case Opcode::Call:
return "call";
case Opcode::Phi:
return "phi";
case Opcode::Gep:
return "getelementptr";
case Opcode::SIToFP:
return "sitofp";
case Opcode::FPToSI:
return "fptosi";
case Opcode::ZExt:
return "zext";
} }
return "?"; return "?";
} }
static std::string ValueToString(const Value* v) { static std::string FloatToString(float v) {
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) { std::uint32_t bits = 0;
static_assert(sizeof(bits) == sizeof(v), "float size mismatch");
std::memcpy(&bits, &v, sizeof(bits));
std::ostringstream oss;
oss << "bitcast (i32 " << std::dec << static_cast<std::uint64_t>(bits)
<< " to float)";
return oss.str();
}
static bool IsZeroInitializer(const ConstantValue* c) {
if (auto* ci = dynamic_cast<const ConstantInt*>(c)) {
return ci->GetValue() == 0;
}
if (auto* cf = dynamic_cast<const ConstantFloat*>(c)) {
return cf->GetValue() == 0.0f;
}
if (auto* ca = dynamic_cast<const ConstantArray*>(c)) {
for (auto* elem : ca->GetElements()) {
if (!elem || !IsZeroInitializer(elem)) {
return false;
}
}
return true;
}
return false;
}
static std::string ConstantToString(const ConstantValue* c) {
if (auto* ci = dynamic_cast<const ConstantInt*>(c)) {
return std::to_string(ci->GetValue()); return std::to_string(ci->GetValue());
} }
if (auto* cf = dynamic_cast<const ConstantFloat*>(c)) {
return FloatToString(cf->GetValue());
}
if (auto* ca = dynamic_cast<const ConstantArray*>(c)) {
if (IsZeroInitializer(ca)) {
return "zeroinitializer";
}
std::ostringstream oss;
oss << "[";
const auto& elems = ca->GetElements();
for (size_t i = 0; i < elems.size(); ++i) {
if (i > 0) oss << ", ";
oss << TypeToString(*elems[i]->GetType()) << " "
<< ConstantToString(elems[i]);
}
oss << "]";
return oss.str();
}
return "<const>";
}
static std::string ValueToString(const Value* v) {
if (auto* c = dynamic_cast<const ConstantValue*>(v)) {
return ConstantToString(c);
}
if (auto* func = dynamic_cast<const Function*>(v)) {
const auto& name = func->GetName();
if (!name.empty() && name[0] == '@') return name;
return "@" + name;
}
if (auto* gv = dynamic_cast<const GlobalValue*>(v)) {
const auto& name = gv->GetName();
if (!name.empty() && name[0] == '@') return name;
return "@" + name;
}
return v ? v->GetName() : "<null>"; return v ? v->GetName() : "<null>";
} }
static std::string LabelToString(const BasicBlock* bb) {
if (!bb) return "%<null>";
const auto& name = bb->GetName();
if (!name.empty() && name[0] == '%') return name;
return "%" + name;
}
static const char* ICmpPredToString(ICmpPredicate pred) {
switch (pred) {
case ICmpPredicate::Eq:
return "eq";
case ICmpPredicate::Ne:
return "ne";
case ICmpPredicate::Slt:
return "slt";
case ICmpPredicate::Sle:
return "sle";
case ICmpPredicate::Sgt:
return "sgt";
case ICmpPredicate::Sge:
return "sge";
}
return "?";
}
static const char* FCmpPredToString(FCmpPredicate pred) {
switch (pred) {
case FCmpPredicate::Oeq:
return "oeq";
case FCmpPredicate::One:
return "one";
case FCmpPredicate::Olt:
return "olt";
case FCmpPredicate::Ole:
return "ole";
case FCmpPredicate::Ogt:
return "ogt";
case FCmpPredicate::Oge:
return "oge";
}
return "?";
}
void IRPrinter::Print(const Module& module, std::ostream& os) { void IRPrinter::Print(const Module& module, std::ostream& os) {
for (const auto& g : module.GetGlobals()) {
if (!g) continue;
os << "@" << g->GetName() << " = "
<< (g->IsConst() ? "constant " : "global ")
<< TypeToString(*g->GetValueType()) << " ";
if (auto* init = g->GetInitializer()) {
os << ConstantToString(init);
} else {
if (g->GetValueType()->IsArray()) {
os << "zeroinitializer";
} else if (g->GetValueType()->IsFloat()) {
os << "0.0";
} else {
os << "0";
}
}
os << "\n";
}
for (const auto& func : module.GetFunctions()) { for (const auto& func : module.GetFunctions()) {
os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName() if (func->IsDeclaration()) {
<< "() {\n"; os << "declare " << TypeToString(*func->GetReturnType()) << " @"
<< func->GetName() << "(";
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) os << ", ";
os << TypeToString(*args[i]->GetType());
}
os << ")\n";
continue;
}
os << "define " << TypeToString(*func->GetReturnType()) << " @"
<< func->GetName() << "(";
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) os << ", ";
os << TypeToString(*args[i]->GetType()) << " " << args[i]->GetName();
}
os << ") {\n";
for (const auto& bb : func->GetBlocks()) { for (const auto& bb : func->GetBlocks()) {
if (!bb) { if (!bb) {
continue; continue;
@ -65,7 +270,13 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
switch (inst->GetOpcode()) { switch (inst->GetOpcode()) {
case Opcode::Add: case Opcode::Add:
case Opcode::Sub: case Opcode::Sub:
case Opcode::Mul: { case Opcode::Mul:
case Opcode::SDiv:
case Opcode::SRem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv: {
auto* bin = static_cast<const BinaryInst*>(inst); auto* bin = static_cast<const BinaryInst*>(inst);
os << " " << bin->GetName() << " = " os << " " << bin->GetName() << " = "
<< OpcodeToString(bin->GetOpcode()) << " " << OpcodeToString(bin->GetOpcode()) << " "
@ -74,27 +285,122 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
<< ValueToString(bin->GetRhs()) << "\n"; << ValueToString(bin->GetRhs()) << "\n";
break; break;
} }
case Opcode::ICmp: {
auto* cmp = static_cast<const ICmpInst*>(inst);
os << " " << cmp->GetName() << " = icmp "
<< ICmpPredToString(cmp->GetPredicate()) << " "
<< TypeToString(*cmp->GetLhs()->GetType()) << " "
<< ValueToString(cmp->GetLhs()) << ", "
<< ValueToString(cmp->GetRhs()) << "\n";
break;
}
case Opcode::FCmp: {
auto* cmp = static_cast<const FCmpInst*>(inst);
os << " " << cmp->GetName() << " = fcmp "
<< FCmpPredToString(cmp->GetPredicate()) << " "
<< TypeToString(*cmp->GetLhs()->GetType()) << " "
<< ValueToString(cmp->GetLhs()) << ", "
<< ValueToString(cmp->GetRhs()) << "\n";
break;
}
case Opcode::SIToFP:
case Opcode::FPToSI:
case Opcode::ZExt: {
auto* cast = static_cast<const CastInst*>(inst);
os << " " << cast->GetName() << " = "
<< OpcodeToString(cast->GetOpcode()) << " "
<< TypeToString(*cast->GetValue()->GetType()) << " "
<< ValueToString(cast->GetValue()) << " to "
<< TypeToString(*cast->GetType()) << "\n";
break;
}
case Opcode::Alloca: { case Opcode::Alloca: {
auto* alloca = static_cast<const AllocaInst*>(inst); auto* alloca = static_cast<const AllocaInst*>(inst);
os << " " << alloca->GetName() << " = alloca i32\n"; os << " " << alloca->GetName() << " = alloca "
<< TypeToString(*alloca->GetAllocatedType()) << "\n";
break; break;
} }
case Opcode::Load: { case Opcode::Load: {
auto* load = static_cast<const LoadInst*>(inst); auto* load = static_cast<const LoadInst*>(inst);
os << " " << load->GetName() << " = load i32, i32* " os << " " << load->GetName() << " = load "
<< TypeToString(*load->GetType()) << ", "
<< TypeToString(*load->GetPtr()->GetType()) << " "
<< ValueToString(load->GetPtr()) << "\n"; << ValueToString(load->GetPtr()) << "\n";
break; break;
} }
case Opcode::Store: { case Opcode::Store: {
auto* store = static_cast<const StoreInst*>(inst); auto* store = static_cast<const StoreInst*>(inst);
os << " store i32 " << ValueToString(store->GetValue()) os << " store " << TypeToString(*store->GetValue()->GetType())
<< ", i32* " << ValueToString(store->GetPtr()) << "\n"; << " " << ValueToString(store->GetValue()) << ", "
<< TypeToString(*store->GetPtr()->GetType()) << " "
<< ValueToString(store->GetPtr()) << "\n";
break;
}
case Opcode::Br: {
auto* br = static_cast<const BranchInst*>(inst);
os << " br label " << LabelToString(br->GetDest()) << "\n";
break;
}
case Opcode::CondBr: {
auto* cbr = static_cast<const CondBrInst*>(inst);
os << " br i1 " << ValueToString(cbr->GetCond())
<< ", label " << LabelToString(cbr->GetTrueDest())
<< ", label " << LabelToString(cbr->GetFalseDest()) << "\n";
break;
}
case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst);
const auto& args = call->GetArgs();
if (!call->GetType()->IsVoid()) {
os << " " << call->GetName() << " = ";
} else {
os << " ";
}
os << "call " << TypeToString(*call->GetType()) << " "
<< ValueToString(call->GetCallee()) << "(";
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) os << ", ";
os << TypeToString(*args[i]->GetType()) << " "
<< ValueToString(args[i]);
}
os << ")\n";
break;
}
case Opcode::Phi: {
auto* phi = static_cast<const PhiInst*>(inst);
os << " " << phi->GetName() << " = phi "
<< TypeToString(*phi->GetType()) << " ";
const auto& values = phi->GetIncomingValues();
const auto& blocks = phi->GetIncomingBlocks();
for (size_t i = 0; i < values.size(); ++i) {
if (i > 0) os << ", ";
os << "[ " << ValueToString(values[i]) << ", "
<< LabelToString(blocks[i]) << " ]";
}
os << "\n";
break;
}
case Opcode::Gep: {
auto* gep = static_cast<const GepInst*>(inst);
os << " " << gep->GetName() << " = getelementptr "
<< TypeToString(*gep->GetBasePtr()->GetType()->GetElementType())
<< ", " << TypeToString(*gep->GetBasePtr()->GetType()) << " "
<< ValueToString(gep->GetBasePtr());
const auto& idx = gep->GetIndices();
for (auto* v : idx) {
os << ", i32 " << ValueToString(v);
}
os << "\n";
break; break;
} }
case Opcode::Ret: { case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst); auto* ret = static_cast<const ReturnInst*>(inst);
os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " " if (ret->HasReturnValue()) {
<< ValueToString(ret->GetValue()) << "\n"; os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " "
<< ValueToString(ret->GetValue()) << "\n";
} else {
os << " ret void\n";
}
break; break;
} }
} }

@ -3,6 +3,7 @@
// - 指令操作数与结果类型管理,支持打印与优化 // - 指令操作数与结果类型管理,支持打印与优化
#include "ir/IR.h" #include "ir/IR.h"
#include <cstddef>
#include <stdexcept> #include <stdexcept>
#include "utils/Log.h" #include "utils/Log.h"
@ -36,6 +37,7 @@ void User::SetOperand(size_t index, Value* value) {
} }
operands_[index] = value; operands_[index] = value;
value->AddUse(this, index); value->AddUse(this, index);
OnOperandChanged(index, value);
} }
void User::AddOperand(Value* value) { void User::AddOperand(Value* value) {
@ -47,22 +49,56 @@ void User::AddOperand(Value* value) {
value->AddUse(this, operand_index); value->AddUse(this, operand_index);
} }
void User::RemoveOperand(size_t index) {
if (index >= operands_.size()) {
throw std::out_of_range("User operand index out of range");
}
OnOperandRemoving(index);
if (auto* old = operands_[index]) {
old->RemoveUse(this, index);
}
for (size_t i = index + 1; i < operands_.size(); ++i) {
if (auto* value = operands_[i]) {
value->RemoveUse(this, i);
value->AddUse(this, i - 1);
}
}
operands_.erase(operands_.begin() + static_cast<std::ptrdiff_t>(index));
}
void User::OnOperandChanged(size_t, Value*) {}
void User::OnOperandRemoving(size_t) {}
Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name) Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)), opcode_(op) {} : User(std::move(ty), std::move(name)), opcode_(op) {}
Opcode Instruction::GetOpcode() const { return opcode_; } Opcode Instruction::GetOpcode() const { return opcode_; }
bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret; } bool Instruction::IsTerminator() const {
return opcode_ == Opcode::Ret || opcode_ == Opcode::Br ||
opcode_ == Opcode::CondBr;
}
BasicBlock* Instruction::GetParent() const { return parent_; } BasicBlock* Instruction::GetParent() const { return parent_; }
void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; } void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; }
static bool IsIntBinaryOp(Opcode op) {
return op == Opcode::Add || op == Opcode::Sub || op == Opcode::Mul ||
op == Opcode::SDiv || op == Opcode::SRem;
}
static bool IsFloatBinaryOp(Opcode op) {
return op == Opcode::FAdd || op == Opcode::FSub || op == Opcode::FMul ||
op == Opcode::FDiv;
}
BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
Value* rhs, std::string name) Value* rhs, std::string name)
: Instruction(op, std::move(ty), std::move(name)) { : Instruction(op, std::move(ty), std::move(name)) {
if (op != Opcode::Add) { if (!IsIntBinaryOp(op) && !IsFloatBinaryOp(op)) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add")); throw std::runtime_error(FormatError("ir", "BinaryInst 非算术 op"));
} }
if (!lhs || !rhs) { if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数")); throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数"));
@ -70,12 +106,15 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
if (!type_ || !lhs->GetType() || !rhs->GetType()) { if (!type_ || !lhs->GetType() || !rhs->GetType()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息")); throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息"));
} }
if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind() || if (!lhs->GetType()->Equals(*rhs->GetType()) ||
type_->GetKind() != lhs->GetType()->GetKind()) { !type_->Equals(*lhs->GetType())) {
throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配")); throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配"));
} }
if (!type_->IsInt32()) { if (IsIntBinaryOp(op) && !type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32")); throw std::runtime_error(FormatError("ir", "整数二元只支持 i32"));
}
if (IsFloatBinaryOp(op) && !type_->IsFloat()) {
throw std::runtime_error(FormatError("ir", "浮点二元只支持 float"));
} }
AddOperand(lhs); AddOperand(lhs);
AddOperand(rhs); AddOperand(rhs);
@ -85,6 +124,127 @@ Value* BinaryInst::GetLhs() const { return GetOperand(0); }
Value* BinaryInst::GetRhs() const { return GetOperand(1); } Value* BinaryInst::GetRhs() const { return GetOperand(1); }
ICmpInst::ICmpInst(ICmpPredicate pred, Value* lhs, Value* rhs, std::string name)
: Instruction(Opcode::ICmp, Type::GetInt1Type(), std::move(name)),
pred_(pred) {
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "ICmpInst 缺少操作数"));
}
if (!lhs->GetType() || !rhs->GetType() ||
!lhs->GetType()->Equals(*rhs->GetType())) {
throw std::runtime_error(FormatError("ir", "ICmpInst 类型不匹配"));
}
if (!lhs->GetType()->IsInt1() && !lhs->GetType()->IsInt32()) {
throw std::runtime_error(FormatError("ir", "ICmpInst 仅支持整型"));
}
AddOperand(lhs);
AddOperand(rhs);
}
Value* ICmpInst::GetLhs() const { return GetOperand(0); }
Value* ICmpInst::GetRhs() const { return GetOperand(1); }
FCmpInst::FCmpInst(FCmpPredicate pred, Value* lhs, Value* rhs, std::string name)
: Instruction(Opcode::FCmp, Type::GetInt1Type(), std::move(name)),
pred_(pred) {
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "FCmpInst 缺少操作数"));
}
if (!lhs->GetType() || !rhs->GetType() ||
!lhs->GetType()->Equals(*rhs->GetType())) {
throw std::runtime_error(FormatError("ir", "FCmpInst 类型不匹配"));
}
if (!lhs->GetType()->IsFloat()) {
throw std::runtime_error(FormatError("ir", "FCmpInst 仅支持 float"));
}
AddOperand(lhs);
AddOperand(rhs);
}
Value* FCmpInst::GetLhs() const { return GetOperand(0); }
Value* FCmpInst::GetRhs() const { return GetOperand(1); }
CastInst::CastInst(Opcode op, std::shared_ptr<Type> dst_ty, Value* src,
std::string name)
: Instruction(op, std::move(dst_ty), std::move(name)) {
if (op != Opcode::SIToFP && op != Opcode::FPToSI && op != Opcode::ZExt) {
throw std::runtime_error(FormatError("ir", "CastInst 不支持的 op"));
}
if (!src) {
throw std::runtime_error(FormatError("ir", "CastInst 缺少 src"));
}
if (op == Opcode::SIToFP) {
if (!src->GetType()->IsInt32() && !src->GetType()->IsInt1()) {
throw std::runtime_error(FormatError("ir", "SIToFP 仅支持整型"));
}
if (!type_ || !type_->IsFloat()) {
throw std::runtime_error(FormatError("ir", "SIToFP 目标必须是 float"));
}
} else if (op == Opcode::FPToSI) {
if (!src->GetType()->IsFloat()) {
throw std::runtime_error(FormatError("ir", "FPToSI 仅支持 float"));
}
if (!type_ || !type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "FPToSI 目标必须是 i32"));
}
} else {
if (!src->GetType()->IsInt1()) {
throw std::runtime_error(FormatError("ir", "ZExt 仅支持 i1"));
}
if (!type_ || !type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "ZExt 目标必须是 i32"));
}
}
AddOperand(src);
}
Value* CastInst::GetValue() const { return GetOperand(0); }
BranchInst::BranchInst(BasicBlock* dest)
: Instruction(Opcode::Br, Type::GetVoidType(), "") {
if (!dest) {
throw std::runtime_error(FormatError("ir", "BranchInst 缺少目标块"));
}
AddOperand(dest);
}
BasicBlock* BranchInst::GetDest() const {
return static_cast<BasicBlock*>(GetOperand(0));
}
CondBrInst::CondBrInst(Value* cond, BasicBlock* true_dest,
BasicBlock* false_dest)
: Instruction(Opcode::CondBr, Type::GetVoidType(), "") {
if (!cond || !true_dest || !false_dest) {
throw std::runtime_error(FormatError("ir", "CondBrInst 缺少参数"));
}
if (!cond->GetType() || !cond->GetType()->IsInt1()) {
throw std::runtime_error(FormatError("ir", "CondBrInst cond 必须是 i1"));
}
AddOperand(cond);
AddOperand(true_dest);
AddOperand(false_dest);
}
Value* CondBrInst::GetCond() const { return GetOperand(0); }
BasicBlock* CondBrInst::GetTrueDest() const {
return static_cast<BasicBlock*>(GetOperand(1));
}
BasicBlock* CondBrInst::GetFalseDest() const {
return static_cast<BasicBlock*>(GetOperand(2));
}
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty)
: Instruction(Opcode::Ret, std::move(void_ty), "") {
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void"));
}
}
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val) ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
: Instruction(Opcode::Ret, std::move(void_ty), "") { : Instruction(Opcode::Ret, std::move(void_ty), "") {
if (!val) { if (!val) {
@ -96,26 +256,36 @@ ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
AddOperand(val); AddOperand(val);
} }
Value* ReturnInst::GetValue() const { return GetOperand(0); } bool ReturnInst::HasReturnValue() const { return GetNumOperands() > 0; }
Value* ReturnInst::GetValue() const {
if (!HasReturnValue()) return nullptr;
return GetOperand(0);
}
AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name) AllocaInst::AllocaInst(std::shared_ptr<Type> allocated_ty, std::string name)
: Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) { : Instruction(Opcode::Alloca, Type::GetPointerType(allocated_ty),
if (!type_ || !type_->IsPtrInt32()) { std::move(name)),
throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*")); allocated_type_(std::move(allocated_ty)) {
if (!allocated_type_) {
throw std::runtime_error(FormatError("ir", "AllocaInst 缺少类型"));
} }
} }
const std::shared_ptr<Type>& AllocaInst::GetAllocatedType() const {
return allocated_type_;
}
LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name) LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name)
: Instruction(Opcode::Load, std::move(val_ty), std::move(name)) { : Instruction(Opcode::Load, std::move(val_ty), std::move(name)) {
if (!ptr) { if (!ptr) {
throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr")); throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr"));
} }
if (!type_ || !type_->IsInt32()) { if (!ptr->GetType() || !ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32")); throw std::runtime_error(FormatError("ir", "LoadInst ptr 不是指针"));
} }
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { if (!type_ || !ptr->GetType()->GetElementType()->Equals(*type_)) {
throw std::runtime_error( throw std::runtime_error(FormatError("ir", "LoadInst 类型不匹配"));
FormatError("ir", "LoadInst 当前只支持从 i32* 加载"));
} }
AddOperand(ptr); AddOperand(ptr);
} }
@ -133,12 +303,11 @@ StoreInst::StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr)
if (!type_ || !type_->IsVoid()) { if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void")); throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void"));
} }
if (!val->GetType() || !val->GetType()->IsInt32()) { if (!ptr->GetType() || !ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32")); throw std::runtime_error(FormatError("ir", "StoreInst ptr 不是指针"));
} }
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { if (!ptr->GetType()->GetElementType()->Equals(*val->GetType())) {
throw std::runtime_error( throw std::runtime_error(FormatError("ir", "StoreInst 类型不匹配"));
FormatError("ir", "StoreInst 当前只支持写入 i32*"));
} }
AddOperand(val); AddOperand(val);
AddOperand(ptr); AddOperand(ptr);
@ -148,4 +317,141 @@ Value* StoreInst::GetValue() const { return GetOperand(0); }
Value* StoreInst::GetPtr() const { return GetOperand(1); } Value* StoreInst::GetPtr() const { return GetOperand(1); }
CallInst::CallInst(std::shared_ptr<Type> ret_ty, Value* callee,
std::vector<Value*> args, std::string name)
: Instruction(Opcode::Call, std::move(ret_ty), std::move(name)),
args_(std::move(args)) {
if (!callee) {
throw std::runtime_error(FormatError("ir", "CallInst 缺少 callee"));
}
AddOperand(callee);
for (auto* arg : args_) {
if (!arg) {
throw std::runtime_error(FormatError("ir", "CallInst arg 为空"));
}
AddOperand(arg);
}
}
Value* CallInst::GetCallee() const { return GetOperand(0); }
void CallInst::OnOperandChanged(size_t index, Value* value) {
if (index == 0) return;
size_t arg_index = index - 1;
if (arg_index < args_.size()) {
args_[arg_index] = value;
}
}
void CallInst::OnOperandRemoving(size_t index) {
if (index == 0) return;
size_t arg_index = index - 1;
if (arg_index < args_.size()) {
args_.erase(args_.begin() + static_cast<std::ptrdiff_t>(arg_index));
}
}
PhiInst::PhiInst(std::shared_ptr<Type> ty, std::string name)
: Instruction(Opcode::Phi, std::move(ty), std::move(name)) {}
void PhiInst::AddIncoming(Value* value, BasicBlock* block) {
if (!value || !block) {
throw std::runtime_error(FormatError("ir", "PhiInst incoming 为空"));
}
if (!value->GetType() || !type_ || !value->GetType()->Equals(*type_)) {
throw std::runtime_error(FormatError("ir", "PhiInst 类型不匹配"));
}
incoming_values_.push_back(value);
incoming_blocks_.push_back(block);
AddOperand(value);
AddOperand(block);
}
void PhiInst::RemoveIncomingFrom(BasicBlock* block) {
for (size_t i = 0; i < incoming_blocks_.size();) {
if (incoming_blocks_[i] != block) {
++i;
continue;
}
RemoveOperand(2 * i + 1);
RemoveOperand(2 * i);
}
}
const std::vector<Value*>& PhiInst::GetIncomingValues() const {
return incoming_values_;
}
const std::vector<BasicBlock*>& PhiInst::GetIncomingBlocks() const {
return incoming_blocks_;
}
void PhiInst::OnOperandChanged(size_t index, Value* value) {
size_t incoming_index = index / 2;
if (index % 2 == 0) {
if (incoming_index < incoming_values_.size()) {
incoming_values_[incoming_index] = value;
}
return;
}
if (incoming_index < incoming_blocks_.size()) {
incoming_blocks_[incoming_index] = static_cast<BasicBlock*>(value);
}
}
void PhiInst::OnOperandRemoving(size_t index) {
size_t incoming_index = index / 2;
if (index % 2 == 0) {
if (incoming_index < incoming_values_.size()) {
incoming_values_.erase(incoming_values_.begin() +
static_cast<std::ptrdiff_t>(incoming_index));
}
return;
}
if (incoming_index < incoming_blocks_.size()) {
incoming_blocks_.erase(incoming_blocks_.begin() +
static_cast<std::ptrdiff_t>(incoming_index));
}
}
GepInst::GepInst(std::shared_ptr<Type> result_ptr_ty, Value* base_ptr,
std::vector<Value*> indices, std::string name)
: Instruction(Opcode::Gep, std::move(result_ptr_ty), std::move(name)),
indices_(std::move(indices)) {
if (!base_ptr) {
throw std::runtime_error(FormatError("ir", "GepInst 缺少 base_ptr"));
}
if (!base_ptr->GetType() || !base_ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "GepInst base_ptr 不是指针"));
}
if (!type_ || !type_->IsPointer()) {
throw std::runtime_error(FormatError("ir", "GepInst 结果必须是指针"));
}
AddOperand(base_ptr);
for (auto* idx : indices_) {
if (!idx || !idx->GetType() || !idx->GetType()->IsInt32()) {
throw std::runtime_error(FormatError("ir", "GepInst index 必须是 i32"));
}
AddOperand(idx);
}
}
Value* GepInst::GetBasePtr() const { return GetOperand(0); }
void GepInst::OnOperandChanged(size_t index, Value* value) {
if (index == 0) return;
size_t idx = index - 1;
if (idx < indices_.size()) {
indices_[idx] = value;
}
}
void GepInst::OnOperandRemoving(size_t index) {
if (index == 0) return;
size_t idx = index - 1;
if (idx < indices_.size()) {
indices_.erase(indices_.begin() + static_cast<std::ptrdiff_t>(idx));
}
}
} // namespace ir } // namespace ir

@ -10,12 +10,39 @@ const Context& Module::GetContext() const { return context_; }
Function* Module::CreateFunction(const std::string& name, Function* Module::CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type) { std::shared_ptr<Type> ret_type) {
functions_.push_back(std::make_unique<Function>(name, std::move(ret_type))); auto func_ty = Type::GetFunctionType(std::move(ret_type), {});
functions_.push_back(std::make_unique<Function>(name, std::move(func_ty)));
return functions_.back().get(); return functions_.back().get();
} }
Function* Module::CreateFunctionWithType(const std::string& name,
std::shared_ptr<Type> func_type) {
functions_.push_back(
std::make_unique<Function>(name, std::move(func_type), false));
return functions_.back().get();
}
Function* Module::CreateFunctionDecl(const std::string& name,
std::shared_ptr<Type> func_type) {
functions_.push_back(
std::make_unique<Function>(name, std::move(func_type), true));
return functions_.back().get();
}
GlobalVariable* Module::CreateGlobalVariable(const std::string& name,
std::shared_ptr<Type> value_type,
ConstantValue* init, bool is_const) {
globals_.push_back(std::make_unique<GlobalVariable>(
std::move(value_type), name, init, is_const));
return globals_.back().get();
}
const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const { const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const {
return functions_; return functions_;
} }
const std::vector<std::unique_ptr<GlobalVariable>>& Module::GetGlobals() const {
return globals_;
}
} // namespace ir } // namespace ir

@ -1,31 +1,148 @@
// 当前仅支持 void、i32 和 i32* // 支持 void/i1/i32/float/ptr/array/function/label
#include "ir/IR.h" #include "ir/IR.h"
namespace ir { namespace ir {
Type::Type(Kind k) : kind_(k) {} Type::Type(Kind k) : kind_(k) {}
Type::Type(Kind k, std::shared_ptr<Type> elem, size_t count)
: kind_(k), elem_type_(std::move(elem)), array_size_(count) {}
Type::Type(Kind k, std::shared_ptr<Type> ret,
std::vector<std::shared_ptr<Type>> params, bool is_vararg)
: kind_(k), ret_type_(std::move(ret)), param_types_(std::move(params)),
is_vararg_(is_vararg) {}
const std::shared_ptr<Type>& Type::GetVoidType() { const std::shared_ptr<Type>& Type::GetVoidType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Void); static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Void);
return type; return type;
} }
const std::shared_ptr<Type>& Type::GetInt1Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int1);
return type;
}
const std::shared_ptr<Type>& Type::GetInt32Type() { const std::shared_ptr<Type>& Type::GetInt32Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int32); static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int32);
return type; return type;
} }
const std::shared_ptr<Type>& Type::GetPtrInt32Type() { const std::shared_ptr<Type>& Type::GetFloatType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::PtrInt32); static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Float);
return type;
}
const std::shared_ptr<Type>& Type::GetLabelType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Label);
return type; return type;
} }
std::shared_ptr<Type> Type::GetPointerType(std::shared_ptr<Type> elem) {
if (!elem) {
throw std::runtime_error("PointerType 缺少 element type");
}
return std::make_shared<Type>(Kind::Pointer, std::move(elem), 0);
}
std::shared_ptr<Type> Type::GetArrayType(std::shared_ptr<Type> elem,
size_t count) {
if (!elem) {
throw std::runtime_error("ArrayType 缺少 element type");
}
return std::make_shared<Type>(Kind::Array, std::move(elem), count);
}
std::shared_ptr<Type> Type::GetFunctionType(
std::shared_ptr<Type> ret, std::vector<std::shared_ptr<Type>> params,
bool is_vararg) {
if (!ret) {
throw std::runtime_error("FunctionType 缺少 return type");
}
return std::make_shared<Type>(Kind::Function, std::move(ret),
std::move(params), is_vararg);
}
Type::Kind Type::GetKind() const { return kind_; } Type::Kind Type::GetKind() const { return kind_; }
bool Type::IsVoid() const { return kind_ == Kind::Void; } bool Type::IsVoid() const { return kind_ == Kind::Void; }
bool Type::IsInt1() const { return kind_ == Kind::Int1; }
bool Type::IsInt32() const { return kind_ == Kind::Int32; } bool Type::IsInt32() const { return kind_ == Kind::Int32; }
bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; } bool Type::IsFloat() const { return kind_ == Kind::Float; }
bool Type::IsPointer() const { return kind_ == Kind::Pointer; }
bool Type::IsArray() const { return kind_ == Kind::Array; }
bool Type::IsFunction() const { return kind_ == Kind::Function; }
bool Type::IsLabel() const { return kind_ == Kind::Label; }
const std::shared_ptr<Type>& Type::GetElementType() const {
if (!elem_type_) {
throw std::runtime_error("Type 没有 element type");
}
return elem_type_;
}
size_t Type::GetArraySize() const {
if (!IsArray()) {
throw std::runtime_error("Type 不是 array");
}
return array_size_;
}
const std::shared_ptr<Type>& Type::GetReturnType() const {
if (!IsFunction()) {
throw std::runtime_error("Type 不是 function");
}
return ret_type_;
}
const std::vector<std::shared_ptr<Type>>& Type::GetParamTypes() const {
if (!IsFunction()) {
throw std::runtime_error("Type 不是 function");
}
return param_types_;
}
bool Type::IsVarArg() const {
if (!IsFunction()) {
throw std::runtime_error("Type 不是 function");
}
return is_vararg_;
}
bool Type::Equals(const Type& other) const {
if (kind_ != other.kind_) return false;
switch (kind_) {
case Kind::Pointer:
return elem_type_ && other.elem_type_ &&
elem_type_->Equals(*other.elem_type_);
case Kind::Array:
return array_size_ == other.array_size_ && elem_type_ &&
other.elem_type_ && elem_type_->Equals(*other.elem_type_);
case Kind::Function: {
if (!ret_type_ || !other.ret_type_ ||
!ret_type_->Equals(*other.ret_type_) ||
is_vararg_ != other.is_vararg_ ||
param_types_.size() != other.param_types_.size()) {
return false;
}
for (size_t i = 0; i < param_types_.size(); ++i) {
if (!param_types_[i] || !other.param_types_[i] ||
!param_types_[i]->Equals(*other.param_types_[i])) {
return false;
}
}
return true;
}
default:
return true;
}
}
} // namespace ir } // namespace ir

@ -18,9 +18,21 @@ void Value::SetName(std::string n) { name_ = std::move(n); }
bool Value::IsVoid() const { return type_ && type_->IsVoid(); } bool Value::IsVoid() const { return type_ && type_->IsVoid(); }
bool Value::IsInt1() const { return type_ && type_->IsInt1(); }
bool Value::IsInt32() const { return type_ && type_->IsInt32(); } bool Value::IsInt32() const { return type_ && type_->IsInt32(); }
bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); } bool Value::IsFloat() const { return type_ && type_->IsFloat(); }
bool Value::IsPointer() const { return type_ && type_->IsPointer(); }
bool Value::IsArray() const { return type_ && type_->IsArray(); }
bool Value::IsFunctionType() const { return type_ && type_->IsFunction(); }
bool Value::IsPtrInt32() const {
return type_ && type_->IsPointer() && type_->GetElementType()->IsInt32();
}
bool Value::IsConstant() const { bool Value::IsConstant() const {
return dynamic_cast<const ConstantValue*>(this) != nullptr; return dynamic_cast<const ConstantValue*>(this) != nullptr;
@ -78,6 +90,25 @@ ConstantValue::ConstantValue(std::shared_ptr<Type> ty, std::string name)
: Value(std::move(ty), std::move(name)) {} : Value(std::move(ty), std::move(name)) {}
ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int v) ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int v)
: ConstantValue(std::move(ty), ""), value_(v) {} : ConstantValue(std::move(ty), ""), value_(v) {
if (!type_ || (!type_->IsInt32() && !type_->IsInt1())) {
throw std::runtime_error("ConstantInt 需要 i1/i32 类型");
}
}
ConstantFloat::ConstantFloat(std::shared_ptr<Type> ty, float v)
: ConstantValue(std::move(ty), ""), value_(v) {
if (!type_ || !type_->IsFloat()) {
throw std::runtime_error("ConstantFloat 需要 float 类型");
}
}
ConstantArray::ConstantArray(std::shared_ptr<Type> ty,
std::vector<ConstantValue*> elements)
: ConstantValue(std::move(ty), ""), elements_(std::move(elements)) {
if (!type_ || !type_->IsArray()) {
throw std::runtime_error("ConstantArray 需要 array 类型");
}
}
} // namespace ir } // namespace ir

@ -1,4 +1,135 @@
// CFG 简化: #include "ir/IR.h"
// - 删除不可达块、合并空块、简化分支等
// - 改善 IR 结构,便于后续优化与后端生成 #include <memory>
#include <queue>
#include <unordered_set>
namespace ir {
namespace {
Instruction* Terminator(BasicBlock& block) {
auto& insts = block.GetMutableInstructions();
if (insts.empty()) return nullptr;
auto* inst = insts.back().get();
return inst && inst->IsTerminator() ? inst : nullptr;
}
void AddCFGEdge(BasicBlock* from, BasicBlock* to) {
if (!from || !to) return;
from->AddSuccessor(to);
to->AddPredecessor(from);
}
void RebuildCFG(Function& func) {
for (const auto& block : func.GetBlocks()) {
if (!block) continue;
block->ClearPredecessors();
block->ClearSuccessors();
}
for (const auto& block : func.GetBlocks()) {
if (!block) continue;
auto* term = Terminator(*block);
if (auto* br = dynamic_cast<BranchInst*>(term)) {
AddCFGEdge(block.get(), br->GetDest());
} else if (auto* cbr = dynamic_cast<CondBrInst*>(term)) {
AddCFGEdge(block.get(), cbr->GetTrueDest());
AddCFGEdge(block.get(), cbr->GetFalseDest());
}
}
}
bool SimplifyBranches(Function& func) {
bool changed = false;
for (const auto& block : func.GetBlocks()) {
if (!block) continue;
auto* cbr = dynamic_cast<CondBrInst*>(Terminator(*block));
if (!cbr) continue;
BasicBlock* dest = nullptr;
if (cbr->GetTrueDest() == cbr->GetFalseDest()) {
dest = cbr->GetTrueDest();
} else if (auto* cond = dynamic_cast<ConstantInt*>(cbr->GetCond())) {
dest = cond->GetValue() != 0 ? cbr->GetTrueDest() : cbr->GetFalseDest();
}
if (dest) {
block->ReplaceTerminator(std::make_unique<BranchInst>(dest));
changed = true;
}
}
return changed;
}
std::unordered_set<BasicBlock*> ReachableBlocks(Function& func) {
std::unordered_set<BasicBlock*> reachable;
std::queue<BasicBlock*> work;
if (auto* entry = func.GetEntry()) {
reachable.insert(entry);
work.push(entry);
}
while (!work.empty()) {
auto* block = work.front();
work.pop();
for (auto* succ : block->GetSuccessors()) {
if (succ && reachable.insert(succ).second) {
work.push(succ);
}
}
}
return reachable;
}
bool RemoveUnreachable(Function& func) {
auto reachable = ReachableBlocks(func);
bool changed = false;
for (const auto& block : func.GetBlocks()) {
if (!block || reachable.count(block.get()) == 0) continue;
for (const auto& inst : block->GetInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst.get());
if (!phi) continue;
for (const auto& other : func.GetBlocks()) {
if (other && reachable.count(other.get()) == 0) {
phi->RemoveIncomingFrom(other.get());
}
}
}
}
auto& blocks = func.GetMutableBlocks();
for (auto it = blocks.begin(); it != blocks.end();) {
auto* block = it->get();
if (!block || reachable.count(block) != 0) {
++it;
continue;
}
auto& insts = block->GetMutableInstructions();
while (!insts.empty()) {
block->EraseInstruction(insts.back().get());
}
it = blocks.erase(it);
changed = true;
}
return changed;
}
} // namespace
bool RunCFGSimplify(Module& module) {
bool changed = false;
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
RebuildCFG(*func);
bool local_changed = SimplifyBranches(*func);
if (local_changed) {
RebuildCFG(*func);
}
local_changed |= RemoveUnreachable(*func);
if (local_changed) {
RebuildCFG(*func);
}
changed |= local_changed;
}
return changed;
}
} // namespace ir

@ -1,4 +1,94 @@
// 公共子表达式消除CSE #include "ir/IR.h"
// - 识别并复用重复计算的等价表达式
// - 典型放置在 ConstFold 之后、DCE 之前 #include <algorithm>
// - 当前为 Lab4 的框架占位,具体算法由实验实现 #include <cstdint>
#include <sstream>
#include <string>
#include <unordered_map>
namespace ir {
namespace {
bool IsCommutative(Opcode op) {
return op == Opcode::Add || op == Opcode::Mul || op == Opcode::FAdd ||
op == Opcode::FMul;
}
std::string ValueId(Value* value) {
std::ostringstream oss;
oss << reinterpret_cast<std::uintptr_t>(value);
return oss.str();
}
std::string KeyFor(const Instruction& inst) {
std::ostringstream oss;
oss << static_cast<int>(inst.GetOpcode()) << ":";
if (auto* icmp = dynamic_cast<const ICmpInst*>(&inst)) {
oss << static_cast<int>(icmp->GetPredicate()) << ":";
} else if (auto* fcmp = dynamic_cast<const FCmpInst*>(&inst)) {
oss << static_cast<int>(fcmp->GetPredicate()) << ":";
}
std::vector<std::string> operands;
for (size_t i = 0; i < inst.GetNumOperands(); ++i) {
operands.push_back(ValueId(inst.GetOperand(i)));
}
if (IsCommutative(inst.GetOpcode()) && operands.size() == 2) {
std::sort(operands.begin(), operands.end());
}
for (const auto& operand : operands) {
oss << operand << ",";
}
return oss.str();
}
bool IsCSECandidate(const Instruction& inst) {
switch (inst.GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::SDiv:
case Opcode::SRem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::ICmp:
case Opcode::FCmp:
case Opcode::SIToFP:
case Opcode::FPToSI:
case Opcode::ZExt:
case Opcode::Gep:
return true;
default:
return false;
}
}
} // namespace
bool RunCSE(Module& module) {
bool changed = false;
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
for (const auto& block : func->GetBlocks()) {
if (!block) continue;
std::unordered_map<std::string, Instruction*> available;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!inst || !IsCSECandidate(*inst)) continue;
std::string key = KeyFor(*inst);
auto it = available.find(key);
if (it != available.end()) {
inst->ReplaceAllUsesWith(it->second);
changed = true;
} else {
available.emplace(std::move(key), inst);
}
}
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,239 @@
// IR 常量折叠: #include "ir/IR.h"
// - 折叠可判定的常量表达式
// - 简化常量控制流分支(按实现范围裁剪) #include <cmath>
namespace ir {
namespace {
ConstantInt* AsConstInt(Value* value) {
return dynamic_cast<ConstantInt*>(value);
}
ConstantFloat* AsConstFloat(Value* value) {
return dynamic_cast<ConstantFloat*>(value);
}
bool IsZero(Value* value) {
if (auto* i = AsConstInt(value)) return i->GetValue() == 0;
if (auto* f = AsConstFloat(value)) return f->GetValue() == 0.0f;
return false;
}
bool IsOne(Value* value) {
if (auto* i = AsConstInt(value)) return i->GetValue() == 1;
if (auto* f = AsConstFloat(value)) return f->GetValue() == 1.0f;
return false;
}
ConstantValue* FoldBinary(BinaryInst& inst, Context& ctx) {
auto* li = AsConstInt(inst.GetLhs());
auto* ri = AsConstInt(inst.GetRhs());
if (li && ri) {
int lhs = li->GetValue();
int rhs = ri->GetValue();
switch (inst.GetOpcode()) {
case Opcode::Add:
return ctx.GetConstInt(lhs + rhs);
case Opcode::Sub:
return ctx.GetConstInt(lhs - rhs);
case Opcode::Mul:
return ctx.GetConstInt(lhs * rhs);
case Opcode::SDiv:
if (rhs != 0) return ctx.GetConstInt(lhs / rhs);
break;
case Opcode::SRem:
if (rhs != 0) return ctx.GetConstInt(lhs % rhs);
break;
default:
break;
}
}
auto* lf = AsConstFloat(inst.GetLhs());
auto* rf = AsConstFloat(inst.GetRhs());
if (lf && rf) {
float lhs = lf->GetValue();
float rhs = rf->GetValue();
switch (inst.GetOpcode()) {
case Opcode::FAdd:
return ctx.GetConstFloat(lhs + rhs);
case Opcode::FSub:
return ctx.GetConstFloat(lhs - rhs);
case Opcode::FMul:
return ctx.GetConstFloat(lhs * rhs);
case Opcode::FDiv:
return ctx.GetConstFloat(lhs / rhs);
default:
break;
}
}
return nullptr;
}
Value* SimplifyBinary(BinaryInst& inst) {
auto op = inst.GetOpcode();
auto* lhs = inst.GetLhs();
auto* rhs = inst.GetRhs();
switch (op) {
case Opcode::Add:
case Opcode::FAdd:
if (IsZero(rhs)) return lhs;
if (IsZero(lhs)) return rhs;
break;
case Opcode::Sub:
case Opcode::FSub:
if (IsZero(rhs)) return lhs;
break;
case Opcode::Mul:
if (IsOne(rhs)) return lhs;
if (IsOne(lhs)) return rhs;
if (IsZero(rhs)) return rhs;
if (IsZero(lhs)) return lhs;
break;
case Opcode::FMul:
if (IsOne(rhs)) return lhs;
if (IsOne(lhs)) return rhs;
break;
case Opcode::SDiv:
case Opcode::FDiv:
if (IsOne(rhs)) return lhs;
break;
case Opcode::SRem:
if (IsOne(rhs)) return rhs;
break;
default:
break;
}
return nullptr;
}
ConstantInt* FoldICmp(ICmpInst& inst, Context& ctx) {
auto* lhs = AsConstInt(inst.GetLhs());
auto* rhs = AsConstInt(inst.GetRhs());
if (!lhs || !rhs) return nullptr;
int l = lhs->GetValue();
int r = rhs->GetValue();
bool result = false;
switch (inst.GetPredicate()) {
case ICmpPredicate::Eq:
result = l == r;
break;
case ICmpPredicate::Ne:
result = l != r;
break;
case ICmpPredicate::Slt:
result = l < r;
break;
case ICmpPredicate::Sle:
result = l <= r;
break;
case ICmpPredicate::Sgt:
result = l > r;
break;
case ICmpPredicate::Sge:
result = l >= r;
break;
}
return ctx.GetConstBool(result);
}
ConstantInt* FoldFCmp(FCmpInst& inst, Context& ctx) {
auto* lhs = AsConstFloat(inst.GetLhs());
auto* rhs = AsConstFloat(inst.GetRhs());
if (!lhs || !rhs) return nullptr;
float l = lhs->GetValue();
float r = rhs->GetValue();
bool ordered = !std::isnan(l) && !std::isnan(r);
bool result = false;
switch (inst.GetPredicate()) {
case FCmpPredicate::Oeq:
result = ordered && l == r;
break;
case FCmpPredicate::One:
result = ordered && l != r;
break;
case FCmpPredicate::Olt:
result = ordered && l < r;
break;
case FCmpPredicate::Ole:
result = ordered && l <= r;
break;
case FCmpPredicate::Ogt:
result = ordered && l > r;
break;
case FCmpPredicate::Oge:
result = ordered && l >= r;
break;
}
return ctx.GetConstBool(result);
}
ConstantValue* FoldCast(CastInst& inst, Context& ctx) {
switch (inst.GetOpcode()) {
case Opcode::SIToFP:
if (auto* c = AsConstInt(inst.GetValue())) {
return ctx.GetConstFloat(static_cast<float>(c->GetValue()));
}
break;
case Opcode::FPToSI:
if (auto* c = AsConstFloat(inst.GetValue())) {
return ctx.GetConstInt(static_cast<int>(c->GetValue()));
}
break;
case Opcode::ZExt:
if (auto* c = AsConstInt(inst.GetValue())) {
return ctx.GetConstInt(c->GetValue() != 0 ? 1 : 0);
}
break;
default:
break;
}
return nullptr;
}
Value* SimplifyPhi(PhiInst& phi) {
const auto& values = phi.GetIncomingValues();
if (values.empty()) return nullptr;
auto* first = values.front();
for (auto* value : values) {
if (value != first) return nullptr;
}
return first;
}
} // namespace
bool RunConstFold(Module& module) {
bool changed = false;
auto& ctx = module.GetContext();
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
for (const auto& block : func->GetBlocks()) {
if (!block) continue;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
Value* replacement = nullptr;
if (auto* bin = dynamic_cast<BinaryInst*>(inst)) {
replacement = FoldBinary(*bin, ctx);
if (!replacement) replacement = SimplifyBinary(*bin);
} else if (auto* icmp = dynamic_cast<ICmpInst*>(inst)) {
replacement = FoldICmp(*icmp, ctx);
} else if (auto* fcmp = dynamic_cast<FCmpInst*>(inst)) {
replacement = FoldFCmp(*fcmp, ctx);
} else if (auto* cast = dynamic_cast<CastInst*>(inst)) {
replacement = FoldCast(*cast, ctx);
} else if (auto* phi = dynamic_cast<PhiInst*>(inst)) {
replacement = SimplifyPhi(*phi);
}
if (replacement && replacement != inst) {
inst->ReplaceAllUsesWith(replacement);
changed = true;
}
}
}
}
return changed;
}
} // namespace ir

@ -1,5 +1,77 @@
// 常量传播Constant Propagation #include "ir/IR.h"
// - 沿 use-def 关系传播已知常量
// - 将可替换的 SSA 值改写为常量,暴露更多折叠机会 #include <unordered_map>
// - 常与 ConstFold、DCE、CFGSimplify 迭代配合使用
namespace ir {
namespace {
ConstantValue* AsConstant(Value* value) {
return dynamic_cast<ConstantValue*>(value);
}
ConstantValue* ScalarConstInitializer(Value* ptr) {
auto* global = dynamic_cast<GlobalVariable*>(ptr);
if (!global || !global->IsConst()) return nullptr;
auto* init = global->GetInitializer();
if (!init) return nullptr;
if (dynamic_cast<ConstantArray*>(init)) return nullptr;
return init;
}
bool IsScalarStackSlot(Value* ptr) {
auto* alloca = dynamic_cast<AllocaInst*>(ptr);
if (!alloca) return false;
const auto& ty = alloca->GetAllocatedType();
return ty && (ty->IsInt1() || ty->IsInt32() || ty->IsFloat());
}
} // namespace
bool RunConstProp(Module& module) {
bool changed = false;
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
for (const auto& block : func->GetBlocks()) {
if (!block) continue;
std::unordered_map<Value*, ConstantValue*> known_memory;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* load = dynamic_cast<LoadInst*>(inst)) {
auto* ptr = load->GetPtr();
ConstantValue* known = nullptr;
auto it = known_memory.find(ptr);
if (it != known_memory.end()) {
known = it->second;
} else {
known = ScalarConstInitializer(ptr);
}
if (known) {
load->ReplaceAllUsesWith(known);
changed = true;
}
continue;
}
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
auto* ptr = store->GetPtr();
if (IsScalarStackSlot(ptr)) {
if (auto* c = AsConstant(store->GetValue())) {
known_memory[ptr] = c;
} else {
known_memory.erase(ptr);
}
}
continue;
}
if (inst->GetOpcode() == Opcode::Call) {
known_memory.clear();
}
}
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,55 @@
// 死代码删除DCE #include "ir/IR.h"
// - 删除无用指令与无用基本块
// - 通常与 CFG 简化配合使用 namespace ir {
namespace {
bool HasSideEffect(const Instruction& inst) {
switch (inst.GetOpcode()) {
case Opcode::Store:
case Opcode::Ret:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Call:
return true;
default:
return false;
}
}
bool IsDead(const Instruction& inst) {
if (inst.IsVoid()) return false;
if (HasSideEffect(inst)) return false;
return inst.GetUses().empty();
}
} // namespace
bool RunDCE(Module& module) {
bool changed = false;
bool local_changed = true;
while (local_changed) {
local_changed = false;
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
for (const auto& block : func->GetBlocks()) {
if (!block) continue;
auto& insts = block->GetMutableInstructions();
for (auto it = insts.begin(); it != insts.end();) {
auto* inst = it->get();
if (inst && IsDead(*inst)) {
block->EraseInstruction(inst);
local_changed = true;
changed = true;
it = insts.begin();
} else {
++it;
}
}
}
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,109 @@
// Mem2RegSSA 构造): #include "ir/IR.h"
// - 将局部变量的 alloca/load/store 提升为 SSA 形式
// - 插入 PHI 并重写使用,依赖支配树等分析 #include <vector>
namespace ir {
namespace {
bool IsPromotableType(const AllocaInst& alloca) {
const auto& ty = alloca.GetAllocatedType();
return ty && (ty->IsInt1() || ty->IsInt32() || ty->IsFloat());
}
bool IsDirectUseOf(Value* ptr, Instruction* inst) {
if (auto* load = dynamic_cast<LoadInst*>(inst)) {
return load->GetPtr() == ptr;
}
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
return store->GetPtr() == ptr;
}
return false;
}
BasicBlock* SingleUseBlock(AllocaInst& alloca) {
BasicBlock* use_block = nullptr;
for (const auto& use : alloca.GetUses()) {
auto* inst = dynamic_cast<Instruction*>(use.GetUser());
if (!inst || !IsDirectUseOf(&alloca, inst)) return nullptr;
auto* parent = inst->GetParent();
if (!parent) return nullptr;
if (!use_block) {
use_block = parent;
} else if (use_block != parent) {
return nullptr;
}
}
return use_block;
}
bool CanPromoteInBlock(AllocaInst& alloca, BasicBlock& block) {
bool has_value = false;
bool saw_use = false;
for (const auto& inst_ptr : block.GetInstructions()) {
auto* inst = inst_ptr.get();
if (!inst || !IsDirectUseOf(&alloca, inst)) continue;
saw_use = true;
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
if (store->GetValue() == &alloca) return false;
has_value = true;
} else if (dynamic_cast<LoadInst*>(inst)) {
if (!has_value) return false;
}
}
return saw_use;
}
bool PromoteInBlock(AllocaInst& alloca, BasicBlock& block) {
Value* current = nullptr;
std::vector<Instruction*> erase;
for (const auto& inst_ptr : block.GetInstructions()) {
auto* inst = inst_ptr.get();
if (!inst || !IsDirectUseOf(&alloca, inst)) continue;
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
current = store->GetValue();
erase.push_back(store);
} else if (auto* load = dynamic_cast<LoadInst*>(inst)) {
load->ReplaceAllUsesWith(current);
erase.push_back(load);
}
}
for (auto* inst : erase) {
block.EraseInstruction(inst);
}
if (alloca.GetUses().empty()) {
if (auto* parent = alloca.GetParent()) {
parent->EraseInstruction(&alloca);
}
}
return !erase.empty();
}
} // namespace
bool RunMem2Reg(Module& module) {
bool changed = false;
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
std::vector<AllocaInst*> allocas;
for (const auto& block : func->GetBlocks()) {
if (!block) continue;
for (const auto& inst : block->GetInstructions()) {
if (auto* alloca = dynamic_cast<AllocaInst*>(inst.get())) {
if (IsPromotableType(*alloca)) {
allocas.push_back(alloca);
}
}
}
}
for (auto* alloca : allocas) {
if (!alloca || alloca->GetUses().empty()) continue;
auto* block = SingleUseBlock(*alloca);
if (!block || !CanPromoteInBlock(*alloca, *block)) continue;
changed |= PromoteInBlock(*alloca, *block);
}
}
return changed;
}
} // namespace ir

@ -1 +1,23 @@
// IR Pass 管理骨架。 #include "ir/IR.h"
namespace ir {
bool RunScalarOptimizationPipeline(Module& module) {
bool changed = false;
changed |= RunMem2Reg(module);
for (int i = 0; i < 8; ++i) {
bool iter_changed = false;
iter_changed |= RunConstFold(module);
iter_changed |= RunConstProp(module);
iter_changed |= RunCSE(module);
iter_changed |= RunDCE(module);
iter_changed |= RunCFGSimplify(module);
changed |= iter_changed;
if (!iter_changed) break;
}
return changed;
}
} // namespace ir

@ -1,46 +1,32 @@
#include "irgen/IRGen.h" #include "irgen/IRGen.h"
#include <any>
#include <stdexcept> #include <stdexcept>
#include "SysYParser.h" #include "SysYParser.h"
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h" #include "utils/Log.h"
namespace { std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) {
if (!ctx) return BlockFlow::Continue;
std::string GetLValueName(SysYParser::LValueContext& lvalue) { BlockFlow flow = BlockFlow::Continue;
if (!lvalue.ID()) {
throw std::runtime_error(FormatError("irgen", "非法左值"));
}
return lvalue.ID()->getText();
}
} // namespace
std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句块"));
}
for (auto* item : ctx->blockItem()) { for (auto* item : ctx->blockItem()) {
if (item) { if (item) {
if (VisitBlockItemResult(*item) == BlockFlow::Terminated) { if (VisitBlockItemResult(*item) == BlockFlow::Terminated) {
// 当前语法要求 return 为块内最后一条语句;命中后可停止生成。 flow = BlockFlow::Terminated;
break; break;
} }
} }
} }
return {}; return flow;
} }
IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult( IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult(SysYParser::BlockItemContext& item) {
SysYParser::BlockItemContext& item) {
return std::any_cast<BlockFlow>(item.accept(this)); return std::any_cast<BlockFlow>(item.accept(this));
} }
std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
if (!ctx) { if (!ctx) return BlockFlow::Continue;
throw std::runtime_error(FormatError("irgen", "缺少块内项"));
}
if (ctx->decl()) { if (ctx->decl()) {
ctx->decl()->accept(this); ctx->decl()->accept(this);
return BlockFlow::Continue; return BlockFlow::Continue;
@ -48,60 +34,219 @@ std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
if (ctx->stmt()) { if (ctx->stmt()) {
return ctx->stmt()->accept(this); return ctx->stmt()->accept(this);
} }
throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明")); return BlockFlow::Continue;
} }
// 变量声明的 IR 生成目前也是最小实现:
// - 先检查声明的基础类型,当前仅支持局部 int
// - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。
//
// 和更完整的版本相比,这里还没有:
// - 一个 Decl 中多个变量定义的顺序处理;
// - const、数组、全局变量等不同声明形态
// - 更丰富的类型系统。
std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
if (!ctx) { if (!ctx) return {};
throw std::runtime_error(FormatError("irgen", "缺少变量声明")); if (auto* constDecl = ctx->constDecl()) {
} for (auto* def : constDecl->constDef()) {
if (!ctx->btype() || !ctx->btype()->INT()) { def->accept(this);
throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明")); }
return {};
} }
auto* var_def = ctx->varDef(); if (auto* varDecl = ctx->varDecl()) {
if (!var_def) { for (auto* varDef : varDecl->varDef()) {
throw std::runtime_error(FormatError("irgen", "非法变量声明")); varDef->accept(this);
}
return {};
} }
var_def->accept(this);
return {}; return {};
} }
// 当前仍是教学用的最小版本,因此这里只支持:
// - 局部 int 变量;
// - 标量初始化;
// - 一个 VarDef 对应一个槽位。
std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
if (!ctx) { if (!ctx) return {};
throw std::runtime_error(FormatError("irgen", "缺少变量定义")); if (!ctx->ID()) {
}
if (!ctx->lValue()) {
throw std::runtime_error(FormatError("irgen", "变量声明缺少名称")); throw std::runtime_error(FormatError("irgen", "变量声明缺少名称"));
} }
GetLValueName(*ctx->lValue()); if (!func_) {
if (storage_map_.find(ctx) != storage_map_.end()) { const TypeDesc* ty = sema_.GetVarType(ctx);
throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); if (!ty) {
throw std::runtime_error(FormatError("irgen", "全局变量类型缺失"));
}
if (global_var_storage_.find(ctx) != global_var_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "重复生成全局变量"));
}
ir::ConstantValue* init = nullptr;
if (ty->dims.empty()) {
if (auto* initVal = ctx->initVal()) {
if (!initVal->exp()) {
throw std::runtime_error(FormatError("irgen", "全局变量初始化非法"));
}
init = EvalConstScalar(initVal->exp());
if (ty->base == BaseTypeKind::Int &&
dynamic_cast<ir::ConstantFloat*>(init)) {
auto* cf = static_cast<ir::ConstantFloat*>(init);
init = module_.GetContext().GetConstInt(static_cast<int>(cf->GetValue()));
} else if (ty->base == BaseTypeKind::Float &&
dynamic_cast<ir::ConstantInt*>(init)) {
auto* ci = static_cast<ir::ConstantInt*>(init);
init = module_.GetContext().GetConstFloat(static_cast<float>(ci->GetValue()));
}
}
} else if (auto* initVal = ctx->initVal()) {
size_t total = ArrayTotalSize(*ty);
std::vector<ir::ConstantValue*> values(
total,
ty->base == BaseTypeKind::Float
? static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(0.0f))
: static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(0)));
InitGlobalArray(*ty, initVal, values, 0, 0, 0);
init = module_.GetContext().CreateConstArray(ToIRType(*ty), values);
}
auto* gv = module_.CreateGlobalVariable(ctx->ID()->getText(),
ToIRType(*ty), init, false);
global_var_storage_[ctx] = gv;
return {};
}
if (var_storage_.find(ctx) != var_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "重复生成存储槽位"));
} }
auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); const TypeDesc* ty = sema_.GetVarType(ctx);
storage_map_[ctx] = slot; if (!ty) {
throw std::runtime_error(FormatError("irgen", "变量类型缺失"));
}
auto* slot = CreateEntryAlloca(ToIRType(*ty), module_.GetContext().NextTemp());
var_storage_[ctx] = slot;
ir::Value* init = nullptr; if (ty->dims.empty()) {
if (auto* init_value = ctx->initValue()) { ir::Value* init = nullptr;
if (!init_value->exp()) { if (auto* initVal = ctx->initVal()) {
throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化")); if (!initVal->exp()) {
throw std::runtime_error(FormatError("irgen", "标量初始化非法"));
}
init = EvalExp(initVal->exp());
} else {
init = DefaultValue(*ty);
}
if (ty->base == BaseTypeKind::Float) {
if (init->IsInt1() || init->IsInt32()) {
init = CastToFloat(init->IsInt1() ? CastToInt(init) : init);
}
} else if (ty->base == BaseTypeKind::Int) {
if (init->IsFloat() || init->IsInt1()) {
init = CastToInt(init);
}
} }
init = EvalExpr(*init_value->exp()); builder_.CreateStore(init, slot);
} else { } else {
init = builder_.CreateConstInt(0); if (!ctx->initVal() && ty->dims.size() == 1 && ty->dims[0] >= 1024) {
auto* idx_slot = CreateEntryAlloca(ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
builder_.CreateStore(builder_.CreateConstInt(0), idx_slot);
auto* cond_bb = func_->CreateBlock("arr.zero.cond");
auto* body_bb = func_->CreateBlock("arr.zero.body");
auto* end_bb = func_->CreateBlock("arr.zero.end");
builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(cond_bb);
auto* idx = builder_.CreateLoad(idx_slot, module_.GetContext().NextTemp());
auto* bound = builder_.CreateConstInt(ty->dims[0]);
auto* cmp = builder_.CreateICmp(ir::ICmpPredicate::Slt, idx, bound,
module_.GetContext().NextTemp());
builder_.CreateCondBr(cmp, body_bb, end_bb);
builder_.SetInsertPoint(body_bb);
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
indices.push_back(idx);
auto* elem_addr = builder_.CreateGep(slot, std::move(indices),
module_.GetContext().NextTemp());
builder_.CreateStore(DefaultValue(*ty), elem_addr);
auto* next = builder_.CreateAdd(idx, builder_.CreateConstInt(1),
module_.GetContext().NextTemp());
builder_.CreateStore(next, idx_slot);
builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(end_bb);
} else {
InitArray(slot, *ty, ctx->initVal());
}
} }
builder_.CreateStore(init, slot);
return {}; return {};
} }
std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "常量声明缺少名称"));
}
if (!func_) {
const TypeDesc* ty = sema_.GetConstType(ctx);
if (!ty) {
throw std::runtime_error(FormatError("irgen", "全局常量类型缺失"));
}
if (global_const_storage_.find(ctx) != global_const_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "重复生成全局常量"));
}
ir::ConstantValue* init = nullptr;
if (ty->dims.empty()) {
if (auto* initVal = ctx->constInitVal()) {
if (!initVal->constExp()) {
throw std::runtime_error(FormatError("irgen", "全局常量初始化非法"));
}
init = EvalConstScalar(initVal->constExp());
if (ty->base == BaseTypeKind::Int &&
dynamic_cast<ir::ConstantFloat*>(init)) {
auto* cf = static_cast<ir::ConstantFloat*>(init);
init = module_.GetContext().GetConstInt(static_cast<int>(cf->GetValue()));
} else if (ty->base == BaseTypeKind::Float &&
dynamic_cast<ir::ConstantInt*>(init)) {
auto* ci = static_cast<ir::ConstantInt*>(init);
init = module_.GetContext().GetConstFloat(static_cast<float>(ci->GetValue()));
}
}
} else if (auto* initVal = ctx->constInitVal()) {
size_t total = ArrayTotalSize(*ty);
std::vector<ir::ConstantValue*> values(
total,
ty->base == BaseTypeKind::Float
? static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(0.0f))
: static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(0)));
InitGlobalConstArray(*ty, initVal, values, 0, 0, 0);
init = module_.GetContext().CreateConstArray(ToIRType(*ty), values);
}
auto* gv = module_.CreateGlobalVariable(ctx->ID()->getText(),
ToIRType(*ty), init, true);
global_const_storage_[ctx] = gv;
return {};
}
if (const_storage_.find(ctx) != const_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "重复生成常量存储"));
}
const TypeDesc* ty = sema_.GetConstType(ctx);
if (!ty) {
throw std::runtime_error(FormatError("irgen", "常量类型缺失"));
}
auto* slot = CreateEntryAlloca(ToIRType(*ty), module_.GetContext().NextTemp());
const_storage_[ctx] = slot;
if (ty->dims.empty()) {
ir::Value* init = nullptr;
if (auto* initVal = ctx->constInitVal()) {
if (!initVal->constExp()) {
throw std::runtime_error(FormatError("irgen", "常量初始化非法"));
}
init = std::any_cast<ir::Value*>(initVal->constExp()->accept(this));
} else {
init = DefaultValue(*ty);
}
if (ty->base == BaseTypeKind::Float) {
if (init->IsInt1() || init->IsInt32()) {
init = CastToFloat(init->IsInt1() ? CastToInt(init) : init);
}
} else if (ty->base == BaseTypeKind::Int) {
if (init->IsFloat() || init->IsInt1()) {
init = CastToInt(init);
}
}
builder_.CreateStore(init, slot);
} else {
InitConstArray(slot, *ty, ctx->constInitVal());
}
return {};
}

@ -4,12 +4,11 @@
#include "SysYParser.h" #include "SysYParser.h"
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h"
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree, std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemanticContext& sema) { const SemanticContext& sema) {
auto module = std::make_unique<ir::Module>(); auto module = std::make_unique<ir::Module>(); // 无参构造
IRGenImpl gen(*module, sema); IRGenImpl visitor(*module, sema);
tree.accept(&gen); tree.accept(&visitor);
return module; return module;
} }

File diff suppressed because it is too large Load Diff

@ -6,82 +6,116 @@
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h" #include "utils/Log.h"
namespace {
void VerifyFunctionStructure(const ir::Function& func) {
// 当前 IRGen 仍是单入口、顺序生成;这里在生成结束后补一层块终结校验。
for (const auto& bb : func.GetBlocks()) {
if (!bb || !bb->HasTerminator()) {
throw std::runtime_error(
FormatError("irgen", "基本块未正确终结: " +
(bb ? bb->GetName() : std::string("<null>"))));
}
}
}
} // namespace
IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema) IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema)
: module_(module), : module_(module),
sema_(sema), sema_(sema),
func_(nullptr), func_(nullptr),
builder_(module.GetContext(), nullptr) {} builder_(module.GetContext(), nullptr) {}
// 编译单元的 IR 生成当前只实现了最小功能:
// - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容;
// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR
//
// 当前还没有实现:
// - 多个函数定义的遍历与生成;
// - 全局变量、全局常量的 IR 生成。
std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
if (!ctx) { if (!ctx) return {};
throw std::runtime_error(FormatError("irgen", "缺少编译单元"));
func_map_.clear();
global_var_storage_.clear();
global_const_storage_.clear();
func_ = nullptr;
for (auto* decl : ctx->decl()) {
if (decl) decl->accept(this);
}
for (auto* funcDef : ctx->funcDef()) {
if (!funcDef || !funcDef->ID()) continue;
const auto* fty = sema_.GetFuncType(funcDef);
if (!fty) {
throw std::runtime_error(FormatError("irgen", "缺少函数类型"));
}
std::vector<std::shared_ptr<ir::Type>> params;
for (const auto& p : fty->params) {
params.push_back(ToIRParamType(p));
}
auto ret = ToIRType(fty->ret);
auto func_ty = ir::Type::GetFunctionType(ret, params);
auto* fn = module_.CreateFunctionWithType(funcDef->ID()->getText(), func_ty);
func_map_[funcDef] = fn;
} }
auto* func = ctx->funcDef();
if (!func) { auto declare_builtin = [&](const std::string& name,
throw std::runtime_error(FormatError("irgen", "缺少函数定义")); std::shared_ptr<ir::Type> ret,
std::vector<std::shared_ptr<ir::Type>> params) {
for (const auto& fn : module_.GetFunctions()) {
if (fn && fn->GetName() == name) return;
}
auto fty = ir::Type::GetFunctionType(ret, params);
module_.CreateFunctionDecl(name, fty);
};
auto i32 = ir::Type::GetInt32Type();
auto f32 = ir::Type::GetFloatType();
declare_builtin("getint", i32, {});
declare_builtin("getch", i32, {});
declare_builtin("getarray", i32, {ir::Type::GetPointerType(i32)});
declare_builtin("putint", ir::Type::GetVoidType(), {i32});
declare_builtin("putch", ir::Type::GetVoidType(), {i32});
declare_builtin("putarray", ir::Type::GetVoidType(),
{i32, ir::Type::GetPointerType(i32)});
declare_builtin("getfloat", f32, {});
declare_builtin("getfarray", i32, {ir::Type::GetPointerType(f32)});
declare_builtin("putfloat", ir::Type::GetVoidType(), {f32});
declare_builtin("putfarray", ir::Type::GetVoidType(),
{i32, ir::Type::GetPointerType(f32)});
declare_builtin("starttime", ir::Type::GetVoidType(), {});
declare_builtin("stoptime", ir::Type::GetVoidType(), {});
for (auto* funcDef : ctx->funcDef()) {
if (funcDef) funcDef->accept(this);
} }
func->accept(this);
return {}; return {};
} }
// 函数 IR 生成当前实现了:
// 1. 获取函数名;
// 2. 检查函数返回类型;
// 3. 在 Module 中创建 Function
// 4. 将 builder 插入点设置到入口基本块;
// 5. 继续生成函数体。
//
// 当前还没有实现:
// - 通用函数返回类型处理;
// - 形参列表遍历与参数类型收集;
// - FunctionType 这样的函数类型对象;
// - Argument/形式参数 IR 对象;
// - 入口块中的参数初始化逻辑。
// ...
// 因此这里目前只支持最小的“无参 int 函数”生成。
std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
if (!ctx) { if (!ctx || !ctx->block()) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
}
if (!ctx->blockStmt()) {
throw std::runtime_error(FormatError("irgen", "函数体为空")); throw std::runtime_error(FormatError("irgen", "函数体为空"));
} }
if (!ctx->ID()) { auto it = func_map_.find(ctx);
throw std::runtime_error(FormatError("irgen", "缺少函数名")); if (it == func_map_.end()) {
throw std::runtime_error(FormatError("irgen", "函数未注册"));
} }
if (!ctx->funcType() || !ctx->funcType()->INT()) { func_ = it->second;
throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数")); auto* entry = func_->GetEntry();
builder_.SetInsertPoint(entry);
var_storage_.clear();
const_storage_.clear();
param_storage_.clear();
loop_stack_.clear();
const auto* fty = sema_.GetFuncType(ctx);
if (!fty) {
throw std::runtime_error(FormatError("irgen", "缺少函数类型"));
}
if (ctx->funcFParams()) {
auto params = ctx->funcFParams()->funcFParam();
for (size_t i = 0; i < params.size(); ++i) {
auto* param_ctx = params[i];
auto* arg = func_->GetArg(i);
const TypeDesc* pty = sema_.GetParamType(param_ctx);
if (!pty) {
throw std::runtime_error(FormatError("irgen", "缺少参数类型"));
}
auto slot = CreateEntryAlloca(ToIRParamType(*pty),
module_.GetContext().NextTemp());
builder_.CreateStore(arg, slot);
param_storage_[param_ctx] = slot;
}
} }
func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type()); ctx->block()->accept(this);
builder_.SetInsertPoint(func_->GetEntry());
storage_map_.clear();
ctx->blockStmt()->accept(this); if (!builder_.GetInsertBlock()->HasTerminator()) {
// 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。 if (func_->GetReturnType()->IsVoid()) {
VerifyFunctionStructure(*func_); builder_.CreateRetVoid();
} else {
TypeDesc ret = fty->ret;
builder_.CreateRet(DefaultValue(ret));
}
}
return {}; return {};
} }

@ -1,39 +1,132 @@
#include "irgen/IRGen.h" #include "irgen/IRGen.h"
#include <any>
#include <stdexcept> #include <stdexcept>
#include "SysYParser.h" #include "SysYParser.h"
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h" #include "utils/Log.h"
// 语句生成当前只实现了最小子集。
// 目前支持:
// - return <exp>;
//
// 还未支持:
// - 赋值语句
// - if / while 等控制流
// - 空语句、块语句嵌套分发之外的更多语句形态
std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
if (!ctx) { if (!ctx) return {};
throw std::runtime_error(FormatError("irgen", "缺少语句")); if (ctx->lVal() && ctx->ASSIGN()) {
ir::Value* addr = GetLValAddress(ctx->lVal());
ir::Value* val = EvalExp(ctx->exp());
BoundDecl bound = sema_.ResolveVarUse(ctx->lVal());
if ((bound.kind == BoundDecl::Kind::Var && !bound.var_decl) ||
(bound.kind == BoundDecl::Kind::Const && !bound.const_decl) ||
(bound.kind == BoundDecl::Kind::Param && !bound.param_decl)) {
throw std::runtime_error(FormatError(
"irgen", "赋值左值缺少语义绑定: " + ctx->lVal()->getText()));
}
const TypeDesc* ty = nullptr;
if (bound.kind == BoundDecl::Kind::Var && bound.var_decl) {
ty = sema_.GetVarType(bound.var_decl);
} else if (bound.kind == BoundDecl::Kind::Param && bound.param_decl) {
ty = sema_.GetParamType(bound.param_decl);
} else if (bound.kind == BoundDecl::Kind::Const) {
throw std::runtime_error(FormatError("irgen", "不能给常量赋值"));
}
if (!ty) {
throw std::runtime_error(FormatError("irgen", "无法解析赋值类型"));
}
if (ty->base == BaseTypeKind::Float) {
if (val->IsInt1()) {
val = CastToFloat(CastToInt(val));
} else if (val->IsInt32()) {
val = CastToFloat(val);
}
} else if (ty->base == BaseTypeKind::Int) {
if (val->IsFloat()) {
val = CastToInt(val);
} else if (val->IsInt1()) {
val = CastToInt(val);
}
}
builder_.CreateStore(val, addr);
return BlockFlow::Continue;
}
if (ctx->block()) {
return ctx->block()->accept(this);
} }
if (ctx->returnStmt()) { if (ctx->IF()) {
return ctx->returnStmt()->accept(this); auto* then_bb = func_->CreateBlock("if.then");
auto* else_bb = func_->CreateBlock("if.else");
auto* merge_bb = func_->CreateBlock("if.end");
EmitCondBr(ctx->cond(), then_bb, else_bb);
builder_.SetInsertPoint(then_bb);
auto then_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
if (then_flow != BlockFlow::Terminated) {
builder_.CreateBr(merge_bb);
}
builder_.SetInsertPoint(else_bb);
if (ctx->stmt(1)) {
auto else_flow = std::any_cast<BlockFlow>(ctx->stmt(1)->accept(this));
if (else_flow != BlockFlow::Terminated) {
builder_.CreateBr(merge_bb);
}
} else {
builder_.CreateBr(merge_bb);
}
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
} }
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); if (ctx->WHILE()) {
} auto* cond_bb = func_->CreateBlock("while.cond");
auto* body_bb = func_->CreateBlock("while.body");
auto* end_bb = func_->CreateBlock("while.end");
builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(cond_bb);
EmitCondBr(ctx->cond(), body_bb, end_bb);
std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { builder_.SetInsertPoint(body_bb);
if (!ctx) { PushLoop(end_bb, cond_bb);
throw std::runtime_error(FormatError("irgen", "缺少 return 语句")); auto body_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
PopLoop();
if (body_flow != BlockFlow::Terminated) {
builder_.CreateBr(cond_bb);
}
builder_.SetInsertPoint(end_bb);
return BlockFlow::Continue;
}
if (ctx->BREAK()) {
auto* target = CurrentBreak();
if (!target) {
throw std::runtime_error(FormatError("irgen", "break 不在循环内"));
}
builder_.CreateBr(target);
return BlockFlow::Terminated;
}
if (ctx->CONTINUE()) {
auto* target = CurrentContinue();
if (!target) {
throw std::runtime_error(FormatError("irgen", "continue 不在循环内"));
}
builder_.CreateBr(target);
return BlockFlow::Terminated;
}
if (ctx->RETURN()) {
if (!ctx->exp()) {
builder_.CreateRetVoid();
return BlockFlow::Terminated;
}
ir::Value* v = EvalExp(ctx->exp());
auto ret_ty = func_->GetReturnType();
if (ret_ty->IsFloat() && v->IsInt32()) {
v = CastToFloat(v);
} else if (ret_ty->IsInt32() && v->IsFloat()) {
v = CastToInt(v);
}
builder_.CreateRet(v);
return BlockFlow::Terminated;
} }
if (!ctx->exp()) { if (ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "return 缺少表达式")); EvalExp(ctx->exp());
} }
ir::Value* v = EvalExpr(*ctx->exp()); return BlockFlow::Continue;
builder_.CreateRet(v); }
return BlockFlow::Terminated;
}

@ -36,6 +36,9 @@ int main(int argc, char** argv) {
auto sema = RunSema(*comp_unit); auto sema = RunSema(*comp_unit);
auto module = GenerateIR(*comp_unit, sema); auto module = GenerateIR(*comp_unit, sema);
if (opts.optimize_ir) {
ir::RunScalarOptimizationPipeline(*module);
}
if (opts.emit_ir) { if (opts.emit_ir) {
ir::IRPrinter printer; ir::IRPrinter printer;
if (need_blank_line) { if (need_blank_line) {
@ -46,13 +49,10 @@ int main(int argc, char** argv) {
} }
if (opts.emit_asm) { if (opts.emit_asm) {
auto machine_func = mir::LowerToMIR(*module);
mir::RunRegAlloc(*machine_func);
mir::RunFrameLowering(*machine_func);
if (need_blank_line) { if (need_blank_line) {
std::cout << "\n"; std::cout << "\n";
} }
mir::PrintAsm(*machine_func, std::cout); mir::PrintAArch64AsmFromMIR(*module, std::cout);
} }
#else #else
if (opts.emit_ir || opts.emit_asm) { if (opts.emit_ir || opts.emit_asm) {

@ -0,0 +1,3 @@
fn main() {
println!("Hello, world!");
}

@ -1,78 +1,351 @@
// AArch64 汇编发射Lab5
// 输入为寄存器分配 + 栈帧布局后的 MachineModule操作数均为物理寄存器/栈对象)。
#include "mir/MIR.h" #include "mir/MIR.h"
#include <ostream> #include <ostream>
#include <stdexcept> #include <string>
#include <vector>
#include "utils/Log.h"
namespace mir { namespace mir {
namespace { namespace {
const FrameSlot& GetFrameSlot(const MachineFunction& function, int CalleeAreaBytes(const MachineFunction& f) {
const Operand& operand) { return ((int)f.CalleeSavedGPR().size() + (int)f.CalleeSavedFPR().size()) * 8;
if (operand.GetKind() != Operand::Kind::FrameIndex) { }
throw std::runtime_error(FormatError("mir", "期望 FrameIndex 操作数"));
class Printer {
public:
Printer(const MachineModule& m, std::ostream& os) : m_(m), os_(os) {}
void Run();
private:
const MachineModule& m_;
std::ostream& os_;
const MachineFunction* mf_ = nullptr;
void EmitGlobals();
void EmitFunction(const MachineFunction& f);
void EmitProlog(const MachineFunction& f);
void EmitEpilog(const MachineFunction& f);
void EmitInstr(const MachineInstr& mi);
std::string R(const Operand& op) {
if (op.GetClass() == RegClass::FPR)
return FPRName(op.GetId(), op.GetBytes());
return GPRName(op.GetId(), op.GetBytes());
}
// 把任意立即数加载进暂存寄存器x16 / x17
void LoadImm(const char* dst, long long v);
// [x29 + offset] 形式访存offset 过大时用 x16 计算地址。
void MemAccess(const char* mnem, const std::string& reg, int offset);
int FrameOffset(int stack_index) {
return mf_->StackObjects()[stack_index].offset;
}
};
const char* CondStr(Cond c) {
switch (c) {
case Cond::EQ: return "eq";
case Cond::NE: return "ne";
case Cond::LT: return "lt";
case Cond::LE: return "le";
case Cond::GT: return "gt";
case Cond::GE: return "ge";
case Cond::MI: return "mi";
case Cond::LS: return "ls";
case Cond::HI: return "hi";
case Cond::HS: return "hs";
default: return "al";
} }
return function.GetFrameSlot(operand.GetFrameIndex());
} }
void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg, void Printer::LoadImm(const char* dst, long long v) {
int offset) { bool is_w = dst[0] == 'w';
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset unsigned long long uv = (unsigned long long)v;
<< "]\n"; int hi = is_w ? 32 : 64; // w 寄存器只填低 32 位
if (is_w) uv &= 0xffffffffULL;
os_ << "\tmov\t" << dst << ", #" << (uv & 0xffff) << "\n";
for (int sh = 16; sh < hi; sh += 16) {
unsigned chunk = (uv >> sh) & 0xffff;
if (chunk)
os_ << "\tmovk\t" << dst << ", #" << chunk << ", lsl #" << sh << "\n";
}
} }
} // namespace void Printer::MemAccess(const char* mnem, const std::string& reg, int offset) {
if (offset >= -256 && offset <= 4095) {
os_ << "\t" << mnem << "\t" << reg << ", [x29, #" << offset << "]\n";
} else {
LoadImm("x16", offset);
os_ << "\tadd\tx16, x29, x16\n";
os_ << "\t" << mnem << "\t" << reg << ", [x16]\n";
}
}
void PrintAsm(const MachineFunction& function, std::ostream& os) { void Printer::Run() {
os << ".text\n"; EmitGlobals();
os << ".global " << function.GetName() << "\n"; os_ << "\t.text\n";
os << ".type " << function.GetName() << ", %function\n"; for (const auto& f : m_.Functions()) EmitFunction(*f);
os << function.GetName() << ":\n"; }
for (const auto& inst : function.GetEntry().GetInstructions()) { void Printer::EmitGlobals() {
const auto& ops = inst.GetOperands(); if (m_.Globals().empty()) return;
switch (inst.GetOpcode()) { for (const auto& g : m_.Globals()) {
case Opcode::Prologue: if (g.zero_init) {
os << " stp x29, x30, [sp, #-16]!\n"; os_ << "\t.bss\n";
os << " mov x29, sp\n"; os_ << "\t.align\t" << (g.align == 16 ? 4 : 2) << "\n";
if (function.GetFrameSize() > 0) { os_ << "\t.globl\t" << g.name << "\n";
os << " sub sp, sp, #" << function.GetFrameSize() << "\n"; os_ << g.name << ":\n";
} os_ << "\t.zero\t" << g.size << "\n";
break; } else {
case Opcode::Epilogue: os_ << "\t.data\n";
if (function.GetFrameSize() > 0) { os_ << "\t.align\t" << (g.align == 16 ? 4 : 2) << "\n";
os << " add sp, sp, #" << function.GetFrameSize() << "\n"; os_ << "\t.globl\t" << g.name << "\n";
} os_ << g.name << ":\n";
os << " ldp x29, x30, [sp], #16\n"; for (unsigned w : g.words) os_ << "\t.word\t" << w << "\n";
break; }
case Opcode::MovImm: }
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #" }
<< ops.at(1).GetImm() << "\n";
break; void Printer::EmitFunction(const MachineFunction& f) {
case Opcode::LoadStack: { mf_ = &f;
const auto& slot = GetFrameSlot(function, ops.at(1)); os_ << "\t.globl\t" << f.GetName() << "\n";
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset); os_ << "\t.type\t" << f.GetName() << ", %function\n";
os_ << f.GetName() << ":\n";
EmitProlog(f);
for (const auto& bb : f.Blocks()) {
os_ << ".L." << f.GetName() << "." << bb->GetName() << ":\n";
for (const auto& mi : bb->Instrs()) EmitInstr(mi);
}
os_ << "\t.size\t" << f.GetName() << ", .-" << f.GetName() << "\n";
}
void Printer::EmitProlog(const MachineFunction& f) {
int frame = f.GetFrameSize();
// sub sp, sp, #frame ; 保存 fp/lr ; mov x29, sp
if (frame <= 4095) {
os_ << "\tsub\tsp, sp, #" << frame << "\n";
} else {
LoadImm("x16", frame);
os_ << "\tsub\tsp, sp, x16\n";
}
os_ << "\tstp\tx29, x30, [sp]\n";
os_ << "\tmov\tx29, sp\n";
// 保存 callee-saved相对 x29 偏移,从 16 开始)。
int off = 16;
for (int r : f.CalleeSavedGPR()) {
MemAccess("str", GPRName(r, 8), off);
off += 8;
}
for (int r : f.CalleeSavedFPR()) {
MemAccess("str", FPRName(r, 8), off);
off += 8;
}
}
void Printer::EmitEpilog(const MachineFunction& f) {
int frame = f.GetFrameSize();
int off = 16;
for (int r : f.CalleeSavedGPR()) {
MemAccess("ldr", GPRName(r, 8), off);
off += 8;
}
for (int r : f.CalleeSavedFPR()) {
MemAccess("ldr", FPRName(r, 8), off);
off += 8;
}
os_ << "\tldp\tx29, x30, [sp]\n";
if (frame <= 4095) {
os_ << "\tadd\tsp, sp, #" << frame << "\n";
} else {
LoadImm("x16", frame);
os_ << "\tadd\tsp, sp, x16\n";
}
os_ << "\tret\n";
}
void Printer::EmitInstr(const MachineInstr& mi) {
auto label = [&](const Operand& o) {
return ".L." + mf_->GetName() + "." + o.GetSym();
};
switch (mi.op) {
case Opcode::Mov: {
// 同寄存器拷贝可省略peephole 已处理,这里再兜底)。
if (mi.ops[0].IsPReg() && mi.ops[1].IsPReg() &&
mi.ops[0].GetId() == mi.ops[1].GetId())
break; break;
os_ << "\tmov\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n";
break;
}
case Opcode::MovImm:
LoadImm(R(mi.ops[0]).c_str(), mi.ops[1].GetImm());
break;
case Opcode::Sxtw:
os_ << "\tsxtw\t" << GPRName(mi.ops[0].GetId(), 8) << ", "
<< GPRName(mi.ops[1].GetId(), 4) << "\n";
break;
case Opcode::Add:
os_ << "\tadd\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << "\n";
break;
case Opcode::Sub:
os_ << "\tsub\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << "\n";
break;
case Opcode::Mul:
os_ << "\tmul\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << "\n";
break;
case Opcode::SDiv:
os_ << "\tsdiv\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << "\n";
break;
case Opcode::MSub:
os_ << "\tmsub\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << ", " << R(mi.ops[3]) << "\n";
break;
case Opcode::AddImm: {
long long v = mi.ops[2].GetImm();
if (v >= 0 && v <= 4095) {
os_ << "\tadd\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", #" << v
<< "\n";
} else {
LoadImm("x16", v);
os_ << "\tadd\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", x16\n";
} }
case Opcode::StoreStack: { break;
const auto& slot = GetFrameSlot(function, ops.at(1)); }
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset); case Opcode::SubImm:
os_ << "\tsub\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", #"
<< mi.ops[2].GetImm() << "\n";
break;
case Opcode::LslImm:
os_ << "\tlsl\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", #"
<< mi.ops[2].GetImm() << "\n";
break;
case Opcode::Cmp:
os_ << "\tcmp\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n";
break;
case Opcode::CmpImm:
os_ << "\tcmp\t" << R(mi.ops[0]) << ", #" << mi.ops[1].GetImm() << "\n";
break;
case Opcode::CSet:
os_ << "\tcset\t" << R(mi.ops[0]) << ", " << CondStr(mi.cond) << "\n";
break;
case Opcode::FAdd:
os_ << "\tfadd\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << "\n";
break;
case Opcode::FSub:
os_ << "\tfsub\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << "\n";
break;
case Opcode::FMul:
os_ << "\tfmul\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << "\n";
break;
case Opcode::FDiv:
os_ << "\tfdiv\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << ", "
<< R(mi.ops[2]) << "\n";
break;
case Opcode::FCmp:
os_ << "\tfcmp\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n";
break;
case Opcode::FMov: {
if (mi.ops[0].IsPReg() && mi.ops[1].IsPReg() &&
mi.ops[0].GetId() == mi.ops[1].GetId())
break; break;
os_ << "\tfmov\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n";
break;
}
case Opcode::FMovImm:
// 通过 x16 装入 32 位浮点位模式,再 fmov 到 s 寄存器。
LoadImm("w16", (long long)(unsigned)mi.ops[1].GetImm());
os_ << "\tfmov\t" << R(mi.ops[0]) << ", w16\n";
break;
case Opcode::SCvtF:
os_ << "\tscvtf\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n";
break;
case Opcode::FCvtZS:
os_ << "\tfcvtzs\t" << R(mi.ops[0]) << ", " << R(mi.ops[1]) << "\n";
break;
case Opcode::Ldr: {
long long off = mi.ops[2].GetImm();
if (off >= -256 && off <= 4095)
os_ << "\tldr\t" << R(mi.ops[0]) << ", [" << R(mi.ops[1]) << ", #"
<< off << "]\n";
else {
LoadImm("x16", off);
os_ << "\tldr\t" << R(mi.ops[0]) << ", [" << R(mi.ops[1]) << ", x16]\n";
} }
case Opcode::AddRR: break;
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " }
<< PhysRegName(ops.at(1).GetReg()) << ", " case Opcode::Str: {
<< PhysRegName(ops.at(2).GetReg()) << "\n"; long long off = mi.ops[2].GetImm();
break; if (off >= -256 && off <= 4095)
case Opcode::Ret: os_ << "\tstr\t" << R(mi.ops[0]) << ", [" << R(mi.ops[1]) << ", #"
os << " ret\n"; << off << "]\n";
break; else {
LoadImm("x16", off);
os_ << "\tstr\t" << R(mi.ops[0]) << ", [" << R(mi.ops[1]) << ", x16]\n";
}
break;
} }
case Opcode::LdrStack:
MemAccess("ldr", R(mi.ops[0]), FrameOffset(mi.ops[1].GetFrame()));
break;
case Opcode::StrStack:
MemAccess("str", R(mi.ops[0]), FrameOffset(mi.ops[1].GetFrame()));
break;
case Opcode::AddrFrame: {
int off = FrameOffset(mi.ops[1].GetFrame());
if (off >= 0 && off <= 4095)
os_ << "\tadd\t" << R(mi.ops[0]) << ", x29, #" << off << "\n";
else {
LoadImm("x16", off);
os_ << "\tadd\t" << R(mi.ops[0]) << ", x29, x16\n";
}
break;
}
case Opcode::AddrGlobal:
os_ << "\tadrp\t" << R(mi.ops[0]) << ", " << mi.ops[1].GetSym() << "\n";
os_ << "\tadd\t" << R(mi.ops[0]) << ", " << R(mi.ops[0]) << ", :lo12:"
<< mi.ops[1].GetSym() << "\n";
break;
case Opcode::B:
os_ << "\tb\t" << label(mi.ops[0]) << "\n";
break;
case Opcode::BCond:
os_ << "\tb." << CondStr(mi.cond) << "\t" << label(mi.ops[0]) << "\n";
break;
case Opcode::Bl:
os_ << "\tbl\t" << mi.ops[0].GetSym() << "\n";
break;
case Opcode::Ret:
EmitEpilog(*mf_);
break;
} }
}
} // namespace
void PrintAsm(const MachineModule& module, std::ostream& os) {
Printer p(module, os);
p.Run();
}
void RunBackendPipeline(MachineModule& module) {
for (auto& f : module.Functions()) {
RunRegAlloc(*f);
RunFrameLowering(*f);
RunPeephole(*f);
}
}
os << ".size " << function.GetName() << ", .-" << function.GetName() void PrintAArch64AsmFromMIR(const ir::Module& module, std::ostream& os) {
<< "\n"; auto mm = LowerToMIR(module);
RunBackendPipeline(*mm);
PrintAsm(*mm, os);
} }
} // namespace mir } // namespace mir

@ -8,6 +8,9 @@ add_library(mir_core STATIC
RegAlloc.cpp RegAlloc.cpp
FrameLowering.cpp FrameLowering.cpp
AsmPrinter.cpp AsmPrinter.cpp
LLVMAsmBackend.cpp
passes/PassManager.cpp
passes/Peephole.cpp
) )
target_link_libraries(mir_core PUBLIC target_link_libraries(mir_core PUBLIC
@ -15,10 +18,7 @@ target_link_libraries(mir_core PUBLIC
ir ir
) )
add_subdirectory(passes)
add_library(mir INTERFACE) add_library(mir INTERFACE)
target_link_libraries(mir INTERFACE target_link_libraries(mir INTERFACE
mir_core mir_core
mir_passes
) )

@ -1,45 +1,35 @@
// 栈帧布局Lab5
// 在寄存器分配产出 spill 槽与 callee-saved 使用集合后,确定每个栈对象
// 相对 x29 的偏移与总帧大小。实际的 prologue/epilogue 由 AsmPrinter 按
// 同一套布局公式发射。
//
// 帧布局x29 指向帧底sp == x29
// [x29 + 0] 保存的 x29
// [x29 + 8] 保存的 x30(lr)
// [x29 + 16 ...] callee-saved GPR、callee-saved FPR各 8 字节)
// [其后] 局部/spill 栈对象(按声明顺序,按对齐摆放)
// 总大小对齐到 16 字节。
#include "mir/MIR.h" #include "mir/MIR.h"
#include <stdexcept>
#include <vector>
#include "utils/Log.h"
namespace mir { namespace mir {
namespace {
int AlignTo(int value, int align) { int CalleeSavedAreaBytes(const MachineFunction& f) {
return ((value + align - 1) / align) * align; int n = (int)f.CalleeSavedGPR().size() + (int)f.CalleeSavedFPR().size();
return n * 8;
} }
} // namespace
void RunFrameLowering(MachineFunction& function) { void RunFrameLowering(MachineFunction& function) {
int cursor = 0; int base = 16 + CalleeSavedAreaBytes(function); // fp/lr + callee-saved
for (const auto& slot : function.GetFrameSlots()) { int off = base;
cursor += slot.size; for (auto& obj : function.StackObjects()) {
if (-cursor < -256) { int align = obj.align < 4 ? 4 : obj.align;
throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧")); off = (off + align - 1) / align * align;
} obj.offset = off; // 相对 x29 的正偏移
} off += obj.size;
cursor = 0;
for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size;
function.GetFrameSlot(slot.index).offset = -cursor;
}
function.SetFrameSize(AlignTo(cursor, 16));
auto& insts = function.GetEntry().GetInstructions();
std::vector<MachineInstr> lowered;
lowered.emplace_back(Opcode::Prologue);
for (const auto& inst : insts) {
if (inst.GetOpcode() == Opcode::Ret) {
lowered.emplace_back(Opcode::Epilogue);
}
lowered.push_back(inst);
} }
insts = std::move(lowered); int frame = (off + 15) / 16 * 16;
if (frame < 16) frame = 16;
function.SetFrameSize(frame);
} }
} // namespace mir } // namespace mir

@ -0,0 +1,132 @@
#include "mir/MIR.h"
#include <chrono>
#include <cstdio>
#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <ostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include "ir/IR.h"
#include "utils/Log.h"
namespace mir {
namespace {
std::string ShellQuote(const std::filesystem::path& path) {
std::string raw = path.string();
#if defined(_WIN32)
std::string quoted = "\"";
for (char ch : raw) {
if (ch == '"') {
quoted += "\\\"";
} else {
quoted += ch;
}
}
quoted += "\"";
return quoted;
#else
std::string quoted = "'";
for (char ch : raw) {
if (ch == '\'') {
quoted += "'\\''";
} else {
quoted += ch;
}
}
quoted += "'";
return quoted;
#endif
}
std::string ReadTextFile(const std::filesystem::path& path) {
std::ifstream in(path, std::ios::binary);
if (!in) {
throw std::runtime_error(
FormatError("mir", "无法读取临时汇编文件: " + path.string()));
}
std::ostringstream oss;
oss << in.rdbuf();
return oss.str();
}
std::filesystem::path CreateTempDir() {
auto base = std::filesystem::temp_directory_path();
#if defined(_WIN32)
auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count();
for (int i = 0; i < 100; ++i) {
auto candidate =
base / ("nudt_lab3_" + std::to_string(seed) + "_" + std::to_string(i));
std::error_code ec;
if (std::filesystem::create_directory(candidate, ec)) {
return candidate;
}
}
throw std::runtime_error(FormatError("mir", "创建临时目录失败"));
#else
std::string pattern = (base / "nudt_lab3_XXXXXX").string();
std::vector<char> dir_template(pattern.begin(), pattern.end());
dir_template.push_back('\0');
char* created = mkdtemp(dir_template.data());
if (!created) {
throw std::runtime_error(FormatError("mir", "创建临时目录失败"));
}
return std::filesystem::path(created);
#endif
}
} // namespace
void PrintAArch64AsmFromIR(const ir::Module& module, std::ostream& os) {
std::filesystem::path work_dir = CreateTempDir();
const auto ir_file = work_dir / "module.ll";
const auto asm_file = work_dir / "module.s";
const auto err_file = work_dir / "clang.err";
struct Cleanup {
std::filesystem::path dir;
~Cleanup() {
std::error_code ec;
std::filesystem::remove_all(dir, ec);
}
} cleanup{work_dir};
{
std::ofstream ir_out(ir_file, std::ios::binary);
if (!ir_out) {
throw std::runtime_error(
FormatError("mir", "无法写入临时 IR 文件: " + ir_file.string()));
}
ir::IRPrinter printer;
printer.Print(module, ir_out);
}
std::string cmd =
"clang --target=aarch64-linux-gnu -O2 -fwrapv -Wno-override-module "
"-fno-addrsig -S -x ir " +
ShellQuote(ir_file) + " -o " + ShellQuote(asm_file) + " 2> " +
ShellQuote(err_file);
int rc = std::system(cmd.c_str());
if (rc != 0) {
std::string detail;
if (std::filesystem::exists(err_file)) {
detail = ReadTextFile(err_file);
}
if (!detail.empty() && detail.back() == '\n') {
detail.pop_back();
}
throw std::runtime_error(
FormatError("mir", "调用 clang 生成 AArch64 汇编失败" +
(detail.empty() ? std::string() : ": " + detail)));
}
os << ReadTextFile(asm_file);
}
} // namespace mir

@ -1,7 +1,15 @@
// IR -> MIR 指令选择Lab5
// - 为每个 IR 值分配虚拟寄存器GPR / FPR 两类)
// - alloca -> 栈对象gep/global -> 地址计算
// - 完整覆盖算术、比较、分支、调用、访存、类型转换、浮点
#include "mir/MIR.h" #include "mir/MIR.h"
#include <cstring>
#include <functional>
#include <stdexcept> #include <stdexcept>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h" #include "utils/Log.h"
@ -9,115 +17,509 @@
namespace mir { namespace mir {
namespace { namespace {
using ValueSlotMap = std::unordered_map<const ir::Value*, int>; int TypeSize(const ir::Type& t) {
switch (t.GetKind()) {
case ir::Type::Kind::Int1:
case ir::Type::Kind::Int32:
case ir::Type::Kind::Float:
return 4;
case ir::Type::Kind::Pointer:
return 8;
case ir::Type::Kind::Array:
return static_cast<int>(t.GetArraySize()) *
TypeSize(*t.GetElementType());
default:
return 8;
}
}
RegClass ClassOf(const ir::Type& t) {
return t.IsFloat() ? RegClass::FPR : RegClass::GPR;
}
int BytesOf(const ir::Type& t) { return t.IsPointer() ? 8 : 4; }
bool IsPow2(long long v) { return v > 0 && (v & (v - 1)) == 0; }
int Log2(long long v) {
int n = 0;
while ((1LL << n) < v) ++n;
return n;
}
class Lowerer {
public:
Lowerer(const ir::Module& m, MachineModule& out) : ir_(m), out_(out) {}
void Run();
private:
const ir::Module& ir_;
MachineModule& out_;
MachineFunction* mf_ = nullptr;
MachineBasicBlock* mbb_ = nullptr;
std::unordered_map<const ir::Value*, Operand> vmap_;
std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> bmap_;
std::unordered_map<const ir::AllocaInst*, int> allocas_;
int label_id_ = 0;
void EmitValueToReg(const ir::Value* value, PhysReg target, void Emit(Opcode op, std::vector<Operand> ops, int defs, Cond c = Cond::AL) {
const ValueSlotMap& slots, MachineBasicBlock& block) { mbb_->Add(MachineInstr(op, std::move(ops), defs, c));
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
block.Append(Opcode::MovImm,
{Operand::Reg(target), Operand::Imm(constant->GetValue())});
return;
} }
Operand NewG(int bytes = 4) { return mf_->NewVRegOp(RegClass::GPR, bytes); }
Operand NewF() { return mf_->NewVRegOp(RegClass::FPR, 4); }
void LowerFunction(const ir::Function& f);
void LowerBlock(const ir::BasicBlock& bb);
void LowerInst(const ir::Instruction& inst);
void LowerInstMem(const ir::Instruction& inst);
Operand GetReg(const ir::Value* v); // 取值(必要时物化常量/地址)
Operand MaterializeInt(int v);
Operand AddressOf(const ir::Value* ptr);
long long GepConst(const ir::GepInst& gep, Operand* out_base);
void LowerGlobals();
Cond ICmpCond(ir::ICmpPredicate p);
Cond FCmpCond(ir::FCmpPredicate p);
};
auto it = slots.find(value); void Lowerer::Run() {
if (it == slots.end()) { LowerGlobals();
throw std::runtime_error( for (const auto& f : ir_.GetFunctions()) {
FormatError("mir", "找不到值对应的栈槽: " + value->GetName())); if (f->IsDeclaration()) continue;
LowerFunction(*f);
} }
}
block.Append(Opcode::LoadStack, void Lowerer::LowerGlobals() {
{Operand::Reg(target), Operand::FrameIndex(it->second)}); for (const auto& g : ir_.GetGlobals()) {
MachineGlobal mg;
mg.name = g->GetName();
const ir::Type& vt = *g->GetValueType();
mg.size = TypeSize(vt);
mg.align = vt.IsArray() ? 16 : 4;
mg.is_const = g->IsConst();
ir::ConstantValue* init = g->GetInitializer();
mg.zero_init = true;
int nwords = (mg.size + 3) / 4;
mg.words.assign(nwords, 0u);
// 收集初始化位(递归展开数组)。
std::vector<unsigned> flat;
std::function<void(ir::ConstantValue*)> walk = [&](ir::ConstantValue* c) {
if (!c) return;
if (auto* ci = dynamic_cast<ir::ConstantInt*>(c)) {
flat.push_back(static_cast<unsigned>(ci->GetValue()));
} else if (auto* cf = dynamic_cast<ir::ConstantFloat*>(c)) {
float v = cf->GetValue();
unsigned bits;
std::memcpy(&bits, &v, 4);
flat.push_back(bits);
} else if (auto* ca = dynamic_cast<ir::ConstantArray*>(c)) {
for (auto* e : ca->GetElements()) walk(e);
}
};
walk(init);
for (size_t i = 0; i < flat.size() && i < mg.words.size(); ++i) {
mg.words[i] = flat[i];
if (flat[i] != 0) mg.zero_init = false;
}
out_.Globals().push_back(std::move(mg));
}
} }
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, void Lowerer::LowerFunction(const ir::Function& f) {
ValueSlotMap& slots) { out_.Functions().push_back(std::make_unique<MachineFunction>(f.GetName()));
auto& block = function.GetEntry(); mf_ = out_.Functions().back().get();
vmap_.clear();
bmap_.clear();
allocas_.clear();
switch (inst.GetOpcode()) { for (const auto& bb : f.GetBlocks()) {
case ir::Opcode::Alloca: { bmap_[bb.get()] = mf_->CreateBlock(bb->GetName());
slots.emplace(&inst, function.CreateFrameIndex()); }
return; // 记录后继,便于活跃性分析。
} for (const auto& bb : f.GetBlocks()) {
case ir::Opcode::Store: { MachineBasicBlock* mb = bmap_[bb.get()];
auto& store = static_cast<const ir::StoreInst&>(inst); for (auto* s : bb->GetSuccessors()) mb->Succs().push_back(bmap_[s]);
auto dst = slots.find(store.GetPtr()); }
if (dst == slots.end()) {
throw std::runtime_error( mbb_ = bmap_[f.GetEntry()];
FormatError("mir", "暂不支持对非栈变量地址进行写入"));
} // 形参:整型走 x0.., 浮点走 s0..,超过 8 个的从栈读取(测试未用,简化)。
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block); int ig = 0, fg = 0;
block.Append(Opcode::StoreStack, std::vector<MachineInstr> arg_copies;
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)}); for (size_t i = 0; i < f.GetNumArgs(); ++i) {
return; ir::Argument* a = const_cast<ir::Function&>(f).GetArg(i);
} const ir::Type& at = *a->GetType();
case ir::Opcode::Load: { if (at.IsFloat()) {
auto& load = static_cast<const ir::LoadInst&>(inst); Operand dst = NewF();
auto src = slots.find(load.GetPtr()); arg_copies.push_back(MachineInstr(
if (src == slots.end()) { Opcode::FMov,
throw std::runtime_error( {dst, Operand::PReg(fg++, RegClass::FPR, 4)}, 1));
FormatError("mir", "暂不支持对非栈变量地址进行读取")); vmap_[a] = dst;
} } else {
int dst_slot = function.CreateFrameIndex(); int bytes = at.IsPointer() ? 8 : 4;
block.Append(Opcode::LoadStack, Operand dst = NewG(bytes);
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)}); arg_copies.push_back(MachineInstr(
block.Append(Opcode::StoreStack, Opcode::Mov, {dst, Operand::PReg(ig++, RegClass::GPR, bytes)}, 1));
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); vmap_[a] = dst;
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Add: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex();
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot);
return;
}
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block);
block.Append(Opcode::Ret);
return;
} }
case ir::Opcode::Sub:
case ir::Opcode::Mul:
throw std::runtime_error(FormatError("mir", "暂不支持该二元运算"));
} }
mf_->SetArgCounts(ig, fg);
for (auto& mi : arg_copies) mbb_->Add(std::move(mi));
throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令")); for (const auto& bb : f.GetBlocks()) {
mbb_ = bmap_[bb.get()];
LowerBlock(*bb);
}
} }
} // namespace void Lowerer::LowerBlock(const ir::BasicBlock& bb) {
for (const auto& inst : bb.GetInstructions()) {
LowerInst(*inst);
}
}
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module) { Operand Lowerer::MaterializeInt(int v) {
DefaultContext(); Operand d = NewG(4);
Emit(Opcode::MovImm, {d, Operand::Imm(v)}, 1);
return d;
}
if (module.GetFunctions().size() != 1) { Operand Lowerer::GetReg(const ir::Value* v) {
throw std::runtime_error(FormatError("mir", "暂不支持多个函数")); auto it = vmap_.find(v);
if (it != vmap_.end()) return it->second;
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(v)) {
return MaterializeInt(ci->GetValue());
}
if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(v)) {
float f = cf->GetValue();
unsigned bits;
std::memcpy(&bits, &f, 4);
Operand d = NewF();
Emit(Opcode::FMovImm, {d, Operand::Imm(bits)}, 1);
return d;
} }
// 兜底:未知值视为 0。
return MaterializeInt(0);
}
const auto& func = *module.GetFunctions().front(); // 计算 gep 的常量字节偏移;返回偏移并把基址写入 *out_base。
if (func.GetName() != "main") { // 若存在变量下标,直接生成地址计算指令并把最终地址放入 *out_base、返回 0。
throw std::runtime_error(FormatError("mir", "暂不支持非 main 函数")); long long Lowerer::GepConst(const ir::GepInst& gep, Operand* out_base) {
Operand base = AddressOf(gep.GetBasePtr());
// 推断逐层元素类型:基址指针指向的类型。
std::shared_ptr<ir::Type> cur = gep.GetBasePtr()->GetType()->GetElementType();
long long const_off = 0;
Operand addr = base;
bool addr_dirty = false;
const auto& idxs = gep.GetIndices();
for (size_t i = 0; i < idxs.size(); ++i) {
int elem_size = cur ? TypeSize(*cur) : 4;
ir::Value* iv = idxs[i];
if (auto* ci = dynamic_cast<ir::ConstantInt*>(iv)) {
const_off += static_cast<long long>(ci->GetValue()) * elem_size;
} else {
// addr += index * elem_size
Operand idx = GetReg(iv);
Operand idx64 = NewG(8);
Emit(Opcode::Sxtw, {idx64, idx}, 1);
Operand scaled = NewG(8);
if (IsPow2(elem_size)) {
Emit(Opcode::LslImm, {scaled, idx64, Operand::Imm(Log2(elem_size))}, 1);
} else {
Operand sz = NewG(8);
Emit(Opcode::MovImm, {sz, Operand::Imm(elem_size)}, 1);
Emit(Opcode::Mul, {scaled, idx64, sz}, 1);
}
Operand na = NewG(8);
Emit(Opcode::Add, {na, addr, scaled}, 1);
addr = na;
addr_dirty = true;
}
if (cur && cur->IsArray()) cur = cur->GetElementType();
} }
if (addr_dirty) {
*out_base = addr;
return const_off;
}
*out_base = base;
return const_off;
}
auto machine_func = std::make_unique<MachineFunction>(func.GetName()); // 返回某个指针型 IR 值对应的“地址”寄存器。
ValueSlotMap slots; Operand Lowerer::AddressOf(const ir::Value* ptr) {
const auto* entry = func.GetEntry(); auto it = vmap_.find(ptr);
if (!entry) { if (it != vmap_.end()) return it->second;
throw std::runtime_error(FormatError("mir", "IR 函数缺少入口基本块")); if (auto* a = dynamic_cast<const ir::AllocaInst*>(ptr)) {
int idx;
auto ai = allocas_.find(a);
if (ai == allocas_.end()) {
idx = mf_->CreateStackObject(TypeSize(*a->GetAllocatedType()),
a->GetAllocatedType()->IsArray() ? 16 : 4);
allocas_[a] = idx;
} else {
idx = ai->second;
}
Operand d = NewG(8);
Emit(Opcode::AddrFrame, {d, Operand::Frame(idx)}, 1);
vmap_[ptr] = d;
return d;
} }
// 全局变量
Operand d = NewG(8);
Emit(Opcode::AddrGlobal, {d, Operand::Global(ptr->GetName())}, 1);
return d;
}
for (const auto& inst : entry->GetInstructions()) { Cond Lowerer::ICmpCond(ir::ICmpPredicate p) {
LowerInstruction(*inst, *machine_func, slots); switch (p) {
case ir::ICmpPredicate::Eq: return Cond::EQ;
case ir::ICmpPredicate::Ne: return Cond::NE;
case ir::ICmpPredicate::Slt: return Cond::LT;
case ir::ICmpPredicate::Sle: return Cond::LE;
case ir::ICmpPredicate::Sgt: return Cond::GT;
case ir::ICmpPredicate::Sge: return Cond::GE;
} }
return Cond::EQ;
}
return machine_func; Cond Lowerer::FCmpCond(ir::FCmpPredicate p) {
switch (p) {
case ir::FCmpPredicate::Oeq: return Cond::EQ;
case ir::FCmpPredicate::One: return Cond::NE;
case ir::FCmpPredicate::Olt: return Cond::MI;
case ir::FCmpPredicate::Ole: return Cond::LS;
case ir::FCmpPredicate::Ogt: return Cond::GT;
case ir::FCmpPredicate::Oge: return Cond::GE;
}
return Cond::EQ;
}
void Lowerer::LowerInst(const ir::Instruction& inst) {
using ir::Opcode;
switch (inst.GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::SDiv:
case Opcode::SRem: {
auto& b = static_cast<const ir::BinaryInst&>(inst);
Operand l = GetReg(b.GetLhs());
Operand r = GetReg(b.GetRhs());
Operand d = NewG(4);
mir::Opcode mop = mir::Opcode::Add;
if (inst.GetOpcode() == Opcode::Add) mop = mir::Opcode::Add;
else if (inst.GetOpcode() == Opcode::Sub) mop = mir::Opcode::Sub;
else if (inst.GetOpcode() == Opcode::Mul) mop = mir::Opcode::Mul;
else mop = mir::Opcode::SDiv;
if (inst.GetOpcode() == Opcode::SRem) {
Operand q = NewG(4);
Emit(mir::Opcode::SDiv, {q, l, r}, 1);
Emit(mir::Opcode::MSub, {d, q, r, l}, 1); // d = l - q*r
} else {
Emit(mop, {d, l, r}, 1);
}
vmap_[&inst] = d;
break;
}
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv: {
auto& b = static_cast<const ir::BinaryInst&>(inst);
Operand l = GetReg(b.GetLhs());
Operand r = GetReg(b.GetRhs());
Operand d = NewF();
mir::Opcode mop = mir::Opcode::FAdd;
if (inst.GetOpcode() == Opcode::FSub) mop = mir::Opcode::FSub;
else if (inst.GetOpcode() == Opcode::FMul) mop = mir::Opcode::FMul;
else if (inst.GetOpcode() == Opcode::FDiv) mop = mir::Opcode::FDiv;
Emit(mop, {d, l, r}, 1);
vmap_[&inst] = d;
break;
}
case Opcode::SIToFP: {
auto& c = static_cast<const ir::CastInst&>(inst);
Operand s = GetReg(c.GetValue());
Operand d = NewF();
Emit(mir::Opcode::SCvtF, {d, s}, 1);
vmap_[&inst] = d;
break;
}
case Opcode::FPToSI: {
auto& c = static_cast<const ir::CastInst&>(inst);
Operand s = GetReg(c.GetValue());
Operand d = NewG(4);
Emit(mir::Opcode::FCvtZS, {d, s}, 1);
vmap_[&inst] = d;
break;
}
case Opcode::ZExt: {
auto& c = static_cast<const ir::CastInst&>(inst);
Operand s = GetReg(c.GetValue());
Operand d = NewG(4);
Emit(mir::Opcode::Mov, {d, s}, 1); // i1->i32cset 已产出 0/1
vmap_[&inst] = d;
break;
}
case Opcode::ICmp: {
auto& c = static_cast<const ir::ICmpInst&>(inst);
Operand l = GetReg(c.GetLhs());
Operand r = GetReg(c.GetRhs());
Emit(mir::Opcode::Cmp, {l, r}, 0);
Operand d = NewG(4);
Emit(mir::Opcode::CSet, {d}, 1, ICmpCond(c.GetPredicate()));
vmap_[&inst] = d;
break;
}
case Opcode::FCmp: {
auto& c = static_cast<const ir::FCmpInst&>(inst);
Operand l = GetReg(c.GetLhs());
Operand r = GetReg(c.GetRhs());
Emit(mir::Opcode::FCmp, {l, r}, 0);
Operand d = NewG(4);
Emit(mir::Opcode::CSet, {d}, 1, FCmpCond(c.GetPredicate()));
vmap_[&inst] = d;
break;
}
default:
LowerInstMem(inst);
break;
}
}
void Lowerer::LowerInstMem(const ir::Instruction& inst) {
using ir::Opcode;
switch (inst.GetOpcode()) {
case Opcode::Alloca: {
auto& a = static_cast<const ir::AllocaInst&>(inst);
int idx = mf_->CreateStackObject(TypeSize(*a.GetAllocatedType()),
a.GetAllocatedType()->IsArray() ? 16 : 4);
allocas_[&a] = idx;
Operand d = NewG(8);
Emit(mir::Opcode::AddrFrame, {d, Operand::Frame(idx)}, 1);
vmap_[&inst] = d;
break;
}
case Opcode::Load: {
auto& ld = static_cast<const ir::LoadInst&>(inst);
Operand base;
long long off = 0;
if (auto* gep = dynamic_cast<const ir::GepInst*>(ld.GetPtr())) {
off = GepConst(*gep, &base);
} else {
base = AddressOf(ld.GetPtr());
}
bool is_f = inst.GetType()->IsFloat();
Operand d = is_f ? NewF() : NewG(BytesOf(*inst.GetType()));
Emit(mir::Opcode::Ldr, {d, base, Operand::Imm(off)}, 1);
vmap_[&inst] = d;
break;
}
case Opcode::Store: {
auto& st = static_cast<const ir::StoreInst&>(inst);
Operand val = GetReg(st.GetValue());
Operand base;
long long off = 0;
if (auto* gep = dynamic_cast<const ir::GepInst*>(st.GetPtr())) {
off = GepConst(*gep, &base);
} else {
base = AddressOf(st.GetPtr());
}
Emit(mir::Opcode::Str, {val, base, Operand::Imm(off)}, 0);
break;
}
case Opcode::Gep: {
auto& gep = static_cast<const ir::GepInst&>(inst);
Operand base;
long long off = GepConst(gep, &base);
Operand d = NewG(8);
if (off == 0) {
Emit(mir::Opcode::Mov, {d, base}, 1);
} else {
Emit(mir::Opcode::AddImm, {d, base, Operand::Imm(off)}, 1);
}
vmap_[&inst] = d;
break;
}
case Opcode::Call: {
auto& call = static_cast<const ir::CallInst&>(inst);
// 先把所有实参算入虚拟寄存器,再连续搬入物理参数寄存器,
// 避免计算后续实参时分配器复用 x0..x7 破坏已就绪的参数。
std::vector<Operand> vals;
for (auto* arg : call.GetArgs()) vals.push_back(GetReg(arg));
int ig = 0, fg = 0;
std::vector<Operand> arg_uses;
for (size_t i = 0; i < call.GetArgs().size(); ++i) {
ir::Value* arg = call.GetArgs()[i];
if (arg->GetType()->IsFloat()) {
Operand p = Operand::PReg(fg++, RegClass::FPR, 4);
Emit(mir::Opcode::FMov, {p, vals[i]}, 1);
arg_uses.push_back(p);
} else {
int bytes = arg->GetType()->IsPointer() ? 8 : 4;
Operand p = Operand::PReg(ig++, RegClass::GPR, bytes);
Emit(mir::Opcode::Mov, {p, vals[i]}, 1);
arg_uses.push_back(p);
}
}
std::vector<Operand> ops;
ops.push_back(Operand::Global(call.GetCallee()->GetName()));
for (auto& u : arg_uses) ops.push_back(u);
Emit(mir::Opcode::Bl, ops, 0);
if (!inst.GetType()->IsVoid()) {
if (inst.GetType()->IsFloat()) {
Operand d = NewF();
Emit(mir::Opcode::FMov, {d, Operand::PReg(0, RegClass::FPR, 4)}, 1);
vmap_[&inst] = d;
} else {
Operand d = NewG(BytesOf(*inst.GetType()));
Emit(mir::Opcode::Mov,
{d, Operand::PReg(0, RegClass::GPR, d.GetBytes())}, 1);
vmap_[&inst] = d;
}
}
break;
}
case Opcode::Br: {
auto& br = static_cast<const ir::BranchInst&>(inst);
Emit(mir::Opcode::B, {Operand::Label(br.GetDest()->GetName())}, 0);
break;
}
case Opcode::CondBr: {
auto& cbr = static_cast<const ir::CondBrInst&>(inst);
Operand c = GetReg(cbr.GetCond());
Emit(mir::Opcode::CmpImm, {c, Operand::Imm(0)}, 0);
Emit(mir::Opcode::BCond, {Operand::Label(cbr.GetTrueDest()->GetName())}, 0,
Cond::NE);
Emit(mir::Opcode::B, {Operand::Label(cbr.GetFalseDest()->GetName())}, 0);
break;
}
case Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
if (ret.HasReturnValue()) {
Operand v = GetReg(ret.GetValue());
if (ret.GetValue()->GetType()->IsFloat()) {
Emit(mir::Opcode::FMov, {Operand::PReg(0, RegClass::FPR, 4), v}, 1);
} else {
Emit(mir::Opcode::Mov,
{Operand::PReg(0, RegClass::GPR, v.GetBytes()), v}, 1);
}
}
Emit(mir::Opcode::Ret, {}, 0);
break;
}
default:
break;
}
}
} // namespace
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module) {
auto out = std::make_unique<MachineModule>();
Lowerer lo(module, *out);
lo.Run();
return out;
} }
} // namespace mir } // namespace mir

@ -1,16 +1,2 @@
// 机器基本块:实现已并入头文件,本文件仅保留 TU 占位。
#include "mir/MIR.h" #include "mir/MIR.h"
#include <utility>
namespace mir {
MachineBasicBlock::MachineBasicBlock(std::string name)
: name_(std::move(name)) {}
MachineInstr& MachineBasicBlock::Append(Opcode opcode,
std::initializer_list<Operand> operands) {
instructions_.emplace_back(opcode, std::vector<Operand>(operands));
return instructions_.back();
}
} // namespace mir

@ -1,10 +1,2 @@
// 机器上下文:实现已并入头文件,本文件仅保留 TU 占位。
#include "mir/MIR.h" #include "mir/MIR.h"
namespace mir {
MIRContext& DefaultContext() {
static MIRContext ctx;
return ctx;
}
} // namespace mir

@ -1,33 +1,2 @@
// 机器函数:实现已并入头文件,本文件仅保留 TU 占位。
#include "mir/MIR.h" #include "mir/MIR.h"
#include <stdexcept>
#include <utility>
#include "utils/Log.h"
namespace mir {
MachineFunction::MachineFunction(std::string name)
: name_(std::move(name)), entry_("entry") {}
int MachineFunction::CreateFrameIndex(int size) {
int index = static_cast<int>(frame_slots_.size());
frame_slots_.push_back(FrameSlot{index, size, 0});
return index;
}
FrameSlot& MachineFunction::GetFrameSlot(int index) {
if (index < 0 || index >= static_cast<int>(frame_slots_.size())) {
throw std::runtime_error(FormatError("mir", "非法 FrameIndex"));
}
return frame_slots_[index];
}
const FrameSlot& MachineFunction::GetFrameSlot(int index) const {
if (index < 0 || index >= static_cast<int>(frame_slots_.size())) {
throw std::runtime_error(FormatError("mir", "非法 FrameIndex"));
}
return frame_slots_[index];
}
} // namespace mir

@ -1,23 +1,2 @@
// 机器指令:实现已并入头文件,本文件仅保留 TU 占位。
#include "mir/MIR.h" #include "mir/MIR.h"
#include <utility>
namespace mir {
Operand::Operand(Kind kind, PhysReg reg, int imm)
: kind_(kind), reg_(reg), imm_(imm) {}
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); }
Operand Operand::Imm(int value) {
return Operand(Kind::Imm, PhysReg::W0, value);
}
Operand Operand::FrameIndex(int index) {
return Operand(Kind::FrameIndex, PhysReg::W0, index);
}
MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands)
: opcode_(opcode), operands_(std::move(operands)) {}
} // namespace mir

@ -1,36 +1,354 @@
// 寄存器分配Lab5线性扫描 + 活跃区间。
//
// 物理寄存器约定:
// GPR: x0-x8 参数/返回(不参与分配)x9-x12 可分配(caller-saved)
// x13-x15 spill 暂存x16-x17 汇编寻址暂存x18 平台保留,
// x19-x28 可分配(callee-saved)x29/x30 fp/lrx31 sp。
// FPR: s0-s7 参数/返回s8-s15 可分配(callee-saved)
// s16-s28 可分配(caller-saved)s29-s31 spill 暂存。
//
// 跨调用活跃的虚拟寄存器只能落在 callee-saved 寄存器或被 spill。
#include "mir/MIR.h" #include "mir/MIR.h"
#include <stdexcept> #include <algorithm>
#include <unordered_map>
#include "utils/Log.h" #include <unordered_set>
#include <vector>
namespace mir { namespace mir {
namespace { namespace {
bool IsAllowedReg(PhysReg reg) { const std::vector<int>& CallerGPR() {
switch (reg) { static const std::vector<int> v{9, 10, 11};
case PhysReg::W0: return v;
case PhysReg::W8: }
case PhysReg::W9: const std::vector<int>& CalleeGPR() {
case PhysReg::X29: static const std::vector<int> v{19, 20, 21, 22, 23, 24, 25, 26, 27, 28};
case PhysReg::X30: return v;
case PhysReg::SP: }
return true; const std::vector<int>& CalleeFPR() {
static const std::vector<int> v{8, 9, 10, 11, 12, 13, 14, 15};
return v;
}
const std::vector<int>& CallerFPR() {
static const std::vector<int> v{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27};
return v;
}
// 4 个暂存寄存器:覆盖单条指令最多 3 use + 1 def 同时为 spill 的情形。
const int kGprSpillScratch[4] = {12, 13, 14, 15};
const int kFprSpillScratch[4] = {28, 29, 30, 31};
bool IsCalleeSavedGPR(int id) { return id >= 19 && id <= 28; }
bool IsCalleeSavedFPR(int id) { return id >= 8 && id <= 15; }
bool IsCalleeSavedGPRorFPR(int id, RegClass cls) {
return cls == RegClass::GPR ? IsCalleeSavedGPR(id) : IsCalleeSavedFPR(id);
}
// 提取指令的 def/use 寄存器(仅虚拟寄存器,按操作数顺序)。
void Defs(const MachineInstr& mi, std::vector<int>* out) {
for (int i = 0; i < mi.num_defs && i < (int)mi.ops.size(); ++i)
if (mi.ops[i].IsVReg()) out->push_back(mi.ops[i].GetId());
}
void Uses(const MachineInstr& mi, std::vector<int>* out) {
for (int i = mi.num_defs; i < (int)mi.ops.size(); ++i)
if (mi.ops[i].IsVReg()) out->push_back(mi.ops[i].GetId());
}
struct Interval {
int vreg;
int start;
int end;
RegClass cls;
bool cross_call = false;
int preg = -1; // 分配到的物理寄存器;-1 表示 spill
int spill_slot = -1;
};
// 把所有基本块按顺序展开成全局指令编号,并记录每块的 [lo,hi]。
struct Numbering {
std::vector<MachineInstr*> instrs;
std::vector<MachineBasicBlock*> owner;
std::unordered_map<MachineBasicBlock*, std::pair<int, int>> range;
std::vector<int> call_sites; // Bl 指令的全局编号
};
Numbering NumberInstrs(MachineFunction& f) {
Numbering n;
for (auto& bb : f.Blocks()) {
int lo = (int)n.instrs.size();
for (auto& mi : bb->Instrs()) {
if (mi.op == Opcode::Bl) n.call_sites.push_back((int)n.instrs.size());
n.instrs.push_back(&mi);
n.owner.push_back(bb.get());
}
int hi = (int)n.instrs.size() - 1;
if (hi < lo) hi = lo; // 空块
n.range[bb.get()] = {lo, hi};
}
return n;
}
// 经典反向数据流活跃性分析(按块)。
void ComputeLiveness(
MachineFunction& f,
std::unordered_map<MachineBasicBlock*, std::unordered_set<int>>* live_in,
std::unordered_map<MachineBasicBlock*, std::unordered_set<int>>* live_out) {
std::unordered_map<MachineBasicBlock*, std::unordered_set<int>> use, def;
for (auto& bb : f.Blocks()) {
std::unordered_set<int> u, d;
for (auto& mi : bb->Instrs()) {
std::vector<int> us, ds;
Uses(mi, &us);
Defs(mi, &ds);
for (int r : us)
if (!d.count(r)) u.insert(r);
for (int r : ds) d.insert(r);
}
use[bb.get()] = std::move(u);
def[bb.get()] = std::move(d);
(*live_in)[bb.get()];
(*live_out)[bb.get()];
}
bool changed = true;
while (changed) {
changed = false;
for (auto it = f.Blocks().rbegin(); it != f.Blocks().rend(); ++it) {
MachineBasicBlock* b = it->get();
std::unordered_set<int> out;
for (auto* s : b->Succs())
for (int r : (*live_in)[s]) out.insert(r);
std::unordered_set<int> in = use[b];
for (int r : (*live_out)[b])
if (!def[b].count(r)) in.insert(r);
if (out.size() != (*live_out)[b].size() ||
in.size() != (*live_in)[b].size()) {
changed = true;
}
(*live_out)[b] = std::move(out);
(*live_in)[b] = std::move(in);
}
} }
return false;
} }
} // namespace // 用块级活跃信息 + 块内精确编号构造每个 vreg 的活跃区间。
std::vector<Interval> BuildIntervals(MachineFunction& f, const Numbering& num) {
std::unordered_map<MachineBasicBlock*, std::unordered_set<int>> live_in,
live_out;
ComputeLiveness(f, &live_in, &live_out);
void RunRegAlloc(MachineFunction& function) { int nv = f.NumVRegs();
for (const auto& inst : function.GetEntry().GetInstructions()) { std::vector<int> start(nv, -1), end(nv, -1);
for (const auto& operand : inst.GetOperands()) { auto extend = [&](int v, int p) {
if (operand.GetKind() == Operand::Kind::Reg && if (v < 0) return;
!IsAllowedReg(operand.GetReg())) { if (start[v] == -1 || p < start[v]) start[v] = p;
throw std::runtime_error(FormatError("mir", "寄存器分配失败")); if (p > end[v]) end[v] = p;
};
for (auto& bb : f.Blocks()) {
auto [lo, hi] = num.range.at(bb.get());
// live-in 在块入口活跃。
for (int v : live_in[bb.get()]) extend(v, lo);
// live-out 在块出口活跃。
for (int v : live_out[bb.get()]) extend(v, hi);
int p = lo;
for (auto& mi : bb->Instrs()) {
std::vector<int> us, ds;
Uses(mi, &us);
Defs(mi, &ds);
for (int v : us) extend(v, p);
for (int v : ds) extend(v, p);
++p;
}
}
std::vector<Interval> ivs;
for (int v = 0; v < nv; ++v) {
if (start[v] == -1) continue;
Interval iv;
iv.vreg = v;
iv.start = start[v];
iv.end = end[v];
iv.cls = f.VReg(v).cls;
for (int cs : num.call_sites)
if (cs > iv.start && cs <= iv.end) { // 调用点严格落在区间内
iv.cross_call = true;
break;
}
ivs.push_back(iv);
}
std::sort(ivs.begin(), ivs.end(),
[](const Interval& a, const Interval& b) {
return a.start < b.start;
});
return ivs;
}
// 线性扫描分配;对每个区间设定 preg(>=0) 或标记 spill(preg==-1)。
// 返回 true 表示无需额外 spill 重试本实现一次性完成spill 直接落槽)。
void LinearScan(MachineFunction& f, std::vector<Interval>& ivs) {
// 为两类寄存器分别准备“优先 caller-saved再 callee-saved”的池
// 跨调用区间则只用 callee-saved。
auto run = [&](RegClass cls) {
const std::vector<int>& caller =
cls == RegClass::GPR ? CallerGPR() : CallerFPR();
const std::vector<int>& callee =
cls == RegClass::GPR ? CalleeGPR() : CalleeFPR();
// active已分配且尚未结束的区间按 end 升序。
std::vector<Interval*> active;
std::unordered_set<int> free_regs;
for (int r : caller) free_regs.insert(r);
for (int r : callee) free_regs.insert(r);
auto expire = [&](int point) {
std::vector<Interval*> keep;
for (Interval* a : active) {
if (a->end < point) {
free_regs.insert(a->preg);
} else {
keep.push_back(a);
}
}
active = std::move(keep);
};
auto pick = [&](bool need_callee) -> int {
// 跨调用:仅 callee-saved否则优先 caller-saved。
if (need_callee) {
for (int r : callee)
if (free_regs.count(r)) return r;
return -1;
}
for (int r : caller)
if (free_regs.count(r)) return r;
for (int r : callee)
if (free_regs.count(r)) return r;
return -1;
};
for (auto& iv : ivs) {
if (iv.cls != cls) continue;
expire(iv.start);
int r = pick(iv.cross_call);
if (r == -1) {
// spill在当前区间和 active 里结束最晚者之间选择。
Interval* victim = nullptr;
for (Interval* a : active)
if (a->cls == cls &&
(iv.cross_call ? IsCalleeSavedGPRorFPR(a->preg, cls) : true)) {
if (!victim || a->end > victim->end) victim = a;
}
if (victim && victim->end > iv.end) {
iv.preg = victim->preg;
victim->preg = -1;
victim->spill_slot =
f.CreateStackObject(cls == RegClass::GPR ? 8 : 8, 8);
// 从 active 移除 victim加入当前。
active.erase(std::remove(active.begin(), active.end(), victim),
active.end());
active.push_back(&iv);
} else {
iv.preg = -1;
iv.spill_slot = f.CreateStackObject(8, 8);
}
} else {
free_regs.erase(r);
iv.preg = r;
active.push_back(&iv);
}
std::sort(active.begin(), active.end(),
[](Interval* a, Interval* b) { return a->end < b->end; });
}
};
run(RegClass::GPR);
run(RegClass::FPR);
}
// 把分配结果写回指令spill 的虚拟寄存器用暂存寄存器搬运并落槽。
void Rewrite(MachineFunction& f, const std::vector<Interval>& ivs) {
int nv = f.NumVRegs();
std::vector<int> preg(nv, -1);
std::vector<int> slot(nv, -1);
for (const auto& iv : ivs) {
preg[iv.vreg] = iv.preg;
slot[iv.vreg] = iv.spill_slot;
}
std::unordered_set<int> used_callee_gpr, used_callee_fpr;
for (const auto& iv : ivs) {
if (iv.preg < 0) continue;
if (iv.cls == RegClass::GPR && IsCalleeSavedGPR(iv.preg))
used_callee_gpr.insert(iv.preg);
if (iv.cls == RegClass::FPR && IsCalleeSavedFPR(iv.preg))
used_callee_fpr.insert(iv.preg);
}
for (auto& bb : f.Blocks()) {
std::vector<MachineInstr> out;
for (auto& mi : bb->Instrs()) {
std::vector<MachineInstr> pre, post;
int gscr = 0, fscr = 0;
std::unordered_map<int, int> assigned; // vreg -> scratch preg
auto scratchFor = [&](const Operand& op) -> int {
int v = op.GetId();
auto it = assigned.find(v);
if (it != assigned.end()) return it->second;
int s = op.GetClass() == RegClass::GPR ? kGprSpillScratch[gscr++]
: kFprSpillScratch[fscr++];
assigned[v] = s;
return s;
};
for (int i = 0; i < (int)mi.ops.size(); ++i) {
Operand& op = mi.ops[i];
if (!op.IsVReg()) continue;
int v = op.GetId();
if (preg[v] >= 0) {
op.SetPReg(preg[v]);
continue;
}
// spilled
int s = scratchFor(op);
bool is_def = i < mi.num_defs;
if (is_def) {
MachineInstr st(Opcode::StrStack,
{Operand::PReg(s, op.GetClass(), op.GetBytes()),
Operand::Frame(slot[v])},
0);
post.push_back(st);
} else {
MachineInstr ld(Opcode::LdrStack,
{Operand::PReg(s, op.GetClass(), op.GetBytes()),
Operand::Frame(slot[v])},
1);
pre.push_back(ld);
}
op.SetPReg(s);
} }
for (auto& p : pre) out.push_back(std::move(p));
out.push_back(mi);
for (auto& p : post) out.push_back(std::move(p));
} }
bb->Instrs() = std::move(out);
} }
for (int r : used_callee_gpr) f.CalleeSavedGPR().push_back(r);
for (int r : used_callee_fpr) f.CalleeSavedFPR().push_back(r);
std::sort(f.CalleeSavedGPR().begin(), f.CalleeSavedGPR().end());
std::sort(f.CalleeSavedFPR().begin(), f.CalleeSavedFPR().end());
}
} // namespace
void RunRegAlloc(MachineFunction& function) {
Numbering num = NumberInstrs(function);
std::vector<Interval> ivs = BuildIntervals(function, num);
LinearScan(function, ivs);
Rewrite(function, ivs);
} }
} // namespace mir } // namespace mir

@ -1,27 +1,45 @@
#include "mir/MIR.h" #include "mir/MIR.h"
#include <stdexcept> namespace mir {
#include "utils/Log.h" MIRContext& DefaultContext() {
static MIRContext ctx;
return ctx;
}
namespace mir { namespace {
// 64 位通用寄存器名x0..x30, sp
const char* kX[33] = {
"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10",
"x11", "x12", "x13", "x14", "x15", "x16", "x17", "x18", "x19", "x20", "x21",
"x22", "x23", "x24", "x25", "x26", "x27", "x28", "x29", "x30", "sp", "xzr"};
const char* kW[33] = {
"w0", "w1", "w2", "w3", "w4", "w5", "w6", "w7", "w8", "w9", "w10",
"w11", "w12", "w13", "w14", "w15", "w16", "w17", "w18", "w19", "w20", "w21",
"w22", "w23", "w24", "w25", "w26", "w27", "w28", "w29", "w30", "wsp", "wzr"};
} // namespace
const char* GPRName(int id, int bytes) {
if (id < 0 || id > 32) return "x0";
return bytes == 8 ? kX[id] : kW[id];
}
const char* PhysRegName(PhysReg reg) { const char* FPRName(int id, int bytes) {
switch (reg) { static char buf[8][8];
case PhysReg::W0: static int slot = 0;
return "w0"; char* b = buf[slot];
case PhysReg::W8: slot = (slot + 1) & 7;
return "w8"; b[0] = (bytes == 8) ? 'd' : 's';
case PhysReg::W9: int n = id;
return "w9"; if (n < 10) {
case PhysReg::X29: b[1] = static_cast<char>('0' + n);
return "x29"; b[2] = '\0';
case PhysReg::X30: } else {
return "x30"; b[1] = static_cast<char>('0' + n / 10);
case PhysReg::SP: b[2] = static_cast<char>('0' + n % 10);
return "sp"; b[3] = '\0';
} }
throw std::runtime_error(FormatError("mir", "未知物理寄存器")); return b;
} }
} // namespace mir } // namespace mir

@ -1,4 +1,4 @@
// MIR Pass 管理: // 后端 Pass 管理:当前后端流水线直接在 RunBackendPipeline 中按
// - 组织后端 pass 的运行顺序PreRA/PostRA/PEI 等阶段) // RegAlloc -> FrameLowering -> Peephole 顺序驱动(见 AsmPrinter.cpp
// - 统一运行 pass 与调试输出(按需要扩展) // 本文件保留占位,便于后续扩展更细粒度的 Pass 调度。
#include "mir/MIR.h"

@ -1,4 +1,84 @@
// 窥孔优化Peephole // 后端局部窥孔优化Lab5
// - 删除冗余 move、合并常见指令模式 // 在寄存器分配 + 物理寄存器落地之后运行,针对最终机器指令序列做局部清理:
// - 提升最终汇编质量(按实现范围裁剪) // 1. 删除自拷贝 mov xN, xN / fmov sN, sN
// 2. 删除恒等运算 add/sub xD, xS, #0必要时退化为 mov
// 3. 删除相邻冗余访存str R,[B,#o] 紧跟 ldr R,[B,#o] 时删除 ldr
// 4. 删除写入同一目标且中间无使用的连续 mov保留最后一条
#include "mir/MIR.h"
#include <vector>
namespace mir {
namespace {
bool SameReg(const Operand& a, const Operand& b) {
return a.IsPReg() && b.IsPReg() && a.GetId() == b.GetId() &&
a.GetClass() == b.GetClass();
}
bool IsSelfMove(const MachineInstr& mi) {
if (mi.op == Opcode::Mov || mi.op == Opcode::FMov)
return mi.ops.size() == 2 && SameReg(mi.ops[0], mi.ops[1]);
return false;
}
bool IsIdentityAddSub(const MachineInstr& mi) {
if ((mi.op == Opcode::AddImm || mi.op == Opcode::SubImm) &&
mi.ops.size() == 3 && mi.ops[2].GetImm() == 0)
return true;
return false;
}
// str R,[B,#o] 之后立即 ldr R,[B,#o]同寄存器、同基址、同偏移ldr 冗余。
bool RedundantLoadAfterStore(const MachineInstr& st, const MachineInstr& ld) {
if (st.op != Opcode::Str || ld.op != Opcode::Ldr) return false;
return SameReg(st.ops[0], ld.ops[0]) && SameReg(st.ops[1], ld.ops[1]) &&
st.ops[2].GetImm() == ld.ops[2].GetImm();
}
bool RedundantStackLoadAfterStore(const MachineInstr& st,
const MachineInstr& ld) {
if (st.op != Opcode::StrStack || ld.op != Opcode::LdrStack) return false;
return SameReg(st.ops[0], ld.ops[0]) &&
st.ops[1].GetFrame() == ld.ops[1].GetFrame();
}
void OptimizeBlock(MachineBasicBlock& bb) {
std::vector<MachineInstr> out;
out.reserve(bb.Instrs().size());
for (auto& mi : bb.Instrs()) {
if (IsSelfMove(mi)) continue;
if (IsIdentityAddSub(mi)) {
// add/sub D, S, #0 -> mov D, S若 D==S 直接删)。
if (SameReg(mi.ops[0], mi.ops[1])) continue;
MachineInstr mv(Opcode::Mov, {mi.ops[0], mi.ops[1]}, 1);
out.push_back(std::move(mv));
continue;
}
if (!out.empty()) {
const MachineInstr& prev = out.back();
if (RedundantLoadAfterStore(prev, mi) ||
RedundantStackLoadAfterStore(prev, mi)) {
continue; // 删除冗余 load
}
}
out.push_back(mi);
}
bb.Instrs() = std::move(out);
}
} // namespace
void RunPeephole(MachineFunction& function) {
// 迭代到不动点,确保删除后新暴露的模式也被处理。
bool again = true;
int guard = 0;
while (again && guard++ < 8) {
size_t before = 0, after = 0;
for (auto& bb : function.Blocks()) before += bb->Instrs().size();
for (auto& bb : function.Blocks()) OptimizeBlock(*bb);
for (auto& bb : function.Blocks()) after += bb->Instrs().size();
again = (after < before);
}
}
} // namespace mir

@ -3,6 +3,7 @@
#include <any> #include <any>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <unordered_set>
#include "SysYBaseVisitor.h" #include "SysYBaseVisitor.h"
#include "sem/SymbolTable.h" #include "sem/SymbolTable.h"
@ -10,185 +11,500 @@
namespace { namespace {
std::string GetLValueName(SysYParser::LValueContext& lvalue) { static BaseTypeKind BaseTypeFromBType(SysYParser::BTypeContext* ctx) {
if (!lvalue.ID()) { if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法左值")); throw std::runtime_error(FormatError("sema", "缺少 bType"));
} }
return lvalue.ID()->getText(); if (ctx->INT()) return BaseTypeKind::Int;
if (ctx->FLOAT()) return BaseTypeKind::Float;
throw std::runtime_error(FormatError("sema", "未知基础类型"));
} }
static BaseTypeKind BaseTypeFromFuncType(SysYParser::FuncTypeContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少 funcType"));
}
if (ctx->VOID()) return BaseTypeKind::Void;
if (ctx->INT()) return BaseTypeKind::Int;
if (ctx->FLOAT()) return BaseTypeKind::Float;
throw std::runtime_error(FormatError("sema", "未知函数返回类型"));
}
class ConstEvalVisitor final : public SysYBaseVisitor {
public:
explicit ConstEvalVisitor(const SymbolTable& table) : table_(table) {}
std::any visitConstExp(SysYParser::ConstExpContext* ctx) override {
return visitAddExp(ctx->addExp());
}
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
auto muls = ctx->mulExp();
if (muls.empty()) return 0;
int value = std::any_cast<int>(muls[0]->accept(this));
for (size_t i = 1; i < muls.size(); ++i) {
int rhs = std::any_cast<int>(muls[i]->accept(this));
auto* node = ctx->children.at(2 * i - 1);
auto text = node ? node->getText() : "+";
if (text == "+") {
value += rhs;
} else if (text == "-") {
value -= rhs;
} else {
throw std::runtime_error(FormatError("sema", "非法加法运算符"));
}
}
return value;
}
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
auto unaries = ctx->unaryExp();
if (unaries.empty()) return 0;
int value = std::any_cast<int>(unaries[0]->accept(this));
for (size_t i = 1; i < unaries.size(); ++i) {
int rhs = std::any_cast<int>(unaries[i]->accept(this));
auto* node = ctx->children.at(2 * i - 1);
auto text = node ? node->getText() : "*";
if (text == "*") {
value *= rhs;
} else if (text == "/") {
value /= rhs;
} else if (text == "%") {
value %= rhs;
} else {
throw std::runtime_error(FormatError("sema", "非法乘法运算符"));
}
}
return value;
}
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override {
if (ctx->primaryExp()) return ctx->primaryExp()->accept(this);
if (ctx->unaryOp() && ctx->unaryExp()) {
int val = std::any_cast<int>(ctx->unaryExp()->accept(this));
auto op = ctx->unaryOp()->getText();
if (op == "+") return val;
if (op == "-") return -val;
throw std::runtime_error(FormatError("sema", "constExp 不支持 !"));
}
throw std::runtime_error(FormatError("sema", "constExp 不支持函数调用"));
}
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override {
if (ctx->exp()) return ctx->exp()->accept(this);
if (ctx->lVal()) return ctx->lVal()->accept(this);
if (ctx->number()) return ctx->number()->accept(this);
return 0;
}
std::any visitNumber(SysYParser::NumberContext* ctx) override {
if (ctx->INT_CONST()) {
const std::string text = ctx->getText();
size_t idx = 0;
long long val = std::stoll(text, &idx, 0);
if (idx != text.size()) {
throw std::runtime_error(FormatError("sema", "非法整数常量"));
}
return static_cast<int>(val);
}
if (ctx->FLOAT_CONST()) {
return static_cast<int>(std::stof(ctx->getText()));
}
throw std::runtime_error(FormatError("sema", "constExp 仅支持整数"));
}
std::any visitLVal(SysYParser::LValContext* ctx) override {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "constExp 非法变量"));
}
const auto* entry = table_.Lookup(ctx->ID()->getText());
if (!entry || !entry->is_const || !entry->const_value.has_value()) {
throw std::runtime_error(FormatError("sema", "constExp 使用了非常量"));
}
return entry->const_value.value();
}
private:
const SymbolTable& table_;
};
class SemaVisitor final : public SysYBaseVisitor { class SemaVisitor final : public SysYBaseVisitor {
public: public:
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override {
if (!ctx) { if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少编译单元")); throw std::runtime_error(FormatError("sema", "缺少编译单元"));
} }
auto* func = ctx->funcDef(); for (auto* func : ctx->funcDef()) {
if (!func || !func->blockStmt()) { if (!func || !func->ID()) continue;
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); std::string name = func->ID()->getText();
if (func_table_.find(name) != func_table_.end()) {
throw std::runtime_error(FormatError("sema", "重复定义函数: " + name));
}
func_table_[name] = func;
} }
if (!func->ID() || func->ID()->getText() != "main") {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); for (auto* decl : ctx->decl()) {
if (decl) decl->accept(this);
} }
func->accept(this); for (auto* func : ctx->funcDef()) {
if (!seen_return_) { if (func) func->accept(this);
throw std::runtime_error( }
FormatError("sema", "main 函数必须包含 return 语句"));
if (func_table_.find("main") == func_table_.end()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
} }
return {}; return {};
} }
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override {
if (!ctx || !ctx->blockStmt()) { if (!ctx || !ctx->block()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); throw std::runtime_error(FormatError("sema", "函数体为空"));
}
if (!ctx->ID()) {
throw std::runtime_error(FormatError("sema", "缺少函数名"));
} }
if (!ctx->funcType() || !ctx->funcType()->INT()) { FuncTypeDesc fty;
throw std::runtime_error(FormatError("sema", "当前仅支持 int main")); fty.ret.base = BaseTypeFromFuncType(ctx->funcType());
if (ctx->funcFParams()) {
for (auto* param : ctx->funcFParams()->funcFParam()) {
fty.params.push_back(BuildParamType(param));
}
}
sema_.RegisterFunc(ctx, fty);
current_ret_ = fty.ret.base;
seen_return_ = false;
table_.EnterScope();
if (ctx->funcFParams()) {
for (auto* param : ctx->funcFParams()->funcFParam()) {
RegisterParam(param);
}
} }
const auto& items = ctx->blockStmt()->blockItem(); ctx->block()->accept(this);
if (items.empty()) { table_.ExitScope();
throw std::runtime_error(
FormatError("sema", "main 函数不能为空,且必须以 return 结束")); if (current_ret_ != BaseTypeKind::Void && !seen_return_) {
throw std::runtime_error(FormatError("sema", "非 void 函数缺少 return"));
} }
ctx->blockStmt()->accept(this);
return {}; return {};
} }
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override { std::any visitBlock(SysYParser::BlockContext* ctx) override {
if (!ctx) { if (!ctx) return {};
throw std::runtime_error(FormatError("sema", "缺少语句块")); table_.EnterScope();
} for (auto* item : ctx->blockItem()) {
const auto& items = ctx->blockItem(); if (item) item->accept(this);
for (size_t i = 0; i < items.size(); ++i) {
auto* item = items[i];
if (!item) {
continue;
}
if (seen_return_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
}
current_item_index_ = i;
total_items_ = items.size();
item->accept(this);
} }
table_.ExitScope();
return {}; return {};
} }
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override {
if (!ctx) { if (!ctx) return {};
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); if (ctx->decl()) return ctx->decl()->accept(this);
} if (ctx->stmt()) return ctx->stmt()->accept(this);
if (ctx->decl()) { return {};
ctx->decl()->accept(this); }
return {};
std::any visitDecl(SysYParser::DeclContext* ctx) override {
if (!ctx) return {};
if (auto* c = ctx->constDecl()) return c->accept(this);
if (auto* v = ctx->varDecl()) return v->accept(this);
return {};
}
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override {
if (!ctx || !ctx->bType()) return {};
BaseTypeKind base = BaseTypeFromBType(ctx->bType());
for (auto* def : ctx->constDef()) {
RegisterConst(def, base);
} }
if (ctx->stmt()) { return {};
ctx->stmt()->accept(this); }
return {};
std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override {
if (!ctx || !ctx->bType()) return {};
BaseTypeKind base = BaseTypeFromBType(ctx->bType());
for (auto* def : ctx->varDef()) {
RegisterVar(def, base);
} }
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); return {};
} }
std::any visitDecl(SysYParser::DeclContext* ctx) override { std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (!ctx) { if (!ctx) return {};
throw std::runtime_error(FormatError("sema", "非法变量声明")); if (ctx->lVal() && ctx->ASSIGN()) {
ctx->lVal()->accept(this);
if (ctx->exp()) ctx->exp()->accept(this);
return {};
}
if (ctx->block()) return ctx->block()->accept(this);
if (ctx->IF()) {
if (ctx->cond()) ctx->cond()->accept(this);
if (ctx->stmt(0)) ctx->stmt(0)->accept(this);
if (ctx->stmt(1)) ctx->stmt(1)->accept(this);
return {};
} }
if (!ctx->btype() || !ctx->btype()->INT()) { if (ctx->WHILE()) {
throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明")); loop_depth_++;
if (ctx->cond()) ctx->cond()->accept(this);
if (ctx->stmt(0)) ctx->stmt(0)->accept(this);
loop_depth_--;
return {};
} }
auto* var_def = ctx->varDef(); if (ctx->BREAK()) {
if (!var_def || !var_def->lValue()) { if (loop_depth_ == 0) {
throw std::runtime_error(FormatError("sema", "非法变量声明")); throw std::runtime_error(FormatError("sema", "break 不在循环内"));
}
return {};
} }
const std::string name = GetLValueName(*var_def->lValue()); if (ctx->CONTINUE()) {
if (table_.Contains(name)) { if (loop_depth_ == 0) {
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); throw std::runtime_error(FormatError("sema", "continue 不在循环内"));
}
return {};
} }
if (auto* init = var_def->initValue()) { if (ctx->RETURN()) {
if (!init->exp()) { if (ctx->exp()) ctx->exp()->accept(this);
throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化")); if (current_ret_ == BaseTypeKind::Void && ctx->exp()) {
throw std::runtime_error(FormatError("sema", "void 函数不能返回值"));
}
if (current_ret_ != BaseTypeKind::Void && !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "非 void 函数必须返回值"));
} }
init->exp()->accept(this); seen_return_ = true;
return {};
} }
table_.Add(name, var_def); if (ctx->exp()) ctx->exp()->accept(this);
return {}; return {};
} }
std::any visitStmt(SysYParser::StmtContext* ctx) override { std::any visitExp(SysYParser::ExpContext* ctx) override {
if (!ctx || !ctx->returnStmt()) { if (ctx->addExp()) return ctx->addExp()->accept(this);
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
ctx->returnStmt()->accept(this);
return {}; return {};
} }
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override { std::any visitCond(SysYParser::CondContext* ctx) override {
if (!ctx || !ctx->exp()) { if (ctx->lOrExp()) return ctx->lOrExp()->accept(this);
throw std::runtime_error(FormatError("sema", "return 缺少表达式"));
}
ctx->exp()->accept(this);
seen_return_ = true;
if (current_item_index_ + 1 != total_items_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
}
return {}; return {};
} }
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override { std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override {
if (!ctx || !ctx->exp()) { for (auto* e : ctx->lAndExp()) e->accept(this);
throw std::runtime_error(FormatError("sema", "非法括号表达式"));
}
ctx->exp()->accept(this);
return {}; return {};
} }
std::any visitVarExp(SysYParser::VarExpContext* ctx) override { std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override {
if (!ctx || !ctx->var()) { for (auto* e : ctx->eqExp()) e->accept(this);
throw std::runtime_error(FormatError("sema", "非法变量表达式"));
}
ctx->var()->accept(this);
return {}; return {};
} }
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override { std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { for (auto* e : ctx->relExp()) e->accept(this);
throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量")); return {};
}
std::any visitRelExp(SysYParser::RelExpContext* ctx) override {
for (auto* e : ctx->addExp()) e->accept(this);
return {};
}
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
for (auto* mul : ctx->mulExp()) mul->accept(this);
return {};
}
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
for (auto* unary : ctx->unaryExp()) unary->accept(this);
return {};
}
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override {
if (ctx->primaryExp()) return ctx->primaryExp()->accept(this);
if (ctx->ID() && ctx->LPAREN()) {
std::string name = ctx->ID()->getText();
auto it = func_table_.find(name);
if (it == func_table_.end()) {
if (builtin_funcs_.find(name) == builtin_funcs_.end()) {
throw std::runtime_error(FormatError("sema", "未定义的函数: " + name));
}
} else {
sema_.BindFuncCall(ctx, it->second);
}
if (ctx->funcRParams()) ctx->funcRParams()->accept(this);
return {};
} }
if (ctx->unaryExp()) return ctx->unaryExp()->accept(this);
return {}; return {};
} }
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override { std::any visitFuncRParams(SysYParser::FuncRParamsContext* ctx) override {
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { for (auto* e : ctx->exp()) e->accept(this);
throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式")); return {};
}
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override {
if (ctx->exp()) return ctx->exp()->accept(this);
if (ctx->lVal()) return ctx->lVal()->accept(this);
if (ctx->number()) return ctx->number()->accept(this);
return {};
}
std::any visitNumber(SysYParser::NumberContext* ctx) override {
if (!ctx->INT_CONST() && !ctx->FLOAT_CONST()) {
throw std::runtime_error(FormatError("sema", "非法常量"));
} }
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {}; return {};
} }
std::any visitVar(SysYParser::VarContext* ctx) override { std::any visitLVal(SysYParser::LValContext* ctx) override {
if (!ctx || !ctx->ID()) { if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "非法变量引用")); throw std::runtime_error(FormatError("sema", "非法变量引用"));
} }
const std::string name = ctx->ID()->getText(); std::string name = ctx->ID()->getText();
auto* decl = table_.Lookup(name); const SymbolEntry* entry = table_.Lookup(name);
if (!decl) { if (!entry) {
throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name));
} }
sema_.BindVarUse(ctx, decl); BoundDecl bound;
if (entry->kind == SymbolKind::Var) {
bound.kind = BoundDecl::Kind::Var;
bound.var_decl = entry->var_decl;
} else if (entry->kind == SymbolKind::Const) {
bound.kind = BoundDecl::Kind::Const;
bound.const_decl = entry->const_decl;
} else {
bound.kind = BoundDecl::Kind::Param;
bound.param_decl = entry->param_decl;
}
sema_.BindVarUse(ctx, bound);
for (auto* exp : ctx->exp()) {
if (exp) {
exp->accept(this);
}
}
return {}; return {};
} }
SemanticContext TakeSemanticContext() { return std::move(sema_); } SemanticContext TakeSemanticContext() { return std::move(sema_); }
private:
TypeDesc BuildParamType(SysYParser::FuncFParamContext* ctx) {
if (!ctx || !ctx->bType()) {
throw std::runtime_error(FormatError("sema", "非法参数"));
}
TypeDesc ty;
ty.base = BaseTypeFromBType(ctx->bType());
if (ctx->LBRACK().size() > 0) {
ty.dims.push_back(-1);
for (auto* exp : ctx->exp()) {
ty.dims.push_back(EvalConstExp(exp));
}
}
return ty;
}
void RegisterParam(SysYParser::FuncFParamContext* ctx) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "参数缺少名称"));
}
std::string name = ctx->ID()->getText();
if (table_.ContainsInCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义参数: " + name));
}
TypeDesc ty = BuildParamType(ctx);
SymbolEntry entry;
entry.kind = SymbolKind::Param;
entry.param_decl = ctx;
entry.is_const = false;
entry.type = ty;
table_.Add(name, entry);
sema_.RegisterParam(ctx, ty);
}
void RegisterVar(SysYParser::VarDefContext* ctx, BaseTypeKind base) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "变量声明缺少名称"));
}
std::string name = ctx->ID()->getText();
if (table_.ContainsInCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
}
TypeDesc ty;
ty.base = base;
for (auto* dim : ctx->constExp()) {
ty.dims.push_back(EvalConstExp(dim));
}
SymbolEntry entry;
entry.kind = SymbolKind::Var;
entry.var_decl = ctx;
entry.is_const = false;
entry.type = ty;
table_.Add(name, entry);
sema_.RegisterVarDecl(ctx, ty);
if (auto* init = ctx->initVal()) {
init->accept(this);
}
}
void RegisterConst(SysYParser::ConstDefContext* ctx, BaseTypeKind base) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "常量声明缺少名称"));
}
std::string name = ctx->ID()->getText();
if (table_.ContainsInCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义常量: " + name));
}
TypeDesc ty;
ty.base = base;
ty.is_const = true;
for (auto* dim : ctx->constExp()) {
ty.dims.push_back(EvalConstExp(dim));
}
SymbolEntry entry;
entry.kind = SymbolKind::Const;
entry.const_decl = ctx;
entry.is_const = true;
entry.type = ty;
if (ctx->constInitVal() && ty.dims.empty() && ty.base == BaseTypeKind::Int) {
if (auto* exp = ctx->constInitVal()->constExp()) {
entry.const_value = EvalConstExp(exp);
}
}
table_.Add(name, entry);
sema_.RegisterConstDecl(ctx, ty);
if (auto* init = ctx->constInitVal()) {
init->accept(this);
}
}
int EvalConstExp(SysYParser::ConstExpContext* ctx) {
ConstEvalVisitor visitor(table_);
return std::any_cast<int>(ctx->accept(&visitor));
}
int EvalConstExp(SysYParser::ExpContext* ctx) {
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("sema", "非法常量表达式"));
}
ConstEvalVisitor visitor(table_);
return std::any_cast<int>(ctx->addExp()->accept(&visitor));
}
private: private:
SymbolTable table_; SymbolTable table_;
SemanticContext sema_; SemanticContext sema_;
std::unordered_map<std::string, SysYParser::FuncDefContext*> func_table_;
const std::unordered_set<std::string> builtin_funcs_ = {
"getint", "getch", "getarray", "putint", "putch", "putarray",
"getfloat", "getfarray", "putfloat", "putfarray", "starttime",
"stoptime"};
BaseTypeKind current_ret_ = BaseTypeKind::Void;
bool seen_return_ = false; bool seen_return_ = false;
size_t current_item_index_ = 0; int loop_depth_ = 0;
size_t total_items_ = 0;
}; };
} // namespace } // namespace
@ -197,4 +513,4 @@ SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) {
SemaVisitor visitor; SemaVisitor visitor;
comp_unit.accept(&visitor); comp_unit.accept(&visitor);
return visitor.TakeSemanticContext(); return visitor.TakeSemanticContext();
} }

@ -2,16 +2,34 @@
#include "sem/SymbolTable.h" #include "sem/SymbolTable.h"
void SymbolTable::Add(const std::string& name, void SymbolTable::EnterScope() { scopes_.emplace_back(); }
SysYParser::VarDefContext* decl) {
table_[name] = decl; void SymbolTable::ExitScope() {
if (!scopes_.empty()) {
scopes_.pop_back();
}
}
bool SymbolTable::ContainsInCurrentScope(const std::string& name) const {
if (scopes_.empty()) {
return false;
}
return scopes_.back().find(name) != scopes_.back().end();
} }
bool SymbolTable::Contains(const std::string& name) const { void SymbolTable::Add(const std::string& name, const SymbolEntry& entry) {
return table_.find(name) != table_.end(); if (scopes_.empty()) {
EnterScope();
}
scopes_.back()[name] = entry;
} }
SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const { const SymbolEntry* SymbolTable::Lookup(const std::string& name) const {
auto it = table_.find(name); for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
return it == table_.end() ? nullptr : it->second; auto found = it->find(name);
if (found != it->end()) {
return &found->second;
}
}
return nullptr;
} }

@ -15,7 +15,7 @@ CLIOptions ParseCLI(int argc, char** argv) {
if (argc <= 1) { if (argc <= 1) {
throw std::runtime_error(FormatError( throw std::runtime_error(FormatError(
"cli", "cli",
"用法: compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] <input.sy>")); "用法: compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] [--no-opt] <input.sy>"));
} }
for (int i = 1; i < argc; ++i) { for (int i = 1; i < argc; ++i) {
@ -58,6 +58,11 @@ CLIOptions ParseCLI(int argc, char** argv) {
continue; continue;
} }
if (std::strcmp(arg, "--no-opt") == 0) {
opt.optimize_ir = false;
continue;
}
if (arg[0] == '-') { if (arg[0] == '-') {
throw std::runtime_error( throw std::runtime_error(
FormatError("cli", std::string("未知参数: ") + arg + FormatError("cli", std::string("未知参数: ") + arg +

@ -50,13 +50,14 @@ void PrintHelp(std::ostream& os) {
os << "SysY Compiler\n" os << "SysY Compiler\n"
<< "\n" << "\n"
<< "用法:\n" << "用法:\n"
<< " compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] <input.sy>\n" << " compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] [--no-opt] <input.sy>\n"
<< "\n" << "\n"
<< "选项:\n" << "选项:\n"
<< " -h, --help 打印帮助信息并退出\n" << " -h, --help 打印帮助信息并退出\n"
<< " --emit-parse-tree 仅在显式模式下启用语法树输出\n" << " --emit-parse-tree 仅在显式模式下启用语法树输出\n"
<< " --emit-ir 仅在显式模式下启用 IR 输出\n" << " --emit-ir 仅在显式模式下启用 IR 输出\n"
<< " --emit-asm 仅在显式模式下启用 AArch64 汇编输出\n" << " --emit-asm 仅在显式模式下启用 AArch64 汇编输出\n"
<< " --no-opt 关闭 Lab4 IR 标量优化管线\n"
<< "\n" << "\n"
<< "说明:\n" << "说明:\n"
<< " - 默认输出 IR\n" << " - 默认输出 IR\n"

@ -1,4 +1,71 @@
// SysY 运行库实现: #include <stdio.h>
// - 按实验/评测规范提供 I/O 等函数实现
// - 与编译器生成的目标代码链接,支撑运行时行为 int getint() {
int v = 0;
if (scanf("%d", &v) != 1) return 0;
return v;
}
int getch() {
int c = getchar();
if (c == '\r') {
int next = getchar();
if (next != '\n' && next != EOF) {
ungetc(next, stdin);
}
return '\n';
}
return c;
}
int getarray(int a[]) {
int n = 0;
if (scanf("%d", &n) != 1) return 0;
for (int i = 0; i < n; ++i) {
scanf("%d", &a[i]);
}
return n;
}
void putint(int x) { printf("%d", x); }
void putch(int x) { putchar(x); }
void putarray(int n, int a[]) {
printf("%d:", n);
for (int i = 0; i < n; ++i) {
printf(" %d", a[i]);
}
printf("\n");
}
float getfloat() {
float v = 0.0f;
if (scanf("%f", &v) != 1) return 0.0f;
return v;
}
int getfarray(float a[]) {
int n = 0;
if (scanf("%d", &n) != 1) return 0;
for (int i = 0; i < n; ++i) {
scanf("%f", &a[i]);
}
return n;
}
void putfloat(float x) { printf("%a", x); }
void putfarray(int n, float a[]) {
printf("%d:", n);
for (int i = 0; i < n; ++i) {
printf(" %a", a[i]);
}
printf("\n");
}
// Performance timing hooks (no-op stubs for correctness testing).
void starttime() {}
void stoptime() {}

@ -0,0 +1,78 @@
#!/bin/bash
# ================================================
# SysY 编译器 Lab1 批量解析测试脚本
# 文件名scripts/test_parse.sh
# 适用环境Arch Linuxbash 原生支持,无需额外安装)
# 功能:
# - 遍历 test/test_case 下所有 .sy 文件functional + performance
# - 执行 --emit-parse-tree 检查是否能成功解析
# - 输出简洁的 PASS/FAIL 结果 + 统计
# - 错误文件会自动打印最后 10 行报错信息(方便调试)
# - 所有结果保存到 test/test_result/parse_test.log
# ================================================
set -u # 遇到未定义变量直接报错
# ================== 配置 ==================
COMPILER="./build/bin/compiler"
TEST_DIR="test/test_case"
LOG_FILE="test/test_result/parse_test.log"
MAX_ERROR_LINES=10
# 检查编译器是否存在
if [[ ! -x "$COMPILER" ]]; then
echo "❌ 错误:找不到编译器 $COMPILER"
echo " 请先执行 Lab1 构建命令:"
echo " cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=ON"
echo " cmake --build build -j \"\$(nproc)\""
exit 1
fi
# 创建日志目录(如果不存在)
mkdir -p "$(dirname "$LOG_FILE")"
> "$LOG_FILE" # 清空日志
echo "开始 Lab1 批量语法树测试..." | tee -a "$LOG_FILE"
echo "测试目录:$TEST_DIR" | tee -a "$LOG_FILE"
echo "编译器:$COMPILER" | tee -a "$LOG_FILE"
echo "========================================" | tee -a "$LOG_FILE"
pass=0
fail=0
total=0
# 遍历所有 .sy 文件(支持子目录)
while IFS= read -r -d '' sy_file; do
((total++))
echo -n "[$total] 测试: $sy_file ... " | tee -a "$LOG_FILE"
# 执行解析(把输出丢到 /dev/null防止刷屏
if "$COMPILER" --emit-parse-tree "$sy_file" > /dev/null 2>&1; then
echo "✅PASS" | tee -a "$LOG_FILE"
((pass++))
else
echo "FAIL" | tee -a "$LOG_FILE"
((fail++))
# 打印错误信息到日志(最后几行)
echo " └── 错误详情(最后 $MAX_ERROR_LINES 行):" >> "$LOG_FILE"
"$COMPILER" --emit-parse-tree "$sy_file" 2>&1 | tail -n "$MAX_ERROR_LINES" >> "$LOG_FILE"
echo "" >> "$LOG_FILE"
fi
done < <(find "$TEST_DIR" -name "*.sy" -print0 | sort -z)
# ================== 总结 ==================
echo "========================================" | tee -a "$LOG_FILE"
echo "测试完成!" | tee -a "$LOG_FILE"
echo "总文件数 : $total" | tee -a "$LOG_FILE"
echo "通过 : $pass" | tee -a "$LOG_FILE"
echo "失败 : $fail" | tee -a "$LOG_FILE"
if [[ $fail -eq 0 ]]; then
echo "恭喜Lab1 语法树构建全部通过!可以进入 Lab2 啦~" | tee -a "$LOG_FILE"
else
echo "$fail 个文件解析失败,请检查 SysY.g4 或报错日志" | tee -a "$LOG_FILE"
echo " 日志文件:$LOG_FILE" | tee -a "$LOG_FILE"
fi
echo "========================================" | tee -a "$LOG_FILE"

@ -11,6 +11,8 @@
#include "atn/ProfilingATNSimulator.h" #include "atn/ProfilingATNSimulator.h"
#include <chrono>
using namespace antlr4; using namespace antlr4;
using namespace antlr4::atn; using namespace antlr4::atn;
using namespace antlr4::dfa; using namespace antlr4::dfa;

Loading…
Cancel
Save