You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/include/ir/IR.h

550 lines
16 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// 当前只支撑 i32、i32*、void 以及最小的内存/算术指令,演示用。
//
// 当前已经实现:
// 1. 基础类型系统void / i32 / i32*
// 2. Value 体系Value / ConstantValue / ConstantInt / Function / BasicBlock / User / GlobalValue / Instruction
// 3. 最小指令集Add / Alloca / Load / Store / Ret
// 4. BasicBlock / Function / Module 三层组织结构
// 5. IRBuilder便捷创建常量和最小指令
// 6. def-use 关系的轻量实现:
// - Instruction 保存 operand 列表
// - Value 保存 uses
// - 支持 ReplaceAllUsesWith 的简化实现
//
// 当前尚未实现或只做了最小占位:
// 1. 完整类型系统数组、函数类型、label 类型等
// 2. 更完整的指令系统br / condbr / call / phi / gep 等
// 3. 更成熟的 Use 管理(例如 LLVM 风格的双向链式结构)
// 4. 更完整的 IR verifier 和优化基础设施
//
// 当前需要特别说明的两个简化点:
// 1. BasicBlock 虽然已经纳入 Value 体系,但其类型目前仍用 void 作为占位,
// 后续如果补 label type可以再改成更合理的块标签类型。
// 2. ConstantValue 体系目前只实现了 ConstantInt后续可以继续补 ConstantFloat、
// ConstantArray等更完整的常量种类。
//
// 建议的扩展顺序:
// 1. 先补更多指令和类型
// 2. 再补控制流相关 IR
// 3. 最后再考虑把 Value/User/Use 进一步抽象成更完整的框架
#pragma once
#include <iosfwd>
#include <memory>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <variant>
namespace ir {
class Type;
class Value;
class User;
class ConstantValue;
class ConstantInt;
class ConstantFloat;
class ConstantArray;
class GlobalValue;
class Instruction;
class BasicBlock;
class Function;
class Module;
// ======================== Use 类 ========================
class Use {
public:
Use() = default;
Use(Value* value, User* user, size_t operand_index)
: value_(value), user_(user), operand_index_(operand_index) {}
Value* GetValue() const { return value_; }
User* GetUser() const { return user_; }
size_t GetOperandIndex() const { return operand_index_; }
void SetValue(Value* value) { value_ = value; }
void SetUser(User* user) { user_ = user; }
void SetOperandIndex(size_t operand_index) { operand_index_ = operand_index; }
private:
Value* value_ = nullptr;
User* user_ = nullptr;
size_t operand_index_ = 0;
};
// ======================== Context 类 ========================
class Context {
public:
Context() = default;
~Context();
// 常量创建
ConstantInt* GetConstInt(int v);
ConstantFloat* GetConstFloat(float v);
// 临时变量名生成
std::string NextTemp();
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<float, std::unique_ptr<ConstantFloat>> const_floats_;
int temp_index_ = -1;
};
// ======================== Type 类型体系 ========================
// 类型基类,支持参数化类型
class Type {
public:
enum class Kind {
Void,
Int32,
Float32,
Pointer,
Array,
Function,
Label
};
explicit Type(Kind k);
virtual ~Type() = default;
Kind GetKind() const;
bool IsVoid() const;
bool IsInt32() const;
bool IsFloat32() const;
bool IsPointer() const;
bool IsArray() const;
bool IsFunction() const;
bool IsLabel() const;
bool IsPtrInt32() const; // 兼容旧接口
bool IsPtrFloat32() const; // 判断是否为 float32*
bool operator==(const Type& other) const;
bool operator!=(const Type& other) const;
// 静态单例获取基础类型
static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetFloat32Type();
static const std::shared_ptr<Type>& GetLabelType();
static const std::shared_ptr<Type>& GetPtrInt32Type();
// 复合类型工厂方法
static std::shared_ptr<Type> GetPointerType(std::shared_ptr<Type> pointee);
static std::shared_ptr<Type> GetArrayType(std::shared_ptr<Type> elem, size_t size);
static std::shared_ptr<Type> GetFunctionType(std::shared_ptr<Type> ret,
std::vector<std::shared_ptr<Type>> params);
private:
Kind kind_;
};
// 指针类型
class PointerType : public Type {
public:
PointerType(std::shared_ptr<Type> pointee)
: Type(Type::Kind::Pointer), pointee_(std::move(pointee)) {}
const std::shared_ptr<Type>& GetPointeeType() const { return pointee_; }
private:
std::shared_ptr<Type> pointee_;
};
// 数组类型
class ArrayType : public Type {
public:
ArrayType(std::shared_ptr<Type> elem, size_t size)
: Type(Type::Kind::Array), elem_type_(std::move(elem)), size_(size) {}
const std::shared_ptr<Type>& GetElementType() const { return elem_type_; }
size_t GetSize() const { return size_; }
private:
std::shared_ptr<Type> elem_type_;
size_t size_;
};
// 函数类型
class FunctionType : public Type {
public:
FunctionType(std::shared_ptr<Type> ret, std::vector<std::shared_ptr<Type>> params)
: Type(Type::Kind::Function), ret_type_(std::move(ret)), param_types_(std::move(params)) {}
const std::shared_ptr<Type>& GetReturnType() const { return ret_type_; }
const std::vector<std::shared_ptr<Type>>& GetParamTypes() const { return param_types_; }
private:
std::shared_ptr<Type> ret_type_;
std::vector<std::shared_ptr<Type>> param_types_;
};
// ======================== Value 类 ========================
class Value {
public:
Value(std::shared_ptr<Type> ty, std::string name);
virtual ~Value() = default;
const std::shared_ptr<Type>& GetType() const;
const std::string& GetName() const;
void SetName(std::string n);
bool IsVoid() const;
bool IsInt32() const;
bool IsFloat32() const;
bool IsPtrInt32() const; // 兼容旧接口,实际上判断是否为 i32*
bool IsPtrFloat32() const; // 判断是否为 float32*
bool IsConstant() const;
bool IsInstruction() const;
bool IsUser() const;
bool IsFunction() const;
bool IsGlobalValue() const;
void AddUse(User* user, size_t operand_index);
void RemoveUse(User* user, size_t operand_index);
const std::vector<Use>& GetUses() const;
void ReplaceAllUsesWith(Value* new_value);
protected:
std::shared_ptr<Type> type_;
std::string name_;
std::vector<Use> uses_;
};
// ======================== 常量体系 ========================
class ConstantValue : public Value {
public:
ConstantValue(std::shared_ptr<Type> ty, std::string name = "");
};
class ConstantInt : public ConstantValue {
public:
ConstantInt(std::shared_ptr<Type> ty, int v);
int GetValue() const { return value_; }
private:
int value_{};
};
class ConstantFloat : public ConstantValue {
public:
ConstantFloat(std::shared_ptr<Type> ty, float v);
float GetValue() const { return value_; }
private:
float value_{};
};
// 常量数组(简单聚合,可存储常量元素)
class ConstantArray : public ConstantValue {
public:
ConstantArray(std::shared_ptr<Type> ty, std::vector<ConstantValue*> elems);
const std::vector<ConstantValue*>& GetElements() const { return elements_; }
private:
std::vector<ConstantValue*> elements_;
};
// ======================== User 类 ========================
class User : public Value {
public:
User(std::shared_ptr<Type> ty, std::string name);
size_t GetNumOperands() const;
Value* GetOperand(size_t index) const;
void SetOperand(size_t index, Value* value);
protected:
void AddOperand(Value* value);
private:
std::vector<Value*> operands_;
};
// ======================== GlobalValue 类 ========================
class GlobalValue : public User {
public:
GlobalValue(std::shared_ptr<Type> ty, std::string name);
ConstantValue* GetInitializer() const { return init_; }
void SetInitializer(ConstantValue* init) { init_ = init; }
private:
ConstantValue* init_ = nullptr;
};
// ======================== 指令操作码 ========================
enum class Opcode {
// 算术
Add, Sub, Mul, Div, Mod,
// 位运算
And, Or, Xor, Shl, LShr, AShr,
// 比较
ICmp, FCmp,
// 内存
Alloca, Load, Store,
// 控制流
Ret, Br, CondBr,
// 函数调用
Call,
// 数组访问
GEP,
// Phi
Phi
};
// ======================== Instruction 类 ========================
class Instruction : public User {
public:
Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name = "");
Opcode GetOpcode() const;
bool IsTerminator() const;
BasicBlock* GetParent() const;
void SetParent(BasicBlock* parent);
private:
Opcode opcode_;
BasicBlock* parent_ = nullptr;
};
// 二元运算指令
class BinaryInst : public Instruction {
public:
BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, Value* rhs,
std::string name);
Value* GetLhs() const;
Value* GetRhs() const;
};
// 比较指令icmp/fcmp
class CmpInst : public Instruction {
public:
enum Predicate {
EQ, NE, LT, LE, GT, GE
};
CmpInst(Opcode op, Predicate pred, Value* lhs, Value* rhs, std::string name);
Predicate GetPredicate() const { return pred_; }
Value* GetLhs() const { return lhs_; }
Value* GetRhs() const { return rhs_; }
private:
Predicate pred_;
Value* lhs_;
Value* rhs_;
};
// 返回指令
class ReturnInst : public Instruction {
public:
ReturnInst(std::shared_ptr<Type> void_ty, Value* val = nullptr);
Value* GetValue() const;
};
// 无条件分支
class BranchInst : public Instruction {
public:
BranchInst(BasicBlock* target);
BasicBlock* GetTarget() const;
};
// 条件分支
class CondBranchInst : public Instruction {
public:
CondBranchInst(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb);
Value* GetCond() const;
BasicBlock* GetTrueBlock() const;
BasicBlock* GetFalseBlock() const;
};
// 函数调用
class CallInst : public Instruction {
public:
CallInst(Function* callee, std::vector<Value*> args, std::string name);
Function* GetCallee() const;
const std::vector<Value*>& GetArgs() const;
};
// Phi 指令(用于 SSA
class PhiInst : public Instruction {
public:
PhiInst(std::shared_ptr<Type> ty, std::string name);
void AddIncoming(Value* val, BasicBlock* block);
const std::vector<std::pair<Value*, BasicBlock*>>& GetIncomings() const;
};
// GetElementPtr 指令(数组/结构体指针计算)
class GetElementPtrInst : public Instruction {
public:
GetElementPtrInst(std::shared_ptr<Type> ty, Value* ptr,
std::vector<Value*> indices, std::string name);
Value* GetPtr() const;
const std::vector<Value*>& GetIndices() const;
};
// 分配栈内存指令
class AllocaInst : public Instruction {
public:
AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name);
};
// 加载指令
class LoadInst : public Instruction {
public:
LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name);
Value* GetPtr() const;
};
// 存储指令
class StoreInst : public Instruction {
public:
StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr);
Value* GetValue() const;
Value* GetPtr() const;
};
// ======================== BasicBlock 类 ========================
class BasicBlock : public Value {
public:
explicit BasicBlock(std::string name);
Function* GetParent() const;
void SetParent(Function* parent);
bool HasTerminator() const;
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const;
const std::vector<BasicBlock*>& GetPredecessors() const;
const std::vector<BasicBlock*>& GetSuccessors() const;
void AddPredecessor(BasicBlock* pred);
void AddSuccessor(BasicBlock* succ);
void RemovePredecessor(BasicBlock* pred);
void RemoveSuccessor(BasicBlock* succ);
template <typename T, typename... Args>
T* Append(Args&&... args) {
if (HasTerminator()) {
throw std::runtime_error("BasicBlock 已有 terminator不能继续追加指令: " +
name_);
}
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.push_back(std::move(inst));
return ptr;
}
private:
Function* parent_ = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions_;
std::vector<BasicBlock*> predecessors_;
std::vector<BasicBlock*> successors_;
};
// ======================== Function 类 ========================
class Function : public Value {
public:
// 构造函数,接收函数名、返回类型和参数类型列表(可选)
Function(std::string name, std::shared_ptr<Type> ret_type,
std::vector<std::shared_ptr<Type>> param_types = {});
BasicBlock* CreateBlock(const std::string& name);
BasicBlock* GetEntry();
const BasicBlock* GetEntry() const;
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const;
// 参数管理
const std::vector<Value*>& GetParams() const { return params_; }
void AddParam(Value* param);
// 函数类型(完整签名)
std::shared_ptr<FunctionType> GetFunctionType() const;
private:
BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<BasicBlock>> blocks_;
std::vector<Value*> params_; // 参数值(通常是 Argument 类型,后续可定义)
std::shared_ptr<FunctionType> func_type_; // 缓存函数类型
};
// ======================== Module 类 ========================
class Module {
public:
Module() = default;
Context& GetContext();
const Context& GetContext() const;
// 创建函数,支持参数类型
Function* CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type,
std::vector<std::shared_ptr<Type>> param_types = {});
// 创建全局变量
GlobalValue* CreateGlobalVariable(const std::string& name,
std::shared_ptr<Type> ty,
ConstantValue* init = nullptr);
const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
const std::vector<std::unique_ptr<GlobalValue>>& GetGlobalVariables() const;
private:
Context context_;
std::vector<std::unique_ptr<Function>> functions_;
std::vector<std::unique_ptr<GlobalValue>> global_vars_;
};
// ======================== IRBuilder 类 ========================
class IRBuilder {
public:
IRBuilder(Context& ctx, BasicBlock* bb = nullptr);
void SetInsertPoint(BasicBlock* bb);
BasicBlock* GetInsertBlock() const;
// 常量创建
ConstantInt* CreateConstInt(int v);
ConstantFloat* CreateConstFloat(float v);
// 算术指令
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name);
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateDiv(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateMod(Value* lhs, Value* rhs, const std::string& name);
// 比较指令
CmpInst* CreateICmp(CmpInst::Predicate pred, Value* lhs, Value* rhs,
const std::string& name);
CmpInst* CreateFCmp(CmpInst::Predicate pred, Value* lhs, Value* rhs,
const std::string& name);
// 内存指令
AllocaInst* CreateAllocaI32(const std::string& name);
AllocaInst* CreateAllocaFloat(const std::string& name);
AllocaInst* CreateAlloca(std::shared_ptr<Type> ty, const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr);
// 控制流指令
ReturnInst* CreateRet(Value* v);
BranchInst* CreateBr(BasicBlock* target);
CondBranchInst* CreateCondBr(Value* cond, BasicBlock* true_bb,
BasicBlock* false_bb);
// 函数调用
CallInst* CreateCall(Function* callee, std::vector<Value*> args,
const std::string& name);
// 数组访问
GetElementPtrInst* CreateGEP(Value* ptr, std::vector<Value*> indices,
const std::string& name);
// Phi 指令
PhiInst* CreatePhi(std::shared_ptr<Type> ty, const std::string& name);
private:
Context& ctx_;
BasicBlock* insert_block_;
};
// ======================== IRPrinter 类 ========================
class IRPrinter {
public:
void Print(const Module& module, std::ostream& os);
};
} // namespace ir