From 184d5c7cb5abd425e7e42a9adc1c5bbed0500e61 Mon Sep 17 00:00:00 2001 From: pu9sp4t32 <2931381969@qq.com> Date: Thu, 28 May 2026 11:34:18 +0800 Subject: [PATCH] =?UTF-8?q?lab4=E5=AE=9E=E9=AA=8C=EF=BC=8C=E6=89=B9?= =?UTF-8?q?=E9=87=8F=E6=B5=8B=E8=AF=95=E8=84=9A=E6=9C=AC=E5=B7=B2=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- $null | 0 CMakeLists.txt | 3 +- doc/Lab4-基本标量优化.md | 5 + include/ir/IR.h | 26 ++ include/utils/CLI.h | 1 + src/ir/BasicBlock.cpp | 50 ++++ src/ir/Function.cpp | 4 + src/ir/Instruction.cpp | 94 +++++++ src/ir/passes/CFGSimplify.cpp | 137 +++++++++- src/ir/passes/CSE.cpp | 98 ++++++- src/ir/passes/ConstFold.cpp | 241 +++++++++++++++++- src/ir/passes/ConstProp.cpp | 80 +++++- src/ir/passes/DCE.cpp | 57 ++++- src/ir/passes/Mem2Reg.cpp | 111 +++++++- src/ir/passes/PassManager.cpp | 24 +- src/main.cpp | 3 + src/mir/LLVMAsmBackend.cpp | 43 +++- src/utils/CLI.cpp | 7 +- src/utils/Log.cpp | 3 +- sylib/sylib.c | 12 +- .../runtime/src/atn/ProfilingATNSimulator.cpp | 2 + 21 files changed, 969 insertions(+), 32 deletions(-) create mode 100644 $null diff --git a/$null b/$null new file mode 100644 index 0000000..e69de29 diff --git a/CMakeLists.txt b/CMakeLists.txt index 3ac5b22..5c4c727 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,7 +33,7 @@ target_include_directories(build_options INTERFACE option(COMPILER_ENABLE_WARNINGS "Enable common compiler warnings" ON) if(COMPILER_ENABLE_WARNINGS) if(MSVC) - target_compile_options(build_options INTERFACE /W4) + target_compile_options(build_options INTERFACE /W4 /utf-8) else() target_compile_options(build_options INTERFACE -Wall -Wextra -Wpedantic) endif() @@ -52,6 +52,7 @@ set(ANTLR4_RUNTIME_SRC_DIR "${PROJECT_SOURCE_DIR}/third_party/antlr4-runtime-4.1 add_library(antlr4_runtime STATIC) target_compile_features(antlr4_runtime PUBLIC cxx_std_17) +target_compile_definitions(antlr4_runtime PUBLIC ANTLR4CPP_STATIC) target_include_directories(antlr4_runtime PUBLIC "${ANTLR4_RUNTIME_SRC_DIR}" diff --git a/doc/Lab4-基本标量优化.md b/doc/Lab4-基本标量优化.md index 1b22974..a65a52d 100644 --- a/doc/Lab4-基本标量优化.md +++ b/doc/Lab4-基本标量优化.md @@ -109,3 +109,8 @@ cmake --build build -j "$(nproc)" 目标:脚本自动读取同名 `.in`,并将程序输出与退出码和同名 `.out` 比对,确保优化后程序行为与优化前保持一致。 完成 Lab4 后,应对 `test/test_case` 下全部测试用例逐个回归;如有需要,也可以自行编写批量测试脚本统一执行。 + + +批量测试脚本: +bash test/test_result/lab4_batch/run_all.sh + \ No newline at end of file diff --git a/include/ir/IR.h b/include/ir/IR.h index 75ce986..8922d56 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -228,10 +228,13 @@ class User : public Value { size_t GetNumOperands() const; Value* GetOperand(size_t index) const; void SetOperand(size_t index, Value* value); + void RemoveOperand(size_t index); protected: // 统一的 operand 入口。 void AddOperand(Value* value); + virtual void OnOperandChanged(size_t index, Value* value); + virtual void OnOperandRemoving(size_t index); private: std::vector operands_; @@ -360,6 +363,8 @@ class CallInst : public Instruction { const std::vector& GetArgs() const { return args_; } private: + void OnOperandChanged(size_t index, Value* value) override; + void OnOperandRemoving(size_t index) override; std::vector args_; }; @@ -367,10 +372,13 @@ class PhiInst : public Instruction { public: PhiInst(std::shared_ptr ty, std::string name); void AddIncoming(Value* value, BasicBlock* block); + void RemoveIncomingFrom(BasicBlock* block); const std::vector& GetIncomingValues() const; const std::vector& GetIncomingBlocks() const; private: + void OnOperandChanged(size_t index, Value* value) override; + void OnOperandRemoving(size_t index) override; std::vector incoming_values_; std::vector incoming_blocks_; }; @@ -383,6 +391,8 @@ class GepInst : public Instruction { const std::vector& GetIndices() const { return indices_; } private: + void OnOperandChanged(size_t index, Value* value) override; + void OnOperandRemoving(size_t index) override; std::vector indices_; }; @@ -394,10 +404,17 @@ class BasicBlock : public Value { void SetParent(Function* parent); bool HasTerminator() const; const std::vector>& GetInstructions() const; + std::vector>& GetMutableInstructions(); const std::vector& GetPredecessors() const; const std::vector& GetSuccessors() const; void AddPredecessor(BasicBlock* pred); void AddSuccessor(BasicBlock* succ); + void ClearPredecessors(); + void ClearSuccessors(); + void RemovePredecessor(BasicBlock* pred); + void RemoveSuccessor(BasicBlock* succ); + void EraseInstruction(Instruction* inst); + void ReplaceTerminator(std::unique_ptr inst); template T* Append(Args&&... args) { if (HasTerminator()) { @@ -443,6 +460,7 @@ class Function : public Value { BasicBlock* GetEntry(); const BasicBlock* GetEntry() const; const std::vector>& GetBlocks() const; + std::vector>& GetMutableBlocks(); const std::vector>& GetArguments() const; size_t GetNumArgs() const; Argument* GetArg(size_t index); @@ -548,4 +566,12 @@ class IRPrinter { void Print(const Module& module, std::ostream& os); }; +bool RunMem2Reg(Module& module); +bool RunConstFold(Module& module); +bool RunConstProp(Module& module); +bool RunCSE(Module& module); +bool RunDCE(Module& module); +bool RunCFGSimplify(Module& module); +bool RunScalarOptimizationPipeline(Module& module); + } // namespace ir diff --git a/include/utils/CLI.h b/include/utils/CLI.h index 4b3a781..c067ec8 100644 --- a/include/utils/CLI.h +++ b/include/utils/CLI.h @@ -8,6 +8,7 @@ struct CLIOptions { bool emit_parse_tree = false; bool emit_ir = true; bool emit_asm = false; + bool optimize_ir = true; bool show_help = false; }; diff --git a/src/ir/BasicBlock.cpp b/src/ir/BasicBlock.cpp index 8d9affa..f16a977 100644 --- a/src/ir/BasicBlock.cpp +++ b/src/ir/BasicBlock.cpp @@ -9,6 +9,7 @@ #include "ir/IR.h" +#include #include namespace ir { @@ -32,6 +33,10 @@ const std::vector>& BasicBlock::GetInstructions() return instructions_; } +std::vector>& BasicBlock::GetMutableInstructions() { + return instructions_; +} + // 前驱/后继接口先保留给后续 CFG 扩展使用。 // 当前最小 IR 中还没有 branch 指令,因此这些列表通常为空。 const std::vector& BasicBlock::GetPredecessors() const { @@ -58,6 +63,51 @@ void BasicBlock::AddSuccessor(BasicBlock* succ) { successors_.push_back(succ); } +void BasicBlock::ClearPredecessors() { predecessors_.clear(); } + +void BasicBlock::ClearSuccessors() { successors_.clear(); } + +void BasicBlock::RemovePredecessor(BasicBlock* pred) { + predecessors_.erase(std::remove(predecessors_.begin(), predecessors_.end(), pred), + predecessors_.end()); +} + +void BasicBlock::RemoveSuccessor(BasicBlock* succ) { + successors_.erase(std::remove(successors_.begin(), successors_.end(), succ), + successors_.end()); +} + +void BasicBlock::EraseInstruction(Instruction* inst) { + if (!inst) return; + auto it = std::find_if(instructions_.begin(), instructions_.end(), + [&](const auto& ptr) { return ptr.get() == inst; }); + if (it == instructions_.end()) return; + for (size_t i = 0; i < inst->GetNumOperands(); ++i) { + if (auto* operand = inst->GetOperand(i)) { + operand->RemoveUse(inst, i); + } + } + inst->SetParent(nullptr); + instructions_.erase(it); +} + +void BasicBlock::ReplaceTerminator(std::unique_ptr inst) { + if (!inst || !inst->IsTerminator()) return; + if (!instructions_.empty() && instructions_.back()->IsTerminator()) { + auto* old = instructions_.back().get(); + for (size_t i = 0; i < old->GetNumOperands(); ++i) { + if (auto* operand = old->GetOperand(i)) { + operand->RemoveUse(old, i); + } + } + old->SetParent(nullptr); + instructions_.pop_back(); + } + inst->SetParent(this); + instructions_.push_back(std::move(inst)); + LinkSuccessorsIfNeeded(instructions_.back().get()); +} + void BasicBlock::LinkSuccessorsIfNeeded(Instruction* inst) { if (!inst) return; if (auto* br = dynamic_cast(inst)) { diff --git a/src/ir/Function.cpp b/src/ir/Function.cpp index f27dd98..f056c0e 100644 --- a/src/ir/Function.cpp +++ b/src/ir/Function.cpp @@ -48,6 +48,10 @@ const std::vector>& Function::GetBlocks() const { return blocks_; } +std::vector>& Function::GetMutableBlocks() { + return blocks_; +} + const std::vector>& Function::GetArguments() const { return args_; } diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index 6e053c8..dc05420 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -3,6 +3,7 @@ // - 指令操作数与结果类型管理,支持打印与优化 #include "ir/IR.h" +#include #include #include "utils/Log.h" @@ -36,6 +37,7 @@ void User::SetOperand(size_t index, Value* value) { } operands_[index] = value; value->AddUse(this, index); + OnOperandChanged(index, value); } void User::AddOperand(Value* value) { @@ -47,6 +49,27 @@ void User::AddOperand(Value* value) { value->AddUse(this, operand_index); } +void User::RemoveOperand(size_t index) { + if (index >= operands_.size()) { + throw std::out_of_range("User operand index out of range"); + } + OnOperandRemoving(index); + if (auto* old = operands_[index]) { + old->RemoveUse(this, index); + } + for (size_t i = index + 1; i < operands_.size(); ++i) { + if (auto* value = operands_[i]) { + value->RemoveUse(this, i); + value->AddUse(this, i - 1); + } + } + operands_.erase(operands_.begin() + static_cast(index)); +} + +void User::OnOperandChanged(size_t, Value*) {} + +void User::OnOperandRemoving(size_t) {} + Instruction::Instruction(Opcode op, std::shared_ptr ty, std::string name) : User(std::move(ty), std::move(name)), opcode_(op) {} @@ -312,6 +335,22 @@ CallInst::CallInst(std::shared_ptr ret_ty, Value* callee, Value* CallInst::GetCallee() const { return GetOperand(0); } +void CallInst::OnOperandChanged(size_t index, Value* value) { + if (index == 0) return; + size_t arg_index = index - 1; + if (arg_index < args_.size()) { + args_[arg_index] = value; + } +} + +void CallInst::OnOperandRemoving(size_t index) { + if (index == 0) return; + size_t arg_index = index - 1; + if (arg_index < args_.size()) { + args_.erase(args_.begin() + static_cast(arg_index)); + } +} + PhiInst::PhiInst(std::shared_ptr ty, std::string name) : Instruction(Opcode::Phi, std::move(ty), std::move(name)) {} @@ -328,6 +367,17 @@ void PhiInst::AddIncoming(Value* value, BasicBlock* block) { AddOperand(block); } +void PhiInst::RemoveIncomingFrom(BasicBlock* block) { + for (size_t i = 0; i < incoming_blocks_.size();) { + if (incoming_blocks_[i] != block) { + ++i; + continue; + } + RemoveOperand(2 * i + 1); + RemoveOperand(2 * i); + } +} + const std::vector& PhiInst::GetIncomingValues() const { return incoming_values_; } @@ -336,6 +386,34 @@ const std::vector& PhiInst::GetIncomingBlocks() const { return incoming_blocks_; } +void PhiInst::OnOperandChanged(size_t index, Value* value) { + size_t incoming_index = index / 2; + if (index % 2 == 0) { + if (incoming_index < incoming_values_.size()) { + incoming_values_[incoming_index] = value; + } + return; + } + if (incoming_index < incoming_blocks_.size()) { + incoming_blocks_[incoming_index] = static_cast(value); + } +} + +void PhiInst::OnOperandRemoving(size_t index) { + size_t incoming_index = index / 2; + if (index % 2 == 0) { + if (incoming_index < incoming_values_.size()) { + incoming_values_.erase(incoming_values_.begin() + + static_cast(incoming_index)); + } + return; + } + if (incoming_index < incoming_blocks_.size()) { + incoming_blocks_.erase(incoming_blocks_.begin() + + static_cast(incoming_index)); + } +} + GepInst::GepInst(std::shared_ptr result_ptr_ty, Value* base_ptr, std::vector indices, std::string name) : Instruction(Opcode::Gep, std::move(result_ptr_ty), std::move(name)), @@ -360,4 +438,20 @@ GepInst::GepInst(std::shared_ptr result_ptr_ty, Value* base_ptr, Value* GepInst::GetBasePtr() const { return GetOperand(0); } +void GepInst::OnOperandChanged(size_t index, Value* value) { + if (index == 0) return; + size_t idx = index - 1; + if (idx < indices_.size()) { + indices_[idx] = value; + } +} + +void GepInst::OnOperandRemoving(size_t index) { + if (index == 0) return; + size_t idx = index - 1; + if (idx < indices_.size()) { + indices_.erase(indices_.begin() + static_cast(idx)); + } +} + } // namespace ir diff --git a/src/ir/passes/CFGSimplify.cpp b/src/ir/passes/CFGSimplify.cpp index 3779397..ed5bac0 100644 --- a/src/ir/passes/CFGSimplify.cpp +++ b/src/ir/passes/CFGSimplify.cpp @@ -1,4 +1,135 @@ -// CFG 简化: -// - 删除不可达块、合并空块、简化分支等 -// - 改善 IR 结构,便于后续优化与后端生成 +#include "ir/IR.h" + +#include +#include +#include + +namespace ir { +namespace { + +Instruction* Terminator(BasicBlock& block) { + auto& insts = block.GetMutableInstructions(); + if (insts.empty()) return nullptr; + auto* inst = insts.back().get(); + return inst && inst->IsTerminator() ? inst : nullptr; +} + +void AddCFGEdge(BasicBlock* from, BasicBlock* to) { + if (!from || !to) return; + from->AddSuccessor(to); + to->AddPredecessor(from); +} + +void RebuildCFG(Function& func) { + for (const auto& block : func.GetBlocks()) { + if (!block) continue; + block->ClearPredecessors(); + block->ClearSuccessors(); + } + for (const auto& block : func.GetBlocks()) { + if (!block) continue; + auto* term = Terminator(*block); + if (auto* br = dynamic_cast(term)) { + AddCFGEdge(block.get(), br->GetDest()); + } else if (auto* cbr = dynamic_cast(term)) { + AddCFGEdge(block.get(), cbr->GetTrueDest()); + AddCFGEdge(block.get(), cbr->GetFalseDest()); + } + } +} + +bool SimplifyBranches(Function& func) { + bool changed = false; + for (const auto& block : func.GetBlocks()) { + if (!block) continue; + auto* cbr = dynamic_cast(Terminator(*block)); + if (!cbr) continue; + BasicBlock* dest = nullptr; + if (cbr->GetTrueDest() == cbr->GetFalseDest()) { + dest = cbr->GetTrueDest(); + } else if (auto* cond = dynamic_cast(cbr->GetCond())) { + dest = cond->GetValue() != 0 ? cbr->GetTrueDest() : cbr->GetFalseDest(); + } + if (dest) { + block->ReplaceTerminator(std::make_unique(dest)); + changed = true; + } + } + return changed; +} + +std::unordered_set ReachableBlocks(Function& func) { + std::unordered_set reachable; + std::queue work; + if (auto* entry = func.GetEntry()) { + reachable.insert(entry); + work.push(entry); + } + while (!work.empty()) { + auto* block = work.front(); + work.pop(); + for (auto* succ : block->GetSuccessors()) { + if (succ && reachable.insert(succ).second) { + work.push(succ); + } + } + } + return reachable; +} + +bool RemoveUnreachable(Function& func) { + auto reachable = ReachableBlocks(func); + bool changed = false; + + for (const auto& block : func.GetBlocks()) { + if (!block || reachable.count(block.get()) == 0) continue; + for (const auto& inst : block->GetInstructions()) { + auto* phi = dynamic_cast(inst.get()); + if (!phi) continue; + for (const auto& other : func.GetBlocks()) { + if (other && reachable.count(other.get()) == 0) { + phi->RemoveIncomingFrom(other.get()); + } + } + } + } + + auto& blocks = func.GetMutableBlocks(); + for (auto it = blocks.begin(); it != blocks.end();) { + auto* block = it->get(); + if (!block || reachable.count(block) != 0) { + ++it; + continue; + } + auto& insts = block->GetMutableInstructions(); + while (!insts.empty()) { + block->EraseInstruction(insts.back().get()); + } + it = blocks.erase(it); + changed = true; + } + return changed; +} + +} // namespace + +bool RunCFGSimplify(Module& module) { + bool changed = false; + for (const auto& func : module.GetFunctions()) { + if (!func || func->IsDeclaration()) continue; + RebuildCFG(*func); + bool local_changed = SimplifyBranches(*func); + if (local_changed) { + RebuildCFG(*func); + } + local_changed |= RemoveUnreachable(*func); + if (local_changed) { + RebuildCFG(*func); + } + changed |= local_changed; + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/CSE.cpp b/src/ir/passes/CSE.cpp index 4b24dd0..1442c6a 100644 --- a/src/ir/passes/CSE.cpp +++ b/src/ir/passes/CSE.cpp @@ -1,4 +1,94 @@ -// 公共子表达式消除(CSE): -// - 识别并复用重复计算的等价表达式 -// - 典型放置在 ConstFold 之后、DCE 之前 -// - 当前为 Lab4 的框架占位,具体算法由实验实现 +#include "ir/IR.h" + +#include +#include +#include +#include +#include + +namespace ir { +namespace { + +bool IsCommutative(Opcode op) { + return op == Opcode::Add || op == Opcode::Mul || op == Opcode::FAdd || + op == Opcode::FMul; +} + +std::string ValueId(Value* value) { + std::ostringstream oss; + oss << reinterpret_cast(value); + return oss.str(); +} + +std::string KeyFor(const Instruction& inst) { + std::ostringstream oss; + oss << static_cast(inst.GetOpcode()) << ":"; + if (auto* icmp = dynamic_cast(&inst)) { + oss << static_cast(icmp->GetPredicate()) << ":"; + } else if (auto* fcmp = dynamic_cast(&inst)) { + oss << static_cast(fcmp->GetPredicate()) << ":"; + } + + std::vector operands; + for (size_t i = 0; i < inst.GetNumOperands(); ++i) { + operands.push_back(ValueId(inst.GetOperand(i))); + } + if (IsCommutative(inst.GetOpcode()) && operands.size() == 2) { + std::sort(operands.begin(), operands.end()); + } + for (const auto& operand : operands) { + oss << operand << ","; + } + return oss.str(); +} + +bool IsCSECandidate(const Instruction& inst) { + switch (inst.GetOpcode()) { + case Opcode::Add: + case Opcode::Sub: + case Opcode::Mul: + case Opcode::SDiv: + case Opcode::SRem: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: + case Opcode::ICmp: + case Opcode::FCmp: + case Opcode::SIToFP: + case Opcode::FPToSI: + case Opcode::ZExt: + case Opcode::Gep: + return true; + default: + return false; + } +} + +} // namespace + +bool RunCSE(Module& module) { + bool changed = false; + for (const auto& func : module.GetFunctions()) { + if (!func || func->IsDeclaration()) continue; + for (const auto& block : func->GetBlocks()) { + if (!block) continue; + std::unordered_map available; + for (const auto& inst_ptr : block->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (!inst || !IsCSECandidate(*inst)) continue; + std::string key = KeyFor(*inst); + auto it = available.find(key); + if (it != available.end()) { + inst->ReplaceAllUsesWith(it->second); + changed = true; + } else { + available.emplace(std::move(key), inst); + } + } + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/ConstFold.cpp b/src/ir/passes/ConstFold.cpp index 19f2d43..fdad6ee 100644 --- a/src/ir/passes/ConstFold.cpp +++ b/src/ir/passes/ConstFold.cpp @@ -1,4 +1,239 @@ -// IR 常量折叠: -// - 折叠可判定的常量表达式 -// - 简化常量控制流分支(按实现范围裁剪) +#include "ir/IR.h" + +#include + +namespace ir { +namespace { + +ConstantInt* AsConstInt(Value* value) { + return dynamic_cast(value); +} + +ConstantFloat* AsConstFloat(Value* value) { + return dynamic_cast(value); +} + +bool IsZero(Value* value) { + if (auto* i = AsConstInt(value)) return i->GetValue() == 0; + if (auto* f = AsConstFloat(value)) return f->GetValue() == 0.0f; + return false; +} + +bool IsOne(Value* value) { + if (auto* i = AsConstInt(value)) return i->GetValue() == 1; + if (auto* f = AsConstFloat(value)) return f->GetValue() == 1.0f; + return false; +} + +ConstantValue* FoldBinary(BinaryInst& inst, Context& ctx) { + auto* li = AsConstInt(inst.GetLhs()); + auto* ri = AsConstInt(inst.GetRhs()); + if (li && ri) { + int lhs = li->GetValue(); + int rhs = ri->GetValue(); + switch (inst.GetOpcode()) { + case Opcode::Add: + return ctx.GetConstInt(lhs + rhs); + case Opcode::Sub: + return ctx.GetConstInt(lhs - rhs); + case Opcode::Mul: + return ctx.GetConstInt(lhs * rhs); + case Opcode::SDiv: + if (rhs != 0) return ctx.GetConstInt(lhs / rhs); + break; + case Opcode::SRem: + if (rhs != 0) return ctx.GetConstInt(lhs % rhs); + break; + default: + break; + } + } + + auto* lf = AsConstFloat(inst.GetLhs()); + auto* rf = AsConstFloat(inst.GetRhs()); + if (lf && rf) { + float lhs = lf->GetValue(); + float rhs = rf->GetValue(); + switch (inst.GetOpcode()) { + case Opcode::FAdd: + return ctx.GetConstFloat(lhs + rhs); + case Opcode::FSub: + return ctx.GetConstFloat(lhs - rhs); + case Opcode::FMul: + return ctx.GetConstFloat(lhs * rhs); + case Opcode::FDiv: + return ctx.GetConstFloat(lhs / rhs); + default: + break; + } + } + return nullptr; +} + +Value* SimplifyBinary(BinaryInst& inst) { + auto op = inst.GetOpcode(); + auto* lhs = inst.GetLhs(); + auto* rhs = inst.GetRhs(); + switch (op) { + case Opcode::Add: + case Opcode::FAdd: + if (IsZero(rhs)) return lhs; + if (IsZero(lhs)) return rhs; + break; + case Opcode::Sub: + case Opcode::FSub: + if (IsZero(rhs)) return lhs; + break; + case Opcode::Mul: + if (IsOne(rhs)) return lhs; + if (IsOne(lhs)) return rhs; + if (IsZero(rhs)) return rhs; + if (IsZero(lhs)) return lhs; + break; + case Opcode::FMul: + if (IsOne(rhs)) return lhs; + if (IsOne(lhs)) return rhs; + break; + case Opcode::SDiv: + case Opcode::FDiv: + if (IsOne(rhs)) return lhs; + break; + case Opcode::SRem: + if (IsOne(rhs)) return rhs; + break; + default: + break; + } + return nullptr; +} + +ConstantInt* FoldICmp(ICmpInst& inst, Context& ctx) { + auto* lhs = AsConstInt(inst.GetLhs()); + auto* rhs = AsConstInt(inst.GetRhs()); + if (!lhs || !rhs) return nullptr; + int l = lhs->GetValue(); + int r = rhs->GetValue(); + bool result = false; + switch (inst.GetPredicate()) { + case ICmpPredicate::Eq: + result = l == r; + break; + case ICmpPredicate::Ne: + result = l != r; + break; + case ICmpPredicate::Slt: + result = l < r; + break; + case ICmpPredicate::Sle: + result = l <= r; + break; + case ICmpPredicate::Sgt: + result = l > r; + break; + case ICmpPredicate::Sge: + result = l >= r; + break; + } + return ctx.GetConstBool(result); +} + +ConstantInt* FoldFCmp(FCmpInst& inst, Context& ctx) { + auto* lhs = AsConstFloat(inst.GetLhs()); + auto* rhs = AsConstFloat(inst.GetRhs()); + if (!lhs || !rhs) return nullptr; + float l = lhs->GetValue(); + float r = rhs->GetValue(); + bool ordered = !std::isnan(l) && !std::isnan(r); + bool result = false; + switch (inst.GetPredicate()) { + case FCmpPredicate::Oeq: + result = ordered && l == r; + break; + case FCmpPredicate::One: + result = ordered && l != r; + break; + case FCmpPredicate::Olt: + result = ordered && l < r; + break; + case FCmpPredicate::Ole: + result = ordered && l <= r; + break; + case FCmpPredicate::Ogt: + result = ordered && l > r; + break; + case FCmpPredicate::Oge: + result = ordered && l >= r; + break; + } + return ctx.GetConstBool(result); +} + +ConstantValue* FoldCast(CastInst& inst, Context& ctx) { + switch (inst.GetOpcode()) { + case Opcode::SIToFP: + if (auto* c = AsConstInt(inst.GetValue())) { + return ctx.GetConstFloat(static_cast(c->GetValue())); + } + break; + case Opcode::FPToSI: + if (auto* c = AsConstFloat(inst.GetValue())) { + return ctx.GetConstInt(static_cast(c->GetValue())); + } + break; + case Opcode::ZExt: + if (auto* c = AsConstInt(inst.GetValue())) { + return ctx.GetConstInt(c->GetValue() != 0 ? 1 : 0); + } + break; + default: + break; + } + return nullptr; +} + +Value* SimplifyPhi(PhiInst& phi) { + const auto& values = phi.GetIncomingValues(); + if (values.empty()) return nullptr; + auto* first = values.front(); + for (auto* value : values) { + if (value != first) return nullptr; + } + return first; +} + +} // namespace + +bool RunConstFold(Module& module) { + bool changed = false; + auto& ctx = module.GetContext(); + for (const auto& func : module.GetFunctions()) { + if (!func || func->IsDeclaration()) continue; + for (const auto& block : func->GetBlocks()) { + if (!block) continue; + for (const auto& inst_ptr : block->GetInstructions()) { + auto* inst = inst_ptr.get(); + Value* replacement = nullptr; + if (auto* bin = dynamic_cast(inst)) { + replacement = FoldBinary(*bin, ctx); + if (!replacement) replacement = SimplifyBinary(*bin); + } else if (auto* icmp = dynamic_cast(inst)) { + replacement = FoldICmp(*icmp, ctx); + } else if (auto* fcmp = dynamic_cast(inst)) { + replacement = FoldFCmp(*fcmp, ctx); + } else if (auto* cast = dynamic_cast(inst)) { + replacement = FoldCast(*cast, ctx); + } else if (auto* phi = dynamic_cast(inst)) { + replacement = SimplifyPhi(*phi); + } + if (replacement && replacement != inst) { + inst->ReplaceAllUsesWith(replacement); + changed = true; + } + } + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/ConstProp.cpp b/src/ir/passes/ConstProp.cpp index 1768b71..81b9cd1 100644 --- a/src/ir/passes/ConstProp.cpp +++ b/src/ir/passes/ConstProp.cpp @@ -1,5 +1,77 @@ -// 常量传播(Constant Propagation): -// - 沿 use-def 关系传播已知常量 -// - 将可替换的 SSA 值改写为常量,暴露更多折叠机会 -// - 常与 ConstFold、DCE、CFGSimplify 迭代配合使用 +#include "ir/IR.h" + +#include + +namespace ir { +namespace { + +ConstantValue* AsConstant(Value* value) { + return dynamic_cast(value); +} + +ConstantValue* ScalarConstInitializer(Value* ptr) { + auto* global = dynamic_cast(ptr); + if (!global || !global->IsConst()) return nullptr; + auto* init = global->GetInitializer(); + if (!init) return nullptr; + if (dynamic_cast(init)) return nullptr; + return init; +} + +bool IsScalarStackSlot(Value* ptr) { + auto* alloca = dynamic_cast(ptr); + if (!alloca) return false; + const auto& ty = alloca->GetAllocatedType(); + return ty && (ty->IsInt1() || ty->IsInt32() || ty->IsFloat()); +} + +} // namespace + +bool RunConstProp(Module& module) { + bool changed = false; + for (const auto& func : module.GetFunctions()) { + if (!func || func->IsDeclaration()) continue; + for (const auto& block : func->GetBlocks()) { + if (!block) continue; + std::unordered_map known_memory; + for (const auto& inst_ptr : block->GetInstructions()) { + auto* inst = inst_ptr.get(); + if (auto* load = dynamic_cast(inst)) { + auto* ptr = load->GetPtr(); + ConstantValue* known = nullptr; + auto it = known_memory.find(ptr); + if (it != known_memory.end()) { + known = it->second; + } else { + known = ScalarConstInitializer(ptr); + } + if (known) { + load->ReplaceAllUsesWith(known); + changed = true; + } + continue; + } + + if (auto* store = dynamic_cast(inst)) { + auto* ptr = store->GetPtr(); + if (IsScalarStackSlot(ptr)) { + if (auto* c = AsConstant(store->GetValue())) { + known_memory[ptr] = c; + } else { + known_memory.erase(ptr); + } + } + continue; + } + + if (inst->GetOpcode() == Opcode::Call) { + known_memory.clear(); + } + } + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/DCE.cpp b/src/ir/passes/DCE.cpp index 5a0db91..cbf83d4 100644 --- a/src/ir/passes/DCE.cpp +++ b/src/ir/passes/DCE.cpp @@ -1,4 +1,55 @@ -// 死代码删除(DCE): -// - 删除无用指令与无用基本块 -// - 通常与 CFG 简化配合使用 +#include "ir/IR.h" + +namespace ir { +namespace { + +bool HasSideEffect(const Instruction& inst) { + switch (inst.GetOpcode()) { + case Opcode::Store: + case Opcode::Ret: + case Opcode::Br: + case Opcode::CondBr: + case Opcode::Call: + return true; + default: + return false; + } +} + +bool IsDead(const Instruction& inst) { + if (inst.IsVoid()) return false; + if (HasSideEffect(inst)) return false; + return inst.GetUses().empty(); +} + +} // namespace + +bool RunDCE(Module& module) { + bool changed = false; + bool local_changed = true; + while (local_changed) { + local_changed = false; + for (const auto& func : module.GetFunctions()) { + if (!func || func->IsDeclaration()) continue; + for (const auto& block : func->GetBlocks()) { + if (!block) continue; + auto& insts = block->GetMutableInstructions(); + for (auto it = insts.begin(); it != insts.end();) { + auto* inst = it->get(); + if (inst && IsDead(*inst)) { + block->EraseInstruction(inst); + local_changed = true; + changed = true; + it = insts.begin(); + } else { + ++it; + } + } + } + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/Mem2Reg.cpp b/src/ir/passes/Mem2Reg.cpp index 0b052ba..43d1b32 100644 --- a/src/ir/passes/Mem2Reg.cpp +++ b/src/ir/passes/Mem2Reg.cpp @@ -1,4 +1,109 @@ -// Mem2Reg(SSA 构造): -// - 将局部变量的 alloca/load/store 提升为 SSA 形式 -// - 插入 PHI 并重写使用,依赖支配树等分析 +#include "ir/IR.h" + +#include + +namespace ir { +namespace { + +bool IsPromotableType(const AllocaInst& alloca) { + const auto& ty = alloca.GetAllocatedType(); + return ty && (ty->IsInt1() || ty->IsInt32() || ty->IsFloat()); +} + +bool IsDirectUseOf(Value* ptr, Instruction* inst) { + if (auto* load = dynamic_cast(inst)) { + return load->GetPtr() == ptr; + } + if (auto* store = dynamic_cast(inst)) { + return store->GetPtr() == ptr; + } + return false; +} + +BasicBlock* SingleUseBlock(AllocaInst& alloca) { + BasicBlock* use_block = nullptr; + for (const auto& use : alloca.GetUses()) { + auto* inst = dynamic_cast(use.GetUser()); + if (!inst || !IsDirectUseOf(&alloca, inst)) return nullptr; + auto* parent = inst->GetParent(); + if (!parent) return nullptr; + if (!use_block) { + use_block = parent; + } else if (use_block != parent) { + return nullptr; + } + } + return use_block; +} + +bool CanPromoteInBlock(AllocaInst& alloca, BasicBlock& block) { + bool has_value = false; + bool saw_use = false; + for (const auto& inst_ptr : block.GetInstructions()) { + auto* inst = inst_ptr.get(); + if (!inst || !IsDirectUseOf(&alloca, inst)) continue; + saw_use = true; + if (auto* store = dynamic_cast(inst)) { + if (store->GetValue() == &alloca) return false; + has_value = true; + } else if (dynamic_cast(inst)) { + if (!has_value) return false; + } + } + return saw_use; +} + +bool PromoteInBlock(AllocaInst& alloca, BasicBlock& block) { + Value* current = nullptr; + std::vector erase; + for (const auto& inst_ptr : block.GetInstructions()) { + auto* inst = inst_ptr.get(); + if (!inst || !IsDirectUseOf(&alloca, inst)) continue; + if (auto* store = dynamic_cast(inst)) { + current = store->GetValue(); + erase.push_back(store); + } else if (auto* load = dynamic_cast(inst)) { + load->ReplaceAllUsesWith(current); + erase.push_back(load); + } + } + for (auto* inst : erase) { + block.EraseInstruction(inst); + } + if (alloca.GetUses().empty()) { + if (auto* parent = alloca.GetParent()) { + parent->EraseInstruction(&alloca); + } + } + return !erase.empty(); +} + +} // namespace + +bool RunMem2Reg(Module& module) { + bool changed = false; + for (const auto& func : module.GetFunctions()) { + if (!func || func->IsDeclaration()) continue; + std::vector allocas; + for (const auto& block : func->GetBlocks()) { + if (!block) continue; + for (const auto& inst : block->GetInstructions()) { + if (auto* alloca = dynamic_cast(inst.get())) { + if (IsPromotableType(*alloca)) { + allocas.push_back(alloca); + } + } + } + } + for (auto* alloca : allocas) { + if (!alloca || alloca->GetUses().empty()) continue; + auto* block = SingleUseBlock(*alloca); + if (!block || !CanPromoteInBlock(*alloca, *block)) continue; + changed |= PromoteInBlock(*alloca, *block); + } + } + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/PassManager.cpp b/src/ir/passes/PassManager.cpp index 044328f..8ecc1f7 100644 --- a/src/ir/passes/PassManager.cpp +++ b/src/ir/passes/PassManager.cpp @@ -1 +1,23 @@ -// IR Pass 管理骨架。 +#include "ir/IR.h" + +namespace ir { + +bool RunScalarOptimizationPipeline(Module& module) { + bool changed = false; + changed |= RunMem2Reg(module); + + for (int i = 0; i < 8; ++i) { + bool iter_changed = false; + iter_changed |= RunConstFold(module); + iter_changed |= RunConstProp(module); + iter_changed |= RunCSE(module); + iter_changed |= RunDCE(module); + iter_changed |= RunCFGSimplify(module); + changed |= iter_changed; + if (!iter_changed) break; + } + + return changed; +} + +} // namespace ir diff --git a/src/main.cpp b/src/main.cpp index 5767e6a..da87ec5 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -36,6 +36,9 @@ int main(int argc, char** argv) { auto sema = RunSema(*comp_unit); auto module = GenerateIR(*comp_unit, sema); + if (opts.optimize_ir) { + ir::RunScalarOptimizationPipeline(*module); + } if (opts.emit_ir) { ir::IRPrinter printer; if (need_blank_line) { diff --git a/src/mir/LLVMAsmBackend.cpp b/src/mir/LLVMAsmBackend.cpp index 42c310f..1e56df1 100644 --- a/src/mir/LLVMAsmBackend.cpp +++ b/src/mir/LLVMAsmBackend.cpp @@ -1,5 +1,6 @@ #include "mir/MIR.h" +#include #include #include #include @@ -18,6 +19,18 @@ namespace { std::string ShellQuote(const std::filesystem::path& path) { std::string raw = path.string(); +#if defined(_WIN32) + std::string quoted = "\""; + for (char ch : raw) { + if (ch == '"') { + quoted += "\\\""; + } else { + quoted += ch; + } + } + quoted += "\""; + return quoted; +#else std::string quoted = "'"; for (char ch : raw) { if (ch == '\'') { @@ -28,6 +41,7 @@ std::string ShellQuote(const std::filesystem::path& path) { } quoted += "'"; return quoted; +#endif } std::string ReadTextFile(const std::filesystem::path& path) { @@ -41,20 +55,35 @@ std::string ReadTextFile(const std::filesystem::path& path) { return oss.str(); } -} // namespace - -void PrintAArch64AsmFromIR(const ir::Module& module, std::ostream& os) { - auto tmp_dir = std::filesystem::temp_directory_path(); - std::string pattern = (tmp_dir / "nudt_lab3_XXXXXX").string(); +std::filesystem::path CreateTempDir() { + auto base = std::filesystem::temp_directory_path(); +#if defined(_WIN32) + auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count(); + for (int i = 0; i < 100; ++i) { + auto candidate = + base / ("nudt_lab3_" + std::to_string(seed) + "_" + std::to_string(i)); + std::error_code ec; + if (std::filesystem::create_directory(candidate, ec)) { + return candidate; + } + } + throw std::runtime_error(FormatError("mir", "创建临时目录失败")); +#else + std::string pattern = (base / "nudt_lab3_XXXXXX").string(); std::vector dir_template(pattern.begin(), pattern.end()); dir_template.push_back('\0'); - char* created = mkdtemp(dir_template.data()); if (!created) { throw std::runtime_error(FormatError("mir", "创建临时目录失败")); } + return std::filesystem::path(created); +#endif +} + +} // namespace - std::filesystem::path work_dir(created); +void PrintAArch64AsmFromIR(const ir::Module& module, std::ostream& os) { + std::filesystem::path work_dir = CreateTempDir(); const auto ir_file = work_dir / "module.ll"; const auto asm_file = work_dir / "module.s"; const auto err_file = work_dir / "clang.err"; diff --git a/src/utils/CLI.cpp b/src/utils/CLI.cpp index 21b6d20..8bea2ee 100644 --- a/src/utils/CLI.cpp +++ b/src/utils/CLI.cpp @@ -15,7 +15,7 @@ CLIOptions ParseCLI(int argc, char** argv) { if (argc <= 1) { throw std::runtime_error(FormatError( "cli", - "用法: compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] ")); + "用法: compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] [--no-opt] ")); } for (int i = 1; i < argc; ++i) { @@ -58,6 +58,11 @@ CLIOptions ParseCLI(int argc, char** argv) { continue; } + if (std::strcmp(arg, "--no-opt") == 0) { + opt.optimize_ir = false; + continue; + } + if (arg[0] == '-') { throw std::runtime_error( FormatError("cli", std::string("未知参数: ") + arg + diff --git a/src/utils/Log.cpp b/src/utils/Log.cpp index e540ba8..82921de 100644 --- a/src/utils/Log.cpp +++ b/src/utils/Log.cpp @@ -50,13 +50,14 @@ void PrintHelp(std::ostream& os) { os << "SysY Compiler\n" << "\n" << "用法:\n" - << " compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] \n" + << " compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] [--no-opt] \n" << "\n" << "选项:\n" << " -h, --help 打印帮助信息并退出\n" << " --emit-parse-tree 仅在显式模式下启用语法树输出\n" << " --emit-ir 仅在显式模式下启用 IR 输出\n" << " --emit-asm 仅在显式模式下启用 AArch64 汇编输出\n" + << " --no-opt 关闭 Lab4 IR 标量优化管线\n" << "\n" << "说明:\n" << " - 默认输出 IR\n" diff --git a/sylib/sylib.c b/sylib/sylib.c index 1b357a5..4a14d08 100644 --- a/sylib/sylib.c +++ b/sylib/sylib.c @@ -6,7 +6,17 @@ int getint() { return v; } -int getch() { return getchar(); } +int getch() { + int c = getchar(); + if (c == '\r') { + int next = getchar(); + if (next != '\n' && next != EOF) { + ungetc(next, stdin); + } + return '\n'; + } + return c; +} int getarray(int a[]) { int n = 0; diff --git a/third_party/antlr4-runtime-4.13.2/runtime/src/atn/ProfilingATNSimulator.cpp b/third_party/antlr4-runtime-4.13.2/runtime/src/atn/ProfilingATNSimulator.cpp index 9fd86d6..5df0611 100755 --- a/third_party/antlr4-runtime-4.13.2/runtime/src/atn/ProfilingATNSimulator.cpp +++ b/third_party/antlr4-runtime-4.13.2/runtime/src/atn/ProfilingATNSimulator.cpp @@ -11,6 +11,8 @@ #include "atn/ProfilingATNSimulator.h" +#include + using namespace antlr4; using namespace antlr4::atn; using namespace antlr4::dfa;