#pragma once #include #include #include #include #include #include #include #include #include namespace ir { class Type; class Value; class User; class ConstantValue; class ConstantInt; class ConstantFloat; class ConstantZero; class ConstantArray; class GlobalValue; class GlobalVariable; class Argument; class Instruction; class BinaryInst; class CompareInst; class ReturnInst; class AllocaInst; class LoadInst; class StoreInst; class BranchInst; class CondBranchInst; class CallInst; class GetElementPtrInst; class CastInst; class BasicBlock; class Function; class Use { public: Use() = default; Use(Value* value, User* user, size_t operand_index) : value_(value), user_(user), operand_index_(operand_index) {} Value* GetValue() const { return value_; } User* GetUser() const { return user_; } size_t GetOperandIndex() const { return operand_index_; } private: Value* value_ = nullptr; User* user_ = nullptr; size_t operand_index_ = 0; }; class Context { public: Context() = default; ~Context(); ConstantInt* GetConstInt(int v); ConstantFloat* GetConstFloat(float v); template T* CreateOwnedConstant(Args&&... args) { auto value = std::make_unique(std::forward(args)...); auto* ptr = value.get(); owned_constants_.push_back(std::move(value)); return ptr; } std::string NextTemp(); std::string NextBlock(const std::string& prefix); private: std::unordered_map> const_ints_; std::unordered_map> const_floats_; std::vector> owned_constants_; int temp_index_ = -1; int block_index_ = -1; }; class Type { public: enum class Kind { Void, Int1, Int32, Float32, Pointer, Array, Function }; explicit Type(Kind kind); Type(Kind kind, std::shared_ptr element_type); Type(Kind kind, std::shared_ptr element_type, size_t array_size); Type(std::shared_ptr return_type, std::vector> params); static const std::shared_ptr& GetVoidType(); static const std::shared_ptr& GetInt1Type(); static const std::shared_ptr& GetInt32Type(); static const std::shared_ptr& GetFloatType(); static std::shared_ptr GetPointerType(std::shared_ptr element_type); static std::shared_ptr GetArrayType(std::shared_ptr element_type, size_t array_size); static std::shared_ptr GetFunctionType( std::shared_ptr return_type, std::vector> param_types); static const std::shared_ptr& GetPtrInt32Type(); Kind GetKind() const; const std::shared_ptr& GetElementType() const; size_t GetArraySize() const; const std::shared_ptr& GetReturnType() const; const std::vector>& GetParamTypes() const; bool IsVoid() const; bool IsInt1() const; bool IsInt32() const; bool IsFloat32() const; bool IsPointer() const; bool IsArray() const; bool IsFunction() const; bool IsScalar() const; bool IsInteger() const; bool IsNumeric() const; bool IsPtrInt32() const; bool Equals(const Type& other) const; private: Kind kind_; std::shared_ptr element_type_; size_t array_size_ = 0; std::shared_ptr return_type_; std::vector> param_types_; }; class Value { public: Value(std::shared_ptr ty, std::string name); virtual ~Value() = default; const std::shared_ptr& GetType() const; const std::string& GetName() const; void SetName(std::string name); bool IsVoid() const; bool IsInt1() const; bool IsInt32() const; bool IsFloat32() const; bool IsPointer() const; bool IsArray() const; bool IsFunctionValue() const; bool IsPtrInt32() const; bool IsConstant() const; bool IsInstruction() const; bool IsUser() const; bool IsFunction() const; bool IsGlobalVariable() const; bool IsArgument() const; void AddUse(User* user, size_t operand_index); void RemoveUse(User* user, size_t operand_index); const std::vector& GetUses() const; void ReplaceAllUsesWith(Value* new_value); protected: std::shared_ptr type_; std::string name_; std::vector uses_; }; class ConstantValue : public Value { public: ConstantValue(std::shared_ptr ty, std::string name = ""); virtual bool IsZeroValue() const = 0; }; class ConstantInt : public ConstantValue { public: ConstantInt(std::shared_ptr ty, int value); int GetValue() const { return value_; } bool IsZeroValue() const override { return value_ == 0; } private: int value_ = 0; }; class ConstantFloat : public ConstantValue { public: ConstantFloat(std::shared_ptr ty, float value); float GetValue() const { return value_; } bool IsZeroValue() const override { return value_ == 0.0f; } private: float value_ = 0.0f; }; class ConstantZero : public ConstantValue { public: explicit ConstantZero(std::shared_ptr ty); bool IsZeroValue() const override { return true; } }; class ConstantArray : public ConstantValue { public: ConstantArray(std::shared_ptr ty, std::vector elements); const std::vector& GetElements() const { return elements_; } bool IsZeroValue() const override; private: std::vector elements_; }; enum class Opcode { Add, Sub, Mul, SDiv, SRem, FAdd, FSub, FMul, FDiv, Alloca, Load, Store, ICmp, FCmp, Br, CondBr, Call, GEP, SIToFP, FPToSI, ZExt, Ret, }; enum class ICmpPred { Eq, Ne, Slt, Sle, Sgt, Sge }; enum class FCmpPred { Oeq, One, Olt, Ole, Ogt, Oge }; class User : public Value { public: User(std::shared_ptr ty, std::string name); size_t GetNumOperands() const; Value* GetOperand(size_t index) const; void SetOperand(size_t index, Value* value); protected: void AddOperand(Value* value); private: std::vector operands_; }; class GlobalValue : public Value { public: GlobalValue(std::shared_ptr ty, std::string name); }; class GlobalVariable : public GlobalValue { public: GlobalVariable(std::string name, std::shared_ptr value_type, ConstantValue* initializer, bool is_constant); const std::shared_ptr& GetValueType() const { return value_type_; } ConstantValue* GetInitializer() const { return initializer_; } bool IsConstant() const { return is_constant_; } private: std::shared_ptr value_type_; ConstantValue* initializer_ = nullptr; bool is_constant_ = false; }; class Instruction : public User { public: Instruction(Opcode op, std::shared_ptr ty, std::string name = ""); Opcode GetOpcode() const; bool IsTerminator() const; BasicBlock* GetParent() const; void SetParent(BasicBlock* parent); private: Opcode opcode_; BasicBlock* parent_ = nullptr; }; class BinaryInst : public Instruction { public: BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, Value* rhs, std::string name); Value* GetLhs() const; Value* GetRhs() const; }; class CompareInst : public Instruction { public: CompareInst(ICmpPred pred, Value* lhs, Value* rhs, std::string name); CompareInst(FCmpPred pred, Value* lhs, Value* rhs, std::string name); bool IsFloatCompare() const { return is_float_compare_; } ICmpPred GetICmpPred() const { return icmp_pred_; } FCmpPred GetFCmpPred() const { return fcmp_pred_; } Value* GetLhs() const; Value* GetRhs() const; private: bool is_float_compare_ = false; ICmpPred icmp_pred_ = ICmpPred::Eq; FCmpPred fcmp_pred_ = FCmpPred::Oeq; }; class ReturnInst : public Instruction { public: explicit ReturnInst(Value* value); ReturnInst(); Value* GetValue() const; }; class AllocaInst : public Instruction { public: AllocaInst(std::shared_ptr allocated_type, std::string name); const std::shared_ptr& GetAllocatedType() const { return allocated_type_; } private: std::shared_ptr allocated_type_; }; class LoadInst : public Instruction { public: LoadInst(Value* ptr, std::shared_ptr value_type, std::string name); Value* GetPtr() const; }; class StoreInst : public Instruction { public: StoreInst(Value* value, Value* ptr); Value* GetValue() const; Value* GetPtr() const; }; class BranchInst : public Instruction { public: explicit BranchInst(BasicBlock* target); BasicBlock* GetTarget() const; }; class CondBranchInst : public Instruction { public: CondBranchInst(Value* cond, BasicBlock* true_block, BasicBlock* false_block); Value* GetCond() const; BasicBlock* GetTrueBlock() const; BasicBlock* GetFalseBlock() const; }; class CallInst : public Instruction { public: CallInst(Function* callee, std::vector args, std::string name); Function* GetCallee() const; std::vector GetArgs() const; }; class GetElementPtrInst : public Instruction { public: GetElementPtrInst(Value* base_ptr, std::vector indices, std::shared_ptr result_type, std::string name); Value* GetBasePtr() const; std::vector GetIndices() const; std::shared_ptr GetSourceElementType() const; }; class CastInst : public Instruction { public: CastInst(Opcode op, Value* value, std::shared_ptr dst_type, std::string name); Value* GetValue() const; }; class BasicBlock : public Value { public: explicit BasicBlock(std::string name); Function* GetParent() const; void SetParent(Function* parent); bool HasTerminator() const; void AddSuccessor(BasicBlock* succ); const std::vector>& GetInstructions() const; const std::vector& GetPredecessors() const; const std::vector& GetSuccessors() const; template T* Append(Args&&... args) { if (HasTerminator()) { throw std::runtime_error("BasicBlock 已有 terminator,不能继续追加指令: " + name_); } auto inst = std::make_unique(std::forward(args)...); auto* ptr = inst.get(); ptr->SetParent(this); instructions_.push_back(std::move(inst)); return ptr; } private: Function* parent_ = nullptr; std::vector> instructions_; std::vector predecessors_; std::vector successors_; }; class Argument : public Value { public: Argument(std::shared_ptr ty, std::string name, size_t index, Function* parent); size_t GetIndex() const { return index_; } Function* GetParent() const { return parent_; } private: size_t index_ = 0; Function* parent_ = nullptr; }; class Function : public GlobalValue { public: Function(std::string name, std::shared_ptr function_type, bool is_declaration); const std::shared_ptr& GetFunctionType() const; const std::shared_ptr& GetReturnType() const; const std::vector>& GetArguments() const; bool IsDeclaration() const { return is_declaration_; } Argument* AddArgument(std::shared_ptr ty, const std::string& name); BasicBlock* CreateBlock(const std::string& name); BasicBlock* GetEntry(); const BasicBlock* GetEntry() const; const std::vector>& GetBlocks() const; private: bool is_declaration_ = false; BasicBlock* entry_ = nullptr; std::vector> arguments_; std::vector> blocks_; }; class Module { public: Module() = default; Context& GetContext(); const Context& GetContext() const; GlobalVariable* CreateGlobal(std::string name, std::shared_ptr value_type, ConstantValue* initializer, bool is_constant); Function* CreateFunction(const std::string& name, std::shared_ptr function_type, bool is_declaration = false); Function* FindFunction(const std::string& name) const; GlobalVariable* FindGlobal(const std::string& name) const; const std::vector>& GetGlobals() const; const std::vector>& GetFunctions() const; private: Context context_; std::vector> globals_; std::vector> functions_; }; class IRBuilder { public: IRBuilder(Context& ctx, BasicBlock* bb); void SetInsertPoint(BasicBlock* bb); BasicBlock* GetInsertBlock() const; ConstantInt* CreateConstInt(int v); ConstantFloat* CreateConstFloat(float v); ConstantValue* CreateZero(std::shared_ptr type); BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs, const std::string& name); AllocaInst* CreateAlloca(std::shared_ptr allocated_type, const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name); StoreInst* CreateStore(Value* val, Value* ptr); CompareInst* CreateICmp(ICmpPred pred, Value* lhs, Value* rhs, const std::string& name); CompareInst* CreateFCmp(FCmpPred pred, Value* lhs, Value* rhs, const std::string& name); BranchInst* CreateBr(BasicBlock* target); CondBranchInst* CreateCondBr(Value* cond, BasicBlock* true_block, BasicBlock* false_block); CallInst* CreateCall(Function* callee, const std::vector& args, const std::string& name); GetElementPtrInst* CreateGEP(Value* base_ptr, const std::vector& indices, const std::string& name); CastInst* CreateSIToFP(Value* value, const std::string& name); CastInst* CreateFPToSI(Value* value, const std::string& name); CastInst* CreateZExt(Value* value, std::shared_ptr dst_type, const std::string& name); ReturnInst* CreateRet(Value* value); ReturnInst* CreateRetVoid(); private: Context& ctx_; BasicBlock* insert_block_ = nullptr; }; class IRPrinter { public: void Print(const Module& module, std::ostream& os); }; } // namespace ir