You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

127 lines
3.6 KiB

// ir/IR.cpp
#include "ir/IR.h"
#include <cstring>
#include <sstream>
#include <functional>
namespace ir {
Context::~Context() = default;
ConstantInt* Context::GetConstInt(int v) {
auto it = const_ints_.find(v);
if (it != const_ints_.end()) return it->second.get();
auto inserted =
const_ints_.emplace(v, std::make_unique<ConstantInt>(Type::GetInt32Type(), v)).first;
return inserted->second.get();
}
ConstantFloat* Context::GetConstFloat(float v) {
uint32_t key;
std::memcpy(&key, &v, sizeof(float));
auto it = const_floats_.find(key);
if (it != const_floats_.end()) {
return it->second.get();
}
auto float_ty = Type::GetFloatType();
auto constant = std::make_unique<ConstantFloat>(float_ty, v);
auto* ptr = constant.get();
const_floats_[key] = std::move(constant);
return ptr;
}
ConstantArray* Context::GetConstArray(std::shared_ptr<ArrayType> ty,
std::vector<ConstantValue*> elements) {
// 验证数组常量
size_t expected_size = ty->GetElementCount();
if (elements.size() != expected_size) {
// 如果元素数量不匹配,可能需要补零或报错
// 这里根据需求处理
if (elements.size() < expected_size) {
// 补零
auto elem_type = ty->GetElementType();
while (elements.size() < expected_size) {
if (elem_type->IsInt32()) {
elements.push_back(GetConstInt(0));
} else if (elem_type->IsFloat()) {
elements.push_back(GetConstFloat(0.0f));
}
}
} else {
throw std::runtime_error("Array constant size mismatch");
}
}
// 构建缓存键
struct ArrayKey {
std::shared_ptr<ArrayType> type;
std::vector<ConstantValue*> elements;
bool operator==(const ArrayKey& other) const {
if (type != other.type) return false;
if (elements.size() != other.elements.size()) return false;
for (size_t i = 0; i < elements.size(); ++i) {
if (elements[i] != other.elements[i]) return false;
}
return true;
}
};
struct ArrayKeyHash {
size_t operator()(const ArrayKey& key) const {
size_t hash = std::hash<Type*>{}(key.type.get());
for (auto* elem : key.elements) {
hash ^= std::hash<ConstantValue*>{}(elem) + 0x9e3779b9 + (hash << 6) + (hash >> 2);
}
return hash;
}
};
// 使用静态缓存(需要作为成员变量)
static std::unordered_map<ArrayKey, std::unique_ptr<ConstantArray>, ArrayKeyHash> cache;
ArrayKey key{ty, elements};
auto it = cache.find(key);
if (it != cache.end()) {
return it->second.get();
}
auto constant = std::make_unique<ConstantArray>(ty, std::move(elements));
auto* ptr = constant.get();
cache[std::move(key)] = std::move(constant);
return ptr;
}
ConstantZero* Context::GetZeroConstant(std::shared_ptr<Type> ty) {
auto it = zero_constants_.find(ty.get());
if (it != zero_constants_.end()) {
return it->second.get();
}
auto constant = std::make_unique<ConstantZero>(ty);
auto* ptr = constant.get();
zero_constants_[ty.get()] = std::move(constant);
return ptr;
}
ConstantAggregateZero* Context::GetAggregateZero(std::shared_ptr<Type> ty) {
auto it = aggregate_zeros_.find(ty.get());
if (it != aggregate_zeros_.end()) {
return it->second.get();
}
auto constant = std::make_unique<ConstantAggregateZero>(ty);
auto* ptr = constant.get();
aggregate_zeros_[ty.get()] = std::move(constant);
return ptr;
}
std::string Context::NextTemp() {
std::ostringstream oss;
oss << "%t" << ++temp_index_;
return oss.str();
}
} // namespace ir