You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
127 lines
3.6 KiB
127 lines
3.6 KiB
// ir/IR.cpp
|
|
#include "ir/IR.h"
|
|
#include <cstring>
|
|
#include <sstream>
|
|
#include <functional>
|
|
|
|
namespace ir {
|
|
|
|
Context::~Context() = default;
|
|
|
|
ConstantInt* Context::GetConstInt(int v) {
|
|
auto it = const_ints_.find(v);
|
|
if (it != const_ints_.end()) return it->second.get();
|
|
auto inserted =
|
|
const_ints_.emplace(v, std::make_unique<ConstantInt>(Type::GetInt32Type(), v)).first;
|
|
return inserted->second.get();
|
|
}
|
|
|
|
ConstantFloat* Context::GetConstFloat(float v) {
|
|
uint32_t key;
|
|
std::memcpy(&key, &v, sizeof(float));
|
|
|
|
auto it = const_floats_.find(key);
|
|
if (it != const_floats_.end()) {
|
|
return it->second.get();
|
|
}
|
|
|
|
auto float_ty = Type::GetFloatType();
|
|
auto constant = std::make_unique<ConstantFloat>(float_ty, v);
|
|
auto* ptr = constant.get();
|
|
const_floats_[key] = std::move(constant);
|
|
return ptr;
|
|
}
|
|
|
|
ConstantArray* Context::GetConstArray(std::shared_ptr<ArrayType> ty,
|
|
std::vector<ConstantValue*> elements) {
|
|
// 验证数组常量
|
|
size_t expected_size = ty->GetElementCount();
|
|
if (elements.size() != expected_size) {
|
|
// 如果元素数量不匹配,可能需要补零或报错
|
|
// 这里根据需求处理
|
|
if (elements.size() < expected_size) {
|
|
// 补零
|
|
auto elem_type = ty->GetElementType();
|
|
while (elements.size() < expected_size) {
|
|
if (elem_type->IsInt32()) {
|
|
elements.push_back(GetConstInt(0));
|
|
} else if (elem_type->IsFloat()) {
|
|
elements.push_back(GetConstFloat(0.0f));
|
|
}
|
|
}
|
|
} else {
|
|
throw std::runtime_error("Array constant size mismatch");
|
|
}
|
|
}
|
|
|
|
// 构建缓存键
|
|
struct ArrayKey {
|
|
std::shared_ptr<ArrayType> type;
|
|
std::vector<ConstantValue*> elements;
|
|
|
|
bool operator==(const ArrayKey& other) const {
|
|
if (type != other.type) return false;
|
|
if (elements.size() != other.elements.size()) return false;
|
|
for (size_t i = 0; i < elements.size(); ++i) {
|
|
if (elements[i] != other.elements[i]) return false;
|
|
}
|
|
return true;
|
|
}
|
|
};
|
|
|
|
struct ArrayKeyHash {
|
|
size_t operator()(const ArrayKey& key) const {
|
|
size_t hash = std::hash<Type*>{}(key.type.get());
|
|
for (auto* elem : key.elements) {
|
|
hash ^= std::hash<ConstantValue*>{}(elem) + 0x9e3779b9 + (hash << 6) + (hash >> 2);
|
|
}
|
|
return hash;
|
|
}
|
|
};
|
|
|
|
// 使用静态缓存(需要作为成员变量)
|
|
static std::unordered_map<ArrayKey, std::unique_ptr<ConstantArray>, ArrayKeyHash> cache;
|
|
|
|
ArrayKey key{ty, elements};
|
|
auto it = cache.find(key);
|
|
if (it != cache.end()) {
|
|
return it->second.get();
|
|
}
|
|
|
|
auto constant = std::make_unique<ConstantArray>(ty, std::move(elements));
|
|
auto* ptr = constant.get();
|
|
cache[std::move(key)] = std::move(constant);
|
|
return ptr;
|
|
}
|
|
|
|
ConstantZero* Context::GetZeroConstant(std::shared_ptr<Type> ty) {
|
|
auto it = zero_constants_.find(ty.get());
|
|
if (it != zero_constants_.end()) {
|
|
return it->second.get();
|
|
}
|
|
|
|
auto constant = std::make_unique<ConstantZero>(ty);
|
|
auto* ptr = constant.get();
|
|
zero_constants_[ty.get()] = std::move(constant);
|
|
return ptr;
|
|
}
|
|
|
|
ConstantAggregateZero* Context::GetAggregateZero(std::shared_ptr<Type> ty) {
|
|
auto it = aggregate_zeros_.find(ty.get());
|
|
if (it != aggregate_zeros_.end()) {
|
|
return it->second.get();
|
|
}
|
|
|
|
auto constant = std::make_unique<ConstantAggregateZero>(ty);
|
|
auto* ptr = constant.get();
|
|
aggregate_zeros_[ty.get()] = std::move(constant);
|
|
return ptr;
|
|
}
|
|
|
|
std::string Context::NextTemp() {
|
|
std::ostringstream oss;
|
|
oss << "%t" << ++temp_index_;
|
|
return oss.str();
|
|
}
|
|
|
|
} // namespace ir
|