// ir/GlobalValue.cpp #include "ir/IR.h" #include namespace ir { namespace { ConstantValue* GetScalarZeroConstant(const Type& type) { if (type.IsInt32()) { static ConstantInt* zero_i32 = new ConstantInt(Type::GetInt32Type(), 0); return zero_i32; } if (type.IsFloat()) { static ConstantFloat* zero_f32 = new ConstantFloat(Type::GetFloatType(), 0.0f); return zero_f32; } if (type.IsInt1()) { static ConstantInt* zero_i1 = new ConstantInt(Type::GetInt1Type(), 0); return zero_i1; } return nullptr; } } // namespace GlobalValue::GlobalValue(std::shared_ptr ty, std::string name) : User(std::move(ty), std::move(name)) {} void GlobalValue::SetInitializer(ConstantValue* init) { if (!init) { throw std::runtime_error("GlobalValue::SetInitializer: init is null"); } // 获取实际的值类型(用于类型检查) std::shared_ptr value_type = GetValueType(); // 类型检查 bool type_match = CheckTypeCompatibility(value_type, init); if (!type_match) { throw std::runtime_error("GlobalValue::SetInitializer: type mismatch"); } initializer_.clear(); initializer_.push_back(init); } void GlobalValue::SetInitializer(const std::vector& init) { if (init.empty()) { initializer_.clear(); return; } // 获取实际的值类型 std::shared_ptr value_type = GetValueType(); // 类型检查 if (value_type->IsArray()) { auto* array_ty = static_cast(value_type.get()); size_t array_size = array_ty->GetElementCount(); if (init.size() > array_size) { throw std::runtime_error("GlobalValue::SetInitializer: too many initializers"); } // 检查每个初始化值的类型 auto* elem_type = array_ty->GetElementType().get(); for (size_t i = 0; i < init.size(); ++i) { auto* elem = init[i]; if (!elem) { throw std::runtime_error("GlobalValue::SetInitializer: null initializer at index " + std::to_string(i)); } bool elem_match = false; if (elem_type->IsInt32() && elem->GetType()->IsInt32()) { elem_match = true; } else if (elem_type->IsFloat() && elem->GetType()->IsFloat()) { elem_match = true; } else if (elem_type->IsInt1() && elem->GetType()->IsInt1()) { elem_match = true; } if (!elem_match) { throw std::runtime_error("GlobalValue::SetInitializer: element type mismatch at index " + std::to_string(i)); } } } else if (value_type->IsInt32() || value_type->IsFloat() || value_type->IsInt1()) { if (init.size() != 1) { throw std::runtime_error("GlobalValue::SetInitializer: scalar requires exactly one initializer"); } if (!init[0]) { throw std::runtime_error("GlobalValue::SetInitializer: null initializer"); } if ((value_type->IsInt32() && !init[0]->GetType()->IsInt32()) || (value_type->IsFloat() && !init[0]->GetType()->IsFloat()) || (value_type->IsInt1() && !init[0]->GetType()->IsInt1())) { throw std::runtime_error("GlobalValue::SetInitializer: type mismatch"); } } else { throw std::runtime_error("GlobalValue::SetInitializer: unsupported type"); } initializer_ = init; } // 辅助方法:获取实际的值类型(处理指针包装) std::shared_ptr GlobalValue::GetValueType() const { if (GetType()->IsPtrInt32()) { return Type::GetInt32Type(); } else if (GetType()->IsPtrFloat()) { return Type::GetFloatType(); } else if (GetType()->IsPtrInt1()) { return Type::GetInt1Type(); } return GetType(); } // 辅助方法:检查类型兼容性 bool GlobalValue::CheckTypeCompatibility(std::shared_ptr value_type, ConstantValue* init) const { // 检查标量类型 if (value_type->IsInt32() && init->GetType()->IsInt32()) { return true; } else if (value_type->IsFloat() && init->GetType()->IsFloat()) { return true; } else if (value_type->IsInt1() && init->GetType()->IsInt1()) { return true; } // 检查数组类型:允许用单个标量初始化整个数组 else if (value_type->IsArray()) { auto* array_ty = static_cast(value_type.get()); auto* elem_type = array_ty->GetElementType().get(); if (elem_type->IsInt32() && init->GetType()->IsInt32()) { return true; } else if (elem_type->IsFloat() && init->GetType()->IsFloat()) { return true; } else if (elem_type->IsInt1() && init->GetType()->IsInt1()) { return true; } // 也可以允许 ConstantArray 作为初始化器 else if (init->GetType()->IsArray()) { auto* init_array = static_cast(init); return init_array->IsValid(); } } // 检查指针类型(用于数组参数) else if (value_type->IsPtrInt32() && init->GetType()->IsInt32()) { return true; } else if (value_type->IsPtrFloat() && init->GetType()->IsFloat()) { return true; } return false; } // 添加获取数组元素的便捷方法 ConstantValue* GlobalValue::GetArrayElement(size_t index) const { if (!GetType()->IsArray()) { return nullptr; } auto* array_ty = dynamic_cast(GetType().get()); if (!array_ty) { return nullptr; } if (index >= static_cast(array_ty->GetElementCount())) { return nullptr; } if (index >= initializer_.size()) { return GetScalarZeroConstant(*array_ty->GetElementType()); } return initializer_[index]; } // 添加获取数组元素数量的方法 size_t GlobalValue::GetArraySize() const { if (!IsArrayConstant()) { return 0; } return initializer_.size(); } // 添加判断是否为数组常量的方法 bool GlobalValue::IsArrayConstant() const { return GetType()->IsArray() && !initializer_.empty(); } } // namespace ir