You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

533 lines
19 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// IR 中间表示类型系统、Value 体系、指令集、基本块、函数、模块、IRBuilder。
//
// 已实现:
// 1. 类型系统void / i1 / i32 / i32*
// 2. Value 体系ConstantInt / Function / BasicBlock / GlobalVariable / Instruction
// 3. 指令集Add/Sub/Mul/Div/Mod/ICmp/Alloca/Load/Store/Ret/Br/CondBr/Call/ZExt
// 4. 全局变量 / 外部函数声明
// 5. IRBuilder 便捷接口
// 6. use-def 关系
#pragma once
#include <iosfwd>
#include <memory>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace ir {
class Type;
class Value;
class User;
class ConstantValue;
class ConstantInt;
class ConstantFloat;
class GlobalVariable;
class Instruction;
class BasicBlock;
class Function;
class Module;
// ─── Use ──────────────────────────────────────────────────────────────────────
class Use {
public:
Use() = default;
Use(Value* value, User* user, size_t operand_index)
: value_(value), user_(user), operand_index_(operand_index) {}
Value* GetValue() const { return value_; }
User* GetUser() const { return user_; }
size_t GetOperandIndex() const { return operand_index_; }
void SetValue(Value* value) { value_ = value; }
void SetUser(User* user) { user_ = user; }
void SetOperandIndex(size_t operand_index) { operand_index_ = operand_index; }
private:
Value* value_ = nullptr;
User* user_ = nullptr;
size_t operand_index_ = 0;
};
// ─── Context ──────────────────────────────────────────────────────────────────
class Context {
public:
Context() = default;
~Context();
ConstantInt* GetConstInt(int v);
ConstantFloat* GetConstFloat(float v);
std::string NextTemp();
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<float, std::unique_ptr<ConstantFloat>> const_floats_;
int temp_index_ = -1;
};
// ─── Type ─────────────────────────────────────────────────────────────────────
class Type {
public:
enum class Kind { Void, Int1, Int32, Float32, PtrInt32, PtrFloat32 };
explicit Type(Kind k);
static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt1Type();
static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetFloat32Type();
static const std::shared_ptr<Type>& GetPtrInt32Type();
static const std::shared_ptr<Type>& GetPtrFloat32Type();
Kind GetKind() const;
bool IsVoid() const;
bool IsInt1() const;
bool IsInt32() const;
bool IsFloat32() const;
bool IsPtrInt32() const;
bool IsPtrFloat32() const;
private:
Kind kind_;
};
// ─── Value ────────────────────────────────────────────────────────────────────
class Value {
public:
Value(std::shared_ptr<Type> ty, std::string name);
virtual ~Value() = default;
const std::shared_ptr<Type>& GetType() const;
const std::string& GetName() const;
void SetName(std::string n);
bool IsVoid() const;
bool IsInt1() const;
bool IsInt32() const;
bool IsFloat32() const;
bool IsPtrInt32() const;
bool IsPtrFloat32() const;
bool IsConstant() const;
bool IsInstruction() const;
bool IsUser() const;
bool IsFunction() const;
bool IsBasicBlock() const;
bool IsGlobalVariable() const;
void AddUse(User* user, size_t operand_index);
void RemoveUse(User* user, size_t operand_index);
const std::vector<Use>& GetUses() const;
void ReplaceAllUsesWith(Value* new_value);
protected:
std::shared_ptr<Type> type_;
std::string name_;
std::vector<Use> uses_;
};
// ─── ConstantValue ─────────────────────────────────────────────────────────────
class ConstantValue : public Value {
public:
ConstantValue(std::shared_ptr<Type> ty, std::string name = "");
};
class ConstantInt : public ConstantValue {
public:
ConstantInt(std::shared_ptr<Type> ty, int v);
int GetValue() const { return value_; }
private:
int value_{};
};
class ConstantFloat : public ConstantValue {
public:
ConstantFloat(std::shared_ptr<Type> ty, float v);
float GetValue() const { return value_; }
private:
float value_{};
};
// ─── Opcode ───────────────────────────────────────────────────────────────────
enum class Opcode {
// 整数算术
Add, Sub, Mul, Div, Mod,
// 浮点算术
FAdd, FSub, FMul, FDiv,
// 比较(结果为 i1
ICmp, FCmp,
// 内存
Alloca, Load, Store,
// 地址计算
Gep,
// 控制流
Ret, Br, CondBr,
// 函数调用
Call,
// 类型转换
ZExt, SIToFP, FPToSI,
};
// ICmp 谓词
enum class ICmpPredicate { EQ, NE, SLT, SLE, SGT, SGE };
// FCmp 谓词
enum class FCmpPredicate { OEQ, ONE, OLT, OLE, OGT, OGE };
// ─── User ─────────────────────────────────────────────────────────────────────
class User : public Value {
public:
User(std::shared_ptr<Type> ty, std::string name);
size_t GetNumOperands() const;
Value* GetOperand(size_t index) const;
void SetOperand(size_t index, Value* value);
protected:
void AddOperand(Value* value);
private:
std::vector<Value*> operands_;
};
// ─── GlobalValue (占位) ────────────────────────────────────────────────────────
class GlobalValue : public User {
public:
GlobalValue(std::shared_ptr<Type> ty, std::string name);
};
// ─── GlobalVariable ────────────────────────────────────────────────────────────
// 表示全局整型变量或常量。类型为 i32*(可直接用于 load/store
class GlobalVariable : public Value {
public:
GlobalVariable(std::string name, bool is_const, int init_val,
int num_elements = 1);
bool IsConst() const { return is_const_; }
int GetInitVal() const { return init_val_; }
int GetNumElements() const { return num_elements_; }
bool IsArray() const { return num_elements_ > 1; }
// GlobalVariable 的"指针类型"是 i32*,访问时使用 load/store
private:
bool is_const_;
int init_val_;
int num_elements_;
};
// ─── Instruction ──────────────────────────────────────────────────────────────
class Instruction : public User {
public:
Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name = "");
Opcode GetOpcode() const;
bool IsTerminator() const;
BasicBlock* GetParent() const;
void SetParent(BasicBlock* parent);
private:
Opcode opcode_;
BasicBlock* parent_ = nullptr;
};
// 二元算术指令i32 × i32 → i32
class BinaryInst : public Instruction {
public:
BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, Value* rhs,
std::string name);
Value* GetLhs() const;
Value* GetRhs() const;
};
// 整数比较指令i32 × i32 → i1
class ICmpInst : public Instruction {
public:
ICmpInst(ICmpPredicate pred, Value* lhs, Value* rhs, std::string name);
ICmpPredicate GetPredicate() const { return pred_; }
Value* GetLhs() const;
Value* GetRhs() const;
private:
ICmpPredicate pred_;
};
// 无条件跳转
class BrInst : public Instruction {
public:
explicit BrInst(BasicBlock* target);
BasicBlock* GetTarget() const;
};
// 条件跳转
class CondBrInst : public Instruction {
public:
CondBrInst(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb);
Value* GetCond() const;
BasicBlock* GetTrueBB() const;
BasicBlock* GetFalseBB() const;
};
// 函数调用
// callee 为 nullptr 时表示外部函数,使用 callee_name_
class CallInst : public Instruction {
public:
// 调用已知 Function模块内定义
CallInst(Function* callee, std::vector<Value*> args, std::string name);
// 调用外部声明函数(名称 + 返回类型)
CallInst(std::string callee_name, std::shared_ptr<Type> ret_type,
std::vector<Value*> args, std::string name);
bool IsExternal() const { return callee_ == nullptr; }
Function* GetCallee() const { return callee_; }
const std::string& GetCalleeName() const { return callee_name_; }
size_t GetNumArgs() const { return GetNumOperands(); }
Value* GetArg(size_t i) const { return GetOperand(i); }
private:
Function* callee_ = nullptr;
std::string callee_name_;
};
// 零扩展i1 → i32
class ZExtInst : public Instruction {
public:
ZExtInst(Value* val, std::string name);
Value* GetSrc() const;
};
// 浮点比较指令f32 × f32 → i1
class FCmpInst : public Instruction {
public:
FCmpInst(FCmpPredicate pred, Value* lhs, Value* rhs, std::string name);
FCmpPredicate GetPredicate() const { return pred_; }
Value* GetLhs() const;
Value* GetRhs() const;
private:
FCmpPredicate pred_;
};
// 有符号整数转浮点i32 → f32
class SIToFPInst : public Instruction {
public:
SIToFPInst(Value* val, std::string name);
Value* GetSrc() const;
};
// 浮点转有符号整数f32 → i32
class FPToSIInst : public Instruction {
public:
FPToSIInst(Value* val, std::string name);
Value* GetSrc() const;
};
// return 语句val 为 nullptr 表示 void return
class ReturnInst : public Instruction {
public:
// 有返回值
explicit ReturnInst(Value* val);
// void 返回
ReturnInst();
bool HasValue() const { return GetNumOperands() > 0; }
Value* GetValue() const; // 可能为 nullptr
};
class AllocaInst : public Instruction {
public:
// 标量 allocanum_elements == 1
AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name);
// 数组 allocanum_elements > 1
AllocaInst(std::shared_ptr<Type> ptr_ty, int num_elements, std::string name);
int GetNumElements() const { return num_elements_; }
bool IsArray() const { return num_elements_ > 1; }
private:
int num_elements_ = 1;
};
// GetElementPtr: ptr + index → i32*
class GepInst : public Instruction {
public:
GepInst(Value* base_ptr, Value* index, std::string name);
Value* GetBasePtr() const;
Value* GetIndex() const;
};
class LoadInst : public Instruction {
public:
LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name);
Value* GetPtr() const;
};
class StoreInst : public Instruction {
public:
StoreInst(Value* val, Value* ptr);
Value* GetValue() const;
Value* GetPtr() const;
};
// ─── BasicBlock ───────────────────────────────────────────────────────────────
class BasicBlock : public Value {
public:
explicit BasicBlock(std::string name);
Function* GetParent() const;
void SetParent(Function* parent);
bool HasTerminator() const;
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const;
const std::vector<BasicBlock*>& GetPredecessors() const;
const std::vector<BasicBlock*>& GetSuccessors() const;
template <typename T, typename... Args>
T* Append(Args&&... args) {
if (HasTerminator()) {
throw std::runtime_error("BasicBlock 已有 terminator: " + name_);
}
auto inst = std::make_unique<T>(std::forward<Args>(args)...);
auto* ptr = inst.get();
ptr->SetParent(this);
instructions_.push_back(std::move(inst));
return ptr;
}
private:
Function* parent_ = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions_;
std::vector<BasicBlock*> predecessors_;
std::vector<BasicBlock*> successors_;
};
// ─── Argument ─────────────────────────────────────────────────────────────────
// 函数形式参数,作为 SSA 值可用于 store 等指令
class Argument : public Value {
public:
Argument(std::shared_ptr<Type> ty, std::string name)
: Value(std::move(ty), std::move(name)) {}
};
// ─── Function ─────────────────────────────────────────────────────────────────
class Function : public Value {
public:
Function(std::string name, std::shared_ptr<Type> ret_type);
BasicBlock* CreateBlock(const std::string& name);
BasicBlock* GetEntry();
const BasicBlock* GetEntry() const;
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const;
// 参数值(按顺序,与 GetParamNames/GetParamTypes 一一对应)
Argument* AddArgument(std::shared_ptr<Type> ty, const std::string& name);
Argument* GetArgument(size_t i) const;
size_t GetNumArgs() const { return args_.size(); }
bool IsVoidReturn() const { return type_->IsVoid(); }
private:
BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<BasicBlock>> blocks_;
std::vector<std::unique_ptr<Argument>> args_;
};
// ─── ExternalFuncDecl ─────────────────────────────────────────────────────────
struct ExternalFuncDecl {
std::string name;
std::shared_ptr<Type> ret_type;
std::vector<std::shared_ptr<Type>> param_types;
bool is_variadic = false;
};
// ─── Module ───────────────────────────────────────────────────────────────────
class Module {
public:
Module() = default;
Context& GetContext();
const Context& GetContext() const;
Function* CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type);
Function* GetFunction(const std::string& name) const;
const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
GlobalVariable* CreateGlobalVariable(const std::string& name, bool is_const,
int init_val, int num_elements = 1);
GlobalVariable* GetGlobalVariable(const std::string& name) const;
const std::vector<std::unique_ptr<GlobalVariable>>& GetGlobalVariables() const;
void DeclareExternalFunc(const std::string& name,
std::shared_ptr<Type> ret_type,
std::vector<std::shared_ptr<Type>> param_types,
bool is_variadic = false);
bool HasExternalDecl(const std::string& name) const;
const std::vector<ExternalFuncDecl>& GetExternalDecls() const;
private:
Context context_;
std::vector<std::unique_ptr<Function>> functions_;
std::unordered_map<std::string, Function*> func_map_;
std::vector<std::unique_ptr<GlobalVariable>> globals_;
std::unordered_map<std::string, GlobalVariable*> global_map_;
std::vector<ExternalFuncDecl> external_decls_;
std::unordered_map<std::string, size_t> external_decl_index_;
};
// ─── IRBuilder ────────────────────────────────────────────────────────────────
class IRBuilder {
public:
IRBuilder(Context& ctx, BasicBlock* bb);
void SetInsertPoint(BasicBlock* bb);
BasicBlock* GetInsertBlock() const;
// 常量
ConstantInt* CreateConstInt(int v);
ConstantFloat* CreateConstFloat(float v);
// 整数算术
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name);
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateDiv(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateMod(Value* lhs, Value* rhs, const std::string& name);
// 浮点算术
BinaryInst* CreateFAdd(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFSub(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFMul(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFDiv(Value* lhs, Value* rhs, const std::string& name);
// 比较(返回 i1
ICmpInst* CreateICmp(ICmpPredicate pred, Value* lhs, Value* rhs,
const std::string& name);
FCmpInst* CreateFCmp(FCmpPredicate pred, Value* lhs, Value* rhs,
const std::string& name);
// 内存
AllocaInst* CreateAllocaI32(const std::string& name);
AllocaInst* CreateAllocaF32(const std::string& name);
AllocaInst* CreateAllocaArray(int num_elements, const std::string& name);
GepInst* CreateGep(Value* base_ptr, Value* index, const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr);
// 控制流
ReturnInst* CreateRet(Value* v);
ReturnInst* CreateRetVoid();
BrInst* CreateBr(BasicBlock* target);
CondBrInst* CreateCondBr(Value* cond, BasicBlock* true_bb,
BasicBlock* false_bb);
// 调用
CallInst* CreateCall(Function* callee, std::vector<Value*> args,
const std::string& name = "");
CallInst* CreateCallExternal(const std::string& callee_name,
std::shared_ptr<Type> ret_type,
std::vector<Value*> args,
const std::string& name = "");
// 类型转换
ZExtInst* CreateZExt(Value* val, const std::string& name);
SIToFPInst* CreateSIToFP(Value* val, const std::string& name);
FPToSIInst* CreateFPToSI(Value* val, const std::string& name);
private:
Context& ctx_;
BasicBlock* insert_block_;
};
// ─── IRPrinter ────────────────────────────────────────────────────────────────
class IRPrinter {
public:
void Print(const Module& module, std::ostream& os);
};
} // namespace ir