diff --git a/src/IR.cpp b/src/IR.cpp index bee67fc..0e5312a 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -4,8 +4,8 @@ #include #include #include -#include #include +#include #include namespace sysy { @@ -47,24 +47,30 @@ Type *Type::getFunctionType(Type *returnType, PointerType *PointerType::get(Type *baseType) { static std::map> pointerTypes; + auto iter = pointerTypes.find(baseType); + if (iter != pointerTypes.end()) + return iter->second.get(); auto type = new PointerType(baseType); assert(type); auto result = pointerTypes.try_emplace(baseType, type); return result.first->second.get(); } -bool operator<(const FunctionType &lhs, const FunctionType &rhs) { - return lhs.getReturnType() < rhs.getReturnType() or - lhs.getParamTypes().size() < rhs.getParamTypes().size() and - std::lexicographical_compare(lhs.getParamTypes().begin(), - lhs.getParamTypes().end(), - rhs.getParamTypes().begin(), - rhs.getParamTypes().end()); -} - FunctionType *FunctionType::get(Type *returnType, const std::vector ¶mTypes) { static std::set> functionTypes; + auto iter = std::find_if(functionTypes.begin(), functionTypes.end(), + [&](const std::unique_ptr &type) -> bool { + if (returnType != type->getReturnType() or + paramTypes.size() != type->getParamTypes().size()) + return false; + for (int i = 0; i < paramTypes.size(); ++i) + if (paramTypes[i] != type->getParamTypes()[i]) + return false; + return true; + }); + if (iter != functionTypes.end()) + return iter->get(); auto type = new FunctionType(returnType, paramTypes); assert(type); auto result = functionTypes.emplace(type); @@ -77,6 +83,19 @@ void Value::replaceAllUsesWith(Value *value) { uses.clear(); } +ConstantValue *ConstantValue::getInt(int value, const std::string &name) { + static std::map intConstants; + + auto inst = new ConstantValue(value); + assert(inst); + return inst; +} +ConstantValue *ConstantValue::getFloat(float value, const std::string &name) { + auto inst = new ConstantValue(value); + assert(inst); + return inst; +} + void User::setOperand(int index, Value *value) { assert(index < getNumOperands()); operands[index].setValue(value); diff --git a/src/IR.h b/src/IR.h index 581f5b8..204f9c7 100644 --- a/src/IR.h +++ b/src/IR.h @@ -37,6 +37,9 @@ template range make_range(IterT b, IterT e) { // targets. // 2. `PointerType` and `FunctionType` derive from `Type` and represent pointer // type and function type, respectively. +// +// NOTE `Type` and its derived classes have their ctors declared as 'protected'. +// Users must use Type::getXXXType() methods to obtain `Type` pointers. //===----------------------------------------------------------------------===// class Type { @@ -169,6 +172,36 @@ public: void removeUse(Use *use) { uses.remove(use); } }; // class Value +class ConstantValue : public Value { + friend class IRBuilder; + +protected: + union { + int iConstant; + float fConstant; + }; + +protected: + ConstantValue(int value, const std::string &name = "") + : Value(Type::getIntType(), name), iConstant(value) {} + ConstantValue(float value, const std::string &name = "") + : Value(Type::getFloatType(), name), fConstant(value) {} + +public: + static ConstantValue *getInt(int value, const std::string &name = ""); + static ConstantValue *getFloat(float value, const std::string &name = ""); + +public: + int getInt() const { + assert(isInt()); + return iConstant; + } + float getFloat() const { + assert(isFloat()); + return fConstant; + } +}; // class ConstantValue + class BasicBlock; class Argument : public Value { protected: @@ -215,6 +248,7 @@ public: block_list &getSuccessors() { return successors; } iterator begin() { return instructions.begin(); } iterator end() { return instructions.end(); } + iterator terminator() { return std::prev(end()); } }; // class BasicBlock class User : public Value { @@ -242,29 +276,6 @@ public: const std::string &getName() const { return name; } }; // class User -// class Constant : public User { -// protected: -// union scalar { -// int iConstant; -// float fConstant; -// }; -// std::vector dims; -// std::vector data; - -// protected: -// Constant(Type *type, const std::string &name = "") : User(type, name) {} - -// public: -// int getInt() const { -// assert(isInt()); -// return iConstant; -// } -// float getFloat() const { -// assert(isFloat()); -// return fConstant; -// } -// }; // class ConstantInst - class Instruction : public User { public: enum Kind : uint64_t { @@ -273,17 +284,14 @@ public: kAdd = 0x1UL << 0, kSub = 0x1UL << 1, kMul = 0x1UL << 2, - kSDiv = 0x1UL << 3, - kSRem = 0x1UL << 4, + kDiv = 0x1UL << 3, + kRem = 0x1UL << 4, kICmpEQ = 0x1UL << 5, kICmpNE = 0x1UL << 6, kICmpLT = 0x1UL << 7, kICmpGT = 0x1UL << 8, kICmpLE = 0x1UL << 9, kICmpGE = 0x1UL << 10, - kAShr = 0x1UL << 11, - kLShr = 0x1UL << 12, - kShl = 0x1UL << 13, kFAdd = 0x1UL << 14, kFSub = 0x1UL << 15, kFMul = 0x1UL << 16, @@ -301,19 +309,18 @@ public: kFNeg = 0x1UL << 26, kFtoI = 0x1UL << 28, kIToF = 0x1UL << 29, - kBitCast = 0x1UL << 30, // call - kCall = 0x1UL << 33, + kCall = 0x1UL << 30, // terminator kCondBr = 0x1UL << 31, kBr = 0x1UL << 32, - kReturn = 0x1UL << 34, + kReturn = 0x1UL << 33, // mem op - kAlloca = 0x1UL << 35, - kLoad = 0x1UL << 36, - kStore = 0x1UL << 37, + kAlloca = 0x1UL << 34, + kLoad = 0x1UL << 35, + kStore = 0x1UL << 36, // constant - // kConstant = 0x1UL << 38, + // kConstant = 0x1UL << 37, }; protected: @@ -330,17 +337,10 @@ public: BasicBlock *getParent() const { return parent; } void setParent(BasicBlock *bb) { parent = bb; } - bool isInteger() const { - static constexpr uint64_t IntegerOpMask = - kAdd | kSub | kMul | kSDiv | kSRem | kICmpEQ | kICmpNE | kICmpLT | - kICmpGT | kICmpLE | kICmpGE | kAShr | kLShr | kShl | kNeg | kNot | - kIToF; - return kind & IntegerOpMask; - } bool isCmp() const { static constexpr uint64_t CondOpMask = - kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE | kFCmpEQ | - kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE; + (kICmpEQ | kICmpNE | kICmpLT | kICmpGT | kICmpLE | kICmpGE) | + (kFCmpEQ | kFCmpNE | kFCmpLT | kFCmpGT | kFCmpLE | kFCmpGE); return kind & CondOpMask; } bool isTerminator() { @@ -547,11 +547,13 @@ public: Value *getIndex(int index) const { return getOperand(index + 2).getValue(); } }; // class StoreInst +class Module; class Function : public Value { friend class Module; protected: - Function(Type *type, const std::string &name) : Value(type, name) { + Function(Module *parent, Type *type, const std::string &name) + : Value(type, name), parent(parent), blocks() { blocks.emplace_back(new BasicBlock(this, "entry")); } @@ -559,6 +561,7 @@ public: using block_list = std::list>; protected: + Module *parent; block_list blocks; public: @@ -585,9 +588,12 @@ class GlobalValue : public User { friend class Module; protected: - GlobalValue(Type *type, const std::vector &dims = {}, - const std::string &name = "") - : User(type, name) { + Module *parent; + +protected: + GlobalValue(Module *parent, Type *type, const std::string &name, + const std::vector &dims = {}) + : User(type, name), parent(parent) { addOperands(dims); } @@ -600,22 +606,21 @@ class Module { protected: std::map> functions; std::map> globals; - // std::list> functions; - // std::list> globals; public: Module() = default; public: Function *createFunction(const std::string &name, Type *type) { - auto result = functions.try_emplace(name, new Function(type, name)); + auto result = functions.try_emplace(name, new Function(this, type, name)); if (not result.second) return nullptr; return result.first->second.get(); }; GlobalValue *createGlobalValue(const std::string &name, Type *type, const std::vector &dims = {}) { - auto result = globals.try_emplace(name, new GlobalValue(type, dims, name)); + auto result = + globals.try_emplace(name, new GlobalValue(this, type, name, dims)); if (not result.second) return nullptr; return result.first->second.get(); diff --git a/src/IRBuilder.h b/src/IRBuilder.h index 806c046..2aec108 100644 --- a/src/IRBuilder.h +++ b/src/IRBuilder.h @@ -27,14 +27,6 @@ public: void setPosition(BasicBlock::iterator position) { this->position = position; } public: - CallInst *createCallInst(Function *callee, - const std::vector args = {}, - const std::string &name = "") { - auto inst = new CallInst(callee, args, block, name); - assert(inst); - block->getInstructions().emplace(position, inst); - return inst; - } UnaryInst *createUnaryInst(Instruction::Kind kind, Type *type, Value *operand, const std::string &name = "") { @@ -43,6 +35,26 @@ public: block->getInstructions().emplace(position, inst); return inst; } + UnaryInst *createNegInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kNeg, Type::getIntType(), operand, + name); + } + UnaryInst *createNotInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kNot, Type::getIntType(), operand, + name); + } + UnaryInst *createFtoIInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kFtoI, Type::getIntType(), operand, + name); + } + UnaryInst *createFNegInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kFNeg, Type::getFloatType(), operand, + name); + } + UnaryInst *createIToFInst(Value *operand, const std::string &name = "") { + return createUnaryInst(Instruction::kIToF, Type::getFloatType(), operand, + name); + } BinaryInst *createBinaryInst(Instruction::Kind kind, Type *type, Value *lhs, Value *rhs, const std::string &name = "") { auto inst = new BinaryInst(kind, type, lhs, rhs, block, name); @@ -50,6 +62,116 @@ public: block->getInstructions().emplace(position, inst); return inst; } + BinaryInst *createAddInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kAdd, Type::getIntType(), lhs, rhs, + name); + } + BinaryInst *createSubInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kSub, Type::getIntType(), lhs, rhs, + name); + } + BinaryInst *createMulInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kMul, Type::getIntType(), lhs, rhs, + name); + } + BinaryInst *createDivInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kDiv, Type::getIntType(), lhs, rhs, + name); + } + BinaryInst *createRemInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kRem, Type::getIntType(), lhs, rhs, + name); + } + BinaryInst *createICmpEQInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kICmpEQ, Type::getIntType(), lhs, rhs, + name); + } + BinaryInst *createICmpNEInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kICmpNE, Type::getIntType(), lhs, rhs, + name); + } + BinaryInst *createICmpLTInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kICmpLT, Type::getIntType(), lhs, rhs, + name); + } + BinaryInst *createICmpLEInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kICmpLE, Type::getIntType(), lhs, rhs, + name); + } + BinaryInst *createICmpGTInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kICmpGT, Type::getIntType(), lhs, rhs, + name); + } + BinaryInst *createICmpGEInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kICmpGE, Type::getIntType(), lhs, rhs, + name); + } + BinaryInst *createFAddInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kFAdd, Type::getFloatType(), lhs, rhs, + name); + } + BinaryInst *createFSubInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kFSub, Type::getFloatType(), lhs, rhs, + name); + } + BinaryInst *createFMulInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kFMul, Type::getFloatType(), lhs, rhs, + name); + } + BinaryInst *createFDivInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kFDiv, Type::getFloatType(), lhs, rhs, + name); + } + BinaryInst *createFRemInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kFRem, Type::getFloatType(), lhs, rhs, + name); + } + BinaryInst *createFCmpEQInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpEQ, Type::getFloatType(), lhs, + rhs, name); + } + BinaryInst *createFCmpNEInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpNE, Type::getFloatType(), lhs, + rhs, name); + } + BinaryInst *createFCmpLTInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpLT, Type::getFloatType(), lhs, + rhs, name); + } + BinaryInst *createFCmpLEInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpLE, Type::getFloatType(), lhs, + rhs, name); + } + BinaryInst *createFCmpGTInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpGT, Type::getFloatType(), lhs, + rhs, name); + } + BinaryInst *createFCmpGEInst(Value *lhs, Value *rhs, + const std::string &name = "") { + return createBinaryInst(Instruction::kFCmpGE, Type::getFloatType(), lhs, + rhs, name); + } ReturnInst *createReturnInst(Value *value = nullptr) { auto inst = new ReturnInst(value); assert(inst);