You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

252 lines
9.1 KiB

#include "irgen/IRGen.h"
#include <any>
#include <stdexcept>
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) {
if (!ctx) return BlockFlow::Continue;
BlockFlow flow = BlockFlow::Continue;
for (auto* item : ctx->blockItem()) {
if (item) {
if (VisitBlockItemResult(*item) == BlockFlow::Terminated) {
flow = BlockFlow::Terminated;
break;
}
}
}
return flow;
}
IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult(SysYParser::BlockItemContext& item) {
return std::any_cast<BlockFlow>(item.accept(this));
}
std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
if (!ctx) return BlockFlow::Continue;
if (ctx->decl()) {
ctx->decl()->accept(this);
return BlockFlow::Continue;
}
if (ctx->stmt()) {
return ctx->stmt()->accept(this);
}
return BlockFlow::Continue;
}
std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
if (!ctx) return {};
if (auto* constDecl = ctx->constDecl()) {
for (auto* def : constDecl->constDef()) {
def->accept(this);
}
return {};
}
if (auto* varDecl = ctx->varDecl()) {
for (auto* varDef : varDecl->varDef()) {
varDef->accept(this);
}
return {};
}
return {};
}
std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
if (!ctx) return {};
if (!ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "变量声明缺少名称"));
}
if (!func_) {
const TypeDesc* ty = sema_.GetVarType(ctx);
if (!ty) {
throw std::runtime_error(FormatError("irgen", "全局变量类型缺失"));
}
if (global_var_storage_.find(ctx) != global_var_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "重复生成全局变量"));
}
ir::ConstantValue* init = nullptr;
if (ty->dims.empty()) {
if (auto* initVal = ctx->initVal()) {
if (!initVal->exp()) {
throw std::runtime_error(FormatError("irgen", "全局变量初始化非法"));
}
init = EvalConstScalar(initVal->exp());
if (ty->base == BaseTypeKind::Int &&
dynamic_cast<ir::ConstantFloat*>(init)) {
auto* cf = static_cast<ir::ConstantFloat*>(init);
init = module_.GetContext().GetConstInt(static_cast<int>(cf->GetValue()));
} else if (ty->base == BaseTypeKind::Float &&
dynamic_cast<ir::ConstantInt*>(init)) {
auto* ci = static_cast<ir::ConstantInt*>(init);
init = module_.GetContext().GetConstFloat(static_cast<float>(ci->GetValue()));
}
}
} else if (auto* initVal = ctx->initVal()) {
size_t total = ArrayTotalSize(*ty);
std::vector<ir::ConstantValue*> values(
total,
ty->base == BaseTypeKind::Float
? static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(0.0f))
: static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(0)));
InitGlobalArray(*ty, initVal, values, 0, 0, 0);
init = module_.GetContext().CreateConstArray(ToIRType(*ty), values);
}
auto* gv = module_.CreateGlobalVariable(ctx->ID()->getText(),
ToIRType(*ty), init, false);
global_var_storage_[ctx] = gv;
return {};
}
if (var_storage_.find(ctx) != var_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "重复生成存储槽位"));
}
const TypeDesc* ty = sema_.GetVarType(ctx);
if (!ty) {
throw std::runtime_error(FormatError("irgen", "变量类型缺失"));
}
auto* slot = CreateEntryAlloca(ToIRType(*ty), module_.GetContext().NextTemp());
var_storage_[ctx] = slot;
if (ty->dims.empty()) {
ir::Value* init = nullptr;
if (auto* initVal = ctx->initVal()) {
if (!initVal->exp()) {
throw std::runtime_error(FormatError("irgen", "标量初始化非法"));
}
init = EvalExp(initVal->exp());
} else {
init = DefaultValue(*ty);
}
if (ty->base == BaseTypeKind::Float) {
if (init->IsInt1() || init->IsInt32()) {
init = CastToFloat(init->IsInt1() ? CastToInt(init) : init);
}
} else if (ty->base == BaseTypeKind::Int) {
if (init->IsFloat() || init->IsInt1()) {
init = CastToInt(init);
}
}
builder_.CreateStore(init, slot);
} else {
if (!ctx->initVal() && ty->dims.size() == 1 && ty->dims[0] >= 1024) {
auto* idx_slot = CreateEntryAlloca(ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
builder_.CreateStore(builder_.CreateConstInt(0), idx_slot);
auto* cond_bb = func_->CreateBlock("arr.zero.cond");
auto* body_bb = func_->CreateBlock("arr.zero.body");
auto* end_bb = func_->CreateBlock("arr.zero.end");
builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(cond_bb);
auto* idx = builder_.CreateLoad(idx_slot, module_.GetContext().NextTemp());
auto* bound = builder_.CreateConstInt(ty->dims[0]);
auto* cmp = builder_.CreateICmp(ir::ICmpPredicate::Slt, idx, bound,
module_.GetContext().NextTemp());
builder_.CreateCondBr(cmp, body_bb, end_bb);
builder_.SetInsertPoint(body_bb);
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
indices.push_back(idx);
auto* elem_addr = builder_.CreateGep(slot, std::move(indices),
module_.GetContext().NextTemp());
builder_.CreateStore(DefaultValue(*ty), elem_addr);
auto* next = builder_.CreateAdd(idx, builder_.CreateConstInt(1),
module_.GetContext().NextTemp());
builder_.CreateStore(next, idx_slot);
builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(end_bb);
} else {
InitArray(slot, *ty, ctx->initVal());
}
}
return {};
}
std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "常量声明缺少名称"));
}
if (!func_) {
const TypeDesc* ty = sema_.GetConstType(ctx);
if (!ty) {
throw std::runtime_error(FormatError("irgen", "全局常量类型缺失"));
}
if (global_const_storage_.find(ctx) != global_const_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "重复生成全局常量"));
}
ir::ConstantValue* init = nullptr;
if (ty->dims.empty()) {
if (auto* initVal = ctx->constInitVal()) {
if (!initVal->constExp()) {
throw std::runtime_error(FormatError("irgen", "全局常量初始化非法"));
}
init = EvalConstScalar(initVal->constExp());
if (ty->base == BaseTypeKind::Int &&
dynamic_cast<ir::ConstantFloat*>(init)) {
auto* cf = static_cast<ir::ConstantFloat*>(init);
init = module_.GetContext().GetConstInt(static_cast<int>(cf->GetValue()));
} else if (ty->base == BaseTypeKind::Float &&
dynamic_cast<ir::ConstantInt*>(init)) {
auto* ci = static_cast<ir::ConstantInt*>(init);
init = module_.GetContext().GetConstFloat(static_cast<float>(ci->GetValue()));
}
}
} else if (auto* initVal = ctx->constInitVal()) {
size_t total = ArrayTotalSize(*ty);
std::vector<ir::ConstantValue*> values(
total,
ty->base == BaseTypeKind::Float
? static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(0.0f))
: static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(0)));
InitGlobalConstArray(*ty, initVal, values, 0, 0, 0);
init = module_.GetContext().CreateConstArray(ToIRType(*ty), values);
}
auto* gv = module_.CreateGlobalVariable(ctx->ID()->getText(),
ToIRType(*ty), init, true);
global_const_storage_[ctx] = gv;
return {};
}
if (const_storage_.find(ctx) != const_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "重复生成常量存储"));
}
const TypeDesc* ty = sema_.GetConstType(ctx);
if (!ty) {
throw std::runtime_error(FormatError("irgen", "常量类型缺失"));
}
auto* slot = CreateEntryAlloca(ToIRType(*ty), module_.GetContext().NextTemp());
const_storage_[ctx] = slot;
if (ty->dims.empty()) {
ir::Value* init = nullptr;
if (auto* initVal = ctx->constInitVal()) {
if (!initVal->constExp()) {
throw std::runtime_error(FormatError("irgen", "常量初始化非法"));
}
init = std::any_cast<ir::Value*>(initVal->constExp()->accept(this));
} else {
init = DefaultValue(*ty);
}
if (ty->base == BaseTypeKind::Float) {
if (init->IsInt1() || init->IsInt32()) {
init = CastToFloat(init->IsInt1() ? CastToInt(init) : init);
}
} else if (ty->base == BaseTypeKind::Int) {
if (init->IsFloat() || init->IsInt1()) {
init = CastToInt(init);
}
}
builder_.CreateStore(init, slot);
} else {
InitConstArray(slot, *ty, ctx->constInitVal());
}
return {};
}