You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

543 lines
17 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// 扩展后的 IR 库:
// - 完整基础类型void/i1/i32/float/ptr/array/function/label
// - 指令算术、比较、分支、调用、phi、gep、类型转换等
// - 常量int/float/array
// - 基本块/函数/模块/IRBuilder 的完整接口
#pragma once
#include <cstdint>
#include <iosfwd>
#include <memory>
#include <optional>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace ir {
class Type;
class Value;
class User;
class ConstantValue;
class ConstantInt;
class ConstantFloat;
class ConstantArray;
class GlobalValue;
class GlobalVariable;
class Instruction;
class BasicBlock;
class Function;
class Argument;
// Use 表示一个 Value 的一次使用记录。
// 当前实现设计:
// - value被使用的值
// - user使用该值的 User
// - operand_index该值在 user 操作数列表中的位置
class Use {
public:
Use() = default;
Use(Value* value, User* user, size_t operand_index)
: value_(value), user_(user), operand_index_(operand_index) {}
Value* GetValue() const { return value_; }
User* GetUser() const { return user_; }
size_t GetOperandIndex() const { return operand_index_; }
void SetValue(Value* value) { value_ = value; }
void SetUser(User* user) { user_ = user; }
void SetOperandIndex(size_t operand_index) { operand_index_ = operand_index; }
private:
Value* value_ = nullptr;
User* user_ = nullptr;
size_t operand_index_ = 0;
};
// IR 上下文:集中管理类型、常量等共享资源,便于复用与扩展。
class Context {
public:
Context() = default;
~Context();
// 去重创建 i32 常量。
ConstantInt* GetConstInt(int v);
ConstantInt* GetConstBool(bool v);
ConstantFloat* GetConstFloat(float v);
ConstantArray* CreateConstArray(std::shared_ptr<Type> array_ty,
std::vector<ConstantValue*> elements);
std::string NextTemp();
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_bools_;
std::unordered_map<uint32_t, std::unique_ptr<ConstantFloat>> const_floats_;
std::vector<std::unique_ptr<ConstantArray>> const_arrays_;
int temp_index_ = -1;
};
class Type {
public:
enum class Kind { Void, Int1, Int32, Float, Pointer, Array, Function, Label };
explicit Type(Kind k);
Type(Kind k, std::shared_ptr<Type> elem, size_t count);
Type(Kind k, std::shared_ptr<Type> ret, std::vector<std::shared_ptr<Type>> params,
bool is_vararg);
// 使用静态共享对象获取类型。
// 同一类型可直接比较返回值是否相等,例如:
// Type::GetInt32Type() == Type::GetInt32Type()
static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt1Type();
static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetFloatType();
static const std::shared_ptr<Type>& GetLabelType();
static std::shared_ptr<Type> GetPointerType(std::shared_ptr<Type> elem);
static std::shared_ptr<Type> GetArrayType(std::shared_ptr<Type> elem,
size_t count);
static std::shared_ptr<Type> GetFunctionType(
std::shared_ptr<Type> ret, std::vector<std::shared_ptr<Type>> params,
bool is_vararg = false);
Kind GetKind() const;
bool IsVoid() const;
bool IsInt1() const;
bool IsInt32() const;
bool IsFloat() const;
bool IsPointer() const;
bool IsArray() const;
bool IsFunction() const;
bool IsLabel() const;
const std::shared_ptr<Type>& GetElementType() const;
size_t GetArraySize() const;
const std::shared_ptr<Type>& GetReturnType() const;
const std::vector<std::shared_ptr<Type>>& GetParamTypes() const;
bool IsVarArg() const;
bool Equals(const Type& other) const;
private:
Kind kind_;
std::shared_ptr<Type> elem_type_;
size_t array_size_ = 0;
std::shared_ptr<Type> ret_type_;
std::vector<std::shared_ptr<Type>> param_types_;
bool is_vararg_ = false;
};
class Value {
public:
Value(std::shared_ptr<Type> ty, std::string name);
virtual ~Value() = default;
const std::shared_ptr<Type>& GetType() const;
const std::string& GetName() const;
void SetName(std::string n);
bool IsVoid() const;
bool IsInt1() const;
bool IsInt32() const;
bool IsFloat() const;
bool IsPointer() const;
bool IsArray() const;
bool IsFunctionType() const;
bool IsPtrInt32() const;
bool IsConstant() const;
bool IsInstruction() const;
bool IsUser() const;
bool IsFunction() const;
void AddUse(User* user, size_t operand_index);
void RemoveUse(User* user, size_t operand_index);
const std::vector<Use>& GetUses() const;
void ReplaceAllUsesWith(Value* new_value);
protected:
std::shared_ptr<Type> type_;
std::string name_;
std::vector<Use> uses_;
};
// ConstantValue 是常量体系的基类。
// 当前只实现了 ConstantInt后续可继续扩展更多常量种类。
class ConstantValue : public Value {
public:
ConstantValue(std::shared_ptr<Type> ty, std::string name = "");
};
class ConstantInt : public ConstantValue {
public:
ConstantInt(std::shared_ptr<Type> ty, int v);
int GetValue() const { return value_; }
private:
int value_{};
};
class ConstantFloat : public ConstantValue {
public:
ConstantFloat(std::shared_ptr<Type> ty, float v);
float GetValue() const { return value_; }
private:
float value_{};
};
class ConstantArray : public ConstantValue {
public:
ConstantArray(std::shared_ptr<Type> ty, std::vector<ConstantValue*> elements);
const std::vector<ConstantValue*>& GetElements() const { return elements_; }
private:
std::vector<ConstantValue*> elements_;
};
// 后续还需要扩展更多指令类型。
enum class Opcode {
Add,
Sub,
Mul,
SDiv,
SRem,
FAdd,
FSub,
FMul,
FDiv,
Alloca,
Load,
Store,
Ret,
Br,
CondBr,
ICmp,
FCmp,
Call,
Phi,
Gep,
SIToFP,
FPToSI,
ZExt
};
enum class ICmpPredicate { Eq, Ne, Slt, Sle, Sgt, Sge };
enum class FCmpPredicate { Oeq, One, Olt, Ole, Ogt, Oge };
// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。
// 当前实现中只有 Instruction 继承自 User。
class User : public Value {
public:
User(std::shared_ptr<Type> ty, std::string name);
size_t GetNumOperands() const;
Value* GetOperand(size_t index) const;
void SetOperand(size_t index, Value* value);
protected:
// 统一的 operand 入口。
void AddOperand(Value* value);
private:
std::vector<Value*> operands_;
};
// GlobalValue 是全局值/全局变量体系的空壳占位类。
// 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。
class GlobalValue : public User {
public:
GlobalValue(std::shared_ptr<Type> ty, std::string name);
};
class GlobalVariable : public GlobalValue {
public:
GlobalVariable(std::shared_ptr<Type> value_ty, std::string name,
ConstantValue* init, bool is_const);
const std::shared_ptr<Type>& GetValueType() const;
ConstantValue* GetInitializer() const;
bool IsConst() const;
private:
std::shared_ptr<Type> value_type_;
ConstantValue* initializer_ = nullptr;
bool is_const_ = false;
};
class Instruction : public User {
public:
Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name = "");
Opcode GetOpcode() const;
bool IsTerminator() const;
BasicBlock* GetParent() const;
void SetParent(BasicBlock* parent);
private:
Opcode opcode_;
BasicBlock* parent_ = nullptr;
};
class BinaryInst : public Instruction {
public:
BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, Value* rhs,
std::string name);
Value* GetLhs() const;
Value* GetRhs() const;
};
class ICmpInst : public Instruction {
public:
ICmpInst(ICmpPredicate pred, Value* lhs, Value* rhs, std::string name);
ICmpPredicate GetPredicate() const { return pred_; }
Value* GetLhs() const;
Value* GetRhs() const;
private:
ICmpPredicate pred_;
};
class FCmpInst : public Instruction {
public:
FCmpInst(FCmpPredicate pred, Value* lhs, Value* rhs, std::string name);
FCmpPredicate GetPredicate() const { return pred_; }
Value* GetLhs() const;
Value* GetRhs() const;
private:
FCmpPredicate pred_;
};
class CastInst : public Instruction {
public:
CastInst(Opcode op, std::shared_ptr<Type> dst_ty, Value* src,
std::string name);
Value* GetValue() const;
};
class BranchInst : public Instruction {
public:
explicit BranchInst(BasicBlock* dest);
BasicBlock* GetDest() const;
};
class CondBrInst : public Instruction {
public:
CondBrInst(Value* cond, BasicBlock* true_dest, BasicBlock* false_dest);
Value* GetCond() const;
BasicBlock* GetTrueDest() const;
BasicBlock* GetFalseDest() const;
};
class ReturnInst : public Instruction {
public:
explicit ReturnInst(std::shared_ptr<Type> void_ty);
ReturnInst(std::shared_ptr<Type> void_ty, Value* val);
bool HasReturnValue() const;
Value* GetValue() const;
};
class AllocaInst : public Instruction {
public:
AllocaInst(std::shared_ptr<Type> allocated_ty, std::string name);
const std::shared_ptr<Type>& GetAllocatedType() const;
private:
std::shared_ptr<Type> allocated_type_;
};
class LoadInst : public Instruction {
public:
LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name);
Value* GetPtr() const;
};
class StoreInst : public Instruction {
public:
StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr);
Value* GetValue() const;
Value* GetPtr() const;
};
class CallInst : public Instruction {
public:
CallInst(std::shared_ptr<Type> ret_ty, Value* callee,
std::vector<Value*> args, std::string name);
Value* GetCallee() const;
const std::vector<Value*>& GetArgs() const { return args_; }
private:
std::vector<Value*> args_;
};
class PhiInst : public Instruction {
public:
PhiInst(std::shared_ptr<Type> ty, std::string name);
void AddIncoming(Value* value, BasicBlock* block);
const std::vector<Value*>& GetIncomingValues() const;
const std::vector<BasicBlock*>& GetIncomingBlocks() const;
private:
std::vector<Value*> incoming_values_;
std::vector<BasicBlock*> incoming_blocks_;
};
class GepInst : public Instruction {
public:
GepInst(std::shared_ptr<Type> result_ptr_ty, Value* base_ptr,
std::vector<Value*> indices, std::string name);
Value* GetBasePtr() const;
const std::vector<Value*>& GetIndices() const { return indices_; }
private:
std::vector<Value*> indices_;
};
// BasicBlock 已纳入 Value 体系,使用 label type。
class BasicBlock : public Value {
public:
explicit BasicBlock(std::string name);
Function* GetParent() const;
void SetParent(Function* parent);
bool HasTerminator() const;
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const;
const std::vector<BasicBlock*>& GetPredecessors() const;
const std::vector<BasicBlock*>& GetSuccessors() const;
void AddPredecessor(BasicBlock* pred);
void AddSuccessor(BasicBlock* succ);
template <typename T, typename... Args>
T* Append(Args&&... args) {
if (HasTerminator()) {
throw std::runtime_error("BasicBlock 已有 terminator不能继续追加指令: " +
name_);
}
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.push_back(std::move(inst));
LinkSuccessorsIfNeeded(ptr);
return ptr;
}
private:
Function* parent_ = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions_;
std::vector<BasicBlock*> predecessors_;
std::vector<BasicBlock*> successors_;
void LinkSuccessorsIfNeeded(Instruction* inst);
};
// Function 当前也采用了最小实现。
// 需要特别注意:由于项目里还没有单独的 FunctionType
// Function 继承自 Value 后,其 type_ 目前只保存“返回类型”,
// 并不能完整表达“返回类型 + 形参列表”这一整套函数签名。
// 这对当前只支持 int main() 的最小 IR 足够,但后续若补普通函数、
// 形参和调用,通常需要引入专门的函数类型表示。
class Function : public Value {
public:
Function(std::string name, std::shared_ptr<Type> func_type,
bool is_declaration = false);
BasicBlock* CreateBlock(const std::string& name);
BasicBlock* GetEntry();
const BasicBlock* GetEntry() const;
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const;
const std::vector<std::unique_ptr<Argument>>& GetArguments() const;
size_t GetNumArgs() const;
Argument* GetArg(size_t index);
std::shared_ptr<Type> GetFunctionType() const;
std::shared_ptr<Type> GetReturnType() const;
bool IsDeclaration() const;
private:
BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<BasicBlock>> blocks_;
std::vector<std::unique_ptr<Argument>> args_;
std::unordered_map<std::string, size_t> block_name_counts_;
bool is_declaration_ = false;
};
class Argument : public Value {
public:
Argument(std::shared_ptr<Type> ty, std::string name, size_t index);
size_t GetIndex() const { return index_; }
private:
size_t index_ = 0;
};
class Module {
public:
Module() = default;
Context& GetContext();
const Context& GetContext() const;
// 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。
Function* CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type);
Function* CreateFunctionWithType(const std::string& name,
std::shared_ptr<Type> func_type);
Function* CreateFunctionDecl(const std::string& name,
std::shared_ptr<Type> func_type);
GlobalVariable* CreateGlobalVariable(const std::string& name,
std::shared_ptr<Type> value_type,
ConstantValue* init, bool is_const);
const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
const std::vector<std::unique_ptr<GlobalVariable>>& GetGlobals() const;
private:
Context context_;
std::vector<std::unique_ptr<Function>> functions_;
std::vector<std::unique_ptr<GlobalVariable>> globals_;
};
class IRBuilder {
public:
IRBuilder(Context& ctx, BasicBlock* bb);
void SetInsertPoint(BasicBlock* bb);
BasicBlock* GetInsertBlock() const;
// 构造常量、二元运算、返回指令的最小集合。
ConstantInt* CreateConstInt(int v);
ConstantInt* CreateConstBool(bool v);
ConstantFloat* CreateConstFloat(float v);
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name);
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSDiv(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSRem(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFAdd(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFSub(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFMul(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFDiv(Value* lhs, Value* rhs, const std::string& name);
ICmpInst* CreateICmp(ICmpPredicate pred, Value* lhs, Value* rhs,
const std::string& name);
FCmpInst* CreateFCmp(FCmpPredicate pred, Value* lhs, Value* rhs,
const std::string& name);
CastInst* CreateSIToFP(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name);
CastInst* CreateFPToSI(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name);
CastInst* CreateZExt(Value* src, std::shared_ptr<Type> dst_ty,
const std::string& name);
AllocaInst* CreateAlloca(std::shared_ptr<Type> ty,
const std::string& name);
AllocaInst* CreateAllocaI32(const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr);
GepInst* CreateGep(Value* base_ptr, std::vector<Value*> indices,
const std::string& name);
CallInst* CreateCall(Value* callee, std::vector<Value*> args,
const std::string& name);
PhiInst* CreatePhi(std::shared_ptr<Type> ty, const std::string& name);
BranchInst* CreateBr(BasicBlock* dest);
CondBrInst* CreateCondBr(Value* cond, BasicBlock* true_dest,
BasicBlock* false_dest);
ReturnInst* CreateRet(Value* v);
ReturnInst* CreateRetVoid();
private:
Context& ctx_;
BasicBlock* insert_block_;
};
class IRPrinter {
public:
void Print(const Module& module, std::ostream& os);
};
} // namespace ir