From d94dce0488d4b7b8941d6c9cc5919ffca348783d Mon Sep 17 00:00:00 2001 From: Su Xing Date: Sun, 26 Mar 2023 15:35:03 +0800 Subject: [PATCH] Refine IR --- src/IR.cpp | 38 +++++++++++++++------- src/IR.h | 94 ++++++++++++++++++++++++++++++++++++------------------ 2 files changed, 89 insertions(+), 43 deletions(-) diff --git a/src/IR.cpp b/src/IR.cpp index 99dd5cb..190bf3d 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -45,6 +45,21 @@ Type *Type::getFunctionType(Type *returnType, return FunctionType::get(returnType, paramTypes); } +int Type::getSize() const { + switch (kind) { + case kInt: + case kFloat: + return 4; + case kLabel: + case kPointer: + case kFunction: + return 8; + case kVoid: + return 0; + } + return 0; +} + PointerType *PointerType::get(Type *baseType) { static std::map> pointerTypes; auto iter = pointerTypes.find(baseType); @@ -59,16 +74,15 @@ PointerType *PointerType::get(Type *baseType) { FunctionType *FunctionType::get(Type *returnType, const std::vector ¶mTypes) { static std::set> functionTypes; - auto iter = std::find_if(functionTypes.begin(), functionTypes.end(), - [&](const std::unique_ptr &type) -> bool { - if (returnType != type->getReturnType() or - paramTypes.size() != type->getParamTypes().size()) - return false; - for (int i = 0; i < paramTypes.size(); ++i) - if (paramTypes[i] != type->getParamTypes()[i]) - return false; - return true; - }); + auto iter = + std::find_if(functionTypes.begin(), functionTypes.end(), + [&](const std::unique_ptr &type) -> bool { + if (returnType != type->getReturnType() or + paramTypes.size() != type->getParamTypes().size()) + return false; + return std::equal(paramTypes.begin(), paramTypes.end(), + type->getParamTypes().begin()); + }); if (iter != functionTypes.end()) return iter->get(); auto type = new FunctionType(returnType, paramTypes); @@ -83,7 +97,7 @@ void Value::replaceAllUsesWith(Value *value) { uses.clear(); } -ConstantValue *ConstantValue::getInt(int value, const std::string &name) { +ConstantValue *ConstantValue::get(int value, const std::string &name) { static std::map> intConstants; auto iter = intConstants.find(value); if (iter != intConstants.end()) @@ -94,7 +108,7 @@ ConstantValue *ConstantValue::getInt(int value, const std::string &name) { return result.first->second.get(); } -ConstantValue *ConstantValue::getFloat(float value, const std::string &name) { +ConstantValue *ConstantValue::get(float value, const std::string &name) { static std::map> floatConstants; auto iter = floatConstants.find(value); if (iter != floatConstants.end()) diff --git a/src/IR.h b/src/IR.h index 204f9c7..c670834 100644 --- a/src/IR.h +++ b/src/IR.h @@ -22,11 +22,21 @@ template struct range { explicit range(iterator b, iterator e) : b(b), e(e) {} iterator begin() { return b; } iterator end() { return e; } + auto size() const { return std::distance(b, e); } + auto empty() const { return b == e; } }; template range make_range(IterT b, IterT e) { return range(b, e); } +template +range make_range(ContainerT &c) { + return make_range(c.begin(), c.end()); +} +template +range make_range(const ContainerT &c) { + return make_range(c.begin(), c.end()); +} //===----------------------------------------------------------------------===// // Types @@ -66,7 +76,6 @@ public: static Type *getPointerType(Type *baseType); static Type *getFunctionType(Type *returnType, const std::vector ¶mTypes = {}); - static int getTypeSize(); public: Kind getKind() const { return kind; } @@ -76,6 +85,7 @@ public: bool isLabel() const { return kind == kLabel; } bool isPointer() const { return kind == kPointer; } bool isFunction() const { return kind == kFunction; } + int getSize() const; }; // class Type class PointerType : public Type { @@ -107,7 +117,7 @@ public: public: Type *getReturnType() const { return returnType; } - const std::vector &getParamTypes() const { return paramTypes; } + auto getParamTypes() const { return make_range(paramTypes); } int getNumParams() const { return paramTypes.size(); } }; // class FunctionType @@ -120,6 +130,7 @@ public: class User; class Value; +// `Use` represents the relation between a `Value` and its `User` class Use { public: enum Kind { @@ -173,8 +184,6 @@ public: }; // class Value class ConstantValue : public Value { - friend class IRBuilder; - protected: union { int iConstant; @@ -188,8 +197,8 @@ protected: : Value(Type::getFloatType(), name), fConstant(value) {} public: - static ConstantValue *getInt(int value, const std::string &name = ""); - static ConstantValue *getFloat(float value, const std::string &name = ""); + static ConstantValue *get(int value, const std::string &name = ""); + static ConstantValue *get(float value, const std::string &name = ""); public: int getInt() const { @@ -238,6 +247,7 @@ protected: arguments(), successors(), predecessors() {} public: + int getNumInstructions() const { return instructions.size(); } int getNumArguments() const { return arguments.size(); } int getNumPredecessors() const { return predecessors.size(); } int getNumSuccessors() const { return successors.size(); } @@ -259,15 +269,27 @@ protected: User(Type *type, const std::string &name = "") : Value(type, name), operands() {} +public: + struct operand_iterator : std::vector::iterator { + using Base = std::vector::iterator; + using Base::Base; + using value_type = Value *; + value_type operator->() { return operator*().getValue(); } + }; + public: int getNumOperands() const { return operands.size(); } - const std::vector &getOperands() const { return operands; } - const Use &getOperand(int index) const { return operands[index]; } + auto operand_begin() const { return operands.begin(); } + auto operand_end() const { return operands.end(); } + auto getOperands() const { + return make_range(operand_begin(), operand_end()); + } + Value *getOperand(int index) const { return operands[index].getValue(); } void addOperand(Value *value, Use::Kind mode = Use::kRead) { operands.emplace_back(mode, operands.size(), this, value); value->addUse(&operands.back()); } - void addOperands(const std::vector &operands) { + template void addOperands(const ContainerT &operands) { for (auto value : operands) addOperand(value); } @@ -343,15 +365,19 @@ public: (kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE); return kind & CondOpMask; } - bool isTerminator() { + bool isTerminator() const { static constexpr uint64_t TerminatorOpMask = kCondBr | kBr | kReturn; return kind & TerminatorOpMask; } - bool isCommutative() { + bool isCommutative() const { static constexpr uint64_t CommutativeOpMask = kAdd | kMul | kICmpEQ | kICmpNE | kFAdd | kFMul | kFCmpEQ | kFCmpNE; return kind & CommutativeOpMask; } + bool isMemory() const { + static constexpr uint64_t MemoryOpMask = kAlloca | kLoad | kStore; + return kind & MemoryOpMask; + } }; // class Instruction class Function; @@ -365,7 +391,7 @@ protected: public: Function *getCallee(); auto getArguments() { - return make_range(std::next(operands.begin()), operands.end()); + return make_range(std::next(operand_begin()), operand_end()); } }; // class CallInst @@ -380,7 +406,7 @@ protected: } public: - Value *getOperand() const { return operands[0].getValue(); } + Value *getOperand() const { return User::getOperand(0); } }; // class UnaryInst class BinaryInst : public Instruction { @@ -395,8 +421,8 @@ protected: } public: - Value *getLhs() const { return operands[0].getValue(); } - Value *getRhs() const { return operands[1].getValue(); } + Value *getLhs() const { return getOperand(0); } + Value *getRhs() const { return getOperand(1); } }; // class BinaryInst class TerminatorInst : public Instruction { @@ -413,6 +439,12 @@ protected: if (value) addOperand(value); } + +public: + bool hasReturnValue() const { return not operands.empty(); } + Value *getReturnValue() const { + return hasReturnValue() ? getOperand(0) : nullptr; + } }; // class ReturnInst class BranchInst : public TerminatorInst { @@ -438,7 +470,7 @@ protected: public: BasicBlock *getBlock() const { - return dynamic_cast(getOperand(0).getValue()); + return dynamic_cast(getOperand(0)); } auto getArguments() const { return make_range(std::next(operands.begin()), operands.end()); @@ -464,10 +496,10 @@ protected: public: BasicBlock *getThenBlock() const { - return dynamic_cast(getOperand(0).getValue()); + return dynamic_cast(getOperand(0)); } BasicBlock *getElseBlock() const { - return dynamic_cast(getOperand(1).getValue()); + return dynamic_cast(getOperand(1)); } auto getThenArguments() const { auto begin = operands.begin() + 2; @@ -498,8 +530,8 @@ protected: public: int getNumDims() const { return getNumOperands(); } - auto &getDims() const { return operands; } - Value *getDim(int index) { return getOperand(index).getValue(); } + auto getDims() const { return getOperands(); } + Value *getDim(int index) { return getOperand(index); } }; // class AllocaInst class LoadInst : public MemoryInst { @@ -517,11 +549,11 @@ protected: public: int getNumIndices() const { return getNumOperands() - 1; } - Value *getPointer() const { return operands.front().getValue(); } + Value *getPointer() const { return getOperand(0); } auto getIndices() const { - return make_range(std::next(operands.begin()), operands.end()); + return make_range(std::next(operand_begin()), operand_end()); } - Value *getIndex(int index) const { return getOperand(index + 1).getValue(); } + Value *getIndex(int index) const { return getOperand(index + 1); } }; // class LoadInst class StoreInst : public MemoryInst { @@ -539,12 +571,12 @@ protected: public: int getNumIndices() const { return getNumOperands() - 2; } - Value *getValue() const { return operands[0].getValue(); } - Value *getPointer() const { return operands[1].getValue(); } + Value *getValue() const { return getOperand(0); } + Value *getPointer() const { return getOperand(1); } auto getIndices() const { - return make_range(operands.begin() + 2, operands.end()); + return make_range(operand_begin() + 2, operand_end()); } - Value *getIndex(int index) const { return getOperand(index + 2).getValue(); } + Value *getIndex(int index) const { return getOperand(index + 2); } }; // class StoreInst class Module; @@ -568,10 +600,10 @@ public: Type *getReturnType() const { return dynamic_cast(getType())->getReturnType(); } - auto &getParamTypes() const { + auto getParamTypes() const { return dynamic_cast(getType())->getParamTypes(); } - block_list &getBasicBlocks() { return blocks; } + auto getBasicBlocks() { return make_range(blocks); } BasicBlock *getEntryBlock() { return blocks.front().get(); } BasicBlock *addBasicBlock(const std::string &name = "") { blocks.emplace_back(new BasicBlock(this, name)); @@ -599,7 +631,7 @@ protected: public: int getNumDims() const { return getNumOperands(); } - Value *getDim(int index) { return getOperand(index).getValue(); } + Value *getDim(int index) { return getOperand(index); } }; // class GlobalValue class Module { @@ -648,7 +680,7 @@ inline CallInst::CallInst(Function *callee, const std::vector args, } inline Function *CallInst::getCallee() { - return dynamic_cast(getOperand(0).getValue()); + return dynamic_cast(getOperand(0)); } } // namespace sysy \ No newline at end of file