From 0e5a75eaf3d24a742bcec3e8c9d212cdbb7fbf1a Mon Sep 17 00:00:00 2001 From: jing <3030349106@qq.com> Date: Wed, 11 Mar 2026 23:04:28 +0800 Subject: [PATCH] =?UTF-8?q?fix(ir):=20=E4=BF=AE=E6=94=B9=E4=BA=86=E4=B8=80?= =?UTF-8?q?=E4=B8=8Bcontext=E7=9A=84=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/ir/BasicBlock.cpp | 12 ++++++++++ src/ir/Context.cpp | 7 +----- src/ir/Function.cpp | 1 + src/ir/IR.h | 37 ++++++++++++++++++---------- src/ir/IRBuilder.cpp | 44 ++++++---------------------------- src/ir/IRPrinter.cpp | 36 ++++++++++++++-------------- src/ir/Instruction.cpp | 53 ++++++++++++++++++++++------------------- src/ir/Module.cpp | 4 ++++ src/ir/Type.cpp | 6 ++--- src/ir/Value.cpp | 3 ++- src/irgen/IRGenDecl.cpp | 4 ++-- src/irgen/IRGenFunc.cpp | 9 +++++-- src/main.cpp | 2 +- 13 files changed, 112 insertions(+), 106 deletions(-) diff --git a/src/ir/BasicBlock.cpp b/src/ir/BasicBlock.cpp index ef257bc..681bb6d 100644 --- a/src/ir/BasicBlock.cpp +++ b/src/ir/BasicBlock.cpp @@ -12,6 +12,10 @@ BasicBlock::BasicBlock(std::string name) : name_(std::move(name)) {} const std::string& BasicBlock::name() const { return name_; } +Function* BasicBlock::parent() const { return parent_; } + +void BasicBlock::set_parent(Function* parent) { parent_ = parent; } + bool BasicBlock::HasTerminator() const { return !instructions_.empty() && instructions_.back()->IsTerminator(); } @@ -21,4 +25,12 @@ const std::vector>& BasicBlock::instructions() return instructions_; } +const std::vector& BasicBlock::predecessors() const { + return predecessors_; +} + +const std::vector& BasicBlock::successors() const { + return successors_; +} + } // namespace ir diff --git a/src/ir/Context.cpp b/src/ir/Context.cpp index 49aa9b8..a119446 100644 --- a/src/ir/Context.cpp +++ b/src/ir/Context.cpp @@ -7,11 +7,6 @@ namespace ir { -Context& DefaultContext() { - static Context ctx; - return ctx; -} - Context::~Context() = default; const std::shared_ptr& Context::Void() { @@ -39,7 +34,7 @@ ConstantInt* Context::GetConstInt(int v) { auto it = const_ints_.find(v); if (it != const_ints_.end()) return it->second.get(); auto inserted = - const_ints_.emplace(v, std::make_unique(v)).first; + const_ints_.emplace(v, std::make_unique(Int32(), v)).first; return inserted->second.get(); } diff --git a/src/ir/Function.cpp b/src/ir/Function.cpp index 33312b0..d1a2c66 100644 --- a/src/ir/Function.cpp +++ b/src/ir/Function.cpp @@ -13,6 +13,7 @@ Function::Function(std::string name, std::shared_ptr ret_type) BasicBlock* Function::CreateBlock(const std::string& name) { auto block = std::make_unique(name); auto* ptr = block.get(); + ptr->set_parent(this); blocks_.push_back(std::move(block)); if (!entry_) { entry_ = ptr; diff --git a/src/ir/IR.h b/src/ir/IR.h index 2876333..0b8d2ed 100644 --- a/src/ir/IR.h +++ b/src/ir/IR.h @@ -2,6 +2,7 @@ // 可在此基础上扩展更多类型/指令 #pragma once +#include #include #include #include @@ -15,10 +16,12 @@ class Type; class ConstantInt; class Instruction; class BasicBlock; +class Function; // IR 上下文:集中管理类型、常量等共享资源,便于复用与扩展。 class Context { public: + Context() = default; ~Context(); const std::shared_ptr& Void(); const std::shared_ptr& Int32(); @@ -36,16 +39,14 @@ class Context { int temp_index_ = -1; }; -Context& DefaultContext(); - class Type { public: enum class Kind { Void, Int32, PtrInt32 }; explicit Type(Kind k); Kind kind() const; - static std::shared_ptr Void(); - static std::shared_ptr Int32(); - static std::shared_ptr PtrInt32(); + bool IsVoid() const; + bool IsInt32() const; + bool IsPtrInt32() const; private: Kind kind_; @@ -69,7 +70,7 @@ class Value { class ConstantInt : public Value { public: - explicit ConstantInt(int v); + ConstantInt(std::shared_ptr ty, int v); int value() const { return value_; } private: @@ -106,7 +107,7 @@ class BinaryInst : public Instruction { class ReturnInst : public Instruction { public: - explicit ReturnInst(Value* val); + ReturnInst(std::shared_ptr void_ty, Value* val); Value* value() const; private: @@ -115,12 +116,12 @@ class ReturnInst : public Instruction { class AllocaInst : public Instruction { public: - explicit AllocaInst(std::string name); + AllocaInst(std::shared_ptr ptr_ty, std::string name); }; class LoadInst : public Instruction { public: - LoadInst(Value* ptr, std::string name); + LoadInst(std::shared_ptr val_ty, Value* ptr, std::string name); Value* ptr() const; private: @@ -129,7 +130,7 @@ class LoadInst : public Instruction { class StoreInst : public Instruction { public: - StoreInst(Value* val, Value* ptr); + StoreInst(std::shared_ptr void_ty, Value* val, Value* ptr); Value* value() const; Value* ptr() const; @@ -142,8 +143,12 @@ class BasicBlock { public: explicit BasicBlock(std::string name); const std::string& name() const; + Function* parent() const; + void set_parent(Function* parent); bool HasTerminator() const; const std::vector>& instructions() const; + const std::vector& predecessors() const; + const std::vector& successors() const; template T* Append(Args&&... args) { if (HasTerminator()) { @@ -159,7 +164,10 @@ class BasicBlock { private: std::string name_; + Function* parent_ = nullptr; std::vector> instructions_; + std::vector predecessors_; + std::vector successors_; }; class Function : public Value { @@ -178,18 +186,22 @@ class Function : public Value { class Module { public: + Module() = default; + Context& context(); + const Context& context() const; // 创建函数时显式传入返回类型,便于在 IRGen 中根据语法树信息选择类型。 Function* CreateFunction(const std::string& name, std::shared_ptr ret_type); const std::vector>& functions() const; private: + Context context_; std::vector> functions_; }; class IRBuilder { public: - explicit IRBuilder(BasicBlock* bb); + IRBuilder(Context& ctx, BasicBlock* bb); void SetInsertPoint(BasicBlock* bb); BasicBlock* GetInsertBlock() const; @@ -204,12 +216,13 @@ class IRBuilder { ReturnInst* CreateRet(Value* v); private: + Context& ctx_; BasicBlock* insertBlock_; }; class IRPrinter { public: - void Print(const Module& module); + void Print(const Module& module, std::ostream& os); }; } // namespace ir diff --git a/src/ir/IRBuilder.cpp b/src/ir/IRBuilder.cpp index 58f08de..c30c48a 100644 --- a/src/ir/IRBuilder.cpp +++ b/src/ir/IRBuilder.cpp @@ -6,19 +6,8 @@ #include namespace ir { -namespace { - -bool IsArithmeticType(const std::shared_ptr& ty) { - return ty && ty->kind() == Type::Kind::Int32; -} - -bool IsPtrInt32Type(const std::shared_ptr& ty) { - return ty && ty->kind() == Type::Kind::PtrInt32; -} - -} // namespace - -IRBuilder::IRBuilder(BasicBlock* bb) : insertBlock_(bb) {} +IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb) + : ctx_(ctx), insertBlock_(bb) {} void IRBuilder::SetInsertPoint(BasicBlock* bb) { insertBlock_ = bb; } @@ -26,7 +15,7 @@ BasicBlock* IRBuilder::GetInsertBlock() const { return insertBlock_; } ConstantInt* IRBuilder::CreateConstInt(int v) { // 常量不需要挂在基本块里,由 Context 负责去重与生命周期。 - return DefaultContext().GetConstInt(v); + return ctx_.GetConstInt(v); } BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs, @@ -40,16 +29,6 @@ BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs, if (!rhs) { throw std::runtime_error("IRBuilder::CreateBinary 缺少 rhs"); } - if (op != Opcode::Add) { - throw std::runtime_error("IRBuilder::CreateBinary 当前只支持 Add"); - } - if (!lhs->type() || !rhs->type() || - lhs->type()->kind() != rhs->type()->kind()) { - throw std::runtime_error("IRBuilder::CreateBinary 操作数类型不匹配"); - } - if (!IsArithmeticType(lhs->type())) { - throw std::runtime_error("IRBuilder::CreateBinary 当前只支持 i32 二元运算"); - } return insertBlock_->Append(op, lhs->type(), lhs, rhs, name); } @@ -62,7 +41,7 @@ AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { if (!insertBlock_) { throw std::runtime_error("IRBuilder 未设置插入点"); } - return insertBlock_->Append(name); + return insertBlock_->Append(ctx_.PtrInt32(), name); } LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { @@ -72,10 +51,7 @@ LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { if (!ptr) { throw std::runtime_error("IRBuilder::CreateLoad 缺少 ptr"); } - if (!IsPtrInt32Type(ptr->type())) { - throw std::runtime_error("IRBuilder::CreateLoad 当前只支持从 i32* 加载"); - } - return insertBlock_->Append(ptr, name); + return insertBlock_->Append(ctx_.Int32(), ptr, name); } StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { @@ -88,13 +64,7 @@ StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { if (!ptr) { throw std::runtime_error("IRBuilder::CreateStore 缺少 ptr"); } - if (!IsArithmeticType(val->type())) { - throw std::runtime_error("IRBuilder::CreateStore 当前只支持存储 i32"); - } - if (!IsPtrInt32Type(ptr->type())) { - throw std::runtime_error("IRBuilder::CreateStore 当前只支持写入 i32*"); - } - return insertBlock_->Append(val, ptr); + return insertBlock_->Append(ctx_.Void(), val, ptr); } ReturnInst* IRBuilder::CreateRet(Value* v) { @@ -104,7 +74,7 @@ ReturnInst* IRBuilder::CreateRet(Value* v) { if (!v) { throw std::runtime_error("IRBuilder::CreateRet 缺少返回值"); } - return insertBlock_->Append(v); + return insertBlock_->Append(ctx_.Void(), v); } } // namespace ir diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index 1d6b7eb..71b51f3 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -4,9 +4,9 @@ #include "ir/IR.h" -#include -#include +#include #include +#include namespace ir { @@ -49,15 +49,15 @@ static std::string ValueToString(const Value* v) { return v ? v->name() : ""; } -void IRPrinter::Print(const Module& module) { +void IRPrinter::Print(const Module& module, std::ostream& os) { for (const auto& func : module.functions()) { - std::cout << "define " << TypeToString(*func->type()) << " @" - << func->name() << "() {\n"; + os << "define " << TypeToString(*func->type()) << " @" << func->name() + << "() {\n"; for (const auto& bb : func->blocks()) { if (!bb) { continue; } - std::cout << bb->name() << ":\n"; + os << bb->name() << ":\n"; for (const auto& instPtr : bb->instructions()) { const auto* inst = instPtr.get(); switch (inst->opcode()) { @@ -65,39 +65,39 @@ void IRPrinter::Print(const Module& module) { case Opcode::Sub: case Opcode::Mul: { auto* bin = static_cast(inst); - std::cout << " " << bin->name() << " = " << OpcodeToString(bin->opcode()) - << " " << TypeToString(*bin->lhs()->type()) << " " - << ValueToString(bin->lhs()) << ", " - << ValueToString(bin->rhs()) << "\n"; + os << " " << bin->name() << " = " << OpcodeToString(bin->opcode()) + << " " << TypeToString(*bin->lhs()->type()) << " " + << ValueToString(bin->lhs()) << ", " + << ValueToString(bin->rhs()) << "\n"; break; } case Opcode::Alloca: { auto* alloca = static_cast(inst); - std::cout << " " << alloca->name() << " = alloca i32\n"; + os << " " << alloca->name() << " = alloca i32\n"; break; } case Opcode::Load: { auto* load = static_cast(inst); - std::cout << " " << load->name() << " = load i32, i32* " - << ValueToString(load->ptr()) << "\n"; + os << " " << load->name() << " = load i32, i32* " + << ValueToString(load->ptr()) << "\n"; break; } case Opcode::Store: { auto* store = static_cast(inst); - std::cout << " store i32 " << ValueToString(store->value()) << ", i32* " - << ValueToString(store->ptr()) << "\n"; + os << " store i32 " << ValueToString(store->value()) << ", i32* " + << ValueToString(store->ptr()) << "\n"; break; } case Opcode::Ret: { auto* ret = static_cast(inst); - std::cout << " ret " << TypeToString(*ret->value()->type()) << " " - << ValueToString(ret->value()) << "\n"; + os << " ret " << TypeToString(*ret->value()->type()) << " " + << ValueToString(ret->value()) << "\n"; break; } } } } - std::cout << "}\n"; + os << "}\n"; } } diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index a808daa..e9cb5c4 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -6,18 +6,6 @@ #include namespace ir { -namespace { - -bool IsArithmeticType(const std::shared_ptr& ty) { - return ty && ty->kind() == Type::Kind::Int32; -} - -bool IsPtrInt32Type(const std::shared_ptr& ty) { - return ty && ty->kind() == Type::Kind::PtrInt32; -} - -} // namespace - Instruction::Instruction(Opcode op, std::shared_ptr ty, std::string name) : Value(std::move(ty), std::move(name)), opcode_(op) {} @@ -45,7 +33,7 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, type_->kind() != lhs_->type()->kind()) { throw std::runtime_error("BinaryInst 类型不匹配"); } - if (!IsArithmeticType(type_)) { + if (!type_->IsInt32()) { throw std::runtime_error("BinaryInst 当前只支持 i32"); } if (lhs_) { @@ -60,25 +48,37 @@ Value* BinaryInst::lhs() const { return lhs_; } Value* BinaryInst::rhs() const { return rhs_; } -ReturnInst::ReturnInst(Value* val) - : Instruction(Opcode::Ret, Type::Void(), ""), value_(val) { +ReturnInst::ReturnInst(std::shared_ptr void_ty, Value* val) + : Instruction(Opcode::Ret, std::move(void_ty), ""), + value_(val) { if (!value_) { throw std::runtime_error("ReturnInst 缺少返回值"); } + if (!type_ || !type_->IsVoid()) { + throw std::runtime_error("ReturnInst 返回类型必须为 void"); + } value_->AddUser(this); } Value* ReturnInst::value() const { return value_; } -AllocaInst::AllocaInst(std::string name) - : Instruction(Opcode::Alloca, Type::PtrInt32(), std::move(name)) {} +AllocaInst::AllocaInst(std::shared_ptr ptr_ty, std::string name) + : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) { + if (!type_ || !type_->IsPtrInt32()) { + throw std::runtime_error("AllocaInst 当前只支持 i32*"); + } +} -LoadInst::LoadInst(Value* ptr, std::string name) - : Instruction(Opcode::Load, Type::Int32(), std::move(name)), ptr_(ptr) { +LoadInst::LoadInst(std::shared_ptr val_ty, Value* ptr, std::string name) + : Instruction(Opcode::Load, std::move(val_ty), std::move(name)), + ptr_(ptr) { if (!ptr_) { throw std::runtime_error("LoadInst 缺少 ptr"); } - if (!IsPtrInt32Type(ptr_->type())) { + if (!type_ || !type_->IsInt32()) { + throw std::runtime_error("LoadInst 当前只支持加载 i32"); + } + if (!ptr_->type() || !ptr_->type()->IsPtrInt32()) { throw std::runtime_error("LoadInst 当前只支持从 i32* 加载"); } ptr_->AddUser(this); @@ -86,18 +86,23 @@ LoadInst::LoadInst(Value* ptr, std::string name) Value* LoadInst::ptr() const { return ptr_; } -StoreInst::StoreInst(Value* val, Value* ptr) - : Instruction(Opcode::Store, Type::Void(), ""), value_(val), ptr_(ptr) { +StoreInst::StoreInst(std::shared_ptr void_ty, Value* val, Value* ptr) + : Instruction(Opcode::Store, std::move(void_ty), ""), + value_(val), + ptr_(ptr) { if (!value_) { throw std::runtime_error("StoreInst 缺少 value"); } if (!ptr_) { throw std::runtime_error("StoreInst 缺少 ptr"); } - if (!IsArithmeticType(value_->type())) { + if (!type_ || !type_->IsVoid()) { + throw std::runtime_error("StoreInst 返回类型必须为 void"); + } + if (!value_->type() || !value_->type()->IsInt32()) { throw std::runtime_error("StoreInst 当前只支持存储 i32"); } - if (!IsPtrInt32Type(ptr_->type())) { + if (!ptr_->type() || !ptr_->type()->IsPtrInt32()) { throw std::runtime_error("StoreInst 当前只支持写入 i32*"); } value_->AddUser(this); diff --git a/src/ir/Module.cpp b/src/ir/Module.cpp index 0a2df76..3e5d455 100644 --- a/src/ir/Module.cpp +++ b/src/ir/Module.cpp @@ -6,6 +6,10 @@ namespace ir { +Context& Module::context() { return context_; } + +const Context& Module::context() const { return context_; } + Function* Module::CreateFunction(const std::string& name, std::shared_ptr ret_type) { functions_.push_back(std::make_unique(name, std::move(ret_type))); diff --git a/src/ir/Type.cpp b/src/ir/Type.cpp index 4712485..1dbcc33 100644 --- a/src/ir/Type.cpp +++ b/src/ir/Type.cpp @@ -10,10 +10,10 @@ Type::Type(Kind k) : kind_(k) {} Type::Kind Type::kind() const { return kind_; } -std::shared_ptr Type::Void() { return DefaultContext().Void(); } +bool Type::IsVoid() const { return kind_ == Kind::Void; } -std::shared_ptr Type::Int32() { return DefaultContext().Int32(); } +bool Type::IsInt32() const { return kind_ == Kind::Int32; } -std::shared_ptr Type::PtrInt32() { return DefaultContext().PtrInt32(); } +bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; } } // namespace ir diff --git a/src/ir/Value.cpp b/src/ir/Value.cpp index 9bd2fa7..2dc23ae 100644 --- a/src/ir/Value.cpp +++ b/src/ir/Value.cpp @@ -18,6 +18,7 @@ void Value::AddUser(Instruction* user) { users_.push_back(user); } const std::vector& Value::users() const { return users_; } -ConstantInt::ConstantInt(int v) : Value(Type::Int32(), ""), value_(v) {} +ConstantInt::ConstantInt(std::shared_ptr ty, int v) + : Value(std::move(ty), ""), value_(v) {} } // namespace ir diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index 72c8cfb..1ef1876 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -39,14 +39,14 @@ void IRGenImpl::GenVarDecl(SysYParser::VarDeclContext& decl) { if (storage_map_.find(&decl) != storage_map_.end()) { throw std::runtime_error("[irgen] 声明重复生成存储槽位"); } - auto* slot = builder_.CreateAllocaI32(ir::DefaultContext().NextTemp()); + auto* slot = builder_.CreateAllocaI32(module_.context().NextTemp()); storage_map_[&decl] = slot; ir::Value* init = nullptr; if (decl.exp()) { init = GenExpr(*decl.exp()); } else { - init = ir::DefaultContext().GetConstInt(0); + init = builder_.CreateConstInt(0); } builder_.CreateStore(init, slot); } diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index 8b2edc8..be90ecd 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -20,7 +20,10 @@ void VerifyFunctionStructure(const ir::Function& func) { } // namespace IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema) - : module_(module), sema_(sema), func_(nullptr), builder_(nullptr) {} + : module_(module), + sema_(sema), + func_(nullptr), + builder_(module.context(), nullptr) {} void IRGenImpl::Gen(SysYParser::CompUnitContext& cu) { if (!cu.funcDef()) { @@ -37,7 +40,9 @@ void IRGenImpl::GenFuncDef(SysYParser::FuncDefContext& func) { throw std::runtime_error("[irgen] 缺少函数名"); } - func_ = module_.CreateFunction(func.Ident()->getText(), ir::Type::Int32()); + func_ = module_.CreateFunction( + func.Ident()->getText(), + std::make_shared(ir::Type::Kind::Int32)); builder_.SetInsertPoint(func_->entry()); storage_map_.clear(); diff --git a/src/main.cpp b/src/main.cpp index c798420..5b1ebc3 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -38,7 +38,7 @@ int main(int argc, char** argv) { if (need_blank_line) { std::cout << "\n"; } - printer.Print(*module); + printer.Print(*module, std::cout); need_blank_line = true; }