diff --git a/doc/Lab6-并行与循环优化.md b/doc/Lab6-并行与循环优化.md index 021fd48..d3e6e6d 100644 --- a/doc/Lab6-并行与循环优化.md +++ b/doc/Lab6-并行与循环优化.md @@ -59,8 +59,8 @@ cmake --build build -j "$(nproc)" ### 7.1 功能回归 ```bash -./scripts/verify_ir.sh test/test_case/functional/simple_add.sy test/test_result/function/ir --run -./scripts/verify_asm.sh test/test_case/functional/simple_add.sy test/test_result/function/asm --run +./scripts/verify_ir.sh test/test_case/performance/2025-MYO-20.sy test/test_result/performance/ir --run +./scripts/verify_asm.sh test/test_case/performance/2025-MYO-20.sy test/test_result/performance/asm --run ``` `--run` 模式下脚本会自动读取同名 `.in`,并将程序输出与退出码和同名 `.out` 比对。 diff --git a/include/mir/MIR.h b/include/mir/MIR.h index bdd211d..72c6a08 100644 --- a/include/mir/MIR.h +++ b/include/mir/MIR.h @@ -44,6 +44,8 @@ enum class Opcode { LoadStackAddr, // 加载栈地址:add x9, x29, #offset(用于数组基址) LoadIndirect, // 间接加载:ldr w8, [x9] StoreIndirect, // 间接存储:str w8, [x9] + LoadIndirectScaled, // 间接加载:ldr w8, [x9, w10, uxtw #2] + StoreIndirectScaled, // 间接存储:str w8, [x9, w10, uxtw #2] LoadGlobal, StoreGlobal, LoadGlobalAddr, // 加载全局变量地址(用于数组) diff --git a/scripts/run_all_tests.sh b/scripts/run_all_tests.sh index 53771a7..d65cc9a 100755 --- a/scripts/run_all_tests.sh +++ b/scripts/run_all_tests.sh @@ -38,6 +38,20 @@ if [[ ! -x "$compiler" ]]; then exit 1 fi +find_tool() { + local name + for name in "$@"; do + if command -v "$name" >/dev/null 2>&1; then + command -v "$name" + return 0 + fi + done + return 1 +} + +LLC_CMD=$(find_tool llc llc-20 || true) +CLANG_CMD=$(find_tool clang clang-20 || true) + total=0 passed=0 failed=0 @@ -76,10 +90,10 @@ run_ir_test() { local emit_ns=$(( $(now_ns) - emit_start_ns )) # 需要 llc + clang - if ! command -v llc >/dev/null 2>&1 || ! command -v clang >/dev/null 2>&1; then + if [[ -z "$LLC_CMD" || -z "$CLANG_CMD" ]]; then RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns )) RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns")" - echo " [SKIP-IR] $sy (缺少 llc/clang)" + echo " [SKIP-IR] $sy (缺少 llc/llc-20 或 clang/clang-20)" return 2 fi @@ -88,16 +102,16 @@ run_ir_test() { local lower_link_start_ns lower_link_start_ns=$(now_ns) - if ! llc -filetype=obj "$out_file" -o "$obj" 2>/dev/null; then + if ! "$LLC_CMD" -filetype=obj "$out_file" -o "$obj" 2>/dev/null; then RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns )) RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns")" - echo " [SKIP-IR] $sy (llc 编译失败)" + echo " [SKIP-IR] $sy ($LLC_CMD 编译失败)" return 2 fi - if ! clang -no-pie "$obj" sylib/sylib.c -o "$exe" -lm -pthread 2>/dev/null; then + if ! "$CLANG_CMD" -no-pie "$obj" sylib/sylib.c -o "$exe" -lm -pthread 2>/dev/null; then RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns )) RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns") lower+link=$(format_ns $(( $(now_ns) - lower_link_start_ns )))" - echo " [SKIP-IR] $sy (clang 链接失败)" + echo " [SKIP-IR] $sy ($CLANG_CMD 链接失败)" return 2 fi local lower_link_ns=$(( $(now_ns) - lower_link_start_ns )) @@ -116,7 +130,7 @@ run_ir_test() { timeout $run_timeout "$exe" > "$stdout_file" 2>/dev/null fi local status=$? - set -e + set +e local run_ns=$(( $(now_ns) - run_start_ns )) RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns )) RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns") lower+link=$(format_ns "$lower_link_ns") run=$(format_ns "$run_ns")" @@ -216,7 +230,7 @@ run_asm_test() { timeout $run_timeout qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe" > "$stdout_file" 2>/dev/null fi local status=$? - set -e + set +e local run_ns=$(( $(now_ns) - run_start_ns )) RUN_LAST_TOTAL_NS=$(( $(now_ns) - start_ns )) RUN_LAST_BREAKDOWN="emit=$(format_ns "$emit_ns") asm+link=$(format_ns "$link_ns") run=$(format_ns "$run_ns")" @@ -251,6 +265,7 @@ run_asm_test() { echo "========================================" echo " Lab4 批量回归测试 (mode: $mode)" echo "========================================" +echo " LLVM tools: $LLC_CMD / $CLANG_CMD" echo "" # 收集所有测试文件 diff --git a/scripts/verify_ir.sh b/scripts/verify_ir.sh index 2d00ff5..3cb42c9 100755 --- a/scripts/verify_ir.sh +++ b/scripts/verify_ir.sh @@ -49,6 +49,20 @@ if [[ ! -x "$compiler" ]]; then exit 1 fi +find_tool() { + local name + for name in "$@"; do + if command -v "$name" >/dev/null 2>&1; then + command -v "$name" + return 0 + fi + done + return 1 +} + +llc_cmd=$(find_tool llc llc-20 || true) +clang_cmd=$(find_tool clang clang-20 || true) + mkdir -p "$out_dir" base=$(basename "$input") stem=${base%.sy} @@ -63,12 +77,12 @@ echo "IR 已生成: $out_file" echo "IR 生成耗时: $(format_ns "$emit_elapsed_ns")" if [[ "$run_exec" == true ]]; then - if ! command -v llc >/dev/null 2>&1; then - echo "未找到 llc,无法运行 IR。请安装 LLVM。" >&2 + if [[ -z "$llc_cmd" ]]; then + echo "未找到 llc 或 llc-20,无法运行 IR。请安装 LLVM。" >&2 exit 1 fi - if ! command -v clang >/dev/null 2>&1; then - echo "未找到 clang,无法链接可执行文件。请安装 LLVM/Clang。" >&2 + if [[ -z "$clang_cmd" ]]; then + echo "未找到 clang 或 clang-20,无法链接可执行文件。请安装 Clang。" >&2 exit 1 fi obj="$out_dir/$stem.o" @@ -76,8 +90,8 @@ if [[ "$run_exec" == true ]]; then stdout_file="$out_dir/$stem.stdout" actual_file="$out_dir/$stem.actual.out" lower_link_start_ns=$(now_ns) - llc -filetype=obj "$out_file" -o "$obj" - clang -no-pie "$obj" sylib/sylib.c -o "$exe" -lm -pthread + "$llc_cmd" -filetype=obj "$out_file" -o "$obj" + "$clang_cmd" -no-pie "$obj" sylib/sylib.c -o "$exe" -lm -pthread lower_link_elapsed_ns=$(( $(now_ns) - lower_link_start_ns )) echo "IR 落地/链接耗时: $(format_ns "$lower_link_elapsed_ns")" echo "运行 $exe ..." diff --git a/src/ir/BasicBlock.cpp b/src/ir/BasicBlock.cpp index 2719577..1161b7d 100644 --- a/src/ir/BasicBlock.cpp +++ b/src/ir/BasicBlock.cpp @@ -100,6 +100,20 @@ PhiInst* BasicBlock::PrependPhi(std::shared_ptr ty, void BasicBlock::RemoveInstruction(Instruction* inst) { if (!inst) return; + if (auto* br = dynamic_cast(inst)) { + auto* target = br->GetTarget(); + RemoveSuccessor(target); + target->RemovePredecessor(this); + } else if (auto* cbr = dynamic_cast(inst)) { + auto* true_bb = cbr->GetTrueBlock(); + auto* false_bb = cbr->GetFalseBlock(); + RemoveSuccessor(true_bb); + true_bb->RemovePredecessor(this); + if (false_bb != true_bb) { + RemoveSuccessor(false_bb); + false_bb->RemovePredecessor(this); + } + } // 清除该指令所有操作数的 use 关系 for (size_t i = 0; i < inst->GetNumOperands(); ++i) { auto* operand = inst->GetOperand(i); diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index df62cfa..ac0cc29 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -68,6 +69,24 @@ static const char* OpcodeToString(Opcode op) { return "?"; } +static const char* BinaryOpcodeToString(Opcode op, const Type& ty) { + const bool is_float = ty.IsFloat32(); + switch (op) { + case Opcode::Add: + return is_float ? "fadd" : "add"; + case Opcode::Sub: + return is_float ? "fsub" : "sub"; + case Opcode::Mul: + return is_float ? "fmul" : "mul"; + case Opcode::Div: + return is_float ? "fdiv" : "sdiv"; + case Opcode::Mod: + return "srem"; + default: + return OpcodeToString(op); + } +} + static const char* CmpOpToString(CmpOp op) { switch (op) { case CmpOp::Eq: @@ -86,10 +105,44 @@ static const char* CmpOpToString(CmpOp op) { return "?"; } +static const char* CmpOpcodeToString(const Type& ty) { + return ty.IsFloat32() ? "fcmp" : "icmp"; +} + +static const char* FloatCmpOpToString(CmpOp op) { + switch (op) { + case CmpOp::Eq: + return "oeq"; + case CmpOp::Ne: + return "one"; + case CmpOp::Lt: + return "olt"; + case CmpOp::Le: + return "ole"; + case CmpOp::Gt: + return "ogt"; + case CmpOp::Ge: + return "oge"; + } + return "?"; +} + +static std::string FloatToString(float value) { + double widened = static_cast(value); + std::uint64_t bits = 0; + std::memcpy(&bits, &widened, sizeof(bits)); + std::ostringstream oss; + oss << "0x" << std::uppercase << std::hex << bits; + return oss.str(); +} + static std::string ValueToString(const Value* v) { if (auto* ci = dynamic_cast(v)) { return std::to_string(ci->GetValue()); } + if (auto* cf = dynamic_cast(v)) { + return FloatToString(cf->GetValue()); + } if (auto* gv = dynamic_cast(v)) { return "@" + gv->GetName(); } @@ -102,6 +155,20 @@ static std::string ValueToString(const Value* v) { return v ? v->GetName() : ""; } +static std::string CmpBoolName(const CmpInst* cmp) { + return cmp->GetName() + ".cmp"; +} + +static std::string BranchCondToString(const Value* v) { + if (auto* cmp = dynamic_cast(v)) { + return CmpBoolName(cmp); + } + if (auto* ci = dynamic_cast(v)) { + return ci->GetValue() == 0 ? "false" : "true"; + } + return ValueToString(v); +} + void IRPrinter::Print(const Module& module, std::ostream& os) { // 先打印全局变量 for (const auto& gv : module.GetGlobalVars()) { @@ -115,7 +182,8 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { std::int32_t bits = static_cast(gv->GetInitValue()); float fval = 0.0f; std::memcpy(&fval, &bits, sizeof(fval)); - os << "@" << gv->GetName() << " = global float " << fval << "\n"; + os << "@" << gv->GetName() << " = global float " + << FloatToString(fval) << "\n"; } else { os << "@" << gv->GetName() << " = global i32 " << gv->GetInitValue() << "\n"; @@ -163,7 +231,7 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { case Opcode::Mod: { auto* bin = static_cast(inst); os << " " << bin->GetName() << " = " - << OpcodeToString(bin->GetOpcode()) << " " + << BinaryOpcodeToString(bin->GetOpcode(), *bin->GetLhs()->GetType()) << " " << TypeToString(*bin->GetLhs()->GetType()) << " " << ValueToString(bin->GetLhs()) << ", " << ValueToString(bin->GetRhs()) << "\n"; @@ -171,11 +239,16 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { } case Opcode::Cmp: { auto* cmp = static_cast(inst); - os << " " << cmp->GetName() << " = " << OpcodeToString(cmp->GetOpcode()) - << " " << CmpOpToString(cmp->GetCmpOp()) << " " + const bool is_float_cmp = cmp->GetLhs()->GetType()->IsFloat32(); + os << " " << CmpBoolName(cmp) << " = " + << CmpOpcodeToString(*cmp->GetLhs()->GetType()) + << " " << (is_float_cmp ? FloatCmpOpToString(cmp->GetCmpOp()) + : CmpOpToString(cmp->GetCmpOp())) << " " << TypeToString(*cmp->GetLhs()->GetType()) << " " << ValueToString(cmp->GetLhs()) << ", " << ValueToString(cmp->GetRhs()) << "\n"; + os << " " << cmp->GetName() << " = zext i1 " + << CmpBoolName(cmp) << " to i32\n"; break; } case Opcode::Cast: { @@ -222,7 +295,7 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { } case Opcode::CondBr: { auto* cbr = static_cast(inst); - os << " br i1 " << ValueToString(cbr->GetCond()) << ", label %" + os << " br i1 " << BranchCondToString(cbr->GetCond()) << ", label %" << cbr->GetTrueBlock()->GetName() << ", label %" << cbr->GetFalseBlock()->GetName() << "\n"; break; diff --git a/src/ir/passes/CFGSimplify.cpp b/src/ir/passes/CFGSimplify.cpp index 10e2b79..1eaa719 100644 --- a/src/ir/passes/CFGSimplify.cpp +++ b/src/ir/passes/CFGSimplify.cpp @@ -99,6 +99,11 @@ bool RunCFGSimplify(Function& func, Context& ctx) { // 从后继的前驱列表中移除 for (auto* succ : bb->GetSuccessors()) { succ->RemovePredecessor(bb); + for (auto& inst_ptr : succ->MutableInstructions()) { + auto* phi = dynamic_cast(inst_ptr.get()); + if (!phi) break; + phi->RemoveIncomingBlock(bb); + } } // 清除块中所有指令的 use 关系 std::vector all_insts; diff --git a/src/ir/passes/LICM.cpp b/src/ir/passes/LICM.cpp index 4561e33..808294f 100644 --- a/src/ir/passes/LICM.cpp +++ b/src/ir/passes/LICM.cpp @@ -57,6 +57,15 @@ bool IsSupportedInvariantOpcode(Opcode op) { } } +bool IsCommutativeExpr(Instruction* inst) { + if (!inst || inst->GetNumOperands() != 2) return false; + if (inst->GetOpcode() == Opcode::Add || inst->GetOpcode() == Opcode::Mul) { + return true; + } + auto* cmp = dynamic_cast(inst); + return cmp && (cmp->GetCmpOp() == CmpOp::Eq || cmp->GetCmpOp() == CmpOp::Ne); +} + bool IsLoopInvariantValue(Value* value, analysis::Loop* loop, const std::unordered_set& invariant) { if (!value) return false; @@ -126,6 +135,10 @@ ExprKey MakeExprKey(Instruction* inst) { for (size_t i = 0; i < inst->GetNumOperands(); ++i) { key.operands.push_back(inst->GetOperand(i)); } + if (IsCommutativeExpr(inst) && + std::less()(key.operands[1], key.operands[0])) { + std::swap(key.operands[0], key.operands[1]); + } return key; } diff --git a/src/ir/passes/LoopFission.cpp b/src/ir/passes/LoopFission.cpp index 8833999..a42e0ea 100644 --- a/src/ir/passes/LoopFission.cpp +++ b/src/ir/passes/LoopFission.cpp @@ -43,6 +43,36 @@ bool DependsOnAny(Instruction* inst, const std::unordered_set& def return false; } +void CollectMemoryBases(const std::vector& group, + std::unordered_set* loads, + std::unordered_set* stores) { + for (auto* inst : group) { + if (auto* load = dynamic_cast(inst)) { + loads->insert(StripPointerBase(load->GetPtr())); + } else if (auto* store = dynamic_cast(inst)) { + stores->insert(StripPointerBase(store->GetPtr())); + } + } +} + +bool HasCrossMemoryDependence(const std::vector& group1, + const std::vector& group2) { + std::unordered_set loads1; + std::unordered_set stores1; + std::unordered_set loads2; + std::unordered_set stores2; + CollectMemoryBases(group1, &loads1, &stores1); + CollectMemoryBases(group2, &loads2, &stores2); + + for (auto* base : stores1) { + if (stores2.count(base) != 0 || loads2.count(base) != 0) return true; + } + for (auto* base : stores2) { + if (loads1.count(base) != 0) return true; + } + return false; +} + bool RunFissionOnLoop(Function& func, const CanonicalLoopMatch& match, Context& ctx) { if (!IsFissionCandidate(match)) return false; @@ -93,6 +123,7 @@ bool RunFissionOnLoop(Function& func, const CanonicalLoopMatch& match, for (auto* inst : group1) { if (DependsOnAny(inst, group2_defs)) return false; } + if (HasCrossMemoryDependence(group1, group2)) return false; auto* original_exit = match.exit; std::string block_suffix = ctx.NextTemp(); diff --git a/src/ir/passes/LoopParallelize.cpp b/src/ir/passes/LoopParallelize.cpp index ab45a97..c51c750 100644 --- a/src/ir/passes/LoopParallelize.cpp +++ b/src/ir/passes/LoopParallelize.cpp @@ -349,6 +349,7 @@ bool BuildGuardedFillCandidate(Function& func, analysis::Loop* loop, bool BuildParallelCandidate(Function& func, analysis::Loop* loop, ParallelLoopCandidate* out) { if (BuildGuardedFillCandidate(func, loop, out)) return true; + if (!loop || !loop->IsParallelCandidate()) return false; auto match = MatchCanonicalLoop(loop); if (!match.has_value()) return false; diff --git a/src/ir/passes/StrengthReduction.cpp b/src/ir/passes/StrengthReduction.cpp index ba211f2..fa2edac 100644 --- a/src/ir/passes/StrengthReduction.cpp +++ b/src/ir/passes/StrengthReduction.cpp @@ -31,15 +31,21 @@ int ExtractScale(BinaryInst* mul, PhiInst* iv) { return 0; } -bool ReplaceMulWithRecurrence(Function& func, const CanonicalLoopMatch& match, - BinaryInst* mul, Context& ctx) { - (void)func; - const int scale = ExtractScale(mul, match.induction.phi); - if (scale == 0) return false; - if (match.latch != mul->GetParent() && - !match.loop->Contains(mul->GetParent())) { - return false; - } +bool HasConstantScale(BinaryInst* mul, PhiInst* iv) { + if (!mul || !iv) return false; + return (mul->GetLhs() == iv && + dynamic_cast(mul->GetRhs()) != nullptr) || + (mul->GetRhs() == iv && + dynamic_cast(mul->GetLhs()) != nullptr); +} + +Value* GetOrCreateScaledRecurrence( + const CanonicalLoopMatch& match, int scale, Context& ctx, + std::unordered_map& recurrence_by_scale) { + if (scale == 0) return ctx.GetConstInt(0); + if (scale == 1) return match.induction.phi; + auto it = recurrence_by_scale.find(scale); + if (it != recurrence_by_scale.end()) return it->second; auto* init_scale = InsertBinaryBeforeTerminator(match.preheader, Opcode::Mul, @@ -49,18 +55,26 @@ bool ReplaceMulWithRecurrence(Function& func, const CanonicalLoopMatch& match, auto* sr_phi = match.header->PrependPhi(Type::GetInt32Type(), ctx.NextTemp()); - auto* step_scale = - InsertBinaryBeforeTerminator(match.latch, Opcode::Mul, - ctx.GetConstInt(match.induction.step), - ctx.GetConstInt(scale), ctx.NextTemp()); auto* sr_next = - InsertBinaryBeforeTerminator(match.latch, Opcode::Add, sr_phi, step_scale, + InsertBinaryBeforeTerminator(match.latch, Opcode::Add, sr_phi, + ctx.GetConstInt(match.induction.step * scale), ctx.NextTemp()); sr_phi->AddIncoming(init_scale, match.preheader); sr_phi->AddIncoming(sr_next, match.latch); + recurrence_by_scale.emplace(scale, sr_phi); + return sr_phi; +} + +bool ReplaceMulWithRecurrence( + const CanonicalLoopMatch& match, BinaryInst* mul, Context& ctx, + std::unordered_map& recurrence_by_scale) { + const int scale = ExtractScale(mul, match.induction.phi); + auto* replacement = + GetOrCreateScaledRecurrence(match, scale, ctx, recurrence_by_scale); + if (!replacement) return false; - mul->ReplaceAllUsesWith(sr_phi); + mul->ReplaceAllUsesWith(replacement); mul->GetParent()->RemoveInstruction(mul); return true; } @@ -96,15 +110,16 @@ bool RunStrengthReduction(Function& func, Context& ctx) { auto* inst = inst_ptr.get(); if (IsStrengthReductionCandidate(inst, match->induction.phi)) { auto* mul = static_cast(inst); - if (ExtractScale(mul, match->induction.phi) != 0) { + if (HasConstantScale(mul, match->induction.phi)) { candidates.push_back(mul); } } } } + std::unordered_map recurrence_by_scale; for (auto* mul : candidates) { - changed |= ReplaceMulWithRecurrence(func, *match, mul, ctx); + changed |= ReplaceMulWithRecurrence(*match, mul, ctx, recurrence_by_scale); } } diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index 1d84451..634630b 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -242,6 +242,20 @@ void PrintAsm(const MachineModule& module, std::ostream& os) { << PhysRegName(ops.at(1).GetReg()) << "]\n"; break; } + case Opcode::LoadIndirectScaled: { + // ops: wN, xM, wK + os << " ldr " << PhysRegName(ops.at(0).GetReg()) << ", [" + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << ", uxtw #2]\n"; + break; + } + case Opcode::StoreIndirectScaled: { + // ops: wN, xM, wK + os << " str " << PhysRegName(ops.at(0).GetReg()) << ", [" + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << ", uxtw #2]\n"; + break; + } case Opcode::LoadGlobal: { // adrp x9, global_var // add x9, x9, :lo12:global_var diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index c5c395a..c79d7f8 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -15,15 +15,16 @@ namespace { using ValueSlotMap = std::unordered_map; -// GEP 结果:(base_slot_index, byte_offset, global_symbol) +// GEP 结果:(base_slot_index, byte_offset, global_symbol, index_value) // - base_slot >= 0: 本地数组,base_slot 是栈槽索引 // - base_slot = -1: 全局数组,global_symbol 是全局变量名 // - byte_offset >= 0: 常量索引 -// - byte_offset < 0: 变量索引,编码为 -1 - index_slot +// - byte_offset < 0: 变量索引,index_value 是原始 IR 下标值 struct GepInfo { int base_slot; int byte_offset; std::string global_symbol; + const ir::Value* index_value = nullptr; }; using GepMap = std::unordered_map; @@ -176,7 +177,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, if (auto* const_index = dynamic_cast(index)) { // 常量索引:计算地址并存储 int byte_offset = const_index->GetValue() * 4; - geps.emplace(&inst, GepInfo{-1, byte_offset, gv->GetName()}); + geps.emplace(&inst, GepInfo{-1, byte_offset, gv->GetName(), nullptr}); if (ptr_slot >= 0) { // 计算地址:x9 = &global_array + offset @@ -188,16 +189,11 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, } } else { // 变量索引 - int index_slot = function.CreateFrameIndex(); - EmitValueToReg(index, PhysReg::W8, slots, block); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(index_slot)}); - geps.emplace(&inst, GepInfo{-1, -1 - index_slot, gv->GetName()}); + geps.emplace(&inst, GepInfo{-1, -1, gv->GetName(), index}); if (ptr_slot >= 0) { // 计算地址:x9 = &global_array + (w10 * 4) - block.Append(Opcode::LoadStack, - {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); + EmitValueToReg(index, PhysReg::W10, slots, block); EmitLslBy2(PhysReg::W10, block); block.Append(Opcode::LoadGlobalAddr, {Operand::Reg(PhysReg::X9), Operand::Symbol(gv->GetName())}); @@ -238,17 +234,11 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, {Operand::Reg(PhysReg::X9), Operand::FrameIndex(ptr_slot)}); } else { // 变量索引 - int index_slot = function.CreateFrameIndex(); - EmitValueToReg(index, PhysReg::W8, slots, block); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(index_slot)}); - // x9 = 从栈加载指针 block.Append(Opcode::LoadStack, {Operand::Reg(PhysReg::X9), Operand::FrameIndex(base_it->second)}); // w10 = index * 4 - block.Append(Opcode::LoadStack, - {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); + EmitValueToReg(index, PhysReg::W10, slots, block); EmitLslBy2(PhysReg::W10, block); // x9 = x9 + uxtw(w10) block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9), @@ -268,7 +258,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, // 检查是否是常量索引 if (auto* const_index = dynamic_cast(index)) { int byte_offset = const_index->GetValue() * 4; - geps.emplace(&inst, GepInfo{base_it->second, byte_offset, ""}); + geps.emplace(&inst, GepInfo{base_it->second, byte_offset, "", nullptr}); if (ptr_slot >= 0) { // 计算地址:x9 = &array_base + byte_offset @@ -280,16 +270,11 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, } } else { // 变量索引 - int index_slot = function.CreateFrameIndex(); - EmitValueToReg(index, PhysReg::W8, slots, block); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(index_slot)}); - geps.emplace(&inst, GepInfo{base_it->second, -1 - index_slot, ""}); + geps.emplace(&inst, GepInfo{base_it->second, -1, "", index}); if (ptr_slot >= 0) { // 计算地址:x9 = x29 + base_offset + (w10 * 4) - block.Append(Opcode::LoadStack, - {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); + EmitValueToReg(index, PhysReg::W10, slots, block); EmitLslBy2(PhysReg::W10, block); block.Append(Opcode::LoadStackAddr, {Operand::Reg(PhysReg::X9), Operand::FrameIndex(base_it->second)}); @@ -330,22 +315,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, {Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)}); } else { // 变量索引:global_array[var_idx] - int index_slot = -1 - gep_info.byte_offset; - // 1. 加载 index(4字节 W 寄存器) - block.Append(Opcode::LoadStack, - {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); - // 2. index * 4 - EmitLslBy2(PhysReg::W10, block); - // 3. 获取全局数组基址 + EmitValueToReg(gep_info.index_value, PhysReg::W10, slots, block); block.Append(Opcode::LoadGlobalAddr, {Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)}); - // 4. x9 + w10, uxtw(零扩展 W 寄存器后加到 X 寄存器) - block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::W10)}); - // 5. 存储 - block.Append(Opcode::StoreIndirect, - {Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)}); + block.Append(Opcode::StoreIndirectScaled, + {Operand::Reg(src_reg), Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::W10)}); } } else if (gep_info.byte_offset >= 0) { // 本地数组,常量索引 @@ -355,18 +330,13 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, Operand::Imm(gep_info.byte_offset)}); } else { // 本地数组,变量索引 - int index_slot = -1 - gep_info.byte_offset; - block.Append(Opcode::LoadStack, - {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); - EmitLslBy2(PhysReg::W10, block); + EmitValueToReg(gep_info.index_value, PhysReg::W10, slots, block); block.Append(Opcode::LoadStackAddr, {Operand::Reg(PhysReg::X9), Operand::FrameIndex(gep_info.base_slot)}); - block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::W10)}); - block.Append(Opcode::StoreIndirect, - {Operand::Reg(src_reg), Operand::Reg(PhysReg::X9)}); + block.Append(Opcode::StoreIndirectScaled, + {Operand::Reg(src_reg), Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::W10)}); } return; } @@ -426,17 +396,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, {Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)}); } else { // 变量索引 - int index_slot = -1 - gep_info.byte_offset; - block.Append(Opcode::LoadStack, - {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); - EmitLslBy2(PhysReg::W10, block); + EmitValueToReg(gep_info.index_value, PhysReg::W10, slots, block); block.Append(Opcode::LoadGlobalAddr, {Operand::Reg(PhysReg::X9), Operand::Symbol(gep_info.global_symbol)}); - block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::W10)}); - block.Append(Opcode::LoadIndirect, - {Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)}); + block.Append(Opcode::LoadIndirectScaled, + {Operand::Reg(value_reg), Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::W10)}); } } else if (gep_info.byte_offset >= 0) { // 本地数组,常量索引 @@ -446,18 +411,13 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, Operand::Imm(gep_info.byte_offset)}); } else { // 本地数组,变量索引 - int index_slot = -1 - gep_info.byte_offset; - block.Append(Opcode::LoadStack, - {Operand::Reg(PhysReg::W10), Operand::FrameIndex(index_slot)}); - EmitLslBy2(PhysReg::W10, block); + EmitValueToReg(gep_info.index_value, PhysReg::W10, slots, block); block.Append(Opcode::LoadStackAddr, {Operand::Reg(PhysReg::X9), Operand::FrameIndex(gep_info.base_slot)}); - block.Append(Opcode::AddRR_UXTW, {Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::X9), - Operand::Reg(PhysReg::W10)}); - block.Append(Opcode::LoadIndirect, - {Operand::Reg(value_reg), Operand::Reg(PhysReg::X9)}); + block.Append(Opcode::LoadIndirectScaled, + {Operand::Reg(value_reg), Operand::Reg(PhysReg::X9), + Operand::Reg(PhysReg::W10)}); } block.Append(Opcode::StoreStack, @@ -1120,7 +1080,10 @@ std::unique_ptr LowerToMIR(const ir::Module& module) { } if (!can_emit) continue; // 跳过无法发射的值 - PhysReg tmp = phi_info.is_float ? PhysReg::S8 : PhysReg::W8; + // Phi edge stores may be inserted immediately before fused cbz/cbnz + // branches, which use w8 as the condition register. Use a different + // scratch register so edge materialization does not clobber the branch. + PhysReg tmp = phi_info.is_float ? PhysReg::S10 : PhysReg::W10; MachineBasicBlock tmp_block("__phi_tmp__"); EmitValueToReg(val, tmp, slots, tmp_block); tmp_block.Append(Opcode::StoreStack, diff --git a/src/mir/RegAlloc.cpp b/src/mir/RegAlloc.cpp index c432d74..6609495 100644 --- a/src/mir/RegAlloc.cpp +++ b/src/mir/RegAlloc.cpp @@ -299,10 +299,19 @@ void PromoteToVRegs(MachineFunction& func, if (!ops.empty()) PromoteUse(ops[0]); if (ops.size() > 1) PromoteUse(ops[1]); break; + case Opcode::StoreIndirectScaled: + if (!ops.empty()) PromoteUse(ops[0]); + if (ops.size() > 1) PromoteUse(ops[1]); + if (ops.size() > 2) PromoteUse(ops[2]); + break; case Opcode::LoadIndirect: if (ops.size() > 1) PromoteUse(ops[1]); break; + case Opcode::LoadIndirectScaled: + if (ops.size() > 1) PromoteUse(ops[1]); + if (ops.size() > 2) PromoteUse(ops[2]); + break; case Opcode::AddRI: case Opcode::SubRI: case Opcode::LsrRI: case Opcode::LslRI: @@ -334,6 +343,7 @@ void PromoteToVRegs(MachineFunction& func, case Opcode::MovReg: case Opcode::FMovReg: case Opcode::LoadStack: case Opcode::LoadStackOffset: case Opcode::LoadStackAddr: case Opcode::LoadIndirect: + case Opcode::LoadIndirectScaled: case Opcode::LoadGlobal: case Opcode::LoadGlobalAddr: case Opcode::AddRI: case Opcode::SubRI: case Opcode::AddRR: case Opcode::AddRR_UXTW: case Opcode::SubRR: @@ -382,9 +392,18 @@ void GetVRegDefsUses(const MachineInstr& inst, if (!ops.empty() && ops[0].IsVReg()) uses.insert(ops[0].GetVRegId()); if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId()); break; + case Opcode::StoreIndirectScaled: + if (!ops.empty() && ops[0].IsVReg()) uses.insert(ops[0].GetVRegId()); + if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId()); + if (ops.size() > 2 && ops[2].IsVReg()) uses.insert(ops[2].GetVRegId()); + break; case Opcode::LoadIndirect: if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId()); break; + case Opcode::LoadIndirectScaled: + if (ops.size() > 1 && ops[1].IsVReg()) uses.insert(ops[1].GetVRegId()); + if (ops.size() > 2 && ops[2].IsVReg()) uses.insert(ops[2].GetVRegId()); + break; case Opcode::AddRI: case Opcode::SubRI: case Opcode::LsrRI: case Opcode::LslRI: case Opcode::FSqrtRR: case Opcode::SIToFP: case Opcode::FPToSI: @@ -419,6 +438,7 @@ void GetVRegDefsUses(const MachineInstr& inst, case Opcode::MovReg: case Opcode::FMovReg: case Opcode::LoadStack: case Opcode::LoadStackOffset: case Opcode::LoadStackAddr: case Opcode::LoadIndirect: + case Opcode::LoadIndirectScaled: case Opcode::LoadGlobal: case Opcode::LoadGlobalAddr: case Opcode::AddRI: case Opcode::SubRI: case Opcode::AddRR: case Opcode::AddRR_UXTW: case Opcode::SubRR: @@ -514,6 +534,14 @@ void GetPhysRegDefsUses(const MachineInstr& inst, if (ops.size() > 1 && ops[1].IsReg() && !ShouldPromote(ops[1].GetReg())) phys_uses.insert(ops[1].GetReg()); break; + case Opcode::StoreIndirectScaled: + if (!ops.empty() && ops[0].IsReg() && !ShouldPromote(ops[0].GetReg())) + phys_uses.insert(ops[0].GetReg()); + if (ops.size() > 1 && ops[1].IsReg() && !ShouldPromote(ops[1].GetReg())) + phys_uses.insert(ops[1].GetReg()); + if (ops.size() > 2 && ops[2].IsReg() && !ShouldPromote(ops[2].GetReg())) + phys_uses.insert(ops[2].GetReg()); + break; case Opcode::AddRI: case Opcode::SubRI: case Opcode::LsrRI: case Opcode::LslRI: case Opcode::FSqrtRR: case Opcode::SIToFP: case Opcode::FPToSI: @@ -545,6 +573,12 @@ void GetPhysRegDefsUses(const MachineInstr& inst, if (ops.size() > 1 && ops[1].IsReg() && !ShouldPromote(ops[1].GetReg())) phys_uses.insert(ops[1].GetReg()); break; + case Opcode::LoadIndirectScaled: + if (ops.size() > 1 && ops[1].IsReg() && !ShouldPromote(ops[1].GetReg())) + phys_uses.insert(ops[1].GetReg()); + if (ops.size() > 2 && ops[2].IsReg() && !ShouldPromote(ops[2].GetReg())) + phys_uses.insert(ops[2].GetReg()); + break; default: break; } @@ -555,6 +589,7 @@ void GetPhysRegDefsUses(const MachineInstr& inst, case Opcode::MovReg: case Opcode::FMovReg: case Opcode::LoadStack: case Opcode::LoadStackOffset: case Opcode::LoadStackAddr: case Opcode::LoadIndirect: + case Opcode::LoadIndirectScaled: case Opcode::LoadGlobal: case Opcode::LoadGlobalAddr: case Opcode::AddRI: case Opcode::SubRI: case Opcode::AddRR: case Opcode::AddRR_UXTW: case Opcode::SubRR: diff --git a/src/mir/passes/Peephole.cpp b/src/mir/passes/Peephole.cpp index f17077d..a1408c9 100644 --- a/src/mir/passes/Peephole.cpp +++ b/src/mir/passes/Peephole.cpp @@ -74,6 +74,7 @@ std::optional GetWrittenReg(const MachineInstr& inst) { case Opcode::LoadStackOffset: case Opcode::LoadStackAddr: case Opcode::LoadIndirect: + case Opcode::LoadIndirectScaled: case Opcode::LoadGlobal: case Opcode::LoadGlobalAddr: case Opcode::AddRI: @@ -119,6 +120,8 @@ bool ReadsReg(const MachineInstr& inst, PhysReg reg) { case Opcode::FPToSI: case Opcode::LslRI: return reads_operand(1); + case Opcode::LoadIndirectScaled: + return reads_operand(1) || reads_operand(2); case Opcode::AddRR: case Opcode::AddRR_UXTW: case Opcode::SubRR: @@ -141,6 +144,8 @@ bool ReadsReg(const MachineInstr& inst, PhysReg reg) { return reads_operand(0); case Opcode::StoreIndirect: return reads_operand(0) || reads_operand(1); + case Opcode::StoreIndirectScaled: + return reads_operand(0) || reads_operand(1) || reads_operand(2); case Opcode::Ret: return false; default: @@ -158,6 +163,7 @@ bool CanElideIfOverwritten(const MachineInstr& inst) { case Opcode::LoadStackOffset: case Opcode::LoadStackAddr: case Opcode::LoadIndirect: + case Opcode::LoadIndirectScaled: case Opcode::LoadGlobal: case Opcode::LoadGlobalAddr: case Opcode::AddRI: @@ -184,6 +190,7 @@ bool CanElideIfOverwritten(const MachineInstr& inst) { bool IsMemoryClobber(const MachineInstr& inst) { switch (inst.GetOpcode()) { case Opcode::StoreIndirect: + case Opcode::StoreIndirectScaled: case Opcode::StoreGlobal: case Opcode::Bl: return true; diff --git a/优化方案.md b/优化方案.md new file mode 100644 index 0000000..4c3a037 --- /dev/null +++ b/优化方案.md @@ -0,0 +1,41 @@ +# Lab6 循环优化推进记录 + +## 1. 循环不变代码外提(LICM) + +- 方案:对可交换的不变表达式(`add`、`mul`、`eq/ne cmp`)做操作数规范化,外提到 preheader 后能复用 `a + b` 与 `b + a` 这类等价表达式。 +- 代码位置:`src/ir/passes/LICM.cpp`,新增 `IsCommutativeExpr`,并在 `MakeExprKey` 中规范化 key。 + +## 2. 强度削弱 + +- 方案:同一循环内多个 `iv * C` 共享同一个尺度递推 phi,`iv * 0` 与 `iv * 1` 直接替换为常量或归纳变量;步长增量直接常量化为 `step * C`,减少 latch 中多余乘法。 +- 代码位置:`src/ir/passes/StrengthReduction.cpp`,新增 `GetOrCreateScaledRecurrence`,按 scale 复用递推值。 + +## 3. 循环展开 + +- 方案:补强 CFG 边维护的基础设施,删除旧 terminator 时同步清理旧前驱/后继边,避免循环展开、习语替换等 pass 改写分支后留下过期 CFG 边。 +- 代码位置:`src/ir/BasicBlock.cpp`,在 `BasicBlock::RemoveInstruction` 中处理 `br/condbr` 的 CFG 边。 + +## 4. 循环分裂 + +- 方案:分裂前额外检查两组语句的 load/store 基址依赖;只要跨组存在 store-store 或 store-load 同基址关系,就放弃分裂,保证循环拆分不改变内存可见顺序。 +- 代码位置:`src/ir/passes/LoopFission.cpp`,新增 `CollectMemoryBases` 与 `HasCrossMemoryDependence`。 + +## 5. 循环并行化 + +- 方案:普通点式循环并行化先经过 `LoopInfo::IsParallelCandidate` 过滤,再进入 worker 抽取逻辑;保留已有 guarded fill 特例,减少非候选循环误判和无效分析。 +- 代码位置:`src/ir/passes/LoopParallelize.cpp`,在 `BuildParallelCandidate` 中加入候选过滤。 + +## 6. 验证中发现的后端正确性修复 + +- 方案:phi 消除插入边上赋值时不再使用 `w8/s8` 作为临时寄存器,避免覆盖 `cbz/cbnz` 融合分支刚准备好的条件寄存器。 +- 代码位置:`src/mir/Lowering.cpp`,phi elimination 的临时寄存器改为 `w10/s10`。 + +## 7. ASM 后端数组访存优化 + +- 方案:GEP 变量下标不再先写入临时栈槽,直接保存原始 IR 下标值,访存时发射到 `w10`;同时把变量下标数组访存改为 AArch64 scaled addressing(`[base, index, uxtw #2]`),减少热循环中的 `store/load index`、`lsl` 与 `add`。 +- 代码位置:`src/mir/Lowering.cpp`、`include/mir/MIR.h`、`src/mir/AsmPrinter.cpp`、`src/mir/RegAlloc.cpp`、`src/mir/passes/Peephole.cpp`。 + +## 回归编译错误修复 + +- 方案:IR 文本打印补齐 `ConstantFloat` 的 LLVM 十六进制浮点常量格式,把浮点二元运算/比较输出为 `fadd/fsub/fmul/fdiv`、`fcmp o*`;比较结果打印为 `i1` 后再 `zext` 回内部 `i32` 约定;删除不可达块时同步清理后继 phi 入边;scaled 间接访存补齐 peephole 的读寄存器与内存 clobber 描述。 +- 代码位置:`src/ir/IRPrinter.cpp`、`src/ir/passes/CFGSimplify.cpp`、`src/mir/passes/Peephole.cpp`。