You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/ir/GlobalValue.cpp

195 lines
5.8 KiB

// ir/GlobalValue.cpp
#include "ir/IR.h"
#include <stdexcept>
namespace ir {
namespace {
ConstantValue* GetScalarZeroConstant(const Type& type) {
if (type.IsInt32()) {
static ConstantInt* zero_i32 = new ConstantInt(Type::GetInt32Type(), 0);
return zero_i32;
}
if (type.IsFloat()) {
static ConstantFloat* zero_f32 = new ConstantFloat(Type::GetFloatType(), 0.0f);
return zero_f32;
}
if (type.IsInt1()) {
static ConstantInt* zero_i1 = new ConstantInt(Type::GetInt1Type(), 0);
return zero_i1;
}
return nullptr;
}
} // namespace
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)) {}
void GlobalValue::SetInitializer(ConstantValue* init) {
if (!init) {
throw std::runtime_error("GlobalValue::SetInitializer: init is null");
}
// 获取实际的值类型(用于类型检查)
std::shared_ptr<Type> value_type = GetValueType();
// 类型检查
bool type_match = CheckTypeCompatibility(value_type, init);
if (!type_match) {
throw std::runtime_error("GlobalValue::SetInitializer: type mismatch");
}
initializer_.clear();
initializer_.push_back(init);
}
void GlobalValue::SetInitializer(const std::vector<ConstantValue*>& init) {
if (init.empty()) {
initializer_.clear();
return;
}
// 获取实际的值类型
std::shared_ptr<Type> value_type = GetValueType();
// 类型检查
if (value_type->IsArray()) {
auto* array_ty = static_cast<ArrayType*>(value_type.get());
size_t array_size = array_ty->GetElementCount();
if (init.size() > array_size) {
throw std::runtime_error("GlobalValue::SetInitializer: too many initializers");
}
// 检查每个初始化值的类型
auto* elem_type = array_ty->GetElementType().get();
for (size_t i = 0; i < init.size(); ++i) {
auto* elem = init[i];
if (!elem) {
throw std::runtime_error("GlobalValue::SetInitializer: null initializer at index " + std::to_string(i));
}
bool elem_match = false;
if (elem_type->IsInt32() && elem->GetType()->IsInt32()) {
elem_match = true;
} else if (elem_type->IsFloat() && elem->GetType()->IsFloat()) {
elem_match = true;
} else if (elem_type->IsInt1() && elem->GetType()->IsInt1()) {
elem_match = true;
}
if (!elem_match) {
throw std::runtime_error("GlobalValue::SetInitializer: element type mismatch at index " + std::to_string(i));
}
}
}
else if (value_type->IsInt32() || value_type->IsFloat() || value_type->IsInt1()) {
if (init.size() != 1) {
throw std::runtime_error("GlobalValue::SetInitializer: scalar requires exactly one initializer");
}
if (!init[0]) {
throw std::runtime_error("GlobalValue::SetInitializer: null initializer");
}
if ((value_type->IsInt32() && !init[0]->GetType()->IsInt32()) ||
(value_type->IsFloat() && !init[0]->GetType()->IsFloat()) ||
(value_type->IsInt1() && !init[0]->GetType()->IsInt1())) {
throw std::runtime_error("GlobalValue::SetInitializer: type mismatch");
}
}
else {
throw std::runtime_error("GlobalValue::SetInitializer: unsupported type");
}
initializer_ = init;
}
// 辅助方法:获取实际的值类型(处理指针包装)
std::shared_ptr<Type> GlobalValue::GetValueType() const {
if (GetType()->IsPtrInt32()) {
return Type::GetInt32Type();
} else if (GetType()->IsPtrFloat()) {
return Type::GetFloatType();
} else if (GetType()->IsPtrInt1()) {
return Type::GetInt1Type();
}
return GetType();
}
// 辅助方法:检查类型兼容性
bool GlobalValue::CheckTypeCompatibility(std::shared_ptr<Type> value_type,
ConstantValue* init) const {
// 检查标量类型
if (value_type->IsInt32() && init->GetType()->IsInt32()) {
return true;
} else if (value_type->IsFloat() && init->GetType()->IsFloat()) {
return true;
} else if (value_type->IsInt1() && init->GetType()->IsInt1()) {
return true;
}
// 检查数组类型:允许用单个标量初始化整个数组
else if (value_type->IsArray()) {
auto* array_ty = static_cast<ArrayType*>(value_type.get());
auto* elem_type = array_ty->GetElementType().get();
if (elem_type->IsInt32() && init->GetType()->IsInt32()) {
return true;
} else if (elem_type->IsFloat() && init->GetType()->IsFloat()) {
return true;
} else if (elem_type->IsInt1() && init->GetType()->IsInt1()) {
return true;
}
// 也可以允许 ConstantArray 作为初始化器
else if (init->GetType()->IsArray()) {
auto* init_array = static_cast<ConstantArray*>(init);
return init_array->IsValid();
}
}
// 检查指针类型(用于数组参数)
else if (value_type->IsPtrInt32() && init->GetType()->IsInt32()) {
return true;
} else if (value_type->IsPtrFloat() && init->GetType()->IsFloat()) {
return true;
}
return false;
}
// 添加获取数组元素的便捷方法
ConstantValue* GlobalValue::GetArrayElement(size_t index) const {
if (!GetType()->IsArray()) {
return nullptr;
}
auto* array_ty = dynamic_cast<ArrayType*>(GetType().get());
if (!array_ty) {
return nullptr;
}
if (index >= static_cast<size_t>(array_ty->GetElementCount())) {
return nullptr;
}
if (index >= initializer_.size()) {
return GetScalarZeroConstant(*array_ty->GetElementType());
}
return initializer_[index];
}
// 添加获取数组元素数量的方法
size_t GlobalValue::GetArraySize() const {
if (!IsArrayConstant()) {
return 0;
}
return initializer_.size();
}
// 添加判断是否为数组常量的方法
bool GlobalValue::IsArrayConstant() const {
return GetType()->IsArray() && !initializer_.empty();
}
} // namespace ir