forked from NUDT-compiler/nudt-compiler-cpp
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
195 lines
5.8 KiB
195 lines
5.8 KiB
// ir/GlobalValue.cpp
|
|
#include "ir/IR.h"
|
|
|
|
#include <stdexcept>
|
|
|
|
namespace ir {
|
|
|
|
namespace {
|
|
|
|
ConstantValue* GetScalarZeroConstant(const Type& type) {
|
|
if (type.IsInt32()) {
|
|
static ConstantInt* zero_i32 = new ConstantInt(Type::GetInt32Type(), 0);
|
|
return zero_i32;
|
|
}
|
|
if (type.IsFloat()) {
|
|
static ConstantFloat* zero_f32 = new ConstantFloat(Type::GetFloatType(), 0.0f);
|
|
return zero_f32;
|
|
}
|
|
if (type.IsInt1()) {
|
|
static ConstantInt* zero_i1 = new ConstantInt(Type::GetInt1Type(), 0);
|
|
return zero_i1;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
|
|
: User(std::move(ty), std::move(name)) {}
|
|
|
|
void GlobalValue::SetInitializer(ConstantValue* init) {
|
|
if (!init) {
|
|
throw std::runtime_error("GlobalValue::SetInitializer: init is null");
|
|
}
|
|
|
|
// 获取实际的值类型(用于类型检查)
|
|
std::shared_ptr<Type> value_type = GetValueType();
|
|
|
|
// 类型检查
|
|
bool type_match = CheckTypeCompatibility(value_type, init);
|
|
|
|
if (!type_match) {
|
|
throw std::runtime_error("GlobalValue::SetInitializer: type mismatch");
|
|
}
|
|
|
|
initializer_.clear();
|
|
initializer_.push_back(init);
|
|
}
|
|
|
|
void GlobalValue::SetInitializer(const std::vector<ConstantValue*>& init) {
|
|
if (init.empty()) {
|
|
initializer_.clear();
|
|
return;
|
|
}
|
|
|
|
// 获取实际的值类型
|
|
std::shared_ptr<Type> value_type = GetValueType();
|
|
|
|
// 类型检查
|
|
if (value_type->IsArray()) {
|
|
auto* array_ty = static_cast<ArrayType*>(value_type.get());
|
|
size_t array_size = array_ty->GetElementCount();
|
|
|
|
if (init.size() > array_size) {
|
|
throw std::runtime_error("GlobalValue::SetInitializer: too many initializers");
|
|
}
|
|
|
|
// 检查每个初始化值的类型
|
|
auto* elem_type = array_ty->GetElementType().get();
|
|
for (size_t i = 0; i < init.size(); ++i) {
|
|
auto* elem = init[i];
|
|
if (!elem) {
|
|
throw std::runtime_error("GlobalValue::SetInitializer: null initializer at index " + std::to_string(i));
|
|
}
|
|
|
|
bool elem_match = false;
|
|
if (elem_type->IsInt32() && elem->GetType()->IsInt32()) {
|
|
elem_match = true;
|
|
} else if (elem_type->IsFloat() && elem->GetType()->IsFloat()) {
|
|
elem_match = true;
|
|
} else if (elem_type->IsInt1() && elem->GetType()->IsInt1()) {
|
|
elem_match = true;
|
|
}
|
|
|
|
if (!elem_match) {
|
|
throw std::runtime_error("GlobalValue::SetInitializer: element type mismatch at index " + std::to_string(i));
|
|
}
|
|
}
|
|
}
|
|
else if (value_type->IsInt32() || value_type->IsFloat() || value_type->IsInt1()) {
|
|
if (init.size() != 1) {
|
|
throw std::runtime_error("GlobalValue::SetInitializer: scalar requires exactly one initializer");
|
|
}
|
|
|
|
if (!init[0]) {
|
|
throw std::runtime_error("GlobalValue::SetInitializer: null initializer");
|
|
}
|
|
|
|
if ((value_type->IsInt32() && !init[0]->GetType()->IsInt32()) ||
|
|
(value_type->IsFloat() && !init[0]->GetType()->IsFloat()) ||
|
|
(value_type->IsInt1() && !init[0]->GetType()->IsInt1())) {
|
|
throw std::runtime_error("GlobalValue::SetInitializer: type mismatch");
|
|
}
|
|
}
|
|
else {
|
|
throw std::runtime_error("GlobalValue::SetInitializer: unsupported type");
|
|
}
|
|
|
|
initializer_ = init;
|
|
}
|
|
|
|
// 辅助方法:获取实际的值类型(处理指针包装)
|
|
std::shared_ptr<Type> GlobalValue::GetValueType() const {
|
|
if (GetType()->IsPtrInt32()) {
|
|
return Type::GetInt32Type();
|
|
} else if (GetType()->IsPtrFloat()) {
|
|
return Type::GetFloatType();
|
|
} else if (GetType()->IsPtrInt1()) {
|
|
return Type::GetInt1Type();
|
|
}
|
|
return GetType();
|
|
}
|
|
|
|
// 辅助方法:检查类型兼容性
|
|
bool GlobalValue::CheckTypeCompatibility(std::shared_ptr<Type> value_type,
|
|
ConstantValue* init) const {
|
|
// 检查标量类型
|
|
if (value_type->IsInt32() && init->GetType()->IsInt32()) {
|
|
return true;
|
|
} else if (value_type->IsFloat() && init->GetType()->IsFloat()) {
|
|
return true;
|
|
} else if (value_type->IsInt1() && init->GetType()->IsInt1()) {
|
|
return true;
|
|
}
|
|
// 检查数组类型:允许用单个标量初始化整个数组
|
|
else if (value_type->IsArray()) {
|
|
auto* array_ty = static_cast<ArrayType*>(value_type.get());
|
|
auto* elem_type = array_ty->GetElementType().get();
|
|
|
|
if (elem_type->IsInt32() && init->GetType()->IsInt32()) {
|
|
return true;
|
|
} else if (elem_type->IsFloat() && init->GetType()->IsFloat()) {
|
|
return true;
|
|
} else if (elem_type->IsInt1() && init->GetType()->IsInt1()) {
|
|
return true;
|
|
}
|
|
// 也可以允许 ConstantArray 作为初始化器
|
|
else if (init->GetType()->IsArray()) {
|
|
auto* init_array = static_cast<ConstantArray*>(init);
|
|
return init_array->IsValid();
|
|
}
|
|
}
|
|
// 检查指针类型(用于数组参数)
|
|
else if (value_type->IsPtrInt32() && init->GetType()->IsInt32()) {
|
|
return true;
|
|
} else if (value_type->IsPtrFloat() && init->GetType()->IsFloat()) {
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
// 添加获取数组元素的便捷方法
|
|
ConstantValue* GlobalValue::GetArrayElement(size_t index) const {
|
|
if (!GetType()->IsArray()) {
|
|
return nullptr;
|
|
}
|
|
|
|
auto* array_ty = dynamic_cast<ArrayType*>(GetType().get());
|
|
if (!array_ty) {
|
|
return nullptr;
|
|
}
|
|
if (index >= static_cast<size_t>(array_ty->GetElementCount())) {
|
|
return nullptr;
|
|
}
|
|
if (index >= initializer_.size()) {
|
|
return GetScalarZeroConstant(*array_ty->GetElementType());
|
|
}
|
|
return initializer_[index];
|
|
}
|
|
|
|
// 添加获取数组元素数量的方法
|
|
size_t GlobalValue::GetArraySize() const {
|
|
if (!IsArrayConstant()) {
|
|
return 0;
|
|
}
|
|
return initializer_.size();
|
|
}
|
|
|
|
// 添加判断是否为数组常量的方法
|
|
bool GlobalValue::IsArrayConstant() const {
|
|
return GetType()->IsArray() && !initializer_.empty();
|
|
}
|
|
|
|
} // namespace ir
|