You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

671 lines
22 KiB

#include "irgen/IRGen.h"
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <utility>
namespace {
std::vector<int> ExpandLinearIndex(const std::vector<int>& dims, size_t flat_index) {
std::vector<int> indices(dims.size(), 0);
for (size_t i = dims.size(); i > 0; --i) {
const auto dim_index = i - 1;
indices[dim_index] = static_cast<int>(flat_index % static_cast<size_t>(dims[dim_index]));
flat_index /= static_cast<size_t>(dims[dim_index]);
}
return indices;
}
} // namespace
std::string IRGenImpl::ExpectIdent(const antlr4::ParserRuleContext& ctx,
antlr4::tree::TerminalNode* ident) const {
if (ident == nullptr) {
ThrowError(&ctx, "?????");
}
return ident->getText();
}
SemanticType IRGenImpl::ParseBType(SysYParser::BTypeContext* ctx) const {
if (ctx == nullptr) {
ThrowError(ctx, "??????");
}
if (ctx->INT()) {
return SemanticType::Int;
}
if (ctx->FLOAT()) {
return SemanticType::Float;
}
ThrowError(ctx, "????? int/float ????");
}
SemanticType IRGenImpl::ParseFuncType(SysYParser::FuncTypeContext* ctx) const {
if (ctx == nullptr) {
ThrowError(ctx, "????????");
}
if (ctx->VOID()) {
return SemanticType::Void;
}
if (ctx->INT()) {
return SemanticType::Int;
}
if (ctx->FLOAT()) {
return SemanticType::Float;
}
ThrowError(ctx, "????? void/int/float ??????");
}
std::shared_ptr<ir::Type> IRGenImpl::GetIRScalarType(SemanticType type) const {
switch (type) {
case SemanticType::Void:
return ir::Type::GetVoidType();
case SemanticType::Int:
return ir::Type::GetInt32Type();
case SemanticType::Float:
return ir::Type::GetFloatType();
}
throw std::runtime_error("unknown semantic type");
}
std::shared_ptr<ir::Type> IRGenImpl::BuildArrayType(
SemanticType base_type, const std::vector<int>& dims) const {
auto type = GetIRScalarType(base_type);
for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
type = ir::Type::GetArrayType(type, static_cast<size_t>(*it));
}
return type;
}
std::vector<int> IRGenImpl::ParseArrayDims(
const std::vector<SysYParser::ConstExpContext*>& dims_ctx) {
std::vector<int> dims;
dims.reserve(dims_ctx.size());
for (auto* dim_ctx : dims_ctx) {
if (dim_ctx == nullptr || dim_ctx->addExp() == nullptr) {
ThrowError(dim_ctx, "???????????");
}
auto dim = ConvertConst(EvalConstAddExp(*dim_ctx->addExp()), SemanticType::Int);
if (dim.int_value <= 0) {
ThrowError(dim_ctx, "??????????");
}
dims.push_back(dim.int_value);
}
return dims;
}
std::vector<int> IRGenImpl::ParseParamDims(SysYParser::FuncFParamContext& ctx) {
std::vector<int> dims;
for (auto* exp_ctx : ctx.exp()) {
auto dim = ConvertConst(EvalConstExp(*exp_ctx), SemanticType::Int);
if (dim.int_value <= 0) {
ThrowError(exp_ctx, "????????????");
}
dims.push_back(dim.int_value);
}
return dims;
}
void IRGenImpl::PredeclareGlobalDecl(SysYParser::DeclContext& ctx) {
auto declare_one = [&](const std::string& name, SemanticType type, bool is_const,
const std::vector<int>& dims, const antlr4::ParserRuleContext* node) {
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(node, "????????: " + name);
}
SymbolEntry entry;
entry.kind = is_const ? SymbolKind::Constant : SymbolKind::Variable;
entry.type = type;
entry.is_const = is_const;
entry.is_array = !dims.empty();
entry.is_param_array = false;
entry.dims = dims;
entry.ir_value = module_.CreateGlobalValue(
name, dims.empty() ? GetIRScalarType(type) : BuildArrayType(type, dims),
is_const, nullptr);
if (!symbols_.Insert(name, entry)) {
ThrowError(node, "????????: " + name);
}
};
if (ctx.constDecl() != nullptr) {
const auto type = ParseBType(ctx.constDecl()->bType());
for (auto* def : ctx.constDecl()->constDef()) {
const auto name = ExpectIdent(*def, def->Ident());
const auto dims = ParseArrayDims(def->constExp());
declare_one(name, type, true, dims, def);
auto* symbol = symbols_.Lookup(name);
if (symbol != nullptr && dims.empty()) {
symbol->const_scalar = ConvertConst(
EvalConstAddExp(*def->constInitVal()->constExp()->addExp()), type);
}
}
return;
}
if (ctx.varDecl() != nullptr) {
const auto type = ParseBType(ctx.varDecl()->bType());
for (auto* def : ctx.varDecl()->varDef()) {
declare_one(ExpectIdent(*def, def->Ident()), type, false,
ParseArrayDims(def->constExp()), def);
}
return;
}
ThrowError(&ctx, "????");
}
void IRGenImpl::EmitGlobalDecl(SysYParser::DeclContext& ctx) { EmitDecl(ctx, true); }
void IRGenImpl::EmitDecl(SysYParser::DeclContext& ctx, bool is_global) {
if (ctx.constDecl() != nullptr) {
EmitConstDecl(ctx.constDecl(), is_global);
return;
}
if (ctx.varDecl() != nullptr) {
EmitVarDecl(ctx.varDecl(), is_global, false);
return;
}
ThrowError(&ctx, "????");
}
void IRGenImpl::EmitVarDecl(SysYParser::VarDeclContext* ctx, bool is_global,
bool is_const) {
if (ctx == nullptr) {
ThrowError(ctx, "??????");
}
const auto type = ParseBType(ctx->bType());
for (auto* def : ctx->varDef()) {
if (is_global) {
EmitGlobalVarDef(*def, type);
} else {
EmitLocalVarDef(*def, type, is_const);
}
}
}
void IRGenImpl::EmitConstDecl(SysYParser::ConstDeclContext* ctx, bool is_global) {
if (ctx == nullptr) {
ThrowError(ctx, "??????");
}
const auto type = ParseBType(ctx->bType());
for (auto* def : ctx->constDef()) {
if (is_global) {
EmitGlobalConstDef(*def, type);
} else {
EmitLocalConstDef(*def, type);
}
}
}
void IRGenImpl::EmitGlobalVarDef(SysYParser::VarDefContext& ctx, SemanticType type) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr || !ir::isa<ir::GlobalValue>(symbol->ir_value)) {
ThrowError(&ctx, "??????????: " + name);
}
auto* global = static_cast<ir::GlobalValue*>(symbol->ir_value);
symbol->kind = SymbolKind::Variable;
symbol->type = type;
symbol->is_const = false;
symbol->is_array = !ctx.constExp().empty();
symbol->dims = ParseArrayDims(ctx.constExp());
if (symbol->is_array) {
// Leave uninitialized globals as zeroinitializer instead of materializing
// an explicit all-zero constant array, which can explode memory usage.
if (ctx.initVal() == nullptr) {
global->SetInitializer(nullptr);
return;
}
if (IsExplicitZeroInitVal(ctx.initVal(), type)) {
global->SetInitializer(nullptr);
return;
}
auto flat = FlattenInitVal(ctx.initVal(), type, symbol->dims);
std::vector<ir::Value*> elements;
elements.reserve(flat.size());
for (const auto& value : flat) {
elements.push_back(CreateTypedConstant(value));
}
global->SetInitializer(builder_.CreateConstArray(BuildArrayType(type, symbol->dims),
elements, {}));
} else {
ConstantValue init = ZeroConst(type);
if (ctx.initVal() != nullptr) {
if (ctx.initVal()->exp() == nullptr) {
ThrowError(ctx.initVal(), "???????????????");
}
init = ConvertConst(EvalConstExp(*ctx.initVal()->exp()), type);
}
global->SetInitializer(CreateTypedConstant(init));
}
}
void IRGenImpl::EmitGlobalConstDef(SysYParser::ConstDefContext& ctx,
SemanticType type) {
const auto name = ExpectIdent(ctx, ctx.Ident());
auto* symbol = symbols_.Lookup(name);
if (symbol == nullptr || !ir::isa<ir::GlobalValue>(symbol->ir_value)) {
ThrowError(&ctx, "??????????: " + name);
}
auto* global = static_cast<ir::GlobalValue*>(symbol->ir_value);
symbol->kind = SymbolKind::Constant;
symbol->type = type;
symbol->is_const = true;
symbol->is_array = !ctx.constExp().empty();
symbol->dims = ParseArrayDims(ctx.constExp());
global->SetConstant(true);
if (symbol->is_array) {
if (IsExplicitZeroConstInitVal(ctx.constInitVal(), type)) {
symbol->const_array.clear();
symbol->const_array_all_zero = true;
global->SetInitializer(nullptr);
return;
}
symbol->const_array = FlattenConstInitVal(ctx.constInitVal(), type, symbol->dims);
symbol->const_array_all_zero = false;
std::vector<ir::Value*> elements;
elements.reserve(symbol->const_array.size());
for (const auto& value : symbol->const_array) {
elements.push_back(CreateTypedConstant(value));
}
global->SetInitializer(builder_.CreateConstArray(BuildArrayType(type, symbol->dims),
elements, {}));
} else {
auto init = ConvertConst(EvalConstAddExp(*ctx.constInitVal()->constExp()->addExp()), type);
symbol->const_scalar = init;
global->SetInitializer(CreateTypedConstant(init));
}
}
ir::AllocaInst* IRGenImpl::CreateEntryAlloca(std::shared_ptr<ir::Type> allocated_type,
const std::string& name) {
if (current_function_ == nullptr || current_function_->GetEntryBlock() == nullptr) {
throw std::runtime_error("CreateEntryAlloca requires an active function entry block");
}
auto* entry = current_function_->GetEntryBlock();
size_t insert_pos = 0;
for (const auto& inst : entry->GetInstructions()) {
if (!ir::isa<ir::AllocaInst>(inst.get())) {
break;
}
++insert_pos;
}
return entry->Insert<ir::AllocaInst>(insert_pos, std::move(allocated_type), nullptr,
name);
}
void IRGenImpl::EmitLocalVarDef(SysYParser::VarDefContext& ctx, SemanticType type,
bool is_const) {
const auto name = ExpectIdent(ctx, ctx.Ident());
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(&ctx, "????????: " + name);
}
SymbolEntry entry;
entry.kind = is_const ? SymbolKind::Constant : SymbolKind::Variable;
entry.type = type;
entry.is_const = is_const;
entry.is_array = !ctx.constExp().empty();
entry.dims = ParseArrayDims(ctx.constExp());
if (entry.is_array) {
entry.ir_value = CreateEntryAlloca(BuildArrayType(type, entry.dims),
NextTemp());
} else {
entry.ir_value = CreateEntryAlloca(GetIRScalarType(type),
NextTemp());
}
if (!symbols_.Insert(name, entry)) {
ThrowError(&ctx, "????????: " + name);
}
auto* symbol = symbols_.Lookup(name);
if (!entry.is_array) {
TypedValue init_value{ZeroIRValue(type), type, false, {}};
if (ctx.initVal() != nullptr) {
if (ctx.initVal()->exp() == nullptr) {
ThrowError(ctx.initVal(), "???????????????");
}
init_value = CastScalar(EmitExp(*ctx.initVal()->exp()), type, ctx.initVal());
}
builder_.CreateStore(init_value.value, symbol->ir_value);
return;
}
ZeroInitializeLocalArray(symbol->ir_value, type, symbol->dims);
if (ctx.initVal() != nullptr) {
auto init_slots = FlattenLocalInitVal(ctx.initVal(), symbol->dims);
StoreLocalArrayElements(symbol->ir_value, type, symbol->dims, init_slots);
}
}
void IRGenImpl::EmitLocalConstDef(SysYParser::ConstDefContext& ctx,
SemanticType type) {
const auto name = ExpectIdent(ctx, ctx.Ident());
if (symbols_.ContainsInCurrentScope(name)) {
ThrowError(&ctx, "????????: " + name);
}
SymbolEntry entry;
entry.kind = SymbolKind::Constant;
entry.type = type;
entry.is_const = true;
entry.is_array = !ctx.constExp().empty();
entry.dims = ParseArrayDims(ctx.constExp());
entry.ir_value = CreateEntryAlloca(
entry.is_array ? BuildArrayType(type, entry.dims) : GetIRScalarType(type),
NextTemp());
if (!symbols_.Insert(name, entry)) {
ThrowError(&ctx, "????????: " + name);
}
auto* symbol = symbols_.Lookup(name);
if (!entry.is_array) {
auto init = ConvertConst(EvalConstAddExp(*ctx.constInitVal()->constExp()->addExp()), type);
symbol->const_scalar = init;
builder_.CreateStore(CreateTypedConstant(init), symbol->ir_value);
return;
}
if (IsExplicitZeroConstInitVal(ctx.constInitVal(), type)) {
symbol->const_array.clear();
symbol->const_array_all_zero = true;
ZeroInitializeLocalArray(symbol->ir_value, type, symbol->dims);
return;
}
symbol->const_array = FlattenConstInitVal(ctx.constInitVal(), type, symbol->dims);
symbol->const_array_all_zero = false;
ZeroInitializeLocalArray(symbol->ir_value, type, symbol->dims);
for (size_t i = 0; i < symbol->const_array.size(); ++i) {
if (symbol->const_array[i].type == SemanticType::Int && symbol->const_array[i].int_value == 0) {
continue;
}
if (symbol->const_array[i].type == SemanticType::Float &&
symbol->const_array[i].float_value == 0.0f) {
continue;
}
const auto indices = ExpandLinearIndex(symbol->dims, i);
std::vector<ir::Value*> index_values;
index_values.reserve(indices.size());
for (int index : indices) {
index_values.push_back(builder_.CreateConstInt(index));
}
auto* addr = CreateArrayElementAddr(symbol->ir_value, false, type, symbol->dims,
index_values, &ctx);
builder_.CreateStore(CreateTypedConstant(symbol->const_array[i]), addr);
}
}
std::vector<ConstantValue> IRGenImpl::FlattenConstInitVal(
SysYParser::ConstInitValContext* ctx, SemanticType base_type,
const std::vector<int>& dims) {
std::vector<ConstantValue> out(CountArrayElements(dims), ZeroConst(base_type));
if (ctx != nullptr) {
size_t cursor = 0;
FlattenConstInitValImpl(ctx, base_type, dims, 0, 0, out.size(), cursor, out);
}
return out;
}
std::vector<ConstantValue> IRGenImpl::FlattenInitVal(
SysYParser::InitValContext* ctx, SemanticType base_type,
const std::vector<int>& dims) {
std::vector<ConstantValue> out(CountArrayElements(dims), ZeroConst(base_type));
if (ctx != nullptr) {
size_t cursor = 0;
FlattenInitValImpl(ctx, base_type, dims, 0, 0, out.size(), cursor, out);
}
return out;
}
std::vector<IRGenImpl::InitExprSlot> IRGenImpl::FlattenLocalInitVal(
SysYParser::InitValContext* ctx, const std::vector<int>& dims) {
std::vector<InitExprSlot> out;
if (ctx != nullptr) {
size_t cursor = 0;
FlattenLocalInitValImpl(ctx, dims, 0, 0, CountArrayElements(dims), cursor, out);
}
return out;
}
void IRGenImpl::FlattenConstInitValImpl(SysYParser::ConstInitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims,
size_t depth, size_t object_begin,
size_t object_end, size_t& cursor,
std::vector<ConstantValue>& out) {
if (ctx == nullptr || cursor >= object_end) {
return;
}
if (ctx->constExp() != nullptr) {
out[cursor++] = ConvertConst(EvalConstAddExp(*ctx->constExp()->addExp()), base_type);
return;
}
for (auto* child : ctx->constInitVal()) {
if (cursor >= object_end) {
break;
}
if (child->constExp() == nullptr && depth + 1 < dims.size()) {
cursor = AlignInitializerCursor(dims, depth, object_begin, cursor);
const auto child_begin = cursor;
const auto child_end = std::min(object_end,
child_begin + CountArrayElements(dims, depth + 1));
FlattenConstInitValImpl(child, base_type, dims, depth + 1, child_begin,
child_end, cursor, out);
cursor = child_end;
} else {
FlattenConstInitValImpl(child, base_type, dims, depth + 1, object_begin,
object_end, cursor, out);
}
}
}
void IRGenImpl::FlattenInitValImpl(SysYParser::InitValContext* ctx,
SemanticType base_type,
const std::vector<int>& dims, size_t depth,
size_t object_begin, size_t object_end,
size_t& cursor,
std::vector<ConstantValue>& out) {
if (ctx == nullptr || cursor >= object_end) {
return;
}
if (ctx->exp() != nullptr) {
out[cursor++] = ConvertConst(EvalConstExp(*ctx->exp()), base_type);
return;
}
for (auto* child : ctx->initVal()) {
if (cursor >= object_end) {
break;
}
if (child->exp() == nullptr && depth + 1 < dims.size()) {
cursor = AlignInitializerCursor(dims, depth, object_begin, cursor);
const auto child_begin = cursor;
const auto child_end = std::min(object_end,
child_begin + CountArrayElements(dims, depth + 1));
FlattenInitValImpl(child, base_type, dims, depth + 1, child_begin,
child_end, cursor, out);
cursor = child_end;
} else {
FlattenInitValImpl(child, base_type, dims, depth + 1, object_begin,
object_end, cursor, out);
}
}
}
void IRGenImpl::FlattenLocalInitValImpl(SysYParser::InitValContext* ctx,
const std::vector<int>& dims,
size_t depth, size_t object_begin,
size_t object_end, size_t& cursor,
std::vector<InitExprSlot>& out) {
if (ctx == nullptr || cursor >= object_end) {
return;
}
if (ctx->exp() != nullptr) {
out.push_back({cursor++, ctx->exp()});
return;
}
for (auto* child : ctx->initVal()) {
if (cursor >= object_end) {
break;
}
if (child->exp() == nullptr && depth + 1 < dims.size()) {
cursor = AlignInitializerCursor(dims, depth, object_begin, cursor);
const auto child_begin = cursor;
const auto child_end = std::min(object_end,
child_begin + CountArrayElements(dims, depth + 1));
FlattenLocalInitValImpl(child, dims, depth + 1, child_begin, child_end,
cursor, out);
cursor = child_end;
} else {
FlattenLocalInitValImpl(child, dims, depth + 1, object_begin, object_end,
cursor, out);
}
}
}
size_t IRGenImpl::CountArrayElements(const std::vector<int>& dims, size_t start) const {
size_t count = 1;
for (size_t i = start; i < dims.size(); ++i) {
count *= static_cast<size_t>(dims[i]);
}
return count;
}
size_t IRGenImpl::AlignInitializerCursor(const std::vector<int>& dims,
size_t depth, size_t object_begin,
size_t cursor) const {
if (depth + 1 >= dims.size()) {
return cursor;
}
const auto stride = CountArrayElements(dims, depth + 1);
const auto relative = cursor - object_begin;
return object_begin + ((relative + stride - 1) / stride) * stride;
}
size_t IRGenImpl::FlattenIndices(const std::vector<int>& dims,
const std::vector<int>& indices) const {
size_t offset = 0;
for (size_t i = 0; i < dims.size(); ++i) {
offset *= static_cast<size_t>(dims[i]);
offset += static_cast<size_t>(indices[i]);
}
return offset;
}
bool IRGenImpl::IsZeroConstant(const ConstantValue& value) const {
switch (value.type) {
case SemanticType::Int:
return value.int_value == 0;
case SemanticType::Float: {
std::uint32_t bits = 0;
std::memcpy(&bits, &value.float_value, sizeof(bits));
return bits == 0;
}
case SemanticType::Void:
return false;
}
return false;
}
bool IRGenImpl::IsExplicitZeroConstInitVal(SysYParser::ConstInitValContext* ctx,
SemanticType base_type) {
if (ctx == nullptr) {
return true;
}
if (ctx->constExp() != nullptr) {
return IsZeroConstant(
ConvertConst(EvalConstAddExp(*ctx->constExp()->addExp()), base_type));
}
for (auto* child : ctx->constInitVal()) {
if (!IsExplicitZeroConstInitVal(child, base_type)) {
return false;
}
}
return true;
}
bool IRGenImpl::IsExplicitZeroInitVal(SysYParser::InitValContext* ctx,
SemanticType base_type) {
if (ctx == nullptr) {
return true;
}
if (ctx->exp() != nullptr) {
return IsZeroConstant(ConvertConst(EvalConstExp(*ctx->exp()), base_type));
}
for (auto* child : ctx->initVal()) {
if (!IsExplicitZeroInitVal(child, base_type)) {
return false;
}
}
return true;
}
ConstantValue IRGenImpl::ZeroConst(SemanticType type) const {
ConstantValue value;
value.type = type;
value.int_value = 0;
value.float_value = 0.0f;
return value;
}
ir::Value* IRGenImpl::ZeroIRValue(SemanticType type) {
switch (type) {
case SemanticType::Int:
return builder_.CreateConstInt(0);
case SemanticType::Float:
return builder_.CreateConstFloat(0.0f);
case SemanticType::Void:
break;
}
throw std::runtime_error("void type has no zero IR value");
}
ir::Value* IRGenImpl::CreateTypedConstant(const ConstantValue& value) {
switch (value.type) {
case SemanticType::Int:
return builder_.CreateConstInt(value.int_value);
case SemanticType::Float:
return builder_.CreateConstFloat(value.float_value);
case SemanticType::Void:
break;
}
throw std::runtime_error("void type has no constant value");
}
void IRGenImpl::ZeroInitializeLocalArray(ir::Value* addr, SemanticType base_type,
const std::vector<int>& dims) {
const auto elem_count = CountArrayElements(dims);
int bytes = static_cast<int>(elem_count * (base_type == SemanticType::Float ? 4 : 4));
builder_.CreateMemset(addr, builder_.CreateConstInt(0), builder_.CreateConstInt(bytes),
builder_.CreateConstBool(false));
}
void IRGenImpl::StoreLocalArrayElements(ir::Value* addr, SemanticType base_type,
const std::vector<int>& dims,
const std::vector<InitExprSlot>& init_slots) {
for (const auto& slot : init_slots) {
const auto indices = ExpandLinearIndex(dims, slot.index);
std::vector<ir::Value*> index_values;
index_values.reserve(indices.size());
for (int index : indices) {
index_values.push_back(builder_.CreateConstInt(index));
}
auto* elem_addr = CreateArrayElementAddr(addr, false, base_type, dims,
index_values, slot.expr);
auto value = CastScalar(EmitExp(*slot.expr), base_type, slot.expr);
builder_.CreateStore(value.value, elem_addr);
}
}