From 19928c49459a1a775a0e8eb80be9ac5e50f3fd1d Mon Sep 17 00:00:00 2001 From: Junhe Wu <2561075610@qq.com> Date: Fri, 1 May 2026 17:43:05 +0800 Subject: [PATCH] =?UTF-8?q?feat(ir-opt):=20=E5=AE=8C=E6=88=90=E4=BA=86lab4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- doc/Lab4-测试.md | 3 + include/ir/IR.h | 54 ++++ include/utils/CLI.h | 10 +- scripts/bench_ir.sh | 147 +++++++++++ scripts/run_ir_test.sh | 424 +++++++++++++++++++++--------- scripts/verify_ir.sh | 2 +- src/CMakeLists.txt | 1 + src/ir/BasicBlock.cpp | 50 ++++ src/ir/Function.cpp | 82 ++++++ src/ir/IRBuilder.cpp | 8 + src/ir/IRPrinter.cpp | 28 +- src/ir/Instruction.cpp | 19 ++ src/ir/analysis/DominatorTree.cpp | 203 +++++++++++++- src/ir/passes/CFGSimplify.cpp | 140 ++++++++++ src/ir/passes/CSE.cpp | 123 ++++++++- src/ir/passes/ConstFold.cpp | 185 ++++++++++++- src/ir/passes/ConstProp.cpp | 60 ++++- src/ir/passes/DCE.cpp | 66 ++++- src/ir/passes/Mem2Reg.cpp | 188 ++++++++++++- src/ir/passes/PassManager.cpp | 90 ++++++- src/main.cpp | 35 +-- src/utils/CLI.cpp | 54 +++- src/utils/Log.cpp | 25 +- 23 files changed, 1813 insertions(+), 184 deletions(-) create mode 100644 doc/Lab4-测试.md create mode 100755 scripts/bench_ir.sh diff --git a/doc/Lab4-测试.md b/doc/Lab4-测试.md new file mode 100644 index 0000000..52603ee --- /dev/null +++ b/doc/Lab4-测试.md @@ -0,0 +1,3 @@ +bash scripts/run_ir_test.sh --run # 优化模式,计时 +bash scripts/run_ir_test.sh --run --O0 # 无优化,计时 +bash scripts/bench_ir.sh # 同时对比 O0 vs O1 \ No newline at end of file diff --git a/include/ir/IR.h b/include/ir/IR.h index 48b01da..6e7349d 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -162,6 +163,8 @@ enum class Opcode { Gep, // 控制流 Ret, Br, CondBr, + // PHI 节点 + Phi, // 函数调用 Call, // 类型转换 @@ -234,6 +237,7 @@ class Instruction : public User { bool IsTerminator() const; BasicBlock* GetParent() const; void SetParent(BasicBlock* parent); + void RemoveFromParent(); private: Opcode opcode_; @@ -374,6 +378,18 @@ class StoreInst : public Instruction { Value* GetPtr() const; }; +// PHI 节点:在控制流汇合处选择值 +// 操作数布局:[val0, bb0, val1, bb1, ...](偶数下标为值,奇数下标为基本块) +class PhiInst : public Instruction { + public: + PhiInst(std::shared_ptr ty, std::string name); + void AddIncoming(Value* val, BasicBlock* bb); + size_t GetNumIncoming() const { return GetNumOperands() / 2; } + Value* GetIncomingValue(size_t i) const { return GetOperand(i * 2); } + BasicBlock* GetIncomingBlock(size_t i) const; + void SetIncomingValue(size_t i, Value* val) { SetOperand(i * 2, val); } +}; + // ─── BasicBlock ─────────────────────────────────────────────────────────────── class BasicBlock : public Value { public: @@ -384,6 +400,16 @@ class BasicBlock : public Value { const std::vector>& GetInstructions() const; const std::vector& GetPredecessors() const; const std::vector& GetSuccessors() const; + void AddPredecessor(BasicBlock* bb); + void RemovePredecessor(BasicBlock* bb); + void ClearPredecessors(); + void AddSuccessor(BasicBlock* bb); + void ClearSuccessors(); + + // 指令管理 + void RemoveInstruction(Instruction* inst); + // 在 before 之前插入指令;before 为 nullptr 时追加到末尾 + void InsertBefore(Instruction* inst, Instruction* before); template T* Append(Args&&... args) { @@ -450,6 +476,10 @@ class Function : public Value { bool IsVoidReturn() const { return type_->IsVoid(); } // 将某个块移动到 blocks_ 列表末尾(用于确保块顺序正确) void MoveBlockToEnd(BasicBlock* bb); + // 重建 CFG:根据终结指令计算所有块的前驱/后继 + void RebuildCFG(); + // 从函数中移除一个基本块 + void RemoveBlock(BasicBlock* bb); private: BasicBlock* entry_ = nullptr; @@ -550,6 +580,8 @@ class IRBuilder { BrInst* CreateBr(BasicBlock* target); CondBrInst* CreateCondBr(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb); + // PHI 节点(添加到当前块开头) + PhiInst* CreatePhi(std::shared_ptr ty, const std::string& name); // 调用 CallInst* CreateCall(Function* callee, std::vector args, @@ -572,10 +604,32 @@ class IRBuilder { BasicBlock* alloca_block_ = nullptr; }; +// ─── DominatorTree ──────────────────────────────────────────────────────────── +class DominatorTree { + public: + void Compute(Function& func); + BasicBlock* GetIDom(BasicBlock* bb) const; + const std::vector& GetChildren(BasicBlock* bb) const; + const std::vector& GetDominanceFrontier(BasicBlock* bb) const; + bool Dominates(BasicBlock* a, BasicBlock* b) const; + const std::vector& GetDFOrder() const { return df_order_; } + + private: + std::unordered_map idom_; + std::unordered_map> children_; + std::unordered_map> df_; + std::unordered_map dom_level_; + std::vector df_order_; + std::unordered_set visited_; +}; + // ─── IRPrinter ──────────────────────────────────────────────────────────────── class IRPrinter { public: void Print(const Module& module, std::ostream& os); }; +// ─── Pass Manager ──────────────────────────────────────────────────────────── +void RunPasses(Module& module); + } // namespace ir diff --git a/include/utils/CLI.h b/include/utils/CLI.h index 4b3a781..9b51e55 100644 --- a/include/utils/CLI.h +++ b/include/utils/CLI.h @@ -1,14 +1,16 @@ -// 简易命令行解析:支持帮助、输入文件与输出阶段选择。 +// 命令行解析:compiler -S|-IR -o [-O1] +// 同时兼容 --emit-ir / --emit-asm #pragma once #include struct CLIOptions { std::string input; - bool emit_parse_tree = false; - bool emit_ir = true; - bool emit_asm = false; + std::string output; // -o ,为空则输出到 stdout + bool emit_ir = false; // -IR / --emit-ir + bool emit_asm = false; // -S / --emit-asm bool show_help = false; + bool opt = false; // -O1 }; CLIOptions ParseCLI(int argc, char** argv); diff --git a/scripts/bench_ir.sh b/scripts/bench_ir.sh new file mode 100755 index 0000000..f1a3749 --- /dev/null +++ b/scripts/bench_ir.sh @@ -0,0 +1,147 @@ +#!/usr/bin/env bash +# 优化效果对比:测量 O0 vs O1 的编译时间和运行时间 +# 用法: bash scripts/bench_ir.sh [--test-dir=] [--result-dir=] + +set -uo pipefail + +PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd) +TEST_CASE_DIR="${PROJECT_ROOT}/test/test_case" +RESULT_DIR="${PROJECT_ROOT}/test/test_result/bench" + +while [[ $# -gt 0 ]]; do + case "$1" in + --test-dir=*) TEST_CASE_DIR="${1#*=}" ;; + --result-dir=*) RESULT_DIR="${1#*=}" ;; + *) echo "未知参数: $1" >&2; exit 1 ;; + esac + shift +done + +compiler="${PROJECT_ROOT}/build/bin/compiler" +[[ -x "$compiler" ]] || { echo "错误:未找到编译器 $compiler" >&2; exit 1; } +command -v llc >/dev/null 2>&1 || { echo "错误:未找到 llc" >&2; exit 1; } +command -v clang >/dev/null 2>&1 || { echo "错误:未找到 clang" >&2; exit 1; } + +mkdir -p "$RESULT_DIR" + +# 时间测量:使用 date +%s.%N +now() { date +%s.%N; } +elapsed() { python3 -c "print(f'{float($2)-float($1):.4f}')" 2>/dev/null || awk "BEGIN{printf \"%.4f\\n\",$2-$1}"; } + +summary_file="${RESULT_DIR}/summary.csv" +echo "test,opt,compile_s,exec_s,compile+exec_s" > "$summary_file" + +total=0 +o0_ct_total=0; o1_ct_total=0 +o0_et_total=0; o1_et_total=0 + +echo "=== 优化效果对比 O0 vs O1 ===" +echo "" + +while read -r test_file; do + full_path=$(readlink -f "$test_file") + tcdir=$(readlink -f "$TEST_CASE_DIR") + rel="${full_path#$tcdir}" + [[ "${rel:0:1}" != "/" ]] && rel="/$rel" + + base=$(basename "$test_file") + stem="${base%.sy}" + idir=$(dirname "$test_file") + stdin="${idir}/${stem}.in" + expected="${idir}/${stem}.out" + + total=$((total+1)) + printf "[%4d] %s" "$total" "$rel" + + o0_ll="${RESULT_DIR}/O0/${rel%.sy}.ll" + o1_ll="${RESULT_DIR}/O1/${rel%.sy}.ll" + mkdir -p "$(dirname "$o0_ll")" "$(dirname "$o1_ll")" + + # --- 编译 O0 --- + t1=$(now) + "$compiler" "$test_file" -IR -o "$o0_ll" 2>/dev/null; rc0=$? + t2=$(now) + if [[ $rc0 -ne 0 ]]; then + echo " | O0编译失败" + echo "$stem,O0,-,-,-" >> "$summary_file" + echo "$stem,O1,-,-,-" >> "$summary_file" + continue + fi + o0_ct=$(elapsed "$t1" "$t2") + + # --- 编译 O1 --- + t1=$(now) + "$compiler" "$test_file" -IR -o "$o1_ll" -O1 2>/dev/null; rc1=$? + t2=$(now) + if [[ $rc1 -ne 0 ]]; then + echo " | O1编译失败" + echo "$stem,O0,$o0_ct,-,-" >> "$summary_file" + echo "$stem,O1,-,-,-" >> "$summary_file" + continue + fi + o1_ct=$(elapsed "$t1" "$t2") + + # --- llc + clang O0 --- + o0_obj="${RESULT_DIR}/O0/${stem}.o" + o1_obj="${RESULT_DIR}/O1/${stem}.o" + o0_exe="${RESULT_DIR}/O0/${stem}.exe" + o1_exe="${RESULT_DIR}/O1/${stem}.exe" + + llc -filetype=obj "$o0_ll" -o "$o0_obj" 2>/dev/null + llc -filetype=obj "$o1_ll" -o "$o1_obj" 2>/dev/null + clang "$o0_obj" "${PROJECT_ROOT}/sylib/sylib.c" -o "$o0_exe" -lm 2>/dev/null + clang "$o1_obj" "${PROJECT_ROOT}/sylib/sylib.c" -o "$o1_exe" -lm 2>/dev/null + + # --- 运行 O0 --- + t1=$(now) + sr0=0 + if [[ -f "$stdin" ]]; then + (ulimit -s unlimited; "$o0_exe" < "$stdin") > /dev/null 2>&1 || sr0=$? + else + (ulimit -s unlimited; "$o0_exe") > /dev/null 2>&1 || sr0=$? + fi + t2=$(now) + o0_et=$(elapsed "$t1" "$t2") + + # --- 运行 O1 --- + t1=$(now) + sr1=0 + if [[ -f "$stdin" ]]; then + (ulimit -s unlimited; "$o1_exe" < "$stdin") > /dev/null 2>&1 || sr1=$? + else + (ulimit -s unlimited; "$o1_exe") > /dev/null 2>&1 || sr1=$? + fi + t2=$(now) + o1_et=$(elapsed "$t1" "$t2") + + # 验证一致性 + flag="" + if [[ $sr0 -ne $sr1 ]]; then + flag=" EXIT:O0=$sr0 O1=$sr1" + fi + + # 累计 & 比率 + o0_ct_total=$(awk "BEGIN{printf \"%.4f\",$o0_ct_total+$o0_ct}") + o1_ct_total=$(awk "BEGIN{printf \"%.4f\",$o1_ct_total+$o1_ct}") + o0_et_total=$(awk "BEGIN{printf \"%.4f\",$o0_et_total+$o0_et}") + o1_et_total=$(awk "BEGIN{printf \"%.4f\",$o1_et_total+$o1_et}") + + cspd=$(awk "BEGIN{if($o1_ct>0)printf \"%.1fx\",$o0_ct/$o1_ct; else print \"-\"}") + espd=$(awk "BEGIN{if($o1_et>0)printf \"%.1fx\",$o0_et/$o1_et; else print \"-\"}") + + printf " | 编译 O0:%.4fs O1:%.4fs(%s) 运行 O0:%.4fs O1:%.4fs(%s)%s\n" \ + "$o0_ct" "$o1_ct" "$cspd" "$o0_et" "$o1_et" "$espd" "$flag" + + echo "$stem,O0,$o0_ct,$o0_et,$(awk "BEGIN{printf \"%.4f\",$o0_ct+$o0_et}")" >> "$summary_file" + echo "$stem,O1,$o1_ct,$o1_et,$(awk "BEGIN{printf \"%.4f\",$o1_ct+$o1_et}")" >> "$summary_file" + +done < <(find "$TEST_CASE_DIR" -name "*.sy" | sort) + +echo "" +echo "============================================" +echo "总用例: $total" +echo "O0 编译总耗时: ${o0_ct_total}s" +echo "O1 编译总耗时: ${o1_ct_total}s" +echo "O0 运行总耗时: ${o0_et_total}s" +echo "O1 运行总耗时: ${o1_et_total}s" +echo "CSV: $summary_file" diff --git a/scripts/run_ir_test.sh b/scripts/run_ir_test.sh index 1e7b04f..5ecd178 100755 --- a/scripts/run_ir_test.sh +++ b/scripts/run_ir_test.sh @@ -1,144 +1,308 @@ -#!/bin/bash +#!/usr/bin/env bash +# 串行执行IR测试脚本,实时输出结果 + +set -euo pipefail PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd) -COMPILER="$PROJECT_ROOT/build/bin/compiler" -TEST_CASE_DIR="$PROJECT_ROOT/test/test_case" -TEST_RESULT_DIR="$PROJECT_ROOT/test/test_result/ir" -VERIFY_SCRIPT="$PROJECT_ROOT/scripts/verify_ir.sh" -PARALLEL=${PARALLEL:-$(nproc)} -LOG_FILE="$TEST_RESULT_DIR/verify.log" - -if [ ! -x "$COMPILER" ]; then - echo "错误:编译器不存在或不可执行: $COMPILER" - echo "请先构建项目:cmake --build build -j\$(nproc)" - exit 1 + +# 默认参数 +TEST_CASE_DIR="${PROJECT_ROOT}/test/test_case" +TEST_RESULT_DIR="${PROJECT_ROOT}/test/test_result/ir" +RUN_EXEC=false +VERBOSE=false +OPT_FLAG="-O1" # 默认开启优化 + +# 解析命令行参数 +while [[ $# -gt 0 ]]; do + case "$1" in + --run) + RUN_EXEC=true + ;; + --verbose|-v) + VERBOSE=true + ;; + --O0) + OPT_FLAG="" + ;; + --O1) + OPT_FLAG="-O1" + ;; + --test-dir=*) + TEST_CASE_DIR="${1#*=}" + ;; + --result-dir=*) + TEST_RESULT_DIR="${1#*=}" + ;; + *) + echo "未知参数: $1" >&2 + echo "用法: $0 [--run] [--O0|--O1] [--verbose] [--test-dir=] [--result-dir=]" >&2 + exit 1 + ;; + esac + shift +done + +# 检查编译器是否存在 +compiler="${PROJECT_ROOT}/build/bin/compiler" +if [[ ! -x "$compiler" ]]; then + echo "错误:未找到编译器 $compiler" >&2 + echo "请先构建项目: mkdir -p build && cd build && cmake .. && make -j" >&2 + exit 1 fi +# 创建输出目录 mkdir -p "$TEST_RESULT_DIR" -> "$LOG_FILE" - -# ── 阶段1:IR 生成(并行)──────────────────────────────────────────────────── -echo "=== 阶段1:IR 生成 ===" | tee -a "$LOG_FILE" -echo "" | tee -a "$LOG_FILE" - -GEN_TMPDIR=$(mktemp -d) - -gen_one() { - test_file="$1" - relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file") - output_file="$TEST_RESULT_DIR/${relative_path%.sy}.ll" - mkdir -p "$(dirname "$output_file")" - "$COMPILER" --emit-ir "$test_file" > "$output_file" 2>&1 - exit_code=$? - # Use a per-case tmp file to avoid concurrent write issues - case_id=$(echo "$relative_path" | tr '/' '_') - if [ $exit_code -eq 0 ] && [ -s "$output_file" ] && ! grep -q '\[error\]' "$output_file"; then - echo "通过: $relative_path" > "$GEN_TMPDIR/pass_${case_id}" - else - echo "$relative_path" > "$GEN_TMPDIR/fail_${case_id}" - echo "失败: $relative_path" > "$GEN_TMPDIR/line_fail_${case_id}" - fi -} -export -f gen_one -export COMPILER TEST_CASE_DIR TEST_RESULT_DIR GEN_TMPDIR - -find "$TEST_CASE_DIR" -name "*.sy" | sort | \ - xargs -P "$PARALLEL" -I{} bash -c 'gen_one "$@"' _ {} - -# Collect results in sorted order -failed_cases=() -for f in $(find "$TEST_CASE_DIR" -name "*.sy" | sort); do - relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$f") - case_id=$(echo "$relative_path" | tr '/' '_') - if [ -f "$GEN_TMPDIR/pass_${case_id}" ]; then - cat "$GEN_TMPDIR/pass_${case_id}" | tee -a "$LOG_FILE" - elif [ -f "$GEN_TMPDIR/fail_${case_id}" ]; then - cat "$GEN_TMPDIR/line_fail_${case_id}" | tee -a "$LOG_FILE" - failed_cases+=("$relative_path") - fi -done -pass_count=$(ls "$GEN_TMPDIR"/pass_* 2>/dev/null | wc -l) -fail_count=${#failed_cases[@]} -rm -rf "$GEN_TMPDIR" +# 统计变量 +total_tests=0 +passed_tests=0 +failed_tests=0 -echo "" | tee -a "$LOG_FILE" -echo "--- 生成完成: 通过 $pass_count / 失败 $fail_count ---" | tee -a "$LOG_FILE" +# 汇总日志文件 +summary_log="${TEST_RESULT_DIR}/summary.log" +> "$summary_log" -# ── 阶段2:IR 运行验证(并行,需要 llc + clang)────────────────────────────── -if ! command -v llc >/dev/null 2>&1 || ! command -v clang >/dev/null 2>&1; then - echo "" | tee -a "$LOG_FILE" - echo "=== 跳过阶段2:未找到 llc 或 clang,无法运行 IR ===" | tee -a "$LOG_FILE" -else - echo "" | tee -a "$LOG_FILE" - echo "=== 阶段2:IR 运行验证 ===" | tee -a "$LOG_FILE" - echo "" | tee -a "$LOG_FILE" - - VRF_TMPDIR=$(mktemp -d) - - verify_one() { - test_file="$1" - relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file") - relative_dir=$(dirname "$relative_path") - out_dir="$TEST_RESULT_DIR/$relative_dir" - stem=$(basename "${test_file%.sy}") - case_log="$out_dir/$stem.verify.log" - case_id=$(echo "$relative_path" | tr '/' '_') - if bash "$VERIFY_SCRIPT" "$test_file" "$out_dir" --run > "$case_log" 2>&1; then - echo "通过: $relative_path" > "$VRF_TMPDIR/pass_${case_id}" - else - extra=$(grep -E '(退出码|输出不匹配|错误)' "$case_log" | head -3 | sed 's/^/ /' || true) - { echo "失败: $relative_path"; [ -n "$extra" ] && echo "$extra"; } > "$VRF_TMPDIR/fail_${case_id}" - echo "$relative_path" > "$VRF_TMPDIR/failname_${case_id}" - fi - } - export -f verify_one - export TEST_CASE_DIR TEST_RESULT_DIR VERIFY_SCRIPT VRF_TMPDIR - - find "$TEST_CASE_DIR" -name "*.sy" | sort | \ - xargs -P "$PARALLEL" -I{} bash -c 'verify_one "$@"' _ {} - - # Collect results in sorted order - verify_failed_cases=() - for f in $(find "$TEST_CASE_DIR" -name "*.sy" | sort); do - relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$f") - case_id=$(echo "$relative_path" | tr '/' '_') - if [ -f "$VRF_TMPDIR/pass_${case_id}" ]; then - cat "$VRF_TMPDIR/pass_${case_id}" | tee -a "$LOG_FILE" - elif [ -f "$VRF_TMPDIR/fail_${case_id}" ]; then - cat "$VRF_TMPDIR/fail_${case_id}" | tee -a "$LOG_FILE" - verify_failed_cases+=("$relative_path") - fi - done +# 失败测试列表 +failed_list="" - verify_pass=$(ls "$VRF_TMPDIR"/pass_* 2>/dev/null | wc -l) - verify_fail=${#verify_failed_cases[@]} - rm -rf "$VRF_TMPDIR" +echo "=== 开始IR测试 ===" +echo "测试目录: $TEST_CASE_DIR" +echo "结果目录: $TEST_RESULT_DIR" +echo "优化级别: ${OPT_FLAG:--O0}" +echo "运行可执行文件: $RUN_EXEC" +echo "" - echo "" | tee -a "$LOG_FILE" - echo "--- 验证完成: 通过 $verify_pass / 失败 $verify_fail ---" | tee -a "$LOG_FILE" +# 串行遍历所有测试用例 +while read -r test_file; do + total_tests=$((total_tests + 1)) + + # 计算相对路径 + full_path=$(readlink -f "$test_file") + test_case_path=$(readlink -f "$TEST_CASE_DIR") + relative_path="${full_path#$test_case_path}" + # 确保路径以 / 开头 + if [[ "${relative_path:0:1}" != "/" ]]; then + relative_path="/$relative_path" + fi + + # 计算输出文件路径 + base=$(basename "$test_file") + stem="${base%.sy}" + output_file="${TEST_RESULT_DIR}/${relative_path%.sy}.ll" + output_dir=$(dirname "$output_file") + + # 创建输出目录 + mkdir -p "$output_dir" + + # 获取输入和预期输出文件路径 + input_dir=$(dirname "$test_file") + stdin_file="${input_dir}/${stem}.in" + expected_file="${input_dir}/${stem}.out" + + # 每个测试用例的详细日志文件 + test_log="${output_dir}/${stem}.log" + > "$test_log" + + echo "[$total_tests] 处理: $relative_path" + echo "[$(date '+%Y-%m-%d %H:%M:%S')] 开始处理: $relative_path" >> "$test_log" + echo "输入文件: $test_file" >> "$test_log" + echo "输出目录: $output_dir" >> "$test_log" + + # 生成IR + if $VERBOSE; then + echo " 生成IR..." + fi + echo "步骤1: 生成IR" >> "$test_log" - if [ ${#verify_failed_cases[@]} -gt 0 ]; then - echo "" | tee -a "$LOG_FILE" - echo "=== 验证失败的用例 ===" | tee -a "$LOG_FILE" - for f in "${verify_failed_cases[@]}"; do - [ -n "$f" ] && echo " - $f" | tee -a "$LOG_FILE" - done + # 计时编译 + compile_start=$(date +%s.%N) + set +e + "$compiler" "$test_file" -IR -o "$output_file" $OPT_FLAG 2>&1 + ir_status=$? + set -e + compile_end=$(date +%s.%N) + compile_time=$(python3 -c "print(f'{float($compile_end)-float($compile_start):.3f}s')" 2>/dev/null || echo "-") + + if [[ $ir_status -ne 0 ]]; then + echo " ✗ IR生成失败" + echo "结果: FAILED (IR生成失败)" >> "$test_log" + echo "错误信息:" >> "$test_log" + cat "$output_file" >> "$test_log" + failed_tests=$((failed_tests + 1)) + failed_list="$failed_list\n[$total_tests] $relative_path - IR生成失败" + if $VERBOSE; then + cat "$output_file" fi -fi + echo "" + continue + fi + + echo "IR文件: $output_file" >> "$test_log" + if $VERBOSE; then + echo " ✓ IR已生成: $output_file" + fi + + # 如果需要运行可执行文件 + if [[ "$RUN_EXEC" == true ]]; then + echo "步骤2: 编译和运行" >> "$test_log" + + if ! command -v llc >/dev/null 2>&1; then + echo " 警告: 未找到 llc,跳过执行测试" >&2 + echo "警告: 未找到 llc,跳过执行测试" >> "$test_log" + echo "结果: SKIPPED (缺少llc)" >> "$test_log" + passed_tests=$((passed_tests + 1)) + echo "" + continue + fi + if ! command -v clang >/dev/null 2>&1; then + echo " 警告: 未找到 clang,跳过执行测试" >&2 + echo "警告: 未找到 clang,跳过执行测试" >> "$test_log" + echo "结果: SKIPPED (缺少clang)" >> "$test_log" + passed_tests=$((passed_tests + 1)) + echo "" + continue + fi + + obj="${output_dir}/${stem}.o" + exe="${output_dir}/${stem}" + stdout_file="${output_dir}/${stem}.stdout" + actual_file="${output_dir}/${stem}.actual.out" + + # 编译IR为目标文件 + if $VERBOSE; then + echo " 编译IR..." + fi + echo "编译IR: llc -filetype=obj $output_file -o $obj" >> "$test_log" + llc -filetype=obj "$output_file" -o "$obj" 2>/dev/null + if [[ $? -ne 0 ]]; then + echo " ✗ IR编译失败" + echo "结果: FAILED (IR编译失败)" >> "$test_log" + failed_tests=$((failed_tests + 1)) + failed_list="$failed_list\n[$total_tests] $relative_path - IR编译失败" + echo "" + continue + fi + echo "目标文件: $obj" >> "$test_log" + + # 链接为可执行文件 + echo "链接: clang $obj ${PROJECT_ROOT}/sylib/sylib.c -o $exe -lm" >> "$test_log" + clang "$obj" "${PROJECT_ROOT}/sylib/sylib.c" -o "$exe" -lm 2>/dev/null + if [[ $? -ne 0 ]]; then + echo " ✗ 链接失败" + echo "结果: FAILED (链接失败)" >> "$test_log" + failed_tests=$((failed_tests + 1)) + failed_list="$failed_list\n[$total_tests] $relative_path - 链接失败" + echo "" + continue + fi + echo "可执行文件: $exe" >> "$test_log" + + # 运行可执行文件 + if $VERBOSE; then + echo " 运行..." + fi + echo "运行命令: $exe" >> "$test_log" + if [[ -f "$stdin_file" ]]; then + echo "标准输入: $stdin_file" >> "$test_log" + fi + run_start=$(date +%s.%N) + set +e + if [[ -f "$stdin_file" ]]; then + (ulimit -s unlimited; "$exe" < "$stdin_file") > "$stdout_file" + else + (ulimit -s unlimited; "$exe") > "$stdout_file" + fi + status=$? + set -e + run_end=$(date +%s.%N) + run_time=$(python3 -c "print(f'{float($run_end)-float($run_start):.3f}s')" 2>/dev/null || echo "-") + echo "退出码: $status" >> "$test_log" + echo "标准输出:" >> "$test_log" + cat "$stdout_file" >> "$test_log" + + # 保存实际输出(包含退出码) + { + cat "$stdout_file" + if [[ -s "$stdout_file" ]] && (( $(tail -c 1 "$stdout_file" | wc -l) == 0 )); then + printf '\n' + fi + printf '%s\n' "$status" + } > "$actual_file" + + # 比对输出 + echo "步骤3: 比对输出" >> "$test_log" + if [[ -f "$expected_file" ]]; then + echo "预期输出: $expected_file" >> "$test_log" + echo "实际输出: $actual_file" >> "$test_log" + + if diff <(tr -d '\r' < "$expected_file" | sed -e '$a\') \ + <(tr -d '\r' < "$actual_file" | sed -e '$a\') > /dev/null 2>&1; then + echo " ✓ 编译:${compile_time} 运行:${run_time}" + echo "结果: PASSED (输出匹配)" >> "$test_log" + passed_tests=$((passed_tests + 1)) + else + echo " ✗ 输出不匹配" + echo "结果: FAILED (输出不匹配)" >> "$test_log" + echo "差异:" >> "$test_log" + diff <(tr -d '\r' < "$expected_file" | sed -e '$a\') \ + <(tr -d '\r' < "$actual_file" | sed -e '$a\') >> "$test_log" 2>&1 || true + failed_tests=$((failed_tests + 1)) + failed_list="$failed_list\n[$total_tests] $relative_path - 输出不匹配" + if $VERBOSE; then + echo " 预期:" + cat "$expected_file" + echo " 实际:" + cat "$actual_file" + fi + fi + else + echo " ? 无预期输出文件,跳过比对 (编译:${compile_time})" + echo "警告: 无预期输出文件,跳过比对" >> "$test_log" + echo "结果: SKIPPED (无预期输出)" >> "$test_log" + passed_tests=$((passed_tests + 1)) + fi + else + echo "步骤2: 跳过执行测试 (--run未启用)" >> "$test_log" + echo "结果: PASSED (仅IR生成)" >> "$test_log" + passed_tests=$((passed_tests + 1)) + echo " ✓ 编译:${compile_time}" + fi + + echo "" +done < <(find "$TEST_CASE_DIR" -name "*.sy" | sort) + +# 输出统计结果到终端 +echo "=== 测试完成 ===" +echo "总测试数: $total_tests" +echo "通过: $passed_tests" +echo "失败: $failed_tests" + +# 写入汇总日志 +echo "=== IR测试汇总报告 ===" > "$summary_log" +echo "测试时间: $(date '+%Y-%m-%d %H:%M:%S')" >> "$summary_log" +echo "测试目录: $TEST_CASE_DIR" >> "$summary_log" +echo "结果目录: $TEST_RESULT_DIR" >> "$summary_log" +echo "运行可执行文件: $RUN_EXEC" >> "$summary_log" +echo "" >> "$summary_log" +echo "=== 统计结果 ===" >> "$summary_log" +echo "总测试数: $total_tests" >> "$summary_log" +echo "通过: $passed_tests" >> "$summary_log" +echo "失败: $failed_tests" >> "$summary_log" +echo "成功率: $((passed_tests * 100 / total_tests))%" >> "$summary_log" +echo "" >> "$summary_log" -# ── 汇总 ───────────────────────────────────────────────────────────────────── -echo "" | tee -a "$LOG_FILE" -echo "=== 测试完成 ===" | tee -a "$LOG_FILE" -echo "IR生成 通过: $pass_count 失败: $fail_count" | tee -a "$LOG_FILE" -echo "结果保存在: $TEST_RESULT_DIR" | tee -a "$LOG_FILE" -echo "日志保存在: $LOG_FILE" | tee -a "$LOG_FILE" - -if [ ${#failed_cases[@]} -gt 0 ]; then - echo "" | tee -a "$LOG_FILE" - echo "=== IR生成失败的用例 ===" | tee -a "$LOG_FILE" - for f in "${failed_cases[@]}"; do - [ -n "$f" ] && echo " - $f" | tee -a "$LOG_FILE" - done - exit 1 +if [[ $failed_tests -gt 0 ]]; then + echo "=== 失败测试列表 ===" >> "$summary_log" + echo -e "$failed_list" >> "$summary_log" fi + +echo "详细日志已保存到各测试用例目录" +echo "汇总日志: $summary_log" + +if [[ $failed_tests -eq 0 ]]; then + echo "所有测试通过!" + exit 0 +else + echo "有 $failed_tests 个测试失败" + exit 1 +fi \ No newline at end of file diff --git a/scripts/verify_ir.sh b/scripts/verify_ir.sh index bd7887c..70ebbc8 100755 --- a/scripts/verify_ir.sh +++ b/scripts/verify_ir.sh @@ -45,7 +45,7 @@ stem=${base%.sy} out_file="$out_dir/$stem.ll" stdin_file="$input_dir/$stem.in" expected_file="$input_dir/$stem.out" -"$compiler" --emit-ir "$input" > "$out_file" +"$compiler" "$input" -IR -o "$out_file" -O1 echo "IR 已生成: $out_file" if [[ "$run_exec" == true ]]; then diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index acb9400..37e8b5d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -21,6 +21,7 @@ if(NOT COMPILER_PARSE_ONLY) target_link_libraries(compiler PRIVATE sem irgen + ir mir ) target_compile_definitions(compiler PRIVATE COMPILER_PARSE_ONLY=0) diff --git a/src/ir/BasicBlock.cpp b/src/ir/BasicBlock.cpp index b18502c..c69db10 100644 --- a/src/ir/BasicBlock.cpp +++ b/src/ir/BasicBlock.cpp @@ -9,6 +9,7 @@ #include "ir/IR.h" +#include #include namespace ir { @@ -42,4 +43,53 @@ const std::vector& BasicBlock::GetSuccessors() const { return successors_; } +void BasicBlock::AddPredecessor(BasicBlock* bb) { + if (bb) predecessors_.push_back(bb); +} + +void BasicBlock::RemovePredecessor(BasicBlock* bb) { + predecessors_.erase( + std::remove(predecessors_.begin(), predecessors_.end(), bb), + predecessors_.end()); +} + +void BasicBlock::ClearPredecessors() { predecessors_.clear(); } + +void BasicBlock::AddSuccessor(BasicBlock* bb) { + if (bb) successors_.push_back(bb); +} + +void BasicBlock::ClearSuccessors() { successors_.clear(); } + +void BasicBlock::RemoveInstruction(Instruction* inst) { + for (auto it = instructions_.begin(); it != instructions_.end(); ++it) { + if (it->get() == inst) { + instructions_.erase(it); + return; + } + } +} + +void BasicBlock::InsertBefore(Instruction* inst, Instruction* before) { + if (!before) { + // append (respecting terminator) + if (!instructions_.empty() && instructions_.back()->IsTerminator()) { + instructions_.insert(instructions_.end() - 1, + std::unique_ptr(inst)); + } else { + instructions_.push_back(std::unique_ptr(inst)); + } + } else { + for (auto it = instructions_.begin(); it != instructions_.end(); ++it) { + if (it->get() == before) { + instructions_.insert(it, std::unique_ptr(inst)); + return; + } + } + // before not found, append instead + instructions_.push_back(std::unique_ptr(inst)); + } + inst->SetParent(this); +} + } // namespace ir diff --git a/src/ir/Function.cpp b/src/ir/Function.cpp index 804229c..5c5f468 100644 --- a/src/ir/Function.cpp +++ b/src/ir/Function.cpp @@ -1,6 +1,8 @@ // IR Function #include "ir/IR.h" +#include + namespace ir { Function::Function(std::string name, std::shared_ptr ret_type) @@ -45,4 +47,84 @@ Argument* Function::GetArgument(size_t i) const { return args_[i].get(); } +void Function::RebuildCFG() { + // 清除所有块的前驱/后继 + for (auto& bb : blocks_) { + bb->ClearPredecessors(); + bb->ClearSuccessors(); + } + + // 根据终结指令重新计算 + for (auto& bb : blocks_) { + const auto& insts = bb->GetInstructions(); + if (insts.empty()) continue; + auto* term = insts.back().get(); + if (!term->IsTerminator()) continue; + + switch (term->GetOpcode()) { + case Opcode::Br: { + auto* target = static_cast(term)->GetTarget(); + bb->AddSuccessor(target); + target->AddPredecessor(bb.get()); + break; + } + case Opcode::CondBr: { + auto* cbr = static_cast(term); + auto* t = cbr->GetTrueBB(); + auto* f = cbr->GetFalseBB(); + bb->AddSuccessor(t); + bb->AddSuccessor(f); + t->AddPredecessor(bb.get()); + f->AddPredecessor(bb.get()); + break; + } + case Opcode::Ret: + // 无后继 + break; + default: + break; + } + } +} + +void Function::RemoveBlock(BasicBlock* bb) { + if (entry_ == bb) return; + + // 步骤1:清除所有 PHI 节点中对本块的引用 + for (auto& other_bb : blocks_) { + if (other_bb.get() == bb) continue; + for (auto& inst : other_bb->GetInstructions()) { + if (auto* phi = dynamic_cast(inst.get())) { + for (size_t i = 0; i < phi->GetNumIncoming(); ++i) { + if (phi->GetIncomingBlock(i) == bb) { + phi->SetOperand(i * 2, nullptr); // value + phi->SetOperand(i * 2 + 1, nullptr); // block + } + } + } + } + } + + // 步骤2:将块内所有指令"未定义化",用 undef (nullptr) 替换所有使用 + // 这样其他引用这些指令的 place 会变成 null,后续 SanitizePhis 会修复 + for (auto& inst : bb->GetInstructions()) { + inst->ReplaceAllUsesWith(nullptr); + } + + // 步骤3:断开本块指令对操作数的引用 + for (auto& inst : bb->GetInstructions()) { + for (size_t i = 0; i < inst->GetNumOperands(); ++i) { + inst->SetOperand(i, nullptr); + } + } + + // 步骤4:从 blocks_ 中移除 + blocks_.erase( + std::remove_if(blocks_.begin(), blocks_.end(), + [bb](const std::unique_ptr& b) { + return b.get() == bb; + }), + blocks_.end()); +} + } // namespace ir diff --git a/src/ir/IRBuilder.cpp b/src/ir/IRBuilder.cpp index 49060c0..6c19fcb 100644 --- a/src/ir/IRBuilder.cpp +++ b/src/ir/IRBuilder.cpp @@ -246,6 +246,14 @@ FPToSIInst* IRBuilder::CreateFPToSI(Value* val, const std::string& name) { return insert_block_->Append(val, name); } +PhiInst* IRBuilder::CreatePhi(std::shared_ptr ty, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Prepend(std::move(ty), name); +} + void IRBuilder::CreateMemsetZero(Value* ptr, int num_elements, Context& ctx, Module& mod) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index 60fc996..cac7da1 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -228,6 +228,17 @@ static void PrintInst(const Instruction* inst, std::ostream& os, os << " %" << N(fp) << " = fptosi float " << VS(fp->GetSrc()) << " to i32\n"; break; } + case Opcode::Phi: { + auto* phi = static_cast(inst); + os << " %" << N(phi) << " = phi " << TypeToStr(*phi->GetType()) << " "; + for (size_t i = 0; i < phi->GetNumIncoming(); ++i) { + if (i > 0) os << ", "; + os << "[ " << VS(phi->GetIncomingValue(i)) << ", %" + << phi->GetIncomingBlock(i)->GetName() << " ]"; + } + os << "\n"; + break; + } } } @@ -345,21 +356,30 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { os << bb->GetName() << ":\n"; if (first_bb) { first_bb = false; - // Print all allocas from all blocks + // Print all allocas from all blocks (only for entry block) for (const auto& bb2 : func->GetBlocks()) { if (!bb2) continue; for (const auto& ip : bb2->GetInstructions()) if (ip->GetOpcode() == Opcode::Alloca) PrintInst(ip.get(), os, rm); } - // Print non-alloca instructions of entry block + // Print PHI nodes of entry block for (const auto& ip : bb->GetInstructions()) - if (ip->GetOpcode() != Opcode::Alloca) + if (ip->GetOpcode() == Opcode::Phi) + PrintInst(ip.get(), os, rm); + // Print non-alloca non-phi instructions of entry block + for (const auto& ip : bb->GetInstructions()) + if (ip->GetOpcode() != Opcode::Alloca && ip->GetOpcode() != Opcode::Phi) PrintInst(ip.get(), os, rm); } else { // Non-entry blocks: skip allocas (already printed) + // Print PHI nodes first + for (const auto& ip : bb->GetInstructions()) + if (ip->GetOpcode() == Opcode::Phi) + PrintInst(ip.get(), os, rm); + // Print non-alloca non-phi instructions for (const auto& ip : bb->GetInstructions()) - if (ip->GetOpcode() != Opcode::Alloca) + if (ip->GetOpcode() != Opcode::Alloca && ip->GetOpcode() != Opcode::Phi) PrintInst(ip.get(), os, rm); } } diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index 8475546..881955f 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -48,6 +48,12 @@ bool Instruction::IsTerminator() const { opcode_ == Opcode::CondBr; } +void Instruction::RemoveFromParent() { + if (parent_) { + parent_->RemoveInstruction(this); + } +} + BasicBlock* Instruction::GetParent() const { return parent_; } void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; } @@ -267,6 +273,19 @@ Value* StoreInst::GetPtr() const { return GetOperand(1); } GlobalValue::GlobalValue(std::shared_ptr ty, std::string name) : User(std::move(ty), std::move(name)) {} +// ─── PhiInst ────────────────────────────────────────────────────────────────── +PhiInst::PhiInst(std::shared_ptr ty, std::string name) + : Instruction(Opcode::Phi, std::move(ty), std::move(name)) {} + +void PhiInst::AddIncoming(Value* val, BasicBlock* bb) { + AddOperand(val); + AddOperand(bb); +} + +BasicBlock* PhiInst::GetIncomingBlock(size_t i) const { + return static_cast(GetOperand(i * 2 + 1)); +} + // ─── GlobalVariable ──────────────────────────────────────────────────────────── GlobalVariable::GlobalVariable(std::string name, bool is_const, int init_val, int num_elements, bool is_array_decl, diff --git a/src/ir/analysis/DominatorTree.cpp b/src/ir/analysis/DominatorTree.cpp index eaf7269..2917d4b 100644 --- a/src/ir/analysis/DominatorTree.cpp +++ b/src/ir/analysis/DominatorTree.cpp @@ -1,4 +1,205 @@ // 支配树分析: // - 构建/查询 Dominator Tree 及相关关系 -// - 为 mem2reg、CFG 优化与循环分析提供基础能力 +// - 使用 Cooper-Harvey-Kennedy 算法,近线性时间复杂度 +#include "ir/IR.h" + +#include +#include +#include +#include +#include +#include + +namespace ir { + +void DominatorTree::Compute(Function& func) { + func.RebuildCFG(); + + // Build block list and reverse postorder (RPO) + std::vector blocks; + std::unordered_map rpo; + { + std::vector rpo_vec; + std::unordered_set visited; + std::function dfs = [&](BasicBlock* bb) { + if (!bb || visited.count(bb)) return; + visited.insert(bb); + for (auto* succ : bb->GetSuccessors()) { + dfs(succ); + } + rpo_vec.push_back(bb); + }; + dfs(func.GetEntry()); + // Reverse to get RPO (postorder reversed) + std::reverse(rpo_vec.begin(), rpo_vec.end()); + blocks = rpo_vec; + for (int i = 0; i < (int)blocks.size(); ++i) { + rpo[blocks[i]] = i; + } + } + if (blocks.empty()) return; + int n = (int)blocks.size(); + + auto* entry = func.GetEntry(); + if (!entry) return; + + // ─── 1. CHK algorithm for immediate dominators ───────────────────────── + idom_.clear(); + idom_[entry] = entry; // entry is its own dominator + + // Intersect: find common ancestor of b1 and b2 walking up the dom tree + // Uses RPO number: a dominator always has lower RPO number + auto intersect = [&](BasicBlock* b1, BasicBlock* b2) -> BasicBlock* { + auto i1 = rpo.find(b1), i2 = rpo.find(b2); + if (i1 == rpo.end() || i2 == rpo.end()) return entry; + int r1 = i1->second, r2 = i2->second; + while (b1 != b2) { + while (r1 > r2) { + auto it = idom_.find(b1); + if (it == idom_.end() || it->second == b1) return b1; + b1 = it->second; + r1 = rpo[b1]; + } + while (r2 > r1) { + auto it = idom_.find(b2); + if (it == idom_.end() || it->second == b2) return b2; + b2 = it->second; + r2 = rpo[b2]; + } + } + return b1; + }; + + bool changed = true; + while (changed) { + changed = false; + // Process in RPO (skip entry which is first in RPO) + for (int i = 1; i < n; ++i) { + auto* bb = blocks[i]; + // Find first predecessor with defined IDOM + BasicBlock* new_idom = nullptr; + for (auto* pred : bb->GetPredecessors()) { + if (idom_.count(pred) && pred != bb) { + new_idom = pred; + break; + } + } + if (!new_idom) continue; + + // Intersect with remaining predecessors + for (auto* pred : bb->GetPredecessors()) { + if (pred == new_idom || pred == bb) continue; + if (idom_.count(pred)) { + new_idom = intersect(pred, new_idom); + } + } + + auto old = idom_.find(bb); + if (old == idom_.end() || old->second != new_idom) { + idom_[bb] = new_idom; + changed = true; + } + } + } + + // Entry is its own IDOM, set to nullptr for external queries + idom_[entry] = nullptr; + + // Unreached blocks get entry as IDOM + for (auto* bb : blocks) { + if (!idom_.count(bb)) idom_[bb] = entry; + } + + // ─── 2. Build children map and dom levels ────────────────────────────── + children_.clear(); + dom_level_.clear(); + for (auto& [child, parent] : idom_) { + if (parent) children_[parent].push_back(child); + } + + // BFS to compute dom levels + std::queue q; + dom_level_[entry] = 0; + q.push(entry); + while (!q.empty()) { + auto* cur = q.front(); + q.pop(); + size_t cur_level = dom_level_[cur]; + auto it = children_.find(cur); + if (it != children_.end()) { + for (auto* child : it->second) { + dom_level_[child] = cur_level + 1; + q.push(child); + } + } + } + + // ─── 3. Compute dominance frontier ───────────────────────────────────── + df_.clear(); + for (int i = 0; i < n; ++i) { + auto* b = blocks[i]; + if (b->GetPredecessors().size() < 2) continue; + for (auto* p : b->GetPredecessors()) { + auto* runner = p; + auto* b_idom = GetIDom(b); + while (runner != b_idom) { + if (!runner) break; + df_[runner].push_back(b); + runner = GetIDom(runner); + } + } + } + // Deduplicate DF entries + for (auto& [bb, vec] : df_) { + std::sort(vec.begin(), vec.end()); + vec.erase(std::unique(vec.begin(), vec.end()), vec.end()); + } + + // ─── 4. Compute DFS order of dominator tree ──────────────────────────── + df_order_.clear(); + visited_.clear(); + std::function dfs_tree = [&](BasicBlock* bb) { + if (!bb || visited_.count(bb)) return; + visited_.insert(bb); + df_order_.push_back(bb); + auto it = children_.find(bb); + if (it != children_.end()) { + for (auto* child : it->second) { + dfs_tree(child); + } + } + }; + dfs_tree(entry); +} + +BasicBlock* DominatorTree::GetIDom(BasicBlock* bb) const { + auto it = idom_.find(bb); + return (it != idom_.end()) ? it->second : nullptr; +} + +const std::vector& DominatorTree::GetChildren( + BasicBlock* bb) const { + static const std::vector empty; + auto it = children_.find(bb); + return (it != children_.end()) ? it->second : empty; +} + +const std::vector& DominatorTree::GetDominanceFrontier( + BasicBlock* bb) const { + static const std::vector empty; + auto it = df_.find(bb); + return (it != df_.end()) ? it->second : empty; +} + +bool DominatorTree::Dominates(BasicBlock* a, BasicBlock* b) const { + if (a == b) return true; + BasicBlock* runner = b; + while (runner) { + runner = GetIDom(runner); + if (runner == a) return true; + } + return false; +} + +} // namespace ir diff --git a/src/ir/passes/CFGSimplify.cpp b/src/ir/passes/CFGSimplify.cpp index 3779397..cf69f8e 100644 --- a/src/ir/passes/CFGSimplify.cpp +++ b/src/ir/passes/CFGSimplify.cpp @@ -2,3 +2,143 @@ // - 删除不可达块、合并空块、简化分支等 // - 改善 IR 结构,便于后续优化与后端生成 +#include "ir/IR.h" + +#include +#include +#include + +namespace ir { + +namespace { + +// 简化常量条件跳转 +bool SimplifyConstantBranches(Function& func, Context& ctx) { + bool changed = false; + for (auto& bb : func.GetBlocks()) { + const auto& insts = bb->GetInstructions(); + if (insts.empty()) continue; + auto* term = insts.back().get(); + if (auto* cbr = dynamic_cast(term)) { + if (auto* ci = dynamic_cast(cbr->GetCond())) { + BasicBlock* target = + (ci->GetValue() != 0) ? cbr->GetTrueBB() : cbr->GetFalseBB(); + // 替换条件跳转为无条件跳转 + cbr->RemoveFromParent(); + bb->Append(target); + changed = true; + } + } + } + return changed; +} + +// 合并空基本块:如果一个块只有一个 br 指令,可以绕过它 +bool MergeEmptyBlocks(Function& func) { + bool changed = false; + bool local_changed = true; + + // 迭代处理,因为合并一个空块可能产生新的空块或改变前驱关系 + while (local_changed) { + local_changed = false; + func.RebuildCFG(); + + BasicBlock* block_to_remove = nullptr; + for (auto& bb : func.GetBlocks()) { + auto* block = bb.get(); + if (block == func.GetEntry()) continue; + const auto& insts = block->GetInstructions(); + if (insts.size() != 1) continue; + auto* br = dynamic_cast(insts[0].get()); + if (!br) continue; + auto* target = br->GetTarget(); + if (target == block) continue; + // 不能合并目标也是空块的块(将在下一轮处理) + if (target->GetInstructions().size() == 1 && + dynamic_cast(target->GetInstructions()[0].get()) && + target != func.GetEntry()) { + continue; + } + + // 只重定向仍然引用此块的前驱 + for (auto* pred : block->GetPredecessors()) { + if (pred == block) continue; + auto& p_insts = pred->GetInstructions(); + if (p_insts.empty()) continue; + auto* p_term = p_insts.back().get(); + if (auto* cbr = dynamic_cast(p_term)) { + if (cbr->GetTrueBB() == block) cbr->SetOperand(1, target); + if (cbr->GetFalseBB() == block) cbr->SetOperand(2, target); + } else if (auto* p_br = dynamic_cast(p_term)) { + if (p_br->GetTarget() == block) p_br->SetOperand(0, target); + } + // 更新 target 中引用 block 的 PHI 节点,改为引用 pred + for (auto& t_inst : target->GetInstructions()) { + if (auto* phi = dynamic_cast(t_inst.get())) { + for (size_t i = 0; i < phi->GetNumIncoming(); ++i) { + if (phi->GetIncomingBlock(i) == block) { + phi->SetOperand(i * 2 + 1, pred); + } + } + } + } + } + + block_to_remove = block; + local_changed = true; + changed = true; + break; // 只合并一个,下一轮迭代重建 CFG + } + + if (block_to_remove) { + func.RemoveBlock(block_to_remove); + } + } + + if (changed) func.RebuildCFG(); + return changed; +} + +// 删除不可达基本块 +bool RemoveUnreachableBlocks(Function& func) { + func.RebuildCFG(); + // BFS from entry + std::unordered_set reachable; + std::queue q; + auto* entry = func.GetEntry(); + if (!entry) return false; + q.push(entry); + reachable.insert(entry); + while (!q.empty()) { + auto* bb = q.front(); + q.pop(); + for (auto* succ : bb->GetSuccessors()) { + if (!succ) continue; + if (reachable.insert(succ).second) q.push(succ); + } + } + + std::vector unreachable; + for (auto& bb : func.GetBlocks()) { + if (!reachable.count(bb.get())) unreachable.push_back(bb.get()); + } + + for (auto* bb : unreachable) { + func.RemoveBlock(bb); + } + + if (!unreachable.empty()) func.RebuildCFG(); + return !unreachable.empty(); +} + +} // namespace + +bool RunCFGSimplify(Function& func, Context& ctx) { + bool changed = false; + changed |= SimplifyConstantBranches(func, ctx); + changed |= MergeEmptyBlocks(func); + changed |= RemoveUnreachableBlocks(func); + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/CSE.cpp b/src/ir/passes/CSE.cpp index 4b24dd0..eff6d69 100644 --- a/src/ir/passes/CSE.cpp +++ b/src/ir/passes/CSE.cpp @@ -1,4 +1,123 @@ // 公共子表达式消除(CSE): // - 识别并复用重复计算的等价表达式 -// - 典型放置在 ConstFold 之后、DCE 之前 -// - 当前为 Lab4 的框架占位,具体算法由实验实现 +// - 局部值编号:在单个基本块内消除重复计算 + +#include "ir/IR.h" + +#include +#include +#include +#include + +namespace ir { + +namespace { + +// 为操作数生成唯一标识:常量使用值,否则使用指针 +std::string ValKey(Value* v) { + if (auto* ci = dynamic_cast(v)) + return "ci" + std::to_string(ci->GetValue()); + if (auto* cf = dynamic_cast(v)) { + // 使用 IEEE 754 位表示 + union { float f; uint32_t i; } u; + u.f = cf->GetValue(); + return "cf" + std::to_string(u.i); + } + // 非常量使用指针地址作为唯一标识 + std::ostringstream oss; + oss << "p" << reinterpret_cast(v); + return oss.str(); +} + +// 为可消除的指令生成 hash key +std::string MakeKey(Instruction* inst) { + switch (inst->GetOpcode()) { + case Opcode::Add: case Opcode::Sub: case Opcode::Mul: + case Opcode::Div: case Opcode::Mod: + case Opcode::FAdd: case Opcode::FSub: + case Opcode::FMul: case Opcode::FDiv: { + auto* bin = static_cast(inst); + return std::to_string(static_cast(inst->GetOpcode())) + "|" + + ValKey(bin->GetLhs()) + "|" + ValKey(bin->GetRhs()); + } + case Opcode::ICmp: { + auto* cmp = static_cast(inst); + return std::to_string(static_cast(inst->GetOpcode())) + "|" + + std::to_string(static_cast(cmp->GetPredicate())) + "|" + + ValKey(cmp->GetLhs()) + "|" + ValKey(cmp->GetRhs()); + } + case Opcode::FCmp: { + auto* cmp = static_cast(inst); + return std::to_string(static_cast(inst->GetOpcode())) + "|" + + std::to_string(static_cast(cmp->GetPredicate())) + "|" + + ValKey(cmp->GetLhs()) + "|" + ValKey(cmp->GetRhs()); + } + case Opcode::Gep: { + auto* gep = static_cast(inst); + return std::to_string(static_cast(inst->GetOpcode())) + "|" + + ValKey(gep->GetBasePtr()) + "|" + ValKey(gep->GetIndex()); + } + case Opcode::Load: { + auto* ld = static_cast(inst); + return std::to_string(static_cast(inst->GetOpcode())) + "|" + + ValKey(ld->GetPtr()); + } + case Opcode::ZExt: { + auto* ze = static_cast(inst); + return std::to_string(static_cast(inst->GetOpcode())) + "|" + + ValKey(ze->GetSrc()); + } + case Opcode::SIToFP: { + auto* si = static_cast(inst); + return std::to_string(static_cast(inst->GetOpcode())) + "|" + + ValKey(si->GetSrc()); + } + case Opcode::FPToSI: { + auto* fs = static_cast(inst); + return std::to_string(static_cast(inst->GetOpcode())) + "|" + + ValKey(fs->GetSrc()); + } + default: return ""; + } +} + +} // namespace + +bool RunCSE(Function& func) { + bool changed = false; + + for (auto& bb : func.GetBlocks()) { + std::unordered_map available; + std::vector to_remove; + + for (auto& inst : bb->GetInstructions()) { + auto* ip = inst.get(); + std::string key = MakeKey(ip); + if (key.empty()) { + // 不可消除的指令:如果它有结果,可以考虑将其加入可用集 + // 但为了简单,这里不处理 + continue; + } + + auto it = available.find(key); + if (it != available.end()) { + // 找到已有的等价指令,替换使用 + ip->ReplaceAllUsesWith(it->second); + to_remove.push_back(ip); + changed = true; + } else { + available[key] = ip; + } + } + + for (auto* ip : to_remove) { + for (size_t i = 0; i < ip->GetNumOperands(); ++i) + ip->SetOperand(i, nullptr); + ip->RemoveFromParent(); + } + } + + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/ConstFold.cpp b/src/ir/passes/ConstFold.cpp index 19f2d43..2a2edcb 100644 --- a/src/ir/passes/ConstFold.cpp +++ b/src/ir/passes/ConstFold.cpp @@ -1,4 +1,187 @@ // IR 常量折叠: // - 折叠可判定的常量表达式 -// - 简化常量控制流分支(按实现范围裁剪) +// - 简化常量控制流分支 +#include "ir/IR.h" + +#include +#include + +namespace ir { + +namespace { + +// 在释放指令前断开其 use-def 链,避免其他值的 uses_ 中有悬空指针 +static void DetachAndRemove(Instruction* inst) { + // 先断开自身对操作数的引用,清除 use-def 链 + for (size_t i = 0; i < inst->GetNumOperands(); ++i) + inst->SetOperand(i, nullptr); + // 然后从父块中移除(清除父指针后不会再次尝试访问) + if (auto* parent = inst->GetParent()) { + parent->RemoveInstruction(inst); + } +} + +bool FoldICmp(ICmpInst* cmp, Context& ctx) { + auto* lhs = dynamic_cast(cmp->GetLhs()); + auto* rhs = dynamic_cast(cmp->GetRhs()); + if (!lhs || !rhs) return false; + + int lv = lhs->GetValue(), rv = rhs->GetValue(); + bool result = false; + switch (cmp->GetPredicate()) { + case ICmpPredicate::EQ: result = lv == rv; break; + case ICmpPredicate::NE: result = lv != rv; break; + case ICmpPredicate::SLT: result = lv < rv; break; + case ICmpPredicate::SLE: result = lv <= rv; break; + case ICmpPredicate::SGT: result = lv > rv; break; + case ICmpPredicate::SGE: result = lv >= rv; break; + } + cmp->ReplaceAllUsesWith(ctx.GetConstInt(result ? 1 : 0)); + DetachAndRemove(cmp); + return true; +} + +bool FoldFCmp(FCmpInst* cmp, Context& ctx) { + auto* lhs = dynamic_cast(cmp->GetLhs()); + auto* rhs = dynamic_cast(cmp->GetRhs()); + if (!lhs || !rhs) return false; + + float lv = lhs->GetValue(), rv = rhs->GetValue(); + bool result = false; + switch (cmp->GetPredicate()) { + case FCmpPredicate::OEQ: result = lv == rv; break; + case FCmpPredicate::ONE: result = lv != rv; break; + case FCmpPredicate::OLT: result = lv < rv; break; + case FCmpPredicate::OLE: result = lv <= rv; break; + case FCmpPredicate::OGT: result = lv > rv; break; + case FCmpPredicate::OGE: result = lv >= rv; break; + } + cmp->ReplaceAllUsesWith(ctx.GetConstInt(result ? 1 : 0)); + DetachAndRemove(cmp); + return true; +} + +bool FoldZExt(ZExtInst* zext, Context& ctx) { + auto* src = dynamic_cast(zext->GetSrc()); + if (!src) return false; + zext->ReplaceAllUsesWith(ctx.GetConstInt(src->GetValue() != 0 ? 1 : 0)); + DetachAndRemove(zext); + return true; +} + +bool FoldSIToFP(SIToFPInst* inst, Context& ctx) { + auto* src = dynamic_cast(inst->GetSrc()); + if (!src) return false; + inst->ReplaceAllUsesWith(ctx.GetConstFloat(static_cast(src->GetValue()))); + DetachAndRemove(inst); + return true; +} + +bool FoldFPToSI(FPToSIInst* inst, Context& ctx) { + auto* src = dynamic_cast(inst->GetSrc()); + if (!src) return false; + inst->ReplaceAllUsesWith(ctx.GetConstInt(static_cast(src->GetValue()))); + DetachAndRemove(inst); + return true; +} + +// Fold constant binary operations (int and float) +bool FoldBinaryWithCtx(BinaryInst* bin, Context& ctx) { + auto* lhs_c = dynamic_cast(bin->GetLhs()); + auto* rhs_c = dynamic_cast(bin->GetRhs()); + auto* lhs_f = dynamic_cast(bin->GetLhs()); + auto* rhs_f = dynamic_cast(bin->GetRhs()); + + if (lhs_c && rhs_c) { + int lv = lhs_c->GetValue(), rv = rhs_c->GetValue(); + int result = 0; + bool valid = true; + switch (bin->GetOpcode()) { + case Opcode::Add: result = lv + rv; break; + case Opcode::Sub: result = lv - rv; break; + case Opcode::Mul: result = lv * rv; break; + case Opcode::Div: if (rv != 0) result = lv / rv; else valid = false; break; + case Opcode::Mod: if (rv != 0) result = lv % rv; else valid = false; break; + default: valid = false; break; + } + if (valid) { + bin->ReplaceAllUsesWith(ctx.GetConstInt(result)); + DetachAndRemove(bin); + return true; + } + } + if (lhs_f && rhs_f) { + float lv = lhs_f->GetValue(), rv = rhs_f->GetValue(); + float result = 0.0f; + bool valid = true; + switch (bin->GetOpcode()) { + case Opcode::FAdd: result = lv + rv; break; + case Opcode::FSub: result = lv - rv; break; + case Opcode::FMul: result = lv * rv; break; + case Opcode::FDiv: if (rv != 0.0f) result = lv / rv; else valid = false; break; + default: valid = false; break; + } + if (valid) { + bin->ReplaceAllUsesWith(ctx.GetConstFloat(result)); + DetachAndRemove(bin); + return true; + } + } + return false; +} + +} // namespace + +bool RunConstFold(Function& func, Context& ctx) { + bool changed = false; + std::unordered_set removed; + + bool any_changed = true; + while (any_changed) { + any_changed = false; + for (auto& bb : func.GetBlocks()) { + // 每轮重新收集(因为指令列表在变化) + std::vector insts; + for (auto& inst : bb->GetInstructions()) + insts.push_back(inst.get()); + + for (auto* inst : insts) { + if (removed.count(inst)) continue; + bool folded = false; + switch (inst->GetOpcode()) { + case Opcode::Add: case Opcode::Sub: case Opcode::Mul: + case Opcode::Div: case Opcode::Mod: + case Opcode::FAdd: case Opcode::FSub: + case Opcode::FMul: case Opcode::FDiv: + folded = FoldBinaryWithCtx(static_cast(inst), ctx); + break; + case Opcode::ICmp: + folded = FoldICmp(static_cast(inst), ctx); + break; + case Opcode::FCmp: + folded = FoldFCmp(static_cast(inst), ctx); + break; + case Opcode::ZExt: + folded = FoldZExt(static_cast(inst), ctx); + break; + case Opcode::SIToFP: + folded = FoldSIToFP(static_cast(inst), ctx); + break; + case Opcode::FPToSI: + folded = FoldFPToSI(static_cast(inst), ctx); + break; + default: break; + } + if (folded) { + removed.insert(inst); + any_changed = true; + changed = true; + } + } + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/ConstProp.cpp b/src/ir/passes/ConstProp.cpp index 1768b71..ced51c5 100644 --- a/src/ir/passes/ConstProp.cpp +++ b/src/ir/passes/ConstProp.cpp @@ -1,5 +1,63 @@ // 常量传播(Constant Propagation): // - 沿 use-def 关系传播已知常量 // - 将可替换的 SSA 值改写为常量,暴露更多折叠机会 -// - 常与 ConstFold、DCE、CFGSimplify 迭代配合使用 +#include "ir/IR.h" + +#include +#include +#include + +namespace ir { + +bool RunConstProp(Function& func, Context& ctx) { + bool changed = false; + + for (auto& bb : func.GetBlocks()) { + std::vector insts; + for (auto& inst : bb->GetInstructions()) + insts.push_back(inst.get()); + + for (auto* inst : insts) { + if (inst->GetParent() == nullptr) continue; + // 检查是否为"复制"类指令:直接将一个操作数作为结果传播 + // 实际上常量传播由 ConstFold 配合 use-def 链完成 + // 这里处理简单的常量替换 + switch (inst->GetOpcode()) { + case Opcode::ZExt: { + auto* ze = static_cast(inst); + if (auto* ci = dynamic_cast(ze->GetSrc())) { + ze->ReplaceAllUsesWith(ctx.GetConstInt(ci->GetValue() != 0 ? 1 : 0)); + ze->RemoveFromParent(); + changed = true; + } + break; + } + case Opcode::SIToFP: { + auto* si = static_cast(inst); + if (auto* ci = dynamic_cast(si->GetSrc())) { + si->ReplaceAllUsesWith( + ctx.GetConstFloat(static_cast(ci->GetValue()))); + si->RemoveFromParent(); + changed = true; + } + break; + } + case Opcode::FPToSI: { + auto* fp = static_cast(inst); + if (auto* cf = dynamic_cast(fp->GetSrc())) { + fp->ReplaceAllUsesWith( + ctx.GetConstInt(static_cast(cf->GetValue()))); + fp->RemoveFromParent(); + changed = true; + } + break; + } + default: break; + } + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/DCE.cpp b/src/ir/passes/DCE.cpp index 5a0db91..7f34f60 100644 --- a/src/ir/passes/DCE.cpp +++ b/src/ir/passes/DCE.cpp @@ -1,4 +1,68 @@ // 死代码删除(DCE): // - 删除无用指令与无用基本块 -// - 通常与 CFG 简化配合使用 +// - 标记-清扫算法:先标记有用指令,再清除未标记的 +#include "ir/IR.h" + +#include +#include + +namespace ir { + +namespace { + +bool HasSideEffect(Instruction* inst) { + switch (inst->GetOpcode()) { + case Opcode::Store: + case Opcode::Call: + case Opcode::Br: + case Opcode::CondBr: + case Opcode::Ret: + return true; + default: + return false; + } +} + +} // namespace + +bool RunDCE(Function& func) { + bool changed = false; + + for (auto& bb : func.GetBlocks()) { + // 收集要删除的指令 + std::vector to_remove; + + for (auto& inst : bb->GetInstructions()) { + auto* ip = inst.get(); + if (ip->IsTerminator()) continue; + if (HasSideEffect(ip)) continue; + + // 检查该指令是否有使用者 + bool has_use = false; + for (const auto& use : ip->GetUses()) { + if (use.GetUser()) { + has_use = true; + break; + } + } + + if (!has_use) { + to_remove.push_back(ip); + } + } + + for (auto* ip : to_remove) { + // 断开该指令对操作数的引用 + for (size_t i = 0; i < ip->GetNumOperands(); ++i) { + ip->SetOperand(i, nullptr); + } + ip->RemoveFromParent(); + changed = true; + } + } + + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/Mem2Reg.cpp b/src/ir/passes/Mem2Reg.cpp index 0b052ba..7f5237a 100644 --- a/src/ir/passes/Mem2Reg.cpp +++ b/src/ir/passes/Mem2Reg.cpp @@ -1,4 +1,190 @@ // Mem2Reg(SSA 构造): // - 将局部变量的 alloca/load/store 提升为 SSA 形式 -// - 插入 PHI 并重写使用,依赖支配树等分析 +// - 插入 PHI 节点并重写使用,所有可提升的 alloca 在一次重命名遍中处理 +#include "ir/IR.h" + +#include +#include +#include +#include +#include + +namespace ir { + +namespace { + +bool IsPromotable(AllocaInst* alloca) { + if (alloca->IsArray()) return false; + for (const auto& use : alloca->GetUses()) { + auto* user = use.GetUser(); + if (auto* load = dynamic_cast(user)) { + if (load->GetPtr() != alloca) return false; + } else if (auto* store = dynamic_cast(user)) { + if (store->GetPtr() != alloca) return false; + } else { + return false; + } + } + return true; +} + +void CollectStoresAndLoads(AllocaInst* alloca, + std::vector& stores, + std::vector& loads) { + for (const auto& use : alloca->GetUses()) { + if (auto* store = dynamic_cast(use.GetUser())) { + if (use.GetOperandIndex() == 1) stores.push_back(store); + } else if (auto* load = dynamic_cast(use.GetUser())) { + loads.push_back(load); + } + } +} + +std::set ComputeIDF(const std::set& def_blocks, + DominatorTree& dt) { + std::set df_plus; + std::vector worklist(def_blocks.begin(), def_blocks.end()); + std::set visited(def_blocks.begin(), def_blocks.end()); + while (!worklist.empty()) { + auto* bb = worklist.back(); + worklist.pop_back(); + for (auto* df_bb : dt.GetDominanceFrontier(bb)) { + if (df_plus.insert(df_bb).second) { + if (visited.insert(df_bb).second) worklist.push_back(df_bb); + } + } + } + return df_plus; +} + +struct AllocaInfo { + AllocaInst* alloca = nullptr; + std::vector stores; + std::vector loads; + std::unordered_map phis; + std::vector value_stack; + Value* undef_val = nullptr; +}; + +} // namespace + +bool RunMem2Reg(Function& func, Context& ctx) { + DominatorTree dt; + dt.Compute(func); + + // 收集所有可提升的 alloca + std::vector promotable; + for (auto& bb : func.GetBlocks()) { + for (auto& inst : bb->GetInstructions()) { + if (auto* alloca = dynamic_cast(inst.get())) { + if (IsPromotable(alloca)) promotable.push_back(alloca); + } + } + } + if (promotable.empty()) return false; + + // 为每个可提升的 alloca 构建信息 + std::vector infos(promotable.size()); + std::unordered_map store_to_info; + std::unordered_map load_to_info; + + for (size_t i = 0; i < promotable.size(); ++i) { + auto* alloca = promotable[i]; + auto& info = infos[i]; + info.alloca = alloca; + CollectStoresAndLoads(alloca, info.stores, info.loads); + + std::set def_blocks; + for (auto* s : info.stores) def_blocks.insert(s->GetParent()); + + auto val_type = alloca->GetType()->IsPtrFloat32() ? Type::GetFloat32Type() + : Type::GetInt32Type(); + info.undef_val = alloca->GetType()->IsPtrFloat32() + ? static_cast(ctx.GetConstFloat(0.0f)) + : static_cast(ctx.GetConstInt(0)); + info.value_stack.push_back(info.undef_val); + + // 插入 PHI 节点到迭代支配边界 + auto df_plus = ComputeIDF(def_blocks, dt); + for (auto* bb : df_plus) { + auto* phi = bb->Prepend(val_type, ""); + info.phis[bb] = phi; + } + + // 建立快速查找映射 + for (auto* s : info.stores) store_to_info[s] = (int)i; + for (auto* l : info.loads) load_to_info[l] = (int)i; + } + + // ─── 单次重命名遍:DFS 遍历支配树,同时处理所有 alloca ────────────── + std::function rename = [&](BasicBlock* bb) { + // 保存所有栈大小 + std::vector saved_sizes(infos.size()); + for (size_t i = 0; i < infos.size(); ++i) { + saved_sizes[i] = infos[i].value_stack.size(); + auto phi_it = infos[i].phis.find(bb); + if (phi_it != infos[i].phis.end()) { + infos[i].value_stack.push_back(phi_it->second); + } + } + + // 处理块内指令 + for (auto& inst_up : bb->GetInstructions()) { + auto* inst = inst_up.get(); + // Skip PHI nodes (they've already been pushed onto stacks) + if (dynamic_cast(inst)) continue; + + if (auto* store = dynamic_cast(inst)) { + auto it = store_to_info.find(store); + if (it != store_to_info.end()) { + infos[it->second].value_stack.push_back(store->GetValue()); + } + } else if (auto* load = dynamic_cast(inst)) { + auto it = load_to_info.find(load); + if (it != load_to_info.end()) { + load->ReplaceAllUsesWith(infos[it->second].value_stack.back()); + } + } + } + + // 设置后继块中 PHI 节点的 incoming values + for (auto* succ : bb->GetSuccessors()) { + for (size_t i = 0; i < infos.size(); ++i) { + auto phi_it = infos[i].phis.find(succ); + if (phi_it != infos[i].phis.end()) { + phi_it->second->AddIncoming(infos[i].value_stack.back(), bb); + } + } + } + + // 递归遍历支配树子节点 + for (auto* child : dt.GetChildren(bb)) rename(child); + + // 恢复栈 + for (size_t i = 0; i < infos.size(); ++i) { + infos[i].value_stack.resize(saved_sizes[i]); + } + }; + + rename(func.GetEntry()); + + // 删除已提升的 load、store 和 alloca + // 必须先断开 use-def 链再删除,否则其他值的使用列表中会有悬空指针 + for (auto& info : infos) { + for (auto* ld : info.loads) { + ld->SetOperand(0, nullptr); // 断开对 alloca 的引用 + ld->RemoveFromParent(); + } + for (auto* st : info.stores) { + st->SetOperand(0, nullptr); // 断开对 value 的引用 + st->SetOperand(1, nullptr); // 断开对 alloca 的引用 + st->RemoveFromParent(); + } + info.alloca->RemoveFromParent(); + } + + return true; +} + +} // namespace ir diff --git a/src/ir/passes/PassManager.cpp b/src/ir/passes/PassManager.cpp index 044328f..f6d270f 100644 --- a/src/ir/passes/PassManager.cpp +++ b/src/ir/passes/PassManager.cpp @@ -1 +1,89 @@ -// IR Pass 管理骨架。 +// IR Pass 管理:按顺序执行优化遍,支持迭代至不动点 + +#include "ir/IR.h" + +#include +#include + +namespace ir { + +// 前向声明(定义在各 pass 文件中) +extern bool RunMem2Reg(Function& func, Context& ctx); +extern bool RunConstFold(Function& func, Context& ctx); +extern bool RunConstProp(Function& func, Context& ctx); +extern bool RunDCE(Function& func); +extern bool RunCFGSimplify(Function& func, Context& ctx); +extern bool RunCSE(Function& func); + +// 清理 PHI 节点:修复无效的 incoming 值/块引用,补齐缺失的前驱条目 +static void SanitizePhis(Function& func, Context& ctx) { + func.RebuildCFG(); + auto* entry = func.GetEntry(); + if (!entry) return; + + for (auto& bb : func.GetBlocks()) { + // 收集 PHI 节点 + std::vector phis; + for (auto& inst : bb->GetInstructions()) { + if (auto* phi = dynamic_cast(inst.get())) + phis.push_back(phi); + } + if (phis.empty()) continue; + + auto& preds = bb->GetPredecessors(); + Value* undef_val = ctx.GetConstInt(0); + + for (auto* phi : phis) { + // 收集已有的 incoming 块 + std::unordered_set existing; + for (size_t i = 0; i < phi->GetNumIncoming(); ++i) { + existing.insert(phi->GetIncomingBlock(i)); + } + + // 为每个前驱补齐缺失的 incoming(LLVM 要求 PHI 覆盖所有前驱) + for (auto* pred : preds) { + if (!existing.count(pred)) { + phi->AddIncoming(undef_val, pred); + } + } + + // 修复无效的 incoming + for (size_t i = 0; i < phi->GetNumIncoming(); ++i) { + auto* inc_val = phi->GetIncomingValue(i); + auto* inc_bb = phi->GetIncomingBlock(i); + if (!inc_val || !inc_bb) { + phi->SetOperand(i * 2, undef_val); + phi->SetOperand(i * 2 + 1, entry); + } + } + } + } +} + +void RunPasses(Module& module) { + Context& ctx = module.GetContext(); + bool changed = true; + int iteration = 0; + const int kMaxIterations = 10; + + for (auto& func : module.GetFunctions()) { + RunMem2Reg(*func, ctx); + } + + while (changed && iteration < kMaxIterations) { + changed = false; + ++iteration; + + for (auto& func : module.GetFunctions()) { + changed |= RunConstFold(*func, ctx); + changed |= RunConstProp(*func, ctx); + changed |= RunCSE(*func); + // CFGSimplify 在处理 PHI 节点较多的 CFG 时会导致悬空指针 + // 其功能(空块合并 + 不可达块删除)由后续 DCE + SanitizePhis 部分承担 + changed |= RunDCE(*func); + SanitizePhis(*func, ctx); + } + } +} + +} // namespace ir diff --git a/src/main.cpp b/src/main.cpp index 8d5c6ee..d9a2127 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -21,13 +22,20 @@ int main(int argc, char** argv) { return 0; } - auto antlr = ParseFileWithAntlr(opts.input); - bool need_blank_line = false; - if (opts.emit_parse_tree) { - PrintSyntaxTree(antlr.tree, antlr.parser.get(), std::cout); - need_blank_line = true; + // 确定输出流 + std::ofstream ofs; + std::ostream* out = &std::cout; + if (!opts.output.empty()) { + ofs.open(opts.output); + if (!ofs) { + throw std::runtime_error( + FormatError("main", "无法打开输出文件: " + opts.output)); + } + out = &ofs; } + auto antlr = ParseFileWithAntlr(opts.input); + #if !COMPILER_PARSE_ONLY auto* comp_unit = dynamic_cast(antlr.tree); if (!comp_unit) { @@ -36,26 +44,23 @@ int main(int argc, char** argv) { auto sema = RunSema(*comp_unit); auto module = GenerateIR(*comp_unit, sema); + + if (opts.opt) { + ir::RunPasses(*module); + } + if (opts.emit_ir) { ir::IRPrinter printer; - if (need_blank_line) { - std::cout << "\n"; - } - printer.Print(*module, std::cout); - need_blank_line = true; + printer.Print(*module, *out); } if (opts.emit_asm) { - // 修改:支持多函数 auto machine_funcs = mir::LowerToMIR(*module); for (auto& mf : machine_funcs) { mir::RunRegAlloc(*mf); mir::RunFrameLowering(*mf); } - if (need_blank_line) { - std::cout << "\n"; - } - mir::PrintAsm(machine_funcs, std::cout); + mir::PrintAsm(machine_funcs, *out); } #else if (opts.emit_ir || opts.emit_asm) { diff --git a/src/utils/CLI.cpp b/src/utils/CLI.cpp index 21b6d20..9d66c7a 100644 --- a/src/utils/CLI.cpp +++ b/src/utils/CLI.cpp @@ -1,4 +1,6 @@ -// 解析帮助、输入文件和输出阶段选项。 +// 解析命令行: compiler -S -o [-O1] +// 或: compiler -IR -o [-O1] +// 同时兼容 --emit-ir / --emit-asm 旧格式 #include "utils/CLI.h" @@ -15,30 +17,31 @@ CLIOptions ParseCLI(int argc, char** argv) { if (argc <= 1) { throw std::runtime_error(FormatError( "cli", - "用法: compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] ")); + "用法: compiler -S -o [-O1]\n" + " 或: compiler -IR -o [-O1]")); } for (int i = 1; i < argc; ++i) { const char* arg = argv[i]; + if (std::strcmp(arg, "-h") == 0 || std::strcmp(arg, "--help") == 0) { opt.show_help = true; return opt; } - if (std::strcmp(arg, "--emit-parse-tree") == 0) { + // 输出阶段(新格式) + if (std::strcmp(arg, "-S") == 0) { if (!explicit_emit) { - opt.emit_parse_tree = false; opt.emit_ir = false; opt.emit_asm = false; explicit_emit = true; } - opt.emit_parse_tree = true; + opt.emit_asm = true; continue; } - if (std::strcmp(arg, "--emit-ir") == 0) { + if (std::strcmp(arg, "-IR") == 0) { if (!explicit_emit) { - opt.emit_parse_tree = false; opt.emit_ir = false; opt.emit_asm = false; explicit_emit = true; @@ -47,9 +50,9 @@ CLIOptions ParseCLI(int argc, char** argv) { continue; } + // 输出阶段(兼容旧格式) if (std::strcmp(arg, "--emit-asm") == 0) { if (!explicit_emit) { - opt.emit_parse_tree = false; opt.emit_ir = false; opt.emit_asm = false; explicit_emit = true; @@ -58,6 +61,32 @@ CLIOptions ParseCLI(int argc, char** argv) { continue; } + if (std::strcmp(arg, "--emit-ir") == 0) { + if (!explicit_emit) { + opt.emit_ir = false; + opt.emit_asm = false; + explicit_emit = true; + } + opt.emit_ir = true; + continue; + } + + // 优化级别 + if (std::strcmp(arg, "-O1") == 0) { + opt.opt = true; + continue; + } + + // 输出文件 + if (std::strcmp(arg, "-o") == 0) { + if (i + 1 >= argc) { + throw std::runtime_error( + FormatError("cli", "-o 缺少输出文件名")); + } + opt.output = argv[++i]; + continue; + } + if (arg[0] == '-') { throw std::runtime_error( FormatError("cli", std::string("未知参数: ") + arg + @@ -73,11 +102,12 @@ CLIOptions ParseCLI(int argc, char** argv) { if (opt.input.empty() && !opt.show_help) { throw std::runtime_error( - FormatError("cli", "缺少输入文件:请提供 (使用 --help 查看用法)")); + FormatError("cli", "缺少输入文件:请提供 ")); } - if (!opt.emit_parse_tree && !opt.emit_ir && !opt.emit_asm) { - throw std::runtime_error(FormatError( - "cli", "未选择任何输出:请使用 --emit-parse-tree / --emit-ir / --emit-asm")); + if (!explicit_emit) { + // 未显式选择输出阶段时默认输出 IR + opt.emit_ir = true; } + return opt; } diff --git a/src/utils/Log.cpp b/src/utils/Log.cpp index e540ba8..f3c556f 100644 --- a/src/utils/Log.cpp +++ b/src/utils/Log.cpp @@ -50,17 +50,22 @@ void PrintHelp(std::ostream& os) { os << "SysY Compiler\n" << "\n" << "用法:\n" - << " compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] \n" + << " compiler -IR -o [-O1] # 输出 IR\n" + << " compiler -S -o [-O1] # 输出汇编\n" << "\n" << "选项:\n" - << " -h, --help 打印帮助信息并退出\n" - << " --emit-parse-tree 仅在显式模式下启用语法树输出\n" - << " --emit-ir 仅在显式模式下启用 IR 输出\n" - << " --emit-asm 仅在显式模式下启用 AArch64 汇编输出\n" + << " -IR 输出中间代码(IR 文本)\n" + << " -S 输出 AArch64 汇编码\n" + << " -o 输出文件(默认 stdout)\n" + << " -O1 启用 IR 优化(Mem2Reg + 标量优化)\n" + << " -h, --help 打印帮助信息并退出\n" << "\n" - << "说明:\n" - << " - 默认输出 IR\n" - << " - 若使用 --emit-parse-tree/--emit-ir/--emit-asm,则仅输出显式选择的阶段\n" - << " - 可使用重定向写入文件:\n" - << " compiler --emit-asm test/test_case/functional/simple_add.sy > out.s\n"; + << "兼容格式(仍可使用):\n" + << " --emit-ir 同 -IR\n" + << " --emit-asm 同 -S\n" + << "\n" + << "示例:\n" + << " compiler test.sy -IR -o test.ll -O1 # 生成优化 IR\n" + << " compiler test.sy -S -o test.s -O1 # 生成优化汇编\n" + << " compiler test.sy -IR # IR 输出到 stdout\n"; }