#pragma once #include "utils.h" #include #include #include #include #include #include #include #include namespace ir { class Value; class User; class BasicBlock; class Function; class Instruction; class Argument; class ConstantInt; class ConstantFloat; class ConstantI1; class ConstantArrayValue; class Type; class Use { public: Use() = default; Use(Value* value, User* user, size_t operand_index) : value_(value), user_(user), operand_index_(operand_index) {} Value* GetValue() const { return value_; } User* GetUser() const { return user_; } size_t GetOperandIndex() const { return operand_index_; } void SetValue(Value* value) { value_ = value; } void SetUser(User* user) { user_ = user; } void SetOperandIndex(size_t operand_index) { operand_index_ = operand_index; } private: Value* value_ = nullptr; User* user_ = nullptr; size_t operand_index_ = 0; }; class Context { public: Context() = default; ~Context(); ConstantInt* GetConstInt(int v); ConstantI1* GetConstBool(bool v); std::string NextTemp(); std::string NextBlockName(const std::string& prefix = "bb"); private: std::unordered_map> const_ints_; std::unordered_map> const_bools_; int temp_index_ = -1; int block_index_ = -1; }; class Type { public: enum class Kind { Void, Int1, Int32, Float, Label, Function, Pointer, PtrInt32 = Pointer, Array }; explicit Type(Kind kind); Type(Kind kind, std::shared_ptr element_type, size_t num_elements = 0); static const std::shared_ptr& GetVoidType(); static const std::shared_ptr& GetInt1Type(); static const std::shared_ptr& GetInt32Type(); static const std::shared_ptr& GetFloatType(); static const std::shared_ptr& GetLabelType(); static const std::shared_ptr& GetBoolType(); static std::shared_ptr GetPointerType(std::shared_ptr pointee = nullptr); static const std::shared_ptr& GetPtrInt32Type(); static std::shared_ptr GetArrayType(std::shared_ptr element_type, size_t num_elements); Kind GetKind() const { return kind_; } bool IsVoid() const { return kind_ == Kind::Void; } bool IsInt1() const { return kind_ == Kind::Int1; } bool IsInt32() const { return kind_ == Kind::Int32; } bool IsFloat() const { return kind_ == Kind::Float; } bool IsLabel() const { return kind_ == Kind::Label; } bool IsFunction() const { return kind_ == Kind::Function; } bool IsBool() const { return kind_ == Kind::Int1; } bool IsPointer() const { return kind_ == Kind::Pointer; } bool IsPtrInt32() const { return IsPointer(); } bool IsArray() const { return kind_ == Kind::Array; } std::shared_ptr GetElementType() const { return element_type_; } size_t GetNumElements() const { return num_elements_; } int GetSize() const; void Print(std::ostream& os) const; private: Kind kind_; std::shared_ptr element_type_; size_t num_elements_ = 0; }; class Value { public: Value(std::shared_ptr ty, std::string name); virtual ~Value() = default; const std::shared_ptr& GetType() const { return type_; } const std::string& GetName() const { return name_; } void SetName(std::string name) { name_ = std::move(name); } bool IsVoid() const { return type_ && type_->IsVoid(); } bool IsInt32() const { return type_ && type_->IsInt32(); } bool IsPtrInt32() const { return type_ && type_->IsPtrInt32(); } bool IsFloat() const { return type_ && type_->IsFloat(); } bool IsBool() const { return type_ && type_->IsBool(); } bool IsArray() const { return type_ && type_->IsArray(); } bool IsLabel() const { return type_ && type_->IsLabel(); } virtual bool IsConstant() const { return false; } virtual bool IsInstruction() const { return false; } virtual bool IsUser() const { return false; } virtual bool IsFunction() const { return false; } virtual bool IsArgument() const { return false; } void AddUse(User* user, size_t operand_index); void RemoveUse(User* user, size_t operand_index); const std::vector& GetUses() const { return uses_; } void ReplaceAllUsesWith(Value* new_value); virtual void Print(std::ostream& os) const; protected: std::shared_ptr type_; std::string name_; std::vector uses_; }; template inline bool isa(const Value* value) { return value && T::classof(value); } template inline T* dyncast(Value* value) { return isa(value) ? dynamic_cast(value) : nullptr; } template inline const T* dyncast(const Value* value) { return isa(value) ? dynamic_cast(value) : nullptr; } class ConstantValue : public Value { public: ConstantValue(std::shared_ptr ty, std::string name = ""); bool IsConstant() const override final { return true; } }; class ConstantInt : public ConstantValue { public: ConstantInt(std::shared_ptr ty, int value); int GetValue() const { return value_; } static bool classof(const Value* value) { return value && value->IsConstant() && dynamic_cast(value) != nullptr; } private: int value_; }; class ConstantFloat : public ConstantValue { public: ConstantFloat(std::shared_ptr ty, float value); float GetValue() const { return value_; } static bool classof(const Value* value) { return value && value->IsConstant() && dynamic_cast(value) != nullptr; } private: float value_; }; class ConstantI1 : public ConstantValue { public: ConstantI1(std::shared_ptr ty, bool value); bool GetValue() const { return value_; } static bool classof(const Value* value) { return value && value->IsConstant() && dynamic_cast(value) != nullptr; } private: bool value_; }; class ConstantArrayValue : public Value { public: ConstantArrayValue(std::shared_ptr array_type, const std::vector& elements, const std::vector& dims, const std::string& name = ""); const std::vector& GetElements() const { return elements_; } const std::vector& GetDims() const { return dims_; } void Print(std::ostream& os) const override; static bool classof(const Value* value) { return value && dynamic_cast(value) != nullptr; } private: std::vector elements_; std::vector dims_; }; enum class Opcode { Add, Sub, Mul, Div, Rem, FAdd, FSub, FMul, FDiv, FRem, And, Or, Xor, Shl, AShr, LShr, ICmpEQ, ICmpNE, ICmpLT, ICmpGT, ICmpLE, ICmpGE, FCmpEQ, FCmpNE, FCmpLT, FCmpGT, FCmpLE, FCmpGE, Neg, Not, FNeg, FtoI, IToF, Call, CondBr, Br, Return, Ret = Return, Unreachable, Alloca, Load, Store, Memset, GetElementPtr, Phi, Zext }; class User : public Value { public: User(std::shared_ptr ty, std::string name); bool IsUser() const override final { return true; } size_t GetNumOperands() const { return operands_.size(); } Value* GetOperand(size_t index) const; void SetOperand(size_t index, Value* value); void AddOperand(Value* value); void AddOperands(const std::vector& values); void RemoveOperand(size_t index); void ClearAllOperands(); protected: std::vector operands_; }; class Argument : public Value { public: Argument(std::shared_ptr type, std::string name, size_t index); size_t GetIndex() const { return index_; } bool IsArgument() const override final { return true; } static bool classof(const Value* value) { return value && dynamic_cast(value) != nullptr; } private: size_t index_; }; class GlobalValue : public User { public: GlobalValue(std::shared_ptr object_type, const std::string& name, bool is_const = false, Value* init = nullptr); bool IsConstant() const override { return is_const_; } bool HasInitializer() const { return init_ != nullptr; } Value* GetInitializer() const { return init_; } std::shared_ptr GetObjectType() const { return object_type_; } void SetConstant(bool is_const) { is_const_ = is_const; } void SetInitializer(Value* init) { init_ = init; } static bool classof(const Value* value) { return value && dynamic_cast(value) != nullptr; } private: std::shared_ptr object_type_; bool is_const_ = false; Value* init_ = nullptr; }; class Instruction : public User { public: Instruction(Opcode opcode, std::shared_ptr ty, BasicBlock* parent = nullptr, const std::string& name = ""); bool IsInstruction() const override final { return true; } Opcode GetOpcode() const { return opcode_; } bool IsTerminator() const; BasicBlock* GetParent() const { return parent_; } void SetParent(BasicBlock* parent) { parent_ = parent; } static bool classof(const Value* value) { return value && value->IsInstruction(); } private: Opcode opcode_; BasicBlock* parent_; }; class BinaryInst : public Instruction { public: BinaryInst(Opcode opcode, std::shared_ptr ty, Value* lhs, Value* rhs, BasicBlock* parent = nullptr, const std::string& name = ""); Value* GetLhs() const { return GetOperand(0); } Value* GetRhs() const { return GetOperand(1); } static bool classof(const Value* value); }; class UnaryInst : public Instruction { public: UnaryInst(Opcode opcode, std::shared_ptr ty, Value* operand, BasicBlock* parent = nullptr, const std::string& name = ""); Value* GetOprd() const { return GetOperand(0); } static bool classof(const Value* value); }; class ReturnInst : public Instruction { public: ReturnInst(Value* value = nullptr, BasicBlock* parent = nullptr); bool HasReturnValue() const { return GetNumOperands() > 0; } Value* GetReturnValue() const { return HasReturnValue() ? GetOperand(0) : nullptr; } Value* GetValue() const { return GetReturnValue(); } static bool classof(const Value* value) { return value && Instruction::classof(value) && static_cast(value)->GetOpcode() == Opcode::Return; } }; class AllocaInst : public Instruction { public: AllocaInst(std::shared_ptr allocated_type, BasicBlock* parent = nullptr, const std::string& name = ""); std::shared_ptr GetAllocatedType() const { return allocated_type_; } static bool classof(const Value* value) { return value && Instruction::classof(value) && static_cast(value)->GetOpcode() == Opcode::Alloca; } private: std::shared_ptr allocated_type_; }; class LoadInst : public Instruction { public: LoadInst(std::shared_ptr value_type, Value* ptr, BasicBlock* parent = nullptr, const std::string& name = ""); Value* GetPtr() const { return GetOperand(0); } static bool classof(const Value* value) { return value && Instruction::classof(value) && static_cast(value)->GetOpcode() == Opcode::Load; } }; class StoreInst : public Instruction { public: StoreInst(Value* value, Value* ptr, BasicBlock* parent = nullptr); Value* GetValue() const { return GetOperand(0); } Value* GetPtr() const { return GetOperand(1); } static bool classof(const Value* value) { return value && Instruction::classof(value) && static_cast(value)->GetOpcode() == Opcode::Store; } }; class UncondBrInst : public Instruction { public: UncondBrInst(BasicBlock* dest, BasicBlock* parent = nullptr); BasicBlock* GetDest() const; static bool classof(const Value* value) { return value && Instruction::classof(value) && static_cast(value)->GetOpcode() == Opcode::Br; } }; class CondBrInst : public Instruction { public: CondBrInst(Value* cond, BasicBlock* then_block, BasicBlock* else_block, BasicBlock* parent = nullptr); Value* GetCondition() const { return GetOperand(0); } BasicBlock* GetThenBlock() const; BasicBlock* GetElseBlock() const; static bool classof(const Value* value) { return value && Instruction::classof(value) && static_cast(value)->GetOpcode() == Opcode::CondBr; } }; class UnreachableInst : public Instruction { public: explicit UnreachableInst(BasicBlock* parent = nullptr); static bool classof(const Value* value) { return value && Instruction::classof(value) && static_cast(value)->GetOpcode() == Opcode::Unreachable; } }; class CallInst : public Instruction { public: CallInst(Function* callee, const std::vector& args = {}, BasicBlock* parent = nullptr, const std::string& name = ""); Function* GetCallee() const; std::vector GetArguments() const; static bool classof(const Value* value) { return value && Instruction::classof(value) && static_cast(value)->GetOpcode() == Opcode::Call; } }; class GetElementPtrInst : public Instruction { public: GetElementPtrInst(std::shared_ptr source_type, Value* ptr, const std::vector& indices, BasicBlock* parent = nullptr, const std::string& name = ""); Value* GetPointer() const { return GetOperand(0); } size_t GetNumIndices() const { return GetNumOperands() > 0 ? GetNumOperands() - 1 : 0; } Value* GetIndex(size_t index) const { return GetOperand(index + 1); } std::shared_ptr GetSourceType() const { return source_type_; } static bool classof(const Value* value) { return value && Instruction::classof(value) && static_cast(value)->GetOpcode() == Opcode::GetElementPtr; } private: std::shared_ptr source_type_; }; class PhiInst : public Instruction { public: PhiInst(std::shared_ptr type, BasicBlock* parent = nullptr, const std::string& name = ""); void AddIncoming(Value* value, BasicBlock* block); int GetNumIncomings() const { return static_cast(GetNumOperands() / 2); } Value* GetIncomingValue(int index) const { return GetOperand(static_cast(2 * index)); } BasicBlock* GetIncomingBlock(int index) const; static bool classof(const Value* value) { return value && Instruction::classof(value) && static_cast(value)->GetOpcode() == Opcode::Phi; } }; class ZextInst : public Instruction { public: ZextInst(Value* value, std::shared_ptr target_type, BasicBlock* parent = nullptr, const std::string& name = ""); Value* GetValue() const { return GetOperand(0); } static bool classof(const Value* value) { return value && Instruction::classof(value) && static_cast(value)->GetOpcode() == Opcode::Zext; } }; class MemsetInst : public Instruction { public: MemsetInst(Value* dst, Value* value, Value* len, Value* is_volatile, BasicBlock* parent = nullptr); Value* GetDest() const { return GetOperand(0); } Value* GetValue() const { return GetOperand(1); } Value* GetLength() const { return GetOperand(2); } Value* GetIsVolatile() const { return GetOperand(3); } static bool classof(const Value* value) { return value && Instruction::classof(value) && static_cast(value)->GetOpcode() == Opcode::Memset; } }; class BasicBlock : public Value { public: explicit BasicBlock(const std::string& name); BasicBlock(Function* parent, const std::string& name); Function* GetParent() const { return parent_; } void SetParent(Function* parent) { parent_ = parent; } bool HasTerminator() const; std::vector>& GetInstructions() { return instructions_; } const std::vector>& GetInstructions() const { return instructions_; } void EraseInstruction(Instruction* inst); void AddPredecessor(BasicBlock* pred); void AddSuccessor(BasicBlock* succ); void RemovePredecessor(BasicBlock* pred); void RemoveSuccessor(BasicBlock* succ); const std::vector& GetPredecessors() const { return predecessors_; } const std::vector& GetSuccessors() const { return successors_; } template T* Insert(size_t index, Args&&... args) { if (index > instructions_.size()) { throw std::out_of_range("BasicBlock insert index out of range"); } auto inst = std::make_unique(std::forward(args)...); auto* ptr = inst.get(); ptr->SetParent(this); instructions_.insert(instructions_.begin() + static_cast(index), std::move(inst)); return ptr; } template T* Append(Args&&... args) { if (HasTerminator()) { throw std::runtime_error("BasicBlock already has terminator"); } auto inst = std::make_unique(std::forward(args)...); auto* ptr = inst.get(); ptr->SetParent(this); instructions_.push_back(std::move(inst)); return ptr; } static bool classof(const Value* value) { return value && dynamic_cast(value) != nullptr; } private: Function* parent_ = nullptr; std::vector> instructions_; std::vector predecessors_; std::vector successors_; }; class Function : public Value { public: Function(std::string name, std::shared_ptr ret_type, const std::vector>& param_types = {}, const std::vector& param_names = {}, bool is_external = false); bool IsFunction() const override final { return true; } std::shared_ptr GetReturnType() const { return return_type_; } const std::vector>& GetParamTypes() const { return param_types_; } const std::vector>& GetArguments() const { return arguments_; } Argument* GetArgument(size_t index) const; bool IsExternal() const { return is_external_; } void SetExternal(bool is_external) { is_external_ = is_external; } BasicBlock* GetEntryBlock() const { return entry_; } BasicBlock* GetEntry() const { return entry_; } void SetEntryBlock(BasicBlock* bb) { entry_ = bb; } BasicBlock* EnsureEntryBlock(); BasicBlock* CreateBlock(const std::string& name); BasicBlock* AddBlock(std::unique_ptr block); std::vector>& GetBlocks() { return blocks_; } const std::vector>& GetBlocks() const { return blocks_; } static bool classof(const Value* value) { return value && value->IsFunction(); } private: std::shared_ptr return_type_; std::vector> param_types_; std::vector> arguments_; bool is_external_ = false; BasicBlock* entry_ = nullptr; std::vector> blocks_; }; class Module { public: Module() = default; Context& GetContext() { return context_; } const Context& GetContext() const { return context_; } Function* CreateFunction(const std::string& name, std::shared_ptr ret_type, const std::vector>& param_types = {}, const std::vector& param_names = {}, bool is_external = false); Function* GetFunction(const std::string& name) const; const std::vector>& GetFunctions() const { return functions_; } GlobalValue* CreateGlobalValue(const std::string& name, std::shared_ptr object_type, bool is_const = false, Value* init = nullptr); GlobalValue* GetGlobalValue(const std::string& name) const; const std::vector>& GetGlobalValues() const { return globals_; } private: Context context_; std::vector> functions_; std::map function_map_; std::vector> globals_; std::map global_map_; }; class IRBuilder { public: IRBuilder(Context& ctx, BasicBlock* bb); void SetInsertPoint(BasicBlock* bb); BasicBlock* GetInsertBlock() const { return insert_block_; } ConstantInt* CreateConstInt(int v); ConstantFloat* CreateConstFloat(float v); ConstantI1* CreateConstBool(bool v); ConstantArrayValue* CreateConstArray(std::shared_ptr array_type, const std::vector& elements, const std::vector& dims, const std::string& name = ""); BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs, const std::string& name = ""); BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name = ""); BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name = ""); BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name = ""); BinaryInst* CreateDiv(Value* lhs, Value* rhs, const std::string& name = ""); BinaryInst* CreateRem(Value* lhs, Value* rhs, const std::string& name = ""); BinaryInst* CreateAnd(Value* lhs, Value* rhs, const std::string& name = ""); BinaryInst* CreateOr(Value* lhs, Value* rhs, const std::string& name = ""); BinaryInst* CreateXor(Value* lhs, Value* rhs, const std::string& name = ""); BinaryInst* CreateShl(Value* lhs, Value* rhs, const std::string& name = ""); BinaryInst* CreateAShr(Value* lhs, Value* rhs, const std::string& name = ""); BinaryInst* CreateLShr(Value* lhs, Value* rhs, const std::string& name = ""); BinaryInst* CreateICmp(Opcode op, Value* lhs, Value* rhs, const std::string& name = ""); BinaryInst* CreateFCmp(Opcode op, Value* lhs, Value* rhs, const std::string& name = ""); UnaryInst* CreateNeg(Value* operand, const std::string& name = ""); UnaryInst* CreateNot(Value* operand, const std::string& name = ""); UnaryInst* CreateFNeg(Value* operand, const std::string& name = ""); UnaryInst* CreateFtoI(Value* operand, const std::string& name = ""); UnaryInst* CreateIToF(Value* operand, const std::string& name = ""); AllocaInst* CreateAlloca(std::shared_ptr allocated_type, const std::string& name = ""); LoadInst* CreateLoad(Value* ptr, std::shared_ptr value_type, const std::string& name = ""); LoadInst* CreateLoad(Value* ptr, const std::string& name = "") { return CreateLoad(ptr, Type::GetInt32Type(), name); } StoreInst* CreateStore(Value* val, Value* ptr); UncondBrInst* CreateBr(BasicBlock* dest); CondBrInst* CreateCondBr(Value* cond, BasicBlock* then_bb, BasicBlock* else_bb); ReturnInst* CreateRet(Value* val = nullptr); UnreachableInst* CreateUnreachable(); CallInst* CreateCall(Function* callee, const std::vector& args, const std::string& name = ""); GetElementPtrInst* CreateGEP(Value* ptr, std::shared_ptr source_type, const std::vector& indices, const std::string& name = ""); PhiInst* CreatePhi(std::shared_ptr type, const std::string& name = ""); ZextInst* CreateZext(Value* val, std::shared_ptr target_type, const std::string& name = ""); MemsetInst* CreateMemset(Value* dst, Value* val, Value* len, Value* is_volatile); private: Context& ctx_; BasicBlock* insert_block_; }; class IRPrinter { public: void Print(const Module& module, std::ostream& os); }; inline std::ostream& operator<<(std::ostream& os, const Type& type) { type.Print(os); return os; } inline std::ostream& operator<<(std::ostream& os, const Value& value) { value.Print(os); return os; } } // namespace ir