You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
5.3 KiB

// 当前仅支持 void、i32 和 i32*。
#include "ir/IR.h"
#include <unordered_map>
#include <functional>
namespace ir {
// 用于缓存复合类型的静态映射(简单实现)
static std::unordered_map<std::size_t, std::shared_ptr<Type>> pointer_cache;
static std::unordered_map<std::size_t, std::shared_ptr<Type>> array_cache;
static std::unordered_map<std::size_t, std::shared_ptr<Type>> function_cache;
// 简单哈希组合函数
static std::size_t hash_combine(std::size_t seed, std::size_t v) {
return seed ^ (v + 0x9e3779b9 + (seed << 6) + (seed >> 2));
}
Type::Type(Kind k) : kind_(k) {}
const std::shared_ptr<Type>& Type::GetVoidType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Void);
return type;
}
const std::shared_ptr<Type>& Type::GetInt32Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int32);
return type;
}
const std::shared_ptr<Type>& Type::GetFloat32Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Float32);
return type;
}
const std::shared_ptr<Type>& Type::GetLabelType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Label);
return type;
}
// 兼容旧的 i32* 类型,返回指向 i32 的指针类型
const std::shared_ptr<Type>& Type::GetPtrInt32Type() {
static const std::shared_ptr<Type> type = GetPointerType(GetInt32Type());
return type;
}
std::shared_ptr<Type> Type::GetPointerType(std::shared_ptr<Type> pointee) {
// 简单缓存:使用 pointee 的地址作为键(实际应使用更可靠的标识,但作为演示足够)
std::size_t key = reinterpret_cast<std::size_t>(pointee.get());
auto it = pointer_cache.find(key);
if (it != pointer_cache.end()) {
return it->second;
}
auto ptr_type = std::make_shared<PointerType>(pointee);
pointer_cache[key] = ptr_type;
return ptr_type;
}
std::shared_ptr<Type> Type::GetArrayType(std::shared_ptr<Type> elem, size_t size) {
// 使用元素类型指针和大小组合哈希
std::size_t seed = 0;
seed = hash_combine(seed, reinterpret_cast<std::size_t>(elem.get()));
seed = hash_combine(seed, size);
auto it = array_cache.find(seed);
if (it != array_cache.end()) {
return it->second;
}
auto arr_type = std::make_shared<ArrayType>(elem, size);
array_cache[seed] = arr_type;
return arr_type;
}
std::shared_ptr<Type> Type::GetFunctionType(std::shared_ptr<Type> ret,
std::vector<std::shared_ptr<Type>> params) {
// 哈希组合:返回类型 + 参数类型列表
std::size_t seed = reinterpret_cast<std::size_t>(ret.get());
for (const auto& p : params) {
seed = hash_combine(seed, reinterpret_cast<std::size_t>(p.get()));
}
auto it = function_cache.find(seed);
if (it != function_cache.end()) {
return it->second;
}
auto func_type = std::make_shared<FunctionType>(ret, std::move(params));
function_cache[seed] = func_type;
return func_type;
}
Type::Kind Type::GetKind() const { return kind_; }
bool Type::IsVoid() const { return kind_ == Kind::Void; }
bool Type::IsInt32() const { return kind_ == Kind::Int32; }
bool Type::IsFloat32() const { return kind_ == Kind::Float32; }
bool Type::IsPointer() const { return kind_ == Kind::Pointer; }
bool Type::IsArray() const { return kind_ == Kind::Array; }
bool Type::IsFunction() const { return kind_ == Kind::Function; }
bool Type::IsLabel() const { return kind_ == Kind::Label; }
// 兼容旧代码,检查是否为 i32* 类型
bool Type::IsPtrInt32() const {
if (!IsPointer()) return false;
const auto* ptr_ty = static_cast<const PointerType*>(this);
return ptr_ty->GetPointeeType()->IsInt32();
}
// 检查是否为 float32* 类型
bool Type::IsPtrFloat32() const {
if (!IsPointer()) return false;
const auto* ptr_ty = static_cast<const PointerType*>(this);
return ptr_ty->GetPointeeType()->IsFloat32();
}
bool Type::operator==(const Type& other) const {
if (kind_ != other.kind_) return false;
switch (kind_) {
case Kind::Void:
case Kind::Int32:
case Kind::Float32:
case Kind::Label:
return true;
case Kind::Pointer: {
const auto* this_ptr = static_cast<const PointerType*>(this);
const auto* other_ptr = static_cast<const PointerType*>(&other);
return *this_ptr->GetPointeeType() == *other_ptr->GetPointeeType();
}
case Kind::Array: {
const auto* this_arr = static_cast<const ArrayType*>(this);
const auto* other_arr = static_cast<const ArrayType*>(&other);
return this_arr->GetSize() == other_arr->GetSize() &&
*this_arr->GetElementType() == *other_arr->GetElementType();
}
case Kind::Function: {
const auto* this_func = static_cast<const FunctionType*>(this);
const auto* other_func = static_cast<const FunctionType*>(&other);
if (*this_func->GetReturnType() != *other_func->GetReturnType()) return false;
const auto& this_params = this_func->GetParamTypes();
const auto& other_params = other_func->GetParamTypes();
if (this_params.size() != other_params.size()) return false;
for (size_t i = 0; i < this_params.size(); ++i) {
if (*this_params[i] != *other_params[i]) return false;
}
return true;
}
default:
return false;
}
}
bool Type::operator!=(const Type& other) const {
return !(*this == other);
}
} // namespace ir