diff --git a/src/IR.h b/src/IR.h index b1c6155..5f53177 100644 --- a/src/IR.h +++ b/src/IR.h @@ -11,6 +11,33 @@ namespace sysy { +template struct range { + using iterator = IterT; + using value_type = typename std::iterator_traits::value_type; + using reference = typename std::iterator_traits::reference; + + iterator b; + iterator e; + explicit range(iterator b, iterator e) : b(b), e(e) {} + iterator begin() { return b; } + iterator end() { return e; } +}; + +template range make_range(IterT b, IterT e) { + return range(b, e); +} + +//===----------------------------------------------------------------------===// +// Types +// +// The SysY type system is quite simple. +// 1. The base class `Type` is used to represent all primitive scalar types, +// include `int`, `float`, `void`, and the label type representing branch +// targets. +// 2. `PointerType` and `FunctionType` derive from `Type` and represent pointer +// type and function type, respectively. +//===----------------------------------------------------------------------===// + class Type { public: enum Kind { @@ -31,8 +58,10 @@ public: static Type *getIntType(); static Type *getFloatType(); static Type *getVoidType(); - static Type *getPointerType(); static Type *getLabelType(); + static Type *getPointerType(Type *baseType); + static Type *getFunctionType(Type *returnType, + const std::vector ¶mTypes = {}); static int getTypeSize(); public: @@ -53,7 +82,7 @@ protected: PointerType(Type *baseType) : Type(kPointer), baseType(baseType) {} public: - Type *get(Type *baseType); + static PointerType *get(Type *baseType); public: Type *getBaseType() const { return baseType; } @@ -66,17 +95,25 @@ private: protected: FunctionType(Type *returnType) : Type(kFunction), returnType(returnType) {} - FunctionType(Type *returnType, const std::vector paramTypes = {}) + FunctionType(Type *returnType, const std::vector ¶mTypes = {}) : Type(kFunction), returnType(returnType), paramTypes(paramTypes) {} public: - Type *get(Type *baseType, const std::vector paramTypes = {}); + static FunctionType *get(Type *returnType, + const std::vector ¶mTypes = {}); public: Type *getReturnType() const { return returnType; } const std::vector &getParamTypes() const { return paramTypes; } + int getNumParams() const { return paramTypes.size(); } }; // class FunctionType +//===----------------------------------------------------------------------===// +// Values +// +// description +//===----------------------------------------------------------------------===// + class User; class Value; @@ -113,10 +150,12 @@ public: class Value { protected: Type *type; - std::vector uses; + std::string name; + std::list uses; protected: - Value(Type *type) : type(type), uses() {} + Value(Type *type, const std::string &name = "") + : type(type), name(name), uses() {} virtual ~Value() {} public: @@ -124,48 +163,12 @@ public: bool isInt() const { return type->isInt(); } bool isFloat() const { return type->isFloat(); } bool isPointer() const { return type->isPointer(); } - const std::vector &getUses() { return uses; } + const std::list &getUses() { return uses; } void addUse(Use *use) { uses.push_back(use); } void replaceAllUsesWith(Value *value); + void removeUse(Use *use) { uses.remove(use); } }; // class Value -class User : public Value { -protected: - std::vector operands; - std::string name; - -protected: - User(Type *type, const std::vector &operands = {}, - const std::string &name = "") - : Value(type), operands(), name(name) { - for (auto op : operands) - addOperand(op); - } - -public: - struct OperandIterator : public std::vector::const_iterator { - OperandIterator(const std::vector::const_iterator &iter) - : std::vector::const_iterator(iter) {} - using value_type = Value *; - value_type operator*() { return operator->()->getValue(); } - }; - - OperandIterator op_begin() const { - return OperandIterator(operands.begin()); - }; - OperandIterator op_end() const { return OperandIterator(operands.end()); }; - int getNumOperands() const { return operands.size(); } - const std::vector &getOperands() const { return operands; } - Value *getOperand(int index) const { return operands[index].getValue(); } - void addOperand(Value *value) { - operands.emplace_back(Use::kRead, operands.size(), this, value); - value->addUse(&operands.back()); - } - void replaceOperand(int index, Value *value); - const std::string &getName() const { return name; } - -}; // class User - class BasicBlock; class Argument : public Value { protected: @@ -173,46 +176,95 @@ protected: int index; protected: - Argument(Type *type, BasicBlock *block, int index) - : Value(type), block(block), index(index) {} + Argument(Type *type, BasicBlock *block, int index, + const std::string &name = "") + : Value(type, name), block(block), index(index) {} }; class Instruction; -class BasicBlock : public User, public std::list> { +class Function; +class BasicBlock : public Value { + friend class Function; + public: - using arg_list = std::vector; - using arg_iterator = arg_list::iterator; + using inst_list = std::list>; + using iterator = inst_list::iterator; + using arg_list = std::vector; using block_list = std::vector; - using block_iterator = block_list::iterator; protected: + Function *parent; + inst_list instructions; arg_list arguments; block_list successors; block_list predecessors; protected: - BasicBlock(const std::string &name = "") - : User(Type::getLabelType(), {}, name), - std::list>() {} + explicit BasicBlock(Function *parent, const std::string &name = "") + : Value(Type::getLabelType(), name), parent(parent), instructions(), + arguments(), successors(), predecessors() {} public: - arg_iterator arg_begin() { return arguments.begin(); } - arg_iterator arg_end() { return arguments.end(); } - block_iterator pred_begin() { return predecessors.begin(); } - block_iterator pred_end() { return predecessors.end(); } - block_iterator succ_begin() { return successors.begin(); } - block_iterator succ_end() { return successors.end(); } int getNumArguments() const { return arguments.size(); } int getNumPredecessors() const { return predecessors.size(); } int getNumSuccessors() const { return successors.size(); } - const arg_list &getArguments() const { return arguments; } - const block_list &getPredecessors() const { return predecessors; } - const block_list &getSuccessors() const { return successors; } - Argument *getArgument(int index) const { return arguments[index]; } - BasicBlock *getPredecessor(int index) const { return predecessors[index]; } - BasicBlock *getSuccessor(int index) const { return successors[index]; } + Function *getParent() const { return parent; } + inst_list &getInstructions() { return instructions; } + arg_list &getArguments() { return arguments; } + block_list &getPredecessors() { return predecessors; } + block_list &getSuccessors() { return successors; } + iterator begin() { return instructions.begin(); } + iterator end() { return instructions.end(); } }; // class BasicBlock +class User : public Value { +protected: + std::vector operands; + +protected: + User(Type *type, const std::string &name = "") + : Value(type, name), operands() {} + +public: + int getNumOperands() const { return operands.size(); } + const std::vector &getOperands() const { return operands; } + const Use &getOperand(int index) const { return operands[index]; } + void addOperand(Value *value, Use::Kind mode = Use::kRead) { + operands.emplace_back(mode, operands.size(), this, value); + value->addUse(&operands.back()); + } + void addOperands(const std::vector &operands) { + for (auto value : operands) + addOperand(value); + } + void replaceOperand(int index, Value *value); + void setOperand(int index, Value *value); + const std::string &getName() const { return name; } +}; // class User + +// class Constant : public User { +// protected: +// union scalar { +// int iConstant; +// float fConstant; +// }; +// std::vector dims; +// std::vector data; + +// protected: +// Constant(Type *type, const std::string &name = "") : User(type, name) {} + +// public: +// int getInt() const { +// assert(isInt()); +// return iConstant; +// } +// float getFloat() const { +// assert(isFloat()); +// return fConstant; +// } +// }; // class ConstantInst + class Instruction : public User { public: enum Kind : uint64_t { @@ -261,7 +313,7 @@ public: kLoad = 0x1UL << 36, kStore = 0x1UL << 37, // constant - kConstant = 0x1UL << 38, + // kConstant = 0x1UL << 38, }; protected: @@ -269,14 +321,14 @@ protected: BasicBlock *parent; protected: - Instruction(Kind kind, Type *type, const std::vector &oprands = {}, - BasicBlock *parent = nullptr, const std::string &name = "") - : User(type, oprands, name), kind(kind), parent(parent) {} + Instruction(Kind kind, Type *type, BasicBlock *parent = nullptr, + const std::string &name = "") + : User(type, name), kind(kind), parent(parent) {} public: Kind getKind() const { return kind; } - BasicBlock *getBasicBlock() const { return parent; } - void setBasicBlock(BasicBlock *bb) { parent = bb; } + BasicBlock *getParent() const { return parent; } + void setParent(BasicBlock *bb) { parent = bb; } bool isInteger() const { static constexpr uint64_t IntegerOpMask = @@ -300,65 +352,44 @@ public: kAdd | kMul | kICmpEQ | kICmpNE | kFAdd | kFMul | kFCmpEQ | kFCmpNE; return kind & CommutativeOpMask; } - - // static bool isReverse(Instruction *x, Instruction *y); - // int getRank(); }; // class Instruction -class ConstantInst : public Instruction { -protected: - union { - int iConstant; - float fConstant; - }; - -protected: - ConstantInst(Type *type, BasicBlock *parent = nullptr, - const std::string &name = "") - : Instruction(kConstant, type, {}, parent, name) {} - -public: - int getInt() const { - assert(isInt()); - return iConstant; - } - float getFloat() const { - assert(isFloat()); - return fConstant; - } -}; // class ConstantInst - class Function; class CallInst : public Instruction { + friend class IRBuilder; + protected: CallInst(Function *callee, const std::vector args = {}, BasicBlock *parent = nullptr, const std::string &name = ""); - // : Instruction(kCall, callee->getReturnType(), {}, parent, - // name) {} public: Function *getCallee(); - OperandIterator arg_begin() const { return std::next(op_begin()); } - OperandIterator arg_end() const { return op_end(); } + auto getArguments() { + return make_range(std::next(operands.begin()), operands.end()); + } }; // class CallInst class UnaryInst : public Instruction { + friend class IRBuilder; + protected: - UnaryInst(Kind kind, Type *type, Value *operand, BasicBlock *parent, + UnaryInst(Kind kind, Type *type, Value *operand, BasicBlock *parent = nullptr, const std::string &name = "") - : Instruction(kind, type, {}, parent, name) { + : Instruction(kind, type, parent, name) { addOperand(operand); } public: - Value *getOperand() const { return *op_begin(); } + Value *getOperand() const { return operands[0].getValue(); } }; // class UnaryInst class BinaryInst : public Instruction { + friend class IRBuilder; + protected: BinaryInst(Kind kind, Type *type, Value *lhs, Value *rhs, BasicBlock *parent, const std::string &name = "") - : Instruction(kind, type, {}, parent, name) { + : Instruction(kind, type, parent, name) { addOperand(lhs); addOperand(rhs); } @@ -374,10 +405,14 @@ protected: }; // TerminatorInst class ReturnInst : public TerminatorInst { + friend class IRBuilder; + protected: - ReturnInst(const std::vector &values = {}, - BasicBlock *parent = nullptr) - : TerminatorInst(kReturn, Type::getVoidType(), values, parent, "") {} + ReturnInst(Value *value = nullptr, BasicBlock *parent = nullptr) + : TerminatorInst(kReturn, Type::getVoidType(), parent, "") { + if (value) + addOperand(value); + } }; // class ReturnInst class BranchInst : public TerminatorInst { @@ -389,14 +424,60 @@ public: bool isConditional() const { return kind == kCondBr; } }; // class BranchInst +class UncondBrInst : public BranchInst { + friend class IRBuilder; + +protected: + UncondBrInst(BasicBlock *block, std::vector args, + BasicBlock *parent = nullptr) + : BranchInst(kCondBr, Type::getVoidType(), parent, "") { + assert(block->getNumArguments() == args.size()); + addOperand(block); + addOperands(args); + } + +public: + BasicBlock *getBlock() const { + return dynamic_cast(getOperand(0).getValue()); + } + auto getArguments() const { + return make_range(std::next(operands.begin()), operands.end()); + } +}; // class UncondBrInst + class CondBrInst : public BranchInst { + friend class IRBuilder; + protected: - CondBrInst(Value *condition, BasicBlock *trueBlock, BasicBlock *falseBlock, - BasicBlock *parent = nullptr) - : BranchInst(kCondBr, Type::getVoidType(), {}, parent, "") { + CondBrInst(Value *condition, BasicBlock *thenBlock, BasicBlock *elseBlock, + const std::vector &thenArgs, + const std::vector &elseArgs, BasicBlock *parent = nullptr) + : BranchInst(kCondBr, Type::getVoidType(), parent, "") { + assert(thenBlock->getNumArguments() == thenArgs.size() and + elseBlock->getNumArguments() == elseArgs.size()); addOperand(condition); - addOperand(trueBlock); - addOperand(falseBlock); + addOperand(thenBlock); + addOperand(elseBlock); + addOperands(thenArgs); + addOperands(elseArgs); + } + +public: + BasicBlock *getThenBlock() const { + return dynamic_cast(getOperand(0).getValue()); + } + BasicBlock *getElseBlock() const { + return dynamic_cast(getOperand(1).getValue()); + } + auto getThenArguments() const { + auto begin = operands.begin() + 2; + auto end = begin + getThenBlock()->getNumArguments(); + return make_range(begin, end); + } + auto getElseArguments() const { + auto begin = operands.begin() + 2 + getThenBlock()->getNumArguments(); + auto end = operands.end(); + return make_range(begin, end); } }; // class CondBrInst @@ -406,70 +487,145 @@ protected: }; // class MemoryInst class AllocaInst : public MemoryInst { + friend class IRBuilder; + protected: AllocaInst(Type *type, const std::vector &dims = {}, BasicBlock *parent = nullptr, const std::string &name = "") - : MemoryInst(kAlloca, type, dims, parent, name) {} + : MemoryInst(kAlloca, type, parent, name) { + addOperands(dims); + } public: int getNumDims() const { return getNumOperands(); } - Value *getDim(int index) { return getOperand(index); } + auto &getDims() const { return operands; } + Value *getDim(int index) { return getOperand(index).getValue(); } }; // class AllocaInst class LoadInst : public MemoryInst { + friend class IRBuilder; + protected: LoadInst(Value *pointer, const std::vector &indices = {}, BasicBlock *parent = nullptr, const std::string &name = "") : MemoryInst( kLoad, dynamic_cast(pointer->getType())->getBaseType(), - indices, parent, name) {} + parent, name) { + addOperands(indices); + } public: - Value *getPointer() const { return operands.front().getValue(); } int getNumIndices() const { return getNumOperands() - 1; } - Value *getIndex(int index) const { return getOperand(index + 1); } + Value *getPointer() const { return operands.front().getValue(); } + auto getIndices() const { + return make_range(std::next(operands.begin()), operands.end()); + } + Value *getIndex(int index) const { return getOperand(index + 1).getValue(); } }; // class LoadInst class StoreInst : public MemoryInst { + friend class IRBuilder; + protected: StoreInst(Value *value, Value *pointer, const std::vector &indices = {}, BasicBlock *parent = nullptr, const std::string &name = "") - : MemoryInst(kStore, Type::getVoidType(), {}, parent, name) { + : MemoryInst(kStore, Type::getVoidType(), parent, name) { addOperand(value); addOperand(pointer); - for (auto index : indices) - addOperand(index); + addOperands(indices); } public: + int getNumIndices() const { return getNumOperands() - 2; } Value *getValue() const { return operands[0].getValue(); } Value *getPointer() const { return operands[1].getValue(); } - int getNumIndices() const { return getNumOperands() - 2; } - Value *getIndex(int index) const { return getOperand(index + 2); } + auto getIndices() const { + return make_range(operands.begin() + 2, operands.end()); + } + Value *getIndex(int index) const { return getOperand(index + 2).getValue(); } }; // class StoreInst -class Function : public Value, public std::list> { +class Function : public Value { + friend class Module; + +protected: + Function(Type *type, const std::string &name) : Value(type, name) { + blocks.emplace_back(new BasicBlock(this, "entry")); + } + +public: + using block_list = std::list>; + protected: - Function(Type *type) - : Value(type), std::list>() {} + block_list blocks; + +public: + Type *getReturnType() const { + return dynamic_cast(getType())->getReturnType(); + } + auto &getParamTypes() const { + return dynamic_cast(getType())->getParamTypes(); + } + block_list &getBasicBlocks() { return blocks; } + BasicBlock *getEntryBlock() { return blocks.front().get(); } + BasicBlock *addBasicBlock(const std::string &name = "") { + blocks.emplace_back(new BasicBlock(this, name)); + return blocks.back().get(); + } + void removeBasicBlock(BasicBlock *block) { + blocks.remove_if([&](std::unique_ptr &b) -> bool { + return block == b.get(); + }); + } }; // class Function class GlobalValue : public User { + friend class Module; + protected: GlobalValue(Type *type, const std::vector &dims = {}, const std::string &name = "") - : User(type, dims, name) {} + : User(type, name) { + addOperands(dims); + } + public: int getNumDims() const { return getNumOperands(); } - Value *getDim(int index) { return getOperand(index); } + Value *getDim(int index) { return getOperand(index).getValue(); } }; // class GlobalValue class Module { protected: - std::list functions; - std::list globals; + std::list> functions; + std::list> globals; + +public: + Module() = default; + +public: + Function *addFunction(Type *type, const std::string &name) { + functions.emplace_back(new Function(type, name)); + return functions.back().get(); + }; + GlobalValue *addGlobalValue(Type *type, const std::vector &dims = {}, + const std::string &name = "") { + globals.emplace_back(type, dims, name); + return globals.back().get(); + } }; // class Module +inline CallInst::CallInst(Function *callee, const std::vector args, + BasicBlock *parent, const std::string &name) + : Instruction(kCall, callee->getReturnType(), parent, name) { + addOperand(callee); + for (auto arg : args) + addOperand(arg); +} + +inline Function *CallInst::getCallee() { + return dynamic_cast(getOperand(0).getValue()); +} + } // namespace sysy \ No newline at end of file