You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/ir/Instruction.cpp

271 lines
9.0 KiB

// IR 指令体系:
// - 二元运算/比较、load/store、call、br/condbr、ret、phi、alloca 等
// - 指令操作数与结果类型管理,支持打印与优化
#include "ir/IR.h"
#include <stdexcept>
#include "utils/Log.h"
namespace ir {
User::User(std::shared_ptr<Type> ty, std::string name)
: Value(std::move(ty), std::move(name)) {}
size_t User::GetNumOperands() const { return operands_.size(); }
Value* User::GetOperand(size_t index) const {
if (index >= operands_.size()) {
throw std::out_of_range("User operand index out of range");
}
return operands_[index];
}
void User::SetOperand(size_t index, Value* value) {
if (index >= operands_.size()) {
throw std::out_of_range("User operand index out of range");
}
if (!value) {
throw std::runtime_error(FormatError("ir", "User operand 不能为空"));
}
auto* old = operands_[index];
if (old == value) {
return;
}
if (old) {
old->RemoveUse(this, index);
}
operands_[index] = value;
value->AddUse(this, index);
}
void User::AddOperand(Value* value) {
if (!value) {
throw std::runtime_error(FormatError("ir", "User operand 不能为空"));
}
size_t operand_index = operands_.size();
operands_.push_back(value);
value->AddUse(this, operand_index);
}
Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)), opcode_(op) {}
Opcode Instruction::GetOpcode() const { return opcode_; }
bool Instruction::IsTerminator() const {
return opcode_ == Opcode::Ret || opcode_ == Opcode::Br || opcode_ == Opcode::CondBr;
}
BasicBlock* Instruction::GetParent() const { return parent_; }
void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; }
BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
Value* rhs, std::string name)
: Instruction(op, std::move(ty), std::move(name)) {
if (op != Opcode::Add && op != Opcode::Sub && op != Opcode::Mul && op != Opcode::Div && op != Opcode::Mod) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持算术操作"));
}
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数"));
}
if (!type_ || !lhs->GetType() || !rhs->GetType()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息"));
}
if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind() ||
type_->GetKind() != lhs->GetType()->GetKind()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配"));
}
if (!type_->IsInt32() && !type_->IsFloat32()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32 和 float"));
}
AddOperand(lhs);
AddOperand(rhs);
}
Value* BinaryInst::GetLhs() const { return GetOperand(0); }
Value* BinaryInst::GetRhs() const { return GetOperand(1); }
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
: Instruction(Opcode::Ret, std::move(void_ty), "") {
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void"));
}
// val may be nullptr to represent a void return; only add operand when
// a returned value is present.
if (val) {
AddOperand(val);
}
}
Value* ReturnInst::GetValue() const { return GetOperand(0); }
AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name)
: Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) {
if (!type_ || !type_->IsPointer()) {
throw std::runtime_error(FormatError("ir", "AllocaInst 类型必须为指针"));
}
}
LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name)
: Instruction(Opcode::Load, std::move(val_ty), std::move(name)) {
if (!ptr) {
throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr"));
}
if (!ptr->GetType() || !ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "LoadInst ptr 必须为指针类型"));
}
const auto* ptr_ty = static_cast<const PointerType*>(ptr->GetType().get());
if (!type_ || *type_ != *ptr_ty->GetPointeeType()) {
throw std::runtime_error(FormatError("ir", "LoadInst 类型不匹配"));
}
AddOperand(ptr);
}
Value* LoadInst::GetPtr() const { return GetOperand(0); }
StoreInst::StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr)
: Instruction(Opcode::Store, std::move(void_ty), "") {
if (!val) {
throw std::runtime_error(FormatError("ir", "StoreInst 缺少 value"));
}
if (!ptr) {
throw std::runtime_error(FormatError("ir", "StoreInst 缺少 ptr"));
}
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void"));
}
if (!ptr->GetType() || !ptr->GetType()->IsPointer()) {
throw std::runtime_error(FormatError("ir", "StoreInst ptr 必须为指针类型"));
}
const auto* ptr_ty = static_cast<const PointerType*>(ptr->GetType().get());
if (!val->GetType() || *val->GetType() != *ptr_ty->GetPointeeType()) {
throw std::runtime_error(FormatError("ir", "StoreInst 类型不匹配"));
}
AddOperand(val);
AddOperand(ptr);
}
Value* StoreInst::GetValue() const { return GetOperand(0); }
Value* StoreInst::GetPtr() const { return GetOperand(1); }
CmpInst::CmpInst(Opcode op, Predicate pred, Value* lhs, Value* rhs, std::string name)
: Instruction(op, Type::GetInt32Type(), std::move(name)), pred_(pred), lhs_(lhs), rhs_(rhs) {
if (op != Opcode::ICmp && op != Opcode::FCmp) {
throw std::runtime_error(FormatError("ir", "CmpInst 仅支持 ICmp 和 FCmp"));
}
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "CmpInst 缺少操作数"));
}
AddOperand(lhs);
AddOperand(rhs);
}
BranchInst::BranchInst(BasicBlock* target)
: Instruction(Opcode::Br, Type::GetVoidType(), "") {
if (!target) {
throw std::runtime_error(FormatError("ir", "BranchInst 缺少目标基本块"));
}
AddOperand(target);
}
BasicBlock* BranchInst::GetTarget() const { return static_cast<BasicBlock*>(GetOperand(0)); }
CondBranchInst::CondBranchInst(Value* cond, BasicBlock* true_bb, BasicBlock* false_bb)
: Instruction(Opcode::CondBr, Type::GetVoidType(), "") {
if (!cond || !true_bb || !false_bb) {
throw std::runtime_error(FormatError("ir", "CondBranchInst 缺少操作数"));
}
AddOperand(cond);
AddOperand(true_bb);
AddOperand(false_bb);
}
Value* CondBranchInst::GetCond() const { return GetOperand(0); }
BasicBlock* CondBranchInst::GetTrueBlock() const { return static_cast<BasicBlock*>(GetOperand(1)); }
BasicBlock* CondBranchInst::GetFalseBlock() const { return static_cast<BasicBlock*>(GetOperand(2)); }
CallInst::CallInst(Function* callee, std::vector<Value*> args, std::string name)
: Instruction(Opcode::Call, callee->GetType(), std::move(name)) {
if (!callee) {
throw std::runtime_error(FormatError("ir", "CallInst 缺少被调用函数"));
}
AddOperand(callee);
for (auto* arg : args) {
if (!arg) {
throw std::runtime_error(FormatError("ir", "CallInst 参数不能为空"));
}
AddOperand(arg);
}
}
PhiInst::PhiInst(std::shared_ptr<Type> ty, std::string name)
: Instruction(Opcode::Phi, std::move(ty), std::move(name)) {}
void PhiInst::AddIncoming(Value* val, BasicBlock* block) {
if (!val || !block) {
throw std::runtime_error(FormatError("ir", "PhiInst AddIncoming 参数不能为空"));
}
AddOperand(val);
AddOperand(block);
}
GetElementPtrInst::GetElementPtrInst(std::shared_ptr<Type> ty, Value* ptr,
std::vector<Value*> indices, std::string name)
: Instruction(Opcode::GEP, std::move(ty), std::move(name)) {
if (!ptr) {
throw std::runtime_error(FormatError("ir", "GetElementPtrInst 缺少指针"));
}
AddOperand(ptr);
for (auto* idx : indices) {
if (!idx) {
throw std::runtime_error(FormatError("ir", "GetElementPtrInst 索引不能为空"));
}
AddOperand(idx);
}
}
Function* CallInst::GetCallee() const {
return static_cast<Function*>(GetOperand(0));
}
const std::vector<Value*>& CallInst::GetArgs() const {
// 返回参数列表(跳过被调用函数)
static std::vector<Value*> args;
args.clear();
for (size_t i = 1; i < GetNumOperands(); ++i) {
args.push_back(GetOperand(i));
}
return args;
}
const std::vector<std::pair<Value*, BasicBlock*>>& PhiInst::GetIncomings() const {
// Phi 指令的操作数是成对的:值和基本块
static std::vector<std::pair<Value*, BasicBlock*>> incomings;
incomings.clear();
for (size_t i = 0; i < GetNumOperands(); i += 2) {
Value* val = GetOperand(i);
BasicBlock* block = static_cast<BasicBlock*>(GetOperand(i + 1));
incomings.emplace_back(val, block);
}
return incomings;
}
Value* GetElementPtrInst::GetPtr() const {
return GetOperand(0);
}
const std::vector<Value*>& GetElementPtrInst::GetIndices() const {
// 返回索引列表(跳过指针)
static std::vector<Value*> indices;
indices.clear();
for (size_t i = 1; i < GetNumOperands(); ++i) {
indices.push_back(GetOperand(i));
}
return indices;
}
} // namespace ir