You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

528 lines
14 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#pragma once
#include <cstddef>
#include <cstdint>
#include <iosfwd>
#include <memory>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace ir {
class Type;
class Value;
class User;
class ConstantValue;
class ConstantInt;
class ConstantFloat;
class ConstantZero;
class ConstantArray;
class GlobalValue;
class GlobalVariable;
class Argument;
class Instruction;
class BinaryInst;
class CompareInst;
class ReturnInst;
class AllocaInst;
class LoadInst;
class StoreInst;
class BranchInst;
class CondBranchInst;
class CallInst;
class GetElementPtrInst;
class CastInst;
class BasicBlock;
class Function;
class Use {
public:
Use() = default;
Use(Value* value, User* user, size_t operand_index)
: value_(value), user_(user), operand_index_(operand_index) {}
Value* GetValue() const { return value_; }
User* GetUser() const { return user_; }
size_t GetOperandIndex() const { return operand_index_; }
private:
Value* value_ = nullptr;
User* user_ = nullptr;
size_t operand_index_ = 0;
};
class Context {
public:
Context() = default;
~Context();
ConstantInt* GetConstInt(int v);
ConstantFloat* GetConstFloat(float v);
template <typename T, typename... Args>
T* CreateOwnedConstant(Args&&... args) {
auto value = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = value.get();
owned_constants_.push_back(std::move(value));
return ptr;
}
std::string NextTemp();
std::string NextBlock(const std::string& prefix);
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<uint32_t, std::unique_ptr<ConstantFloat>> const_floats_;
std::vector<std::unique_ptr<ConstantValue>> owned_constants_;
int temp_index_ = -1;
int block_index_ = -1;
};
class Type {
public:
enum class Kind { Void, Int1, Int32, Float32, Pointer, Array, Function };
explicit Type(Kind kind);
Type(Kind kind, std::shared_ptr<Type> element_type);
Type(Kind kind, std::shared_ptr<Type> element_type, size_t array_size);
Type(std::shared_ptr<Type> return_type, std::vector<std::shared_ptr<Type>> params);
static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt1Type();
static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetFloatType();
static std::shared_ptr<Type> GetPointerType(std::shared_ptr<Type> element_type);
static std::shared_ptr<Type> GetArrayType(std::shared_ptr<Type> element_type,
size_t array_size);
static std::shared_ptr<Type> GetFunctionType(
std::shared_ptr<Type> return_type,
std::vector<std::shared_ptr<Type>> param_types);
static const std::shared_ptr<Type>& GetPtrInt32Type();
Kind GetKind() const;
const std::shared_ptr<Type>& GetElementType() const;
size_t GetArraySize() const;
const std::shared_ptr<Type>& GetReturnType() const;
const std::vector<std::shared_ptr<Type>>& GetParamTypes() const;
bool IsVoid() const;
bool IsInt1() const;
bool IsInt32() const;
bool IsFloat32() const;
bool IsPointer() const;
bool IsArray() const;
bool IsFunction() const;
bool IsScalar() const;
bool IsInteger() const;
bool IsNumeric() const;
bool IsPtrInt32() const;
bool Equals(const Type& other) const;
private:
Kind kind_;
std::shared_ptr<Type> element_type_;
size_t array_size_ = 0;
std::shared_ptr<Type> return_type_;
std::vector<std::shared_ptr<Type>> param_types_;
};
class Value {
public:
Value(std::shared_ptr<Type> ty, std::string name);
virtual ~Value() = default;
const std::shared_ptr<Type>& GetType() const;
const std::string& GetName() const;
void SetName(std::string name);
bool IsVoid() const;
bool IsInt1() const;
bool IsInt32() const;
bool IsFloat32() const;
bool IsPointer() const;
bool IsArray() const;
bool IsFunctionValue() const;
bool IsPtrInt32() const;
bool IsConstant() const;
bool IsInstruction() const;
bool IsUser() const;
bool IsFunction() const;
bool IsGlobalVariable() const;
bool IsArgument() const;
void AddUse(User* user, size_t operand_index);
void RemoveUse(User* user, size_t operand_index);
const std::vector<Use>& GetUses() const;
void ReplaceAllUsesWith(Value* new_value);
protected:
std::shared_ptr<Type> type_;
std::string name_;
std::vector<Use> uses_;
};
class ConstantValue : public Value {
public:
ConstantValue(std::shared_ptr<Type> ty, std::string name = "");
virtual bool IsZeroValue() const = 0;
};
class ConstantInt : public ConstantValue {
public:
ConstantInt(std::shared_ptr<Type> ty, int value);
int GetValue() const { return value_; }
bool IsZeroValue() const override { return value_ == 0; }
private:
int value_ = 0;
};
class ConstantFloat : public ConstantValue {
public:
ConstantFloat(std::shared_ptr<Type> ty, float value);
float GetValue() const { return value_; }
bool IsZeroValue() const override { return value_ == 0.0f; }
private:
float value_ = 0.0f;
};
class ConstantZero : public ConstantValue {
public:
explicit ConstantZero(std::shared_ptr<Type> ty);
bool IsZeroValue() const override { return true; }
};
class ConstantArray : public ConstantValue {
public:
ConstantArray(std::shared_ptr<Type> ty, std::vector<ConstantValue*> elements);
const std::vector<ConstantValue*>& GetElements() const { return elements_; }
bool IsZeroValue() const override;
private:
std::vector<ConstantValue*> elements_;
};
enum class Opcode {
Add,
Sub,
Mul,
SDiv,
SRem,
FAdd,
FSub,
FMul,
FDiv,
Alloca,
Load,
Store,
ICmp,
FCmp,
Br,
CondBr,
Call,
GEP,
SIToFP,
FPToSI,
ZExt,
Ret,
};
enum class ICmpPred { Eq, Ne, Slt, Sle, Sgt, Sge };
enum class FCmpPred { Oeq, One, Olt, Ole, Ogt, Oge };
class User : public Value {
public:
User(std::shared_ptr<Type> ty, std::string name);
size_t GetNumOperands() const;
Value* GetOperand(size_t index) const;
void SetOperand(size_t index, Value* value);
protected:
void AddOperand(Value* value);
private:
std::vector<Value*> operands_;
};
class GlobalValue : public Value {
public:
GlobalValue(std::shared_ptr<Type> ty, std::string name);
};
class GlobalVariable : public GlobalValue {
public:
GlobalVariable(std::string name, std::shared_ptr<Type> value_type,
ConstantValue* initializer, bool is_constant);
const std::shared_ptr<Type>& GetValueType() const { return value_type_; }
ConstantValue* GetInitializer() const { return initializer_; }
bool IsConstant() const { return is_constant_; }
private:
std::shared_ptr<Type> value_type_;
ConstantValue* initializer_ = nullptr;
bool is_constant_ = false;
};
class Instruction : public User {
public:
Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name = "");
Opcode GetOpcode() const;
bool IsTerminator() const;
BasicBlock* GetParent() const;
void SetParent(BasicBlock* parent);
private:
Opcode opcode_;
BasicBlock* parent_ = nullptr;
};
class BinaryInst : public Instruction {
public:
BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, Value* rhs,
std::string name);
Value* GetLhs() const;
Value* GetRhs() const;
};
class CompareInst : public Instruction {
public:
CompareInst(ICmpPred pred, Value* lhs, Value* rhs, std::string name);
CompareInst(FCmpPred pred, Value* lhs, Value* rhs, std::string name);
bool IsFloatCompare() const { return is_float_compare_; }
ICmpPred GetICmpPred() const { return icmp_pred_; }
FCmpPred GetFCmpPred() const { return fcmp_pred_; }
Value* GetLhs() const;
Value* GetRhs() const;
private:
bool is_float_compare_ = false;
ICmpPred icmp_pred_ = ICmpPred::Eq;
FCmpPred fcmp_pred_ = FCmpPred::Oeq;
};
class ReturnInst : public Instruction {
public:
explicit ReturnInst(Value* value);
ReturnInst();
Value* GetValue() const;
};
class AllocaInst : public Instruction {
public:
AllocaInst(std::shared_ptr<Type> allocated_type, std::string name);
const std::shared_ptr<Type>& GetAllocatedType() const { return allocated_type_; }
private:
std::shared_ptr<Type> allocated_type_;
};
class LoadInst : public Instruction {
public:
LoadInst(Value* ptr, std::shared_ptr<Type> value_type, std::string name);
Value* GetPtr() const;
};
class StoreInst : public Instruction {
public:
StoreInst(Value* value, Value* ptr);
Value* GetValue() const;
Value* GetPtr() const;
};
class BranchInst : public Instruction {
public:
explicit BranchInst(BasicBlock* target);
BasicBlock* GetTarget() const;
};
class CondBranchInst : public Instruction {
public:
CondBranchInst(Value* cond, BasicBlock* true_block, BasicBlock* false_block);
Value* GetCond() const;
BasicBlock* GetTrueBlock() const;
BasicBlock* GetFalseBlock() const;
};
class CallInst : public Instruction {
public:
CallInst(Function* callee, std::vector<Value*> args, std::string name);
Function* GetCallee() const;
std::vector<Value*> GetArgs() const;
};
class GetElementPtrInst : public Instruction {
public:
GetElementPtrInst(Value* base_ptr, std::vector<Value*> indices,
std::shared_ptr<Type> result_type, std::string name);
Value* GetBasePtr() const;
std::vector<Value*> GetIndices() const;
std::shared_ptr<Type> GetSourceElementType() const;
};
class CastInst : public Instruction {
public:
CastInst(Opcode op, Value* value, std::shared_ptr<Type> dst_type,
std::string name);
Value* GetValue() const;
};
class BasicBlock : public Value {
public:
explicit BasicBlock(std::string name);
Function* GetParent() const;
void SetParent(Function* parent);
bool HasTerminator() const;
void AddSuccessor(BasicBlock* succ);
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const;
const std::vector<BasicBlock*>& GetPredecessors() const;
const std::vector<BasicBlock*>& GetSuccessors() const;
template <typename T, typename... Args>
T* Append(Args&&... args) {
if (HasTerminator()) {
throw std::runtime_error("BasicBlock 已有 terminator不能继续追加指令: " +
name_);
}
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.push_back(std::move(inst));
return ptr;
}
private:
Function* parent_ = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions_;
std::vector<BasicBlock*> predecessors_;
std::vector<BasicBlock*> successors_;
};
class Argument : public Value {
public:
Argument(std::shared_ptr<Type> ty, std::string name, size_t index,
Function* parent);
size_t GetIndex() const { return index_; }
Function* GetParent() const { return parent_; }
private:
size_t index_ = 0;
Function* parent_ = nullptr;
};
class Function : public GlobalValue {
public:
Function(std::string name, std::shared_ptr<Type> function_type,
bool is_declaration);
const std::shared_ptr<Type>& GetFunctionType() const;
const std::shared_ptr<Type>& GetReturnType() const;
const std::vector<std::unique_ptr<Argument>>& GetArguments() const;
bool IsDeclaration() const { return is_declaration_; }
Argument* AddArgument(std::shared_ptr<Type> ty, const std::string& name);
BasicBlock* CreateBlock(const std::string& name);
BasicBlock* GetEntry();
const BasicBlock* GetEntry() const;
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const;
private:
bool is_declaration_ = false;
BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<Argument>> arguments_;
std::vector<std::unique_ptr<BasicBlock>> blocks_;
};
class Module {
public:
Module() = default;
Context& GetContext();
const Context& GetContext() const;
GlobalVariable* CreateGlobal(std::string name, std::shared_ptr<Type> value_type,
ConstantValue* initializer, bool is_constant);
Function* CreateFunction(const std::string& name,
std::shared_ptr<Type> function_type,
bool is_declaration = false);
Function* FindFunction(const std::string& name) const;
GlobalVariable* FindGlobal(const std::string& name) const;
const std::vector<std::unique_ptr<GlobalVariable>>& GetGlobals() const;
const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
private:
Context context_;
std::vector<std::unique_ptr<GlobalVariable>> globals_;
std::vector<std::unique_ptr<Function>> functions_;
};
class IRBuilder {
public:
IRBuilder(Context& ctx, BasicBlock* bb);
void SetInsertPoint(BasicBlock* bb);
BasicBlock* GetInsertBlock() const;
ConstantInt* CreateConstInt(int v);
ConstantFloat* CreateConstFloat(float v);
ConstantValue* CreateZero(std::shared_ptr<Type> type);
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name);
AllocaInst* CreateAlloca(std::shared_ptr<Type> allocated_type,
const std::string& name);
AllocaInst* CreateAllocaI32(const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr);
CompareInst* CreateICmp(ICmpPred pred, Value* lhs, Value* rhs,
const std::string& name);
CompareInst* CreateFCmp(FCmpPred pred, Value* lhs, Value* rhs,
const std::string& name);
BranchInst* CreateBr(BasicBlock* target);
CondBranchInst* CreateCondBr(Value* cond, BasicBlock* true_block,
BasicBlock* false_block);
CallInst* CreateCall(Function* callee, const std::vector<Value*>& args,
const std::string& name);
GetElementPtrInst* CreateGEP(Value* base_ptr, const std::vector<Value*>& indices,
const std::string& name);
CastInst* CreateSIToFP(Value* value, const std::string& name);
CastInst* CreateFPToSI(Value* value, const std::string& name);
CastInst* CreateZExt(Value* value, std::shared_ptr<Type> dst_type,
const std::string& name);
ReturnInst* CreateRet(Value* value);
ReturnInst* CreateRetVoid();
private:
Context& ctx_;
BasicBlock* insert_block_ = nullptr;
};
class IRPrinter {
public:
void Print(const Module& module, std::ostream& os);
};
} // namespace ir