fix(ir): 修改了一下context的管理

master
jing 1 week ago
parent fab6983d40
commit 0e5a75eaf3

@ -12,6 +12,10 @@ BasicBlock::BasicBlock(std::string name) : name_(std::move(name)) {}
const std::string& BasicBlock::name() const { return name_; }
Function* BasicBlock::parent() const { return parent_; }
void BasicBlock::set_parent(Function* parent) { parent_ = parent; }
bool BasicBlock::HasTerminator() const {
return !instructions_.empty() && instructions_.back()->IsTerminator();
}
@ -21,4 +25,12 @@ const std::vector<std::unique_ptr<Instruction>>& BasicBlock::instructions()
return instructions_;
}
const std::vector<BasicBlock*>& BasicBlock::predecessors() const {
return predecessors_;
}
const std::vector<BasicBlock*>& BasicBlock::successors() const {
return successors_;
}
} // namespace ir

@ -7,11 +7,6 @@
namespace ir {
Context& DefaultContext() {
static Context ctx;
return ctx;
}
Context::~Context() = default;
const std::shared_ptr<Type>& Context::Void() {
@ -39,7 +34,7 @@ ConstantInt* Context::GetConstInt(int v) {
auto it = const_ints_.find(v);
if (it != const_ints_.end()) return it->second.get();
auto inserted =
const_ints_.emplace(v, std::make_unique<ConstantInt>(v)).first;
const_ints_.emplace(v, std::make_unique<ConstantInt>(Int32(), v)).first;
return inserted->second.get();
}

@ -13,6 +13,7 @@ Function::Function(std::string name, std::shared_ptr<Type> ret_type)
BasicBlock* Function::CreateBlock(const std::string& name) {
auto block = std::make_unique<BasicBlock>(name);
auto* ptr = block.get();
ptr->set_parent(this);
blocks_.push_back(std::move(block));
if (!entry_) {
entry_ = ptr;

@ -2,6 +2,7 @@
// 可在此基础上扩展更多类型/指令
#pragma once
#include <iosfwd>
#include <memory>
#include <stdexcept>
#include <string>
@ -15,10 +16,12 @@ class Type;
class ConstantInt;
class Instruction;
class BasicBlock;
class Function;
// IR 上下文:集中管理类型、常量等共享资源,便于复用与扩展。
class Context {
public:
Context() = default;
~Context();
const std::shared_ptr<Type>& Void();
const std::shared_ptr<Type>& Int32();
@ -36,16 +39,14 @@ class Context {
int temp_index_ = -1;
};
Context& DefaultContext();
class Type {
public:
enum class Kind { Void, Int32, PtrInt32 };
explicit Type(Kind k);
Kind kind() const;
static std::shared_ptr<Type> Void();
static std::shared_ptr<Type> Int32();
static std::shared_ptr<Type> PtrInt32();
bool IsVoid() const;
bool IsInt32() const;
bool IsPtrInt32() const;
private:
Kind kind_;
@ -69,7 +70,7 @@ class Value {
class ConstantInt : public Value {
public:
explicit ConstantInt(int v);
ConstantInt(std::shared_ptr<Type> ty, int v);
int value() const { return value_; }
private:
@ -106,7 +107,7 @@ class BinaryInst : public Instruction {
class ReturnInst : public Instruction {
public:
explicit ReturnInst(Value* val);
ReturnInst(std::shared_ptr<Type> void_ty, Value* val);
Value* value() const;
private:
@ -115,12 +116,12 @@ class ReturnInst : public Instruction {
class AllocaInst : public Instruction {
public:
explicit AllocaInst(std::string name);
AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name);
};
class LoadInst : public Instruction {
public:
LoadInst(Value* ptr, std::string name);
LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name);
Value* ptr() const;
private:
@ -129,7 +130,7 @@ class LoadInst : public Instruction {
class StoreInst : public Instruction {
public:
StoreInst(Value* val, Value* ptr);
StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr);
Value* value() const;
Value* ptr() const;
@ -142,8 +143,12 @@ class BasicBlock {
public:
explicit BasicBlock(std::string name);
const std::string& name() const;
Function* parent() const;
void set_parent(Function* parent);
bool HasTerminator() const;
const std::vector<std::unique_ptr<Instruction>>& instructions() const;
const std::vector<BasicBlock*>& predecessors() const;
const std::vector<BasicBlock*>& successors() const;
template <typename T, typename... Args>
T* Append(Args&&... args) {
if (HasTerminator()) {
@ -159,7 +164,10 @@ class BasicBlock {
private:
std::string name_;
Function* parent_ = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions_;
std::vector<BasicBlock*> predecessors_;
std::vector<BasicBlock*> successors_;
};
class Function : public Value {
@ -178,18 +186,22 @@ class Function : public Value {
class Module {
public:
Module() = default;
Context& context();
const Context& context() const;
// 创建函数时显式传入返回类型,便于在 IRGen 中根据语法树信息选择类型。
Function* CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type);
const std::vector<std::unique_ptr<Function>>& functions() const;
private:
Context context_;
std::vector<std::unique_ptr<Function>> functions_;
};
class IRBuilder {
public:
explicit IRBuilder(BasicBlock* bb);
IRBuilder(Context& ctx, BasicBlock* bb);
void SetInsertPoint(BasicBlock* bb);
BasicBlock* GetInsertBlock() const;
@ -204,12 +216,13 @@ class IRBuilder {
ReturnInst* CreateRet(Value* v);
private:
Context& ctx_;
BasicBlock* insertBlock_;
};
class IRPrinter {
public:
void Print(const Module& module);
void Print(const Module& module, std::ostream& os);
};
} // namespace ir

@ -6,19 +6,8 @@
#include <stdexcept>
namespace ir {
namespace {
bool IsArithmeticType(const std::shared_ptr<Type>& ty) {
return ty && ty->kind() == Type::Kind::Int32;
}
bool IsPtrInt32Type(const std::shared_ptr<Type>& ty) {
return ty && ty->kind() == Type::Kind::PtrInt32;
}
} // namespace
IRBuilder::IRBuilder(BasicBlock* bb) : insertBlock_(bb) {}
IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb)
: ctx_(ctx), insertBlock_(bb) {}
void IRBuilder::SetInsertPoint(BasicBlock* bb) { insertBlock_ = bb; }
@ -26,7 +15,7 @@ BasicBlock* IRBuilder::GetInsertBlock() const { return insertBlock_; }
ConstantInt* IRBuilder::CreateConstInt(int v) {
// 常量不需要挂在基本块里,由 Context 负责去重与生命周期。
return DefaultContext().GetConstInt(v);
return ctx_.GetConstInt(v);
}
BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs,
@ -40,16 +29,6 @@ BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs,
if (!rhs) {
throw std::runtime_error("IRBuilder::CreateBinary 缺少 rhs");
}
if (op != Opcode::Add) {
throw std::runtime_error("IRBuilder::CreateBinary 当前只支持 Add");
}
if (!lhs->type() || !rhs->type() ||
lhs->type()->kind() != rhs->type()->kind()) {
throw std::runtime_error("IRBuilder::CreateBinary 操作数类型不匹配");
}
if (!IsArithmeticType(lhs->type())) {
throw std::runtime_error("IRBuilder::CreateBinary 当前只支持 i32 二元运算");
}
return insertBlock_->Append<BinaryInst>(op, lhs->type(), lhs, rhs, name);
}
@ -62,7 +41,7 @@ AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
if (!insertBlock_) {
throw std::runtime_error("IRBuilder 未设置插入点");
}
return insertBlock_->Append<AllocaInst>(name);
return insertBlock_->Append<AllocaInst>(ctx_.PtrInt32(), name);
}
LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
@ -72,10 +51,7 @@ LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
if (!ptr) {
throw std::runtime_error("IRBuilder::CreateLoad 缺少 ptr");
}
if (!IsPtrInt32Type(ptr->type())) {
throw std::runtime_error("IRBuilder::CreateLoad 当前只支持从 i32* 加载");
}
return insertBlock_->Append<LoadInst>(ptr, name);
return insertBlock_->Append<LoadInst>(ctx_.Int32(), ptr, name);
}
StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
@ -88,13 +64,7 @@ StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
if (!ptr) {
throw std::runtime_error("IRBuilder::CreateStore 缺少 ptr");
}
if (!IsArithmeticType(val->type())) {
throw std::runtime_error("IRBuilder::CreateStore 当前只支持存储 i32");
}
if (!IsPtrInt32Type(ptr->type())) {
throw std::runtime_error("IRBuilder::CreateStore 当前只支持写入 i32*");
}
return insertBlock_->Append<StoreInst>(val, ptr);
return insertBlock_->Append<StoreInst>(ctx_.Void(), val, ptr);
}
ReturnInst* IRBuilder::CreateRet(Value* v) {
@ -104,7 +74,7 @@ ReturnInst* IRBuilder::CreateRet(Value* v) {
if (!v) {
throw std::runtime_error("IRBuilder::CreateRet 缺少返回值");
}
return insertBlock_->Append<ReturnInst>(v);
return insertBlock_->Append<ReturnInst>(ctx_.Void(), v);
}
} // namespace ir

@ -4,9 +4,9 @@
#include "ir/IR.h"
#include <iostream>
#include <string>
#include <ostream>
#include <stdexcept>
#include <string>
namespace ir {
@ -49,15 +49,15 @@ static std::string ValueToString(const Value* v) {
return v ? v->name() : "<null>";
}
void IRPrinter::Print(const Module& module) {
void IRPrinter::Print(const Module& module, std::ostream& os) {
for (const auto& func : module.functions()) {
std::cout << "define " << TypeToString(*func->type()) << " @"
<< func->name() << "() {\n";
os << "define " << TypeToString(*func->type()) << " @" << func->name()
<< "() {\n";
for (const auto& bb : func->blocks()) {
if (!bb) {
continue;
}
std::cout << bb->name() << ":\n";
os << bb->name() << ":\n";
for (const auto& instPtr : bb->instructions()) {
const auto* inst = instPtr.get();
switch (inst->opcode()) {
@ -65,39 +65,39 @@ void IRPrinter::Print(const Module& module) {
case Opcode::Sub:
case Opcode::Mul: {
auto* bin = static_cast<const BinaryInst*>(inst);
std::cout << " " << bin->name() << " = " << OpcodeToString(bin->opcode())
<< " " << TypeToString(*bin->lhs()->type()) << " "
<< ValueToString(bin->lhs()) << ", "
<< ValueToString(bin->rhs()) << "\n";
os << " " << bin->name() << " = " << OpcodeToString(bin->opcode())
<< " " << TypeToString(*bin->lhs()->type()) << " "
<< ValueToString(bin->lhs()) << ", "
<< ValueToString(bin->rhs()) << "\n";
break;
}
case Opcode::Alloca: {
auto* alloca = static_cast<const AllocaInst*>(inst);
std::cout << " " << alloca->name() << " = alloca i32\n";
os << " " << alloca->name() << " = alloca i32\n";
break;
}
case Opcode::Load: {
auto* load = static_cast<const LoadInst*>(inst);
std::cout << " " << load->name() << " = load i32, i32* "
<< ValueToString(load->ptr()) << "\n";
os << " " << load->name() << " = load i32, i32* "
<< ValueToString(load->ptr()) << "\n";
break;
}
case Opcode::Store: {
auto* store = static_cast<const StoreInst*>(inst);
std::cout << " store i32 " << ValueToString(store->value()) << ", i32* "
<< ValueToString(store->ptr()) << "\n";
os << " store i32 " << ValueToString(store->value()) << ", i32* "
<< ValueToString(store->ptr()) << "\n";
break;
}
case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst);
std::cout << " ret " << TypeToString(*ret->value()->type()) << " "
<< ValueToString(ret->value()) << "\n";
os << " ret " << TypeToString(*ret->value()->type()) << " "
<< ValueToString(ret->value()) << "\n";
break;
}
}
}
}
std::cout << "}\n";
os << "}\n";
}
}

@ -6,18 +6,6 @@
#include <stdexcept>
namespace ir {
namespace {
bool IsArithmeticType(const std::shared_ptr<Type>& ty) {
return ty && ty->kind() == Type::Kind::Int32;
}
bool IsPtrInt32Type(const std::shared_ptr<Type>& ty) {
return ty && ty->kind() == Type::Kind::PtrInt32;
}
} // namespace
Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name)
: Value(std::move(ty), std::move(name)), opcode_(op) {}
@ -45,7 +33,7 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
type_->kind() != lhs_->type()->kind()) {
throw std::runtime_error("BinaryInst 类型不匹配");
}
if (!IsArithmeticType(type_)) {
if (!type_->IsInt32()) {
throw std::runtime_error("BinaryInst 当前只支持 i32");
}
if (lhs_) {
@ -60,25 +48,37 @@ Value* BinaryInst::lhs() const { return lhs_; }
Value* BinaryInst::rhs() const { return rhs_; }
ReturnInst::ReturnInst(Value* val)
: Instruction(Opcode::Ret, Type::Void(), ""), value_(val) {
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
: Instruction(Opcode::Ret, std::move(void_ty), ""),
value_(val) {
if (!value_) {
throw std::runtime_error("ReturnInst 缺少返回值");
}
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error("ReturnInst 返回类型必须为 void");
}
value_->AddUser(this);
}
Value* ReturnInst::value() const { return value_; }
AllocaInst::AllocaInst(std::string name)
: Instruction(Opcode::Alloca, Type::PtrInt32(), std::move(name)) {}
AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name)
: Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) {
if (!type_ || !type_->IsPtrInt32()) {
throw std::runtime_error("AllocaInst 当前只支持 i32*");
}
}
LoadInst::LoadInst(Value* ptr, std::string name)
: Instruction(Opcode::Load, Type::Int32(), std::move(name)), ptr_(ptr) {
LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name)
: Instruction(Opcode::Load, std::move(val_ty), std::move(name)),
ptr_(ptr) {
if (!ptr_) {
throw std::runtime_error("LoadInst 缺少 ptr");
}
if (!IsPtrInt32Type(ptr_->type())) {
if (!type_ || !type_->IsInt32()) {
throw std::runtime_error("LoadInst 当前只支持加载 i32");
}
if (!ptr_->type() || !ptr_->type()->IsPtrInt32()) {
throw std::runtime_error("LoadInst 当前只支持从 i32* 加载");
}
ptr_->AddUser(this);
@ -86,18 +86,23 @@ LoadInst::LoadInst(Value* ptr, std::string name)
Value* LoadInst::ptr() const { return ptr_; }
StoreInst::StoreInst(Value* val, Value* ptr)
: Instruction(Opcode::Store, Type::Void(), ""), value_(val), ptr_(ptr) {
StoreInst::StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr)
: Instruction(Opcode::Store, std::move(void_ty), ""),
value_(val),
ptr_(ptr) {
if (!value_) {
throw std::runtime_error("StoreInst 缺少 value");
}
if (!ptr_) {
throw std::runtime_error("StoreInst 缺少 ptr");
}
if (!IsArithmeticType(value_->type())) {
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error("StoreInst 返回类型必须为 void");
}
if (!value_->type() || !value_->type()->IsInt32()) {
throw std::runtime_error("StoreInst 当前只支持存储 i32");
}
if (!IsPtrInt32Type(ptr_->type())) {
if (!ptr_->type() || !ptr_->type()->IsPtrInt32()) {
throw std::runtime_error("StoreInst 当前只支持写入 i32*");
}
value_->AddUser(this);

@ -6,6 +6,10 @@
namespace ir {
Context& Module::context() { return context_; }
const Context& Module::context() const { return context_; }
Function* Module::CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type) {
functions_.push_back(std::make_unique<Function>(name, std::move(ret_type)));

@ -10,10 +10,10 @@ Type::Type(Kind k) : kind_(k) {}
Type::Kind Type::kind() const { return kind_; }
std::shared_ptr<Type> Type::Void() { return DefaultContext().Void(); }
bool Type::IsVoid() const { return kind_ == Kind::Void; }
std::shared_ptr<Type> Type::Int32() { return DefaultContext().Int32(); }
bool Type::IsInt32() const { return kind_ == Kind::Int32; }
std::shared_ptr<Type> Type::PtrInt32() { return DefaultContext().PtrInt32(); }
bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; }
} // namespace ir

@ -18,6 +18,7 @@ void Value::AddUser(Instruction* user) { users_.push_back(user); }
const std::vector<Instruction*>& Value::users() const { return users_; }
ConstantInt::ConstantInt(int v) : Value(Type::Int32(), ""), value_(v) {}
ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int v)
: Value(std::move(ty), ""), value_(v) {}
} // namespace ir

@ -39,14 +39,14 @@ void IRGenImpl::GenVarDecl(SysYParser::VarDeclContext& decl) {
if (storage_map_.find(&decl) != storage_map_.end()) {
throw std::runtime_error("[irgen] 声明重复生成存储槽位");
}
auto* slot = builder_.CreateAllocaI32(ir::DefaultContext().NextTemp());
auto* slot = builder_.CreateAllocaI32(module_.context().NextTemp());
storage_map_[&decl] = slot;
ir::Value* init = nullptr;
if (decl.exp()) {
init = GenExpr(*decl.exp());
} else {
init = ir::DefaultContext().GetConstInt(0);
init = builder_.CreateConstInt(0);
}
builder_.CreateStore(init, slot);
}

@ -20,7 +20,10 @@ void VerifyFunctionStructure(const ir::Function& func) {
} // namespace
IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema)
: module_(module), sema_(sema), func_(nullptr), builder_(nullptr) {}
: module_(module),
sema_(sema),
func_(nullptr),
builder_(module.context(), nullptr) {}
void IRGenImpl::Gen(SysYParser::CompUnitContext& cu) {
if (!cu.funcDef()) {
@ -37,7 +40,9 @@ void IRGenImpl::GenFuncDef(SysYParser::FuncDefContext& func) {
throw std::runtime_error("[irgen] 缺少函数名");
}
func_ = module_.CreateFunction(func.Ident()->getText(), ir::Type::Int32());
func_ = module_.CreateFunction(
func.Ident()->getText(),
std::make_shared<ir::Type>(ir::Type::Kind::Int32));
builder_.SetInsertPoint(func_->entry());
storage_map_.clear();

@ -38,7 +38,7 @@ int main(int argc, char** argv) {
if (need_blank_line) {
std::cout << "\n";
}
printer.Print(*module);
printer.Print(*module, std::cout);
need_blank_line = true;
}

Loading…
Cancel
Save