From 026717d3d839a4a53f877a59bf79b3c945f17537 Mon Sep 17 00:00:00 2001 From: Su Xing Date: Fri, 24 Mar 2023 10:12:41 +0800 Subject: [PATCH] Refine IR and IRBuilder --- src/IR.h | 37 ++++++++++++++++++++++++++++--------- src/IRBuilder.h | 5 +++++ 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/src/IR.h b/src/IR.h index 29abe4a..581f5b8 100644 --- a/src/IR.h +++ b/src/IR.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -597,21 +598,39 @@ public: class Module { protected: - std::list> functions; - std::list> globals; + std::map> functions; + std::map> globals; + // std::list> functions; + // std::list> globals; public: Module() = default; public: - Function *addFunction(Type *type, const std::string &name) { - functions.emplace_back(new Function(type, name)); - return functions.back().get(); + Function *createFunction(const std::string &name, Type *type) { + auto result = functions.try_emplace(name, new Function(type, name)); + if (not result.second) + return nullptr; + return result.first->second.get(); }; - GlobalValue *addGlobalValue(Type *type, const std::vector &dims = {}, - const std::string &name = "") { - globals.emplace_back(new GlobalValue(type, dims, name)); - return globals.back().get(); + GlobalValue *createGlobalValue(const std::string &name, Type *type, + const std::vector &dims = {}) { + auto result = globals.try_emplace(name, new GlobalValue(type, dims, name)); + if (not result.second) + return nullptr; + return result.first->second.get(); + } + Function *getFunction(const std::string &name) const { + auto result = functions.find(name); + if (result == functions.end()) + return nullptr; + return result->second.get(); + } + GlobalValue *getGlobalValue(const std::string &name) const { + auto result = globals.find(name); + if (result == globals.end()) + return nullptr; + return result->second.get(); } }; // class Module diff --git a/src/IRBuilder.h b/src/IRBuilder.h index 3608b6b..806c046 100644 --- a/src/IRBuilder.h +++ b/src/IRBuilder.h @@ -12,6 +12,7 @@ private: BasicBlock::iterator position; public: + IRBuilder() = default; IRBuilder(BasicBlock *block) : block(block), position(block->end()) {} IRBuilder(BasicBlock *block, BasicBlock::iterator position) : block(block), position(position) {} @@ -19,6 +20,10 @@ public: public: BasicBlock *getBasicBlock() const { return block; } BasicBlock::iterator getPosition() const { return position; } + void setPosition(BasicBlock *block, BasicBlock::iterator position) { + this->block = block; + this->position = position; + } void setPosition(BasicBlock::iterator position) { this->position = position; } public: