You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/ir/Instruction.cpp

282 lines
8.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// IR 指令体系:
// - 二元运算/比较、load/store、call、br/condbr、ret、phi、alloca 等
// - 指令操作数与结果类型管理,支持打印与优化
#include "ir/IR.h"
#include <stdexcept>
#include "utils/Log.h"
namespace ir {
User::User(std::shared_ptr<Type> ty, std::string name)
: Value(std::move(ty), std::move(name)) {}
size_t User::GetNumOperands() const { return operands_.size(); }
Value* User::GetOperand(size_t index) const {
if (index >= operands_.size()) {
throw std::out_of_range("User operand index out of range");
}
return operands_[index];
}
void User::SetOperand(size_t index, Value* value) {
if (index >= operands_.size()) {
throw std::out_of_range("User operand index out of range");
}
if (!value) {
throw std::runtime_error(FormatError("ir", "User operand 不能为空"));
}
auto* old = operands_[index];
if (old == value) {
return;
}
if (old) {
old->RemoveUse(this, index);
}
operands_[index] = value;
value->AddUse(this, index);
}
void User::AddOperand(Value* value) {
if (!value) {
throw std::runtime_error(FormatError("ir", "User operand 不能为空"));
}
size_t operand_index = operands_.size();
operands_.push_back(value);
value->AddUse(this, operand_index);
}
Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)), opcode_(op) {}
Opcode Instruction::GetOpcode() const { return opcode_; }
bool Instruction::IsTerminator() const {
return opcode_ == Opcode::Ret || opcode_ == Opcode::Br ||
opcode_ == Opcode::CondBr;
}
BasicBlock* Instruction::GetParent() const { return parent_; }
void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; }
BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
Value* rhs, std::string name)
: Instruction(op, std::move(ty), std::move(name)) {
// 检查操作码是否为有效的二元操作符
switch (op) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Mod:
case Opcode::And:
case Opcode::Or:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
// 有效的二元操作符
break;
case Opcode::Not:
// Not是一元操作符不应该在BinaryInst中
throw std::runtime_error(FormatError("ir", "Not是一元操作符应使用其他指令"));
default:
throw std::runtime_error(FormatError("ir", "BinaryInst 不支持的操作码"));
}
// 当前 BinaryInst 仅支持 Add/Sub/Mul且操作数和结果必须都是 i32。
if (op != Opcode::Add && op != Opcode::Sub && op != Opcode::Mul) {
}
if (!type_ || !lhs->GetType() || !rhs->GetType()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息"));
}
// 对于比较操作结果类型是i1但我们的类型系统可能还没有i1
// 暂时简化所有操作都返回i32比较操作返回0或1
// 检查操作数类型是否匹配
if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 操作数类型不匹配"));
}
bool is_logical = (op == Opcode::And || op == Opcode::Or);
// 检查操作数类型是否支持
if (is_logical) {
if (!lhs->GetType()->IsInt32() && !lhs->GetType()->IsInt1()) {
throw std::runtime_error(
FormatError("ir", "逻辑运算仅支持 i32/i1"));
}
} else {
if (!lhs->GetType()->IsInt32() && !lhs->GetType()->IsFloat()) {
throw std::runtime_error(
FormatError("ir", "BinaryInst 只支持 int32 和 float 类型"));
}
}
if (is_logical) {
// 逻辑运算结果类型应与操作数一致i1 或 i32
if (type_->GetKind() != lhs->GetType()->GetKind()) {
throw std::runtime_error(
FormatError("ir", "逻辑运算结果类型与操作数类型不匹配"));
}
} else {
// 算术运算的结果类型应与操作数类型相同
if (type_->GetKind() != lhs->GetType()->GetKind()) {
throw std::runtime_error(
FormatError("ir", "BinaryInst 结果类型与操作数类型不匹配"));
}
}
AddOperand(lhs);
AddOperand(rhs);
}
Value* BinaryInst::GetLhs() const { return GetOperand(0); }
Value* BinaryInst::GetRhs() const { return GetOperand(1); }
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
: Instruction(Opcode::Ret, std::move(void_ty), "") {
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void"));
}
if (val) {
AddOperand(val);
}
}
Value* ReturnInst::GetValue() const {
if (GetNumOperands() == 0) {
return nullptr;
}
return GetOperand(0);
}
AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name)
: Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) {
if (!type_ ||
(!type_->IsPtrInt32() && !type_->IsPtrFloat() && !type_->IsArray())) {
throw std::runtime_error(
FormatError("ir", "AllocaInst 仅支持 i32* / float* / array"));
}
}
LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name)
: Instruction(Opcode::Load, std::move(val_ty), std::move(name)) {
if (!ptr) {
throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr"));
}
if (!type_ || (!type_->IsInt32() && !type_->IsFloat() && !type_->IsInt1())) {
throw std::runtime_error(
FormatError("ir", "LoadInst 仅支持加载 i32/float/i1"));
}
if (!ptr->GetType() ||
(!ptr->GetType()->IsPtrInt32() && !ptr->GetType()->IsPtrFloat() &&
!ptr->GetType()->IsArray() && !ptr->GetType()->IsPtrInt1())) {
throw std::runtime_error(
FormatError("ir", "LoadInst 仅支持从指针或数组地址加载"));
}
AddOperand(ptr);
}
Value* LoadInst::GetPtr() const { return GetOperand(0); }
StoreInst::StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr)
: Instruction(Opcode::Store, std::move(void_ty), "") {
if (!val) {
throw std::runtime_error(FormatError("ir", "StoreInst 缺少 value"));
}
if (!ptr) {
throw std::runtime_error(FormatError("ir", "StoreInst 缺少 ptr"));
}
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void"));
}
if (!val->GetType() ||
(!val->GetType()->IsInt32() && !val->GetType()->IsFloat() &&
!val->GetType()->IsInt1() && !val->GetType()->IsArray())) {
throw std::runtime_error(
FormatError("ir", "StoreInst 仅支持存储 i32/float/i1/array"));
}
if (!ptr->GetType() ||
(!ptr->GetType()->IsPtrInt32() && !ptr->GetType()->IsPtrFloat() &&
!ptr->GetType()->IsArray() && !ptr->GetType()->IsPtrInt1())) {
throw std::runtime_error(FormatError("ir", "StoreInst 仅支持写入指针或数组地址"));
}
if (ptr->GetType()->IsArray()) {
if (!val->GetType()->IsArray() ||
val->GetType()->GetKind() != ptr->GetType()->GetKind()) {
throw std::runtime_error(
FormatError("ir", "StoreInst 聚合存储要求 value/ptr 具有相同数组类型"));
}
}
AddOperand(val);
AddOperand(ptr);
}
Value* StoreInst::GetValue() const { return GetOperand(0); }
Value* StoreInst::GetPtr() const { return GetOperand(1); }
Function* CallInst::GetCallee() const { return callee_; }
const std::vector<Value*>& CallInst::GetArgs() const { return args_; }
GEPInst::GEPInst(std::shared_ptr<Type> ptr_ty,
Value* base,
const std::vector<Value*>& indices,
const std::string& name)
: Instruction(Opcode::GEP, ptr_ty, name) {
// 添加base作为第一个操作数
AddOperand(base);
// 添加所有索引作为后续操作数
for (auto* index : indices) {
AddOperand(index);
}
}
Value* GEPInst::GetBase() const {
// 第一个操作数是base
return GetOperand(0);
}
const std::vector<Value*>& GEPInst::GetIndices() const {
// 需要返回索引列表但Instruction只存储操作数
// 这是一个设计问题:要么修改架构,要么提供辅助方法
// 简化实现返回空vector或创建临时vector
static std::vector<Value*> indices;
indices.clear();
// 索引从操作数1开始
for (size_t i = 1; i < GetNumOperands(); ++i) {
indices.push_back(GetOperand(i));
}
return indices;
}
CallInst::CallInst(std::shared_ptr<Type> ret_ty, Function* callee,
const std::vector<Value*>& args, const std::string& name)
: Instruction(Opcode::Call, std::move(ret_ty), name), // name 是 const&,这里会复制
callee_(callee), args_(args) {
if (!callee_) {
throw std::runtime_error(FormatError("ir", "CallInst 缺少被调用函数"));
}
for (auto* arg : args_) {
if (!arg) {
throw std::runtime_error(FormatError("ir", "CallInst 参数不能为 null"));
}
AddOperand(arg);
}
}
} // namespace ir