You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

759 lines
24 KiB

#pragma once
#include "utils.h"
#include <iosfwd>
#include <map>
#include <memory>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace ir {
class Value;
class User;
class BasicBlock;
class Function;
class Instruction;
class Argument;
class ConstantInt;
class ConstantFloat;
class ConstantI1;
class ConstantArrayValue;
class Type;
class Use {
public:
Use() = default;
Use(Value* value, User* user, size_t operand_index)
: value_(value), user_(user), operand_index_(operand_index) {}
Value* GetValue() const { return value_; }
User* GetUser() const { return user_; }
size_t GetOperandIndex() const { return operand_index_; }
void SetValue(Value* value) { value_ = value; }
void SetUser(User* user) { user_ = user; }
void SetOperandIndex(size_t operand_index) { operand_index_ = operand_index; }
private:
Value* value_ = nullptr;
User* user_ = nullptr;
size_t operand_index_ = 0;
};
class Context {
public:
Context() = default;
~Context();
ConstantInt* GetConstInt(int v);
ConstantI1* GetConstBool(bool v);
std::string NextTemp();
std::string NextBlockName(const std::string& prefix = "bb");
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<bool, std::unique_ptr<ConstantI1>> const_bools_;
int temp_index_ = -1;
int block_index_ = -1;
};
class Type {
public:
enum class Kind {
Void,
Int1,
Int32,
Float,
Label,
Function,
Pointer,
PtrInt32 = Pointer,
Array
};
explicit Type(Kind kind);
Type(Kind kind, std::shared_ptr<Type> element_type, size_t num_elements = 0);
static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt1Type();
static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetFloatType();
static const std::shared_ptr<Type>& GetLabelType();
static const std::shared_ptr<Type>& GetBoolType();
static std::shared_ptr<Type> GetPointerType(std::shared_ptr<Type> pointee = nullptr);
static const std::shared_ptr<Type>& GetPtrInt32Type();
static std::shared_ptr<Type> GetArrayType(std::shared_ptr<Type> element_type,
size_t num_elements);
Kind GetKind() const { return kind_; }
bool IsVoid() const { return kind_ == Kind::Void; }
bool IsInt1() const { return kind_ == Kind::Int1; }
bool IsInt32() const { return kind_ == Kind::Int32; }
bool IsFloat() const { return kind_ == Kind::Float; }
bool IsLabel() const { return kind_ == Kind::Label; }
bool IsFunction() const { return kind_ == Kind::Function; }
bool IsBool() const { return kind_ == Kind::Int1; }
bool IsPointer() const { return kind_ == Kind::Pointer; }
bool IsPtrInt32() const { return IsPointer(); }
bool IsArray() const { return kind_ == Kind::Array; }
std::shared_ptr<Type> GetElementType() const { return element_type_; }
size_t GetNumElements() const { return num_elements_; }
int GetSize() const;
void Print(std::ostream& os) const;
private:
Kind kind_;
std::shared_ptr<Type> element_type_;
size_t num_elements_ = 0;
};
class Value {
public:
Value(std::shared_ptr<Type> ty, std::string name);
virtual ~Value() = default;
const std::shared_ptr<Type>& GetType() const { return type_; }
const std::string& GetName() const { return name_; }
void SetName(std::string name) { name_ = std::move(name); }
bool IsVoid() const { return type_ && type_->IsVoid(); }
bool IsInt32() const { return type_ && type_->IsInt32(); }
bool IsPtrInt32() const { return type_ && type_->IsPtrInt32(); }
bool IsFloat() const { return type_ && type_->IsFloat(); }
bool IsBool() const { return type_ && type_->IsBool(); }
bool IsArray() const { return type_ && type_->IsArray(); }
bool IsLabel() const { return type_ && type_->IsLabel(); }
virtual bool IsConstant() const { return false; }
virtual bool IsInstruction() const { return false; }
virtual bool IsUser() const { return false; }
virtual bool IsFunction() const { return false; }
virtual bool IsArgument() const { return false; }
void AddUse(User* user, size_t operand_index);
void RemoveUse(User* user, size_t operand_index);
const std::vector<Use>& GetUses() const { return uses_; }
void ReplaceAllUsesWith(Value* new_value);
virtual void Print(std::ostream& os) const;
protected:
std::shared_ptr<Type> type_;
std::string name_;
std::vector<Use> uses_;
};
template <typename T>
inline bool isa(const Value* value) {
return value && T::classof(value);
}
template <typename T>
inline T* dyncast(Value* value) {
return isa<T>(value) ? dynamic_cast<T*>(value) : nullptr;
}
template <typename T>
inline const T* dyncast(const Value* value) {
return isa<T>(value) ? dynamic_cast<const T*>(value) : nullptr;
}
class ConstantValue : public Value {
public:
ConstantValue(std::shared_ptr<Type> ty, std::string name = "");
bool IsConstant() const override final { return true; }
};
class ConstantInt : public ConstantValue {
public:
ConstantInt(std::shared_ptr<Type> ty, int value);
int GetValue() const { return value_; }
static bool classof(const Value* value) {
return value && value->IsConstant() &&
dynamic_cast<const ConstantInt*>(value) != nullptr;
}
private:
int value_;
};
class ConstantFloat : public ConstantValue {
public:
ConstantFloat(std::shared_ptr<Type> ty, float value);
float GetValue() const { return value_; }
static bool classof(const Value* value) {
return value && value->IsConstant() &&
dynamic_cast<const ConstantFloat*>(value) != nullptr;
}
private:
float value_;
};
class ConstantI1 : public ConstantValue {
public:
ConstantI1(std::shared_ptr<Type> ty, bool value);
bool GetValue() const { return value_; }
static bool classof(const Value* value) {
return value && value->IsConstant() &&
dynamic_cast<const ConstantI1*>(value) != nullptr;
}
private:
bool value_;
};
class ConstantArrayValue : public Value {
public:
ConstantArrayValue(std::shared_ptr<Type> array_type,
const std::vector<Value*>& elements,
const std::vector<size_t>& dims,
const std::string& name = "");
const std::vector<Value*>& GetElements() const { return elements_; }
const std::vector<size_t>& GetDims() const { return dims_; }
void Print(std::ostream& os) const override;
static bool classof(const Value* value) {
return value && dynamic_cast<const ConstantArrayValue*>(value) != nullptr;
}
private:
std::vector<Value*> elements_;
std::vector<size_t> dims_;
};
enum class Opcode {
Add,
Sub,
Mul,
Div,
Rem,
FAdd,
FSub,
FMul,
FDiv,
FRem,
And,
Or,
Xor,
Shl,
AShr,
LShr,
ICmpEQ,
ICmpNE,
ICmpLT,
ICmpGT,
ICmpLE,
ICmpGE,
FCmpEQ,
FCmpNE,
FCmpLT,
FCmpGT,
FCmpLE,
FCmpGE,
Neg,
Not,
FNeg,
FtoI,
IToF,
Call,
CondBr,
Br,
Return,
Ret = Return,
Unreachable,
Alloca,
Load,
Store,
Memset,
GetElementPtr,
Phi,
Zext
};
class User : public Value {
public:
User(std::shared_ptr<Type> ty, std::string name);
bool IsUser() const override final { return true; }
size_t GetNumOperands() const { return operands_.size(); }
Value* GetOperand(size_t index) const;
void SetOperand(size_t index, Value* value);
void AddOperand(Value* value);
void AddOperands(const std::vector<Value*>& values);
void RemoveOperand(size_t index);
void ClearAllOperands();
protected:
std::vector<Use> operands_;
};
class Argument : public Value {
public:
Argument(std::shared_ptr<Type> type, std::string name, size_t index);
size_t GetIndex() const { return index_; }
bool IsArgument() const override final { return true; }
static bool classof(const Value* value) {
return value && dynamic_cast<const Argument*>(value) != nullptr;
}
private:
size_t index_;
};
class GlobalValue : public User {
public:
GlobalValue(std::shared_ptr<Type> object_type, const std::string& name,
bool is_const = false, Value* init = nullptr);
bool IsConstant() const override { return is_const_; }
bool HasInitializer() const { return init_ != nullptr; }
Value* GetInitializer() const { return init_; }
std::shared_ptr<Type> GetObjectType() const { return object_type_; }
void SetConstant(bool is_const) { is_const_ = is_const; }
void SetInitializer(Value* init) { init_ = init; }
static bool classof(const Value* value) {
return value && dynamic_cast<const GlobalValue*>(value) != nullptr;
}
private:
std::shared_ptr<Type> object_type_;
bool is_const_ = false;
Value* init_ = nullptr;
};
class Instruction : public User {
public:
Instruction(Opcode opcode, std::shared_ptr<Type> ty,
BasicBlock* parent = nullptr, const std::string& name = "");
bool IsInstruction() const override final { return true; }
Opcode GetOpcode() const { return opcode_; }
bool IsTerminator() const;
BasicBlock* GetParent() const { return parent_; }
void SetParent(BasicBlock* parent) { parent_ = parent; }
static bool classof(const Value* value) {
return value && value->IsInstruction();
}
private:
Opcode opcode_;
BasicBlock* parent_;
};
class BinaryInst : public Instruction {
public:
BinaryInst(Opcode opcode, std::shared_ptr<Type> ty, Value* lhs, Value* rhs,
BasicBlock* parent = nullptr, const std::string& name = "");
Value* GetLhs() const { return GetOperand(0); }
Value* GetRhs() const { return GetOperand(1); }
static bool classof(const Value* value);
};
class UnaryInst : public Instruction {
public:
UnaryInst(Opcode opcode, std::shared_ptr<Type> ty, Value* operand,
BasicBlock* parent = nullptr, const std::string& name = "");
Value* GetOprd() const { return GetOperand(0); }
static bool classof(const Value* value);
};
class ReturnInst : public Instruction {
public:
ReturnInst(Value* value = nullptr, BasicBlock* parent = nullptr);
bool HasReturnValue() const { return GetNumOperands() > 0; }
Value* GetReturnValue() const {
return HasReturnValue() ? GetOperand(0) : nullptr;
}
Value* GetValue() const { return GetReturnValue(); }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Return;
}
};
class AllocaInst : public Instruction {
public:
AllocaInst(std::shared_ptr<Type> allocated_type, BasicBlock* parent = nullptr,
const std::string& name = "");
std::shared_ptr<Type> GetAllocatedType() const { return allocated_type_; }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Alloca;
}
private:
std::shared_ptr<Type> allocated_type_;
};
class LoadInst : public Instruction {
public:
LoadInst(std::shared_ptr<Type> value_type, Value* ptr,
BasicBlock* parent = nullptr, const std::string& name = "");
Value* GetPtr() const { return GetOperand(0); }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Load;
}
};
class StoreInst : public Instruction {
public:
StoreInst(Value* value, Value* ptr, BasicBlock* parent = nullptr);
Value* GetValue() const { return GetOperand(0); }
Value* GetPtr() const { return GetOperand(1); }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Store;
}
};
class UncondBrInst : public Instruction {
public:
UncondBrInst(BasicBlock* dest, BasicBlock* parent = nullptr);
BasicBlock* GetDest() const;
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Br;
}
};
class CondBrInst : public Instruction {
public:
CondBrInst(Value* cond, BasicBlock* then_block, BasicBlock* else_block,
BasicBlock* parent = nullptr);
Value* GetCondition() const { return GetOperand(0); }
BasicBlock* GetThenBlock() const;
BasicBlock* GetElseBlock() const;
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::CondBr;
}
};
class UnreachableInst : public Instruction {
public:
explicit UnreachableInst(BasicBlock* parent = nullptr);
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Unreachable;
}
};
class CallInst : public Instruction {
public:
CallInst(Function* callee, const std::vector<Value*>& args = {},
BasicBlock* parent = nullptr, const std::string& name = "");
Function* GetCallee() const;
std::vector<Value*> GetArguments() const;
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Call;
}
};
class GetElementPtrInst : public Instruction {
public:
GetElementPtrInst(std::shared_ptr<Type> source_type, Value* ptr,
const std::vector<Value*>& indices,
BasicBlock* parent = nullptr,
const std::string& name = "");
Value* GetPointer() const { return GetOperand(0); }
size_t GetNumIndices() const {
return GetNumOperands() > 0 ? GetNumOperands() - 1 : 0;
}
Value* GetIndex(size_t index) const { return GetOperand(index + 1); }
std::shared_ptr<Type> GetSourceType() const { return source_type_; }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() ==
Opcode::GetElementPtr;
}
private:
std::shared_ptr<Type> source_type_;
};
class PhiInst : public Instruction {
public:
PhiInst(std::shared_ptr<Type> type, BasicBlock* parent = nullptr,
const std::string& name = "");
void AddIncoming(Value* value, BasicBlock* block);
int GetNumIncomings() const {
return static_cast<int>(GetNumOperands() / 2);
}
Value* GetIncomingValue(int index) const {
return GetOperand(static_cast<size_t>(2 * index));
}
BasicBlock* GetIncomingBlock(int index) const;
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Phi;
}
};
class ZextInst : public Instruction {
public:
ZextInst(Value* value, std::shared_ptr<Type> target_type,
BasicBlock* parent = nullptr, const std::string& name = "");
Value* GetValue() const { return GetOperand(0); }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Zext;
}
};
class MemsetInst : public Instruction {
public:
MemsetInst(Value* dst, Value* value, Value* len, Value* is_volatile,
BasicBlock* parent = nullptr);
Value* GetDest() const { return GetOperand(0); }
Value* GetValue() const { return GetOperand(1); }
Value* GetLength() const { return GetOperand(2); }
Value* GetIsVolatile() const { return GetOperand(3); }
static bool classof(const Value* value) {
return value && Instruction::classof(value) &&
static_cast<const Instruction*>(value)->GetOpcode() == Opcode::Memset;
}
};
class BasicBlock : public Value {
public:
explicit BasicBlock(const std::string& name);
BasicBlock(Function* parent, const std::string& name);
Function* GetParent() const { return parent_; }
void SetParent(Function* parent) { parent_ = parent; }
bool HasTerminator() const;
std::vector<std::unique_ptr<Instruction>>& GetInstructions() {
return instructions_;
}
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const {
return instructions_;
}
void EraseInstruction(Instruction* inst);
void AddPredecessor(BasicBlock* pred);
void AddSuccessor(BasicBlock* succ);
void RemovePredecessor(BasicBlock* pred);
void RemoveSuccessor(BasicBlock* succ);
const std::vector<BasicBlock*>& GetPredecessors() const {
return predecessors_;
}
const std::vector<BasicBlock*>& GetSuccessors() const {
return successors_;
}
template <typename T, typename... Args>
T* Insert(size_t index, Args&&... args) {
if (index > instructions_.size()) {
throw std::out_of_range("BasicBlock insert index out of range");
}
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.insert(instructions_.begin() + static_cast<long long>(index),
std::move(inst));
return ptr;
}
template <typename T, typename... Args>
T* Append(Args&&... args) {
if (HasTerminator()) {
throw std::runtime_error("BasicBlock already has terminator");
}
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.push_back(std::move(inst));
return ptr;
}
static bool classof(const Value* value) {
return value && dynamic_cast<const BasicBlock*>(value) != nullptr;
}
private:
Function* parent_ = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions_;
std::vector<BasicBlock*> predecessors_;
std::vector<BasicBlock*> successors_;
};
class Function : public Value {
public:
Function(std::string name, std::shared_ptr<Type> ret_type,
const std::vector<std::shared_ptr<Type>>& param_types = {},
const std::vector<std::string>& param_names = {},
bool is_external = false);
bool IsFunction() const override final { return true; }
std::shared_ptr<Type> GetReturnType() const { return return_type_; }
const std::vector<std::shared_ptr<Type>>& GetParamTypes() const {
return param_types_;
}
const std::vector<std::unique_ptr<Argument>>& GetArguments() const {
return arguments_;
}
Argument* GetArgument(size_t index) const;
bool IsExternal() const { return is_external_; }
void SetExternal(bool is_external) { is_external_ = is_external; }
BasicBlock* GetEntryBlock() const { return entry_; }
BasicBlock* GetEntry() const { return entry_; }
void SetEntryBlock(BasicBlock* bb) { entry_ = bb; }
BasicBlock* EnsureEntryBlock();
BasicBlock* CreateBlock(const std::string& name);
BasicBlock* AddBlock(std::unique_ptr<BasicBlock> block);
std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() { return blocks_; }
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const {
return blocks_;
}
static bool classof(const Value* value) {
return value && value->IsFunction();
}
private:
std::shared_ptr<Type> return_type_;
std::vector<std::shared_ptr<Type>> param_types_;
std::vector<std::unique_ptr<Argument>> arguments_;
bool is_external_ = false;
BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<BasicBlock>> blocks_;
};
class Module {
public:
Module() = default;
Context& GetContext() { return context_; }
const Context& GetContext() const { return context_; }
Function* CreateFunction(const std::string& name, std::shared_ptr<Type> ret_type,
const std::vector<std::shared_ptr<Type>>& param_types = {},
const std::vector<std::string>& param_names = {},
bool is_external = false);
Function* GetFunction(const std::string& name) const;
const std::vector<std::unique_ptr<Function>>& GetFunctions() const {
return functions_;
}
GlobalValue* CreateGlobalValue(const std::string& name,
std::shared_ptr<Type> object_type,
bool is_const = false, Value* init = nullptr);
GlobalValue* GetGlobalValue(const std::string& name) const;
const std::vector<std::unique_ptr<GlobalValue>>& GetGlobalValues() const {
return globals_;
}
private:
Context context_;
std::vector<std::unique_ptr<Function>> functions_;
std::map<std::string, Function*> function_map_;
std::vector<std::unique_ptr<GlobalValue>> globals_;
std::map<std::string, GlobalValue*> global_map_;
};
class IRBuilder {
public:
IRBuilder(Context& ctx, BasicBlock* bb);
void SetInsertPoint(BasicBlock* bb);
BasicBlock* GetInsertBlock() const { return insert_block_; }
ConstantInt* CreateConstInt(int v);
ConstantFloat* CreateConstFloat(float v);
ConstantI1* CreateConstBool(bool v);
ConstantArrayValue* CreateConstArray(std::shared_ptr<Type> array_type,
const std::vector<Value*>& elements,
const std::vector<size_t>& dims,
const std::string& name = "");
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name = "");
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateDiv(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateRem(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateAnd(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateOr(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateXor(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateShl(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateAShr(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateLShr(Value* lhs, Value* rhs, const std::string& name = "");
BinaryInst* CreateICmp(Opcode op, Value* lhs, Value* rhs,
const std::string& name = "");
BinaryInst* CreateFCmp(Opcode op, Value* lhs, Value* rhs,
const std::string& name = "");
UnaryInst* CreateNeg(Value* operand, const std::string& name = "");
UnaryInst* CreateNot(Value* operand, const std::string& name = "");
UnaryInst* CreateFNeg(Value* operand, const std::string& name = "");
UnaryInst* CreateFtoI(Value* operand, const std::string& name = "");
UnaryInst* CreateIToF(Value* operand, const std::string& name = "");
AllocaInst* CreateAlloca(std::shared_ptr<Type> allocated_type,
const std::string& name = "");
LoadInst* CreateLoad(Value* ptr, std::shared_ptr<Type> value_type,
const std::string& name = "");
LoadInst* CreateLoad(Value* ptr, const std::string& name = "") {
return CreateLoad(ptr, Type::GetInt32Type(), name);
}
StoreInst* CreateStore(Value* val, Value* ptr);
UncondBrInst* CreateBr(BasicBlock* dest);
CondBrInst* CreateCondBr(Value* cond, BasicBlock* then_bb,
BasicBlock* else_bb);
ReturnInst* CreateRet(Value* val = nullptr);
UnreachableInst* CreateUnreachable();
CallInst* CreateCall(Function* callee, const std::vector<Value*>& args,
const std::string& name = "");
GetElementPtrInst* CreateGEP(Value* ptr, std::shared_ptr<Type> source_type,
const std::vector<Value*>& indices,
const std::string& name = "");
PhiInst* CreatePhi(std::shared_ptr<Type> type, const std::string& name = "");
ZextInst* CreateZext(Value* val, std::shared_ptr<Type> target_type,
const std::string& name = "");
MemsetInst* CreateMemset(Value* dst, Value* val, Value* len,
Value* is_volatile);
private:
Context& ctx_;
BasicBlock* insert_block_;
};
class IRPrinter {
public:
void Print(const Module& module, std::ostream& os);
};
inline std::ostream& operator<<(std::ostream& os, const Type& type) {
type.Print(os);
return os;
}
inline std::ostream& operator<<(std::ostream& os, const Value& value) {
value.Print(os);
return os;
}
} // namespace ir