You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1070 lines
38 KiB

#include "irgen/IRGen.h"
#include <iostream>
#include <stdexcept>
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
namespace {
ir::Value* AnyToValue(const std::any& val) {
if (val.type() == typeid(ir::Value*)) {
return std::any_cast<ir::Value*>(val);
}
if (val.type() == typeid(ir::ConstantInt*)) {
return static_cast<ir::Value*>(std::any_cast<ir::ConstantInt*>(val));
}
if (val.type() == typeid(ir::ConstantFloat*)) {
return static_cast<ir::Value*>(std::any_cast<ir::ConstantFloat*>(val));
}
if (val.type() == typeid(ir::Instruction*)) {
return std::any_cast<ir::Instruction*>(val);
}
std::cerr << "Unknown type in AnyToValue: " << val.type().name() << std::endl;
throw std::bad_any_cast();
}
ir::Function* FindFunctionByName(ir::Module& module, const std::string& name) {
for (const auto& fn : module.GetFunctions()) {
if (fn && fn->GetName() == name) return fn.get();
}
return nullptr;
}
} // namespace
ir::Value* IRGenImpl::EvalExp(SysYParser::ExpContext* ctx) {
if (!ctx) return nullptr;
return AnyToValue(ctx->accept(this));
}
std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) {
if (!ctx) return static_cast<ir::Value*>(nullptr);
if (ctx->addExp()) return ctx->addExp()->accept(this);
throw std::runtime_error(FormatError("irgen", "不支持的表达式"));
}
std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
auto muls = ctx->mulExp();
if (muls.empty()) return static_cast<ir::Value*>(nullptr);
ir::Value* lhs = AnyToValue(muls[0]->accept(this));
for (size_t i = 1; i < muls.size(); ++i) {
ir::Value* rhs = AnyToValue(muls[i]->accept(this));
auto* node = ctx->children.at(2 * i - 1);
std::string text = node ? node->getText() : "+";
bool use_float = lhs->IsFloat() || rhs->IsFloat();
if (use_float) {
lhs = CastToFloat(lhs);
rhs = CastToFloat(rhs);
if (text == "+") {
lhs = builder_.CreateFAdd(lhs, rhs, module_.GetContext().NextTemp());
} else {
lhs = builder_.CreateFSub(lhs, rhs, module_.GetContext().NextTemp());
}
} else {
if (text == "+") {
lhs = builder_.CreateAdd(lhs, rhs, module_.GetContext().NextTemp());
} else {
lhs = builder_.CreateSub(lhs, rhs, module_.GetContext().NextTemp());
}
}
}
return static_cast<ir::Value*>(lhs);
}
std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
auto unaries = ctx->unaryExp();
if (unaries.empty()) return static_cast<ir::Value*>(nullptr);
ir::Value* lhs = AnyToValue(unaries[0]->accept(this));
for (size_t i = 1; i < unaries.size(); ++i) {
ir::Value* rhs = AnyToValue(unaries[i]->accept(this));
auto* node = ctx->children.at(2 * i - 1);
std::string text = node ? node->getText() : "*";
bool use_float = lhs->IsFloat() || rhs->IsFloat();
if (use_float) {
lhs = CastToFloat(lhs);
rhs = CastToFloat(rhs);
if (text == "*") {
lhs = builder_.CreateFMul(lhs, rhs, module_.GetContext().NextTemp());
} else if (text == "/") {
lhs = builder_.CreateFDiv(lhs, rhs, module_.GetContext().NextTemp());
} else {
throw std::runtime_error(FormatError("irgen", "float 不支持 %"));
}
} else {
if (text == "*") {
lhs = builder_.CreateMul(lhs, rhs, module_.GetContext().NextTemp());
} else if (text == "/") {
lhs = builder_.CreateSDiv(lhs, rhs, module_.GetContext().NextTemp());
} else {
lhs = builder_.CreateSRem(lhs, rhs, module_.GetContext().NextTemp());
}
}
}
return static_cast<ir::Value*>(lhs);
}
std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
if (ctx->primaryExp()) return ctx->primaryExp()->accept(this);
if (ctx->ID() && ctx->LPAREN()) {
std::string name = ctx->ID()->getText();
ir::Function* callee = FindFunctionByName(module_, name);
if (!callee) {
throw std::runtime_error(FormatError("irgen", "未定义函数: " + name));
}
std::vector<ir::Value*> args;
if (ctx->funcRParams()) {
for (auto* exp : ctx->funcRParams()->exp()) {
args.push_back(EvalExp(exp));
}
}
const auto& param_tys = callee->GetFunctionType()->GetParamTypes();
for (size_t i = 0; i < args.size() && i < param_tys.size(); ++i) {
auto* arg = args[i];
const auto& pty = param_tys[i];
if (pty->IsPointer()) {
if (arg && arg->GetType() && arg->GetType()->IsPointer()) {
bool param_elem_array = pty->GetElementType()->IsArray();
bool arg_elem_array = arg->GetType()->GetElementType()->IsArray();
if (!param_elem_array && arg_elem_array) {
std::vector<ir::Value*> idx = {builder_.CreateConstInt(0),
builder_.CreateConstInt(0)};
args[i] = builder_.CreateGep(arg, std::move(idx),
module_.GetContext().NextTemp());
arg = args[i];
} else if (param_elem_array && arg_elem_array) {
auto* param_elem = pty->GetElementType().get();
auto* arg_elem = arg->GetType()->GetElementType().get();
if (param_elem && arg_elem && arg_elem->IsArray() &&
arg_elem->GetElementType()->Equals(*param_elem)) {
std::vector<ir::Value*> idx = {builder_.CreateConstInt(0),
builder_.CreateConstInt(0)};
args[i] = builder_.CreateGep(arg, std::move(idx),
module_.GetContext().NextTemp());
arg = args[i];
}
} else if (param_elem_array && !arg_elem_array) {
if (auto* gep = dynamic_cast<ir::GepInst*>(arg)) {
const auto& idx = gep->GetIndices();
auto is_zero = [](ir::Value* v) {
auto* ci = dynamic_cast<ir::ConstantInt*>(v);
return ci && ci->GetValue() == 0;
};
if (idx.size() == 2 && is_zero(idx[0]) && is_zero(idx[1])) {
auto* base = gep->GetBasePtr();
if (base && base->GetType() && base->GetType()->IsPointer() &&
base->GetType()->GetElementType()->IsArray()) {
args[i] = base;
arg = base;
}
}
}
}
} else if (arg && arg->GetType() && arg->GetType()->IsInt32()) {
if (auto* load = dynamic_cast<ir::LoadInst*>(arg)) {
auto* base = load->GetPtr();
if (base && base->GetType() && base->GetType()->IsPointer()) {
auto* elem = base->GetType()->GetElementType().get();
if (elem && elem->IsPointer() &&
elem->GetElementType()->Equals(*pty->GetElementType())) {
args[i] = base;
arg = base;
}
}
}
}
}
if (pty->IsFloat() && (arg->IsInt1() || arg->IsInt32())) {
args[i] = CastToFloat(arg);
} else if (pty->IsInt32() && (arg->IsInt1() || arg->IsFloat())) {
args[i] = CastToInt(arg);
}
}
std::string tmp = callee->GetReturnType()->IsVoid()
? std::string("")
: module_.GetContext().NextTemp();
auto* call = builder_.CreateCall(callee, std::move(args), tmp);
return static_cast<ir::Value*>(call);
}
if (ctx->unaryOp() && ctx->unaryExp()) {
std::string op = ctx->unaryOp()->getText();
ir::Value* val = AnyToValue(ctx->unaryExp()->accept(this));
if (op == "+") return static_cast<ir::Value*>(val);
if (op == "-") {
if (val->IsFloat()) {
auto* zero = builder_.CreateConstFloat(0.0f);
return static_cast<ir::Value*>(
builder_.CreateFSub(zero, val, module_.GetContext().NextTemp()));
}
auto* zero = builder_.CreateConstInt(0);
return static_cast<ir::Value*>(
builder_.CreateSub(zero, val, module_.GetContext().NextTemp()));
}
if (op == "!") {
ir::Value* b = MakeBool(val);
auto* zero = builder_.CreateConstBool(false);
return static_cast<ir::Value*>(
builder_.CreateICmp(ir::ICmpPredicate::Eq, b, zero,
module_.GetContext().NextTemp()));
}
}
return static_cast<ir::Value*>(nullptr);
}
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
ir::Value* val = EmitRelEq(ctx);
return static_cast<ir::Value*>(val);
}
std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
ir::Value* val = EmitEq(ctx);
return static_cast<ir::Value*>(val);
}
std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
return ctx->eqExp(0)->accept(this);
}
std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
return ctx->lAndExp(0)->accept(this);
}
std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) {
if (ctx->lOrExp()) return ctx->lOrExp()->accept(this);
return static_cast<ir::Value*>(nullptr);
}
std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) {
return ctx ? static_cast<ir::Value*>(nullptr) : static_cast<ir::Value*>(nullptr);
}
std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
if (!ctx) return static_cast<ir::Value*>(nullptr);
if (ctx->LPAREN() && ctx->exp()) return EvalExp(ctx->exp());
if (ctx->lVal()) return ctx->lVal()->accept(this);
if (ctx->number()) return ctx->number()->accept(this);
throw std::runtime_error(FormatError("irgen", "不支持的 PrimaryExp"));
}
std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法常量"));
}
if (ctx->INT_CONST()) {
const std::string text = ctx->getText();
size_t idx = 0;
long long val = std::stoll(text, &idx, 0);
if (idx != text.size()) {
throw std::runtime_error(FormatError("irgen", "非法整数常量"));
}
return static_cast<ir::Value*>(builder_.CreateConstInt(val));
}
if (ctx->FLOAT_CONST()) {
float val = std::stof(ctx->getText());
return static_cast<ir::Value*>(builder_.CreateConstFloat(val));
}
throw std::runtime_error(FormatError("irgen", "不支持的常量类型"));
}
std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "非法左值"));
}
ir::Value* addr = GetLValAddress(ctx);
BoundDecl bound = sema_.ResolveVarUse(ctx);
const TypeDesc* ty = nullptr;
if (bound.kind == BoundDecl::Kind::Var && bound.var_decl) {
ty = sema_.GetVarType(bound.var_decl);
} else if (bound.kind == BoundDecl::Kind::Const && bound.const_decl) {
ty = sema_.GetConstType(bound.const_decl);
} else if (bound.kind == BoundDecl::Kind::Param && bound.param_decl) {
ty = sema_.GetParamType(bound.param_decl);
}
if (!ty && ctx->ID()) {
const auto name = ctx->ID()->getText();
for (const auto& kv : var_storage_) {
auto* def = const_cast<SysYParser::VarDefContext*>(kv.first);
if (def && def->ID() && def->ID()->getText() == name) {
ty = sema_.GetVarType(kv.first);
break;
}
}
if (!ty) {
for (const auto& kv : const_storage_) {
auto* def = const_cast<SysYParser::ConstDefContext*>(kv.first);
if (def && def->ID() && def->ID()->getText() == name) {
ty = sema_.GetConstType(kv.first);
break;
}
}
}
if (!ty) {
for (const auto& kv : param_storage_) {
auto* def = const_cast<SysYParser::FuncFParamContext*>(kv.first);
if (def && def->ID() && def->ID()->getText() == name) {
ty = sema_.GetParamType(kv.first);
break;
}
}
}
if (!ty) {
for (const auto& kv : global_var_storage_) {
auto* def = const_cast<SysYParser::VarDefContext*>(kv.first);
if (def && def->ID() && def->ID()->getText() == name) {
ty = sema_.GetVarType(kv.first);
break;
}
}
}
if (!ty) {
for (const auto& kv : global_const_storage_) {
auto* def = const_cast<SysYParser::ConstDefContext*>(kv.first);
if (def && def->ID() && def->ID()->getText() == name) {
ty = sema_.GetConstType(kv.first);
break;
}
}
}
}
if (!ty) {
throw std::runtime_error(FormatError("irgen", "无法解析左值类型"));
}
bool as_rvalue = true;
if (!ty->dims.empty()) {
const size_t index_count = ctx->exp().size();
if (index_count == 0 || index_count < ty->dims.size()) {
as_rvalue = false;
}
}
return static_cast<ir::Value*>(LoadIfNeeded(addr, *ty, as_rvalue));
}
std::any IRGenImpl::visitConstExp(SysYParser::ConstExpContext* ctx) {
if (!ctx) return static_cast<ir::Value*>(nullptr);
return ctx->addExp()->accept(this);
}
std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) {
return ctx ? static_cast<ir::Value*>(nullptr) : static_cast<ir::Value*>(nullptr);
}
std::any IRGenImpl::visitInitVal(SysYParser::InitValContext* ctx) {
return ctx ? static_cast<ir::Value*>(nullptr) : static_cast<ir::Value*>(nullptr);
}
ir::Value* IRGenImpl::CastToFloat(ir::Value* v) {
if (v->IsFloat()) return v;
if (v->IsInt1() || v->IsInt32()) {
return builder_.CreateSIToFP(v, ir::Type::GetFloatType(),
module_.GetContext().NextTemp());
}
throw std::runtime_error(FormatError("irgen", "无法转换为 float"));
}
ir::Value* IRGenImpl::CastToInt(ir::Value* v) {
if (v->IsInt32()) return v;
if (v->IsInt1()) {
return builder_.CreateZExt(v, ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
}
if (v->IsFloat()) {
return builder_.CreateFPToSI(v, ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
}
throw std::runtime_error(FormatError("irgen", "无法转换为 int"));
}
ir::Value* IRGenImpl::MakeBool(ir::Value* v) {
if (v->IsInt1()) return v;
if (v->IsFloat()) {
auto* zero = builder_.CreateConstFloat(0.0f);
return builder_.CreateFCmp(ir::FCmpPredicate::One, v, zero,
module_.GetContext().NextTemp());
}
auto* zero = builder_.CreateConstInt(0);
return builder_.CreateICmp(ir::ICmpPredicate::Ne, v, zero,
module_.GetContext().NextTemp());
}
ir::Value* IRGenImpl::EmitRelEq(SysYParser::RelExpContext* ctx) {
auto exps = ctx->addExp();
if (exps.empty()) return nullptr;
ir::Value* lhs = AnyToValue(exps[0]->accept(this));
if (exps.size() == 1) return lhs;
for (size_t i = 1; i < exps.size(); ++i) {
ir::Value* rhs = AnyToValue(exps[i]->accept(this));
auto* node = ctx->children.at(2 * i - 1);
std::string text = node ? node->getText() : "<";
bool use_float = lhs->IsFloat() || rhs->IsFloat();
if (use_float) {
lhs = CastToFloat(lhs);
rhs = CastToFloat(rhs);
ir::FCmpPredicate pred = ir::FCmpPredicate::Olt;
if (text == "<") pred = ir::FCmpPredicate::Olt;
else if (text == "<=") pred = ir::FCmpPredicate::Ole;
else if (text == ">") pred = ir::FCmpPredicate::Ogt;
else pred = ir::FCmpPredicate::Oge;
lhs = builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp());
} else {
ir::ICmpPredicate pred = ir::ICmpPredicate::Slt;
if (text == "<") pred = ir::ICmpPredicate::Slt;
else if (text == "<=") pred = ir::ICmpPredicate::Sle;
else if (text == ">") pred = ir::ICmpPredicate::Sgt;
else pred = ir::ICmpPredicate::Sge;
lhs = builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp());
}
}
return lhs;
}
ir::Value* IRGenImpl::EmitEq(SysYParser::EqExpContext* ctx) {
auto rels = ctx->relExp();
if (rels.empty()) return nullptr;
ir::Value* lhs = EmitRelEq(rels[0]);
if (rels.size() == 1) return lhs;
for (size_t i = 1; i < rels.size(); ++i) {
ir::Value* rhs = EmitRelEq(rels[i]);
auto* node = ctx->children.at(2 * i - 1);
std::string text = node ? node->getText() : "==";
if (lhs->IsFloat() || rhs->IsFloat()) {
lhs = CastToFloat(lhs);
rhs = CastToFloat(rhs);
ir::FCmpPredicate pred = text == "==" ? ir::FCmpPredicate::Oeq
: ir::FCmpPredicate::One;
lhs = builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp());
} else {
ir::ICmpPredicate pred = text == "==" ? ir::ICmpPredicate::Eq
: ir::ICmpPredicate::Ne;
lhs = builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp());
}
}
return lhs;
}
ir::Value* IRGenImpl::EvalCondValue(SysYParser::CondContext* ctx) {
if (!ctx) return nullptr;
auto* tmp_true = func_->CreateBlock("cond.true");
auto* tmp_false = func_->CreateBlock("cond.false");
auto* merge = func_->CreateBlock("cond.merge");
EmitCondBr(ctx, tmp_true, tmp_false);
builder_.SetInsertPoint(tmp_true);
builder_.CreateBr(merge);
builder_.SetInsertPoint(tmp_false);
builder_.CreateBr(merge);
builder_.SetInsertPoint(merge);
auto* phi = builder_.CreatePhi(ir::Type::GetInt1Type(),
module_.GetContext().NextTemp());
phi->AddIncoming(builder_.CreateConstBool(true), tmp_true);
phi->AddIncoming(builder_.CreateConstBool(false), tmp_false);
return phi;
}
void IRGenImpl::EmitCondBr(SysYParser::CondContext* ctx, ir::BasicBlock* true_bb,
ir::BasicBlock* false_bb) {
if (!ctx || !ctx->lOrExp()) {
throw std::runtime_error(FormatError("irgen", "非法 cond"));
}
EmitLOrCond(ctx->lOrExp(), true_bb, false_bb);
}
void IRGenImpl::EmitLOrCond(SysYParser::LOrExpContext* ctx,
ir::BasicBlock* true_bb,
ir::BasicBlock* false_bb) {
auto ands = ctx->lAndExp();
if (ands.empty()) {
builder_.CreateBr(false_bb);
return;
}
for (size_t i = 0; i < ands.size(); ++i) {
if (i == ands.size() - 1) {
EmitLAndCond(ands[i], true_bb, false_bb);
} else {
auto* next = func_->CreateBlock("lor.next");
EmitLAndCond(ands[i], true_bb, next);
builder_.SetInsertPoint(next);
}
}
}
void IRGenImpl::EmitLAndCond(SysYParser::LAndExpContext* ctx,
ir::BasicBlock* true_bb,
ir::BasicBlock* false_bb) {
auto eqs = ctx->eqExp();
if (eqs.empty()) {
builder_.CreateBr(false_bb);
return;
}
for (size_t i = 0; i < eqs.size(); ++i) {
ir::Value* cond = EmitEq(eqs[i]);
cond = MakeBool(cond);
if (i == eqs.size() - 1) {
builder_.CreateCondBr(cond, true_bb, false_bb);
} else {
auto* next = func_->CreateBlock("land.next");
builder_.CreateCondBr(cond, next, false_bb);
builder_.SetInsertPoint(next);
}
}
}
ir::Value* IRGenImpl::GetLValAddress(SysYParser::LValContext* ctx) {
BoundDecl bound = sema_.ResolveVarUse(ctx);
ir::Value* base_ptr = nullptr;
const TypeDesc* ty = nullptr;
if (bound.kind == BoundDecl::Kind::Var && bound.var_decl) {
ty = sema_.GetVarType(bound.var_decl);
base_ptr = var_storage_[bound.var_decl];
} else if (bound.kind == BoundDecl::Kind::Const && bound.const_decl) {
ty = sema_.GetConstType(bound.const_decl);
base_ptr = const_storage_[bound.const_decl];
} else if (bound.kind == BoundDecl::Kind::Param && bound.param_decl) {
ty = sema_.GetParamType(bound.param_decl);
base_ptr = param_storage_[bound.param_decl];
}
if (!base_ptr && bound.kind == BoundDecl::Kind::Var && bound.var_decl) {
auto it = global_var_storage_.find(bound.var_decl);
if (it != global_var_storage_.end()) {
ty = sema_.GetVarType(bound.var_decl);
base_ptr = it->second;
}
}
if (!base_ptr && bound.kind == BoundDecl::Kind::Const && bound.const_decl) {
auto it = global_const_storage_.find(bound.const_decl);
if (it != global_const_storage_.end()) {
ty = sema_.GetConstType(bound.const_decl);
base_ptr = it->second;
}
}
if (!base_ptr && ctx && ctx->ID()) {
const auto name = ctx->ID()->getText();
for (const auto& kv : var_storage_) {
auto* def = const_cast<SysYParser::VarDefContext*>(kv.first);
if (def && def->ID() && def->ID()->getText() == name) {
ty = sema_.GetVarType(kv.first);
base_ptr = kv.second;
break;
}
}
if (!base_ptr) {
for (const auto& kv : const_storage_) {
auto* def = const_cast<SysYParser::ConstDefContext*>(kv.first);
if (def && def->ID() && def->ID()->getText() == name) {
ty = sema_.GetConstType(kv.first);
base_ptr = kv.second;
break;
}
}
}
if (!base_ptr) {
for (const auto& kv : param_storage_) {
auto* def = const_cast<SysYParser::FuncFParamContext*>(kv.first);
if (def && def->ID() && def->ID()->getText() == name) {
ty = sema_.GetParamType(kv.first);
base_ptr = kv.second;
break;
}
}
}
if (!base_ptr) {
for (const auto& kv : global_var_storage_) {
auto* def = const_cast<SysYParser::VarDefContext*>(kv.first);
if (def && def->ID() && def->ID()->getText() == name) {
ty = sema_.GetVarType(kv.first);
base_ptr = kv.second;
break;
}
}
}
if (!base_ptr) {
for (const auto& kv : global_const_storage_) {
auto* def = const_cast<SysYParser::ConstDefContext*>(kv.first);
if (def && def->ID() && def->ID()->getText() == name) {
ty = sema_.GetConstType(kv.first);
base_ptr = kv.second;
break;
}
}
}
}
if (!base_ptr || !ty) {
throw std::runtime_error(FormatError("irgen", "左值未绑定"));
}
if (bound.kind == BoundDecl::Kind::Param && !ty->dims.empty()) {
base_ptr = builder_.CreateLoad(base_ptr, module_.GetContext().NextTemp());
}
std::vector<ir::Value*> indices;
const auto exps = ctx->exp();
if (!ty->dims.empty() && !exps.empty()) {
bool need_leading_zero = base_ptr->GetType()->GetElementType()->IsArray();
if (!ty->dims.empty() && ty->dims[0] < 0) {
need_leading_zero = false;
}
if (need_leading_zero) {
indices.push_back(builder_.CreateConstInt(0));
}
}
for (auto* exp : exps) {
indices.push_back(CastToInt(EvalExp(exp)));
}
if (!indices.empty()) {
if (base_ptr->GetType() && base_ptr->GetType()->IsPointer() &&
base_ptr->GetType()->GetElementType()->IsPointer()) {
base_ptr = builder_.CreateLoad(base_ptr, module_.GetContext().NextTemp());
}
return builder_.CreateGep(base_ptr, std::move(indices),
module_.GetContext().NextTemp());
}
return base_ptr;
}
ir::Value* IRGenImpl::LoadIfNeeded(ir::Value* addr_or_val, const TypeDesc& ty,
bool as_rvalue) {
if (!as_rvalue) {
return addr_or_val;
}
return builder_.CreateLoad(addr_or_val, module_.GetContext().NextTemp());
}
std::shared_ptr<ir::Type> IRGenImpl::ToIRType(const TypeDesc& ty) {
std::shared_ptr<ir::Type> base;
if (ty.base == BaseTypeKind::Int) base = ir::Type::GetInt32Type();
else if (ty.base == BaseTypeKind::Float) base = ir::Type::GetFloatType();
else base = ir::Type::GetVoidType();
for (auto it = ty.dims.rbegin(); it != ty.dims.rend(); ++it) {
if (*it < 0) continue;
base = ir::Type::GetArrayType(base, static_cast<size_t>(*it));
}
return base;
}
std::shared_ptr<ir::Type> IRGenImpl::ToIRParamType(const TypeDesc& ty) {
if (ty.dims.empty()) return ToIRType(ty);
TypeDesc elem = ty;
if (!elem.dims.empty() && elem.dims.front() < 0) {
elem.dims.erase(elem.dims.begin());
}
return ir::Type::GetPointerType(ToIRType(elem));
}
ir::Value* IRGenImpl::DefaultValue(const TypeDesc& ty) {
if (ty.base == BaseTypeKind::Float) return builder_.CreateConstFloat(0.0f);
return builder_.CreateConstInt(0);
}
ir::AllocaInst* IRGenImpl::CreateEntryAlloca(std::shared_ptr<ir::Type> ty,
const std::string& name) {
if (!func_) {
throw std::runtime_error(FormatError("irgen", "CreateEntryAlloca 缺少函数"));
}
auto* entry = func_->GetEntry();
if (!entry) {
throw std::runtime_error(FormatError("irgen", "CreateEntryAlloca 缺少入口块"));
}
return entry->Prepend<ir::AllocaInst>(std::move(ty), name);
}
size_t IRGenImpl::ArrayStride(const TypeDesc& ty, size_t dim) const {
size_t stride = 1;
for (size_t i = dim + 1; i < ty.dims.size(); ++i) {
stride *= static_cast<size_t>(ty.dims[i]);
}
return stride;
}
size_t IRGenImpl::ArrayTotalSize(const TypeDesc& ty) const {
size_t total = 1;
for (int d : ty.dims) total *= static_cast<size_t>(d);
return total;
}
static size_t AlignIndex(size_t index, size_t align) {
if (align == 0) return index;
return (index + align - 1) / align * align;
}
size_t IRGenImpl::FillArrayValues(const TypeDesc& ty,
SysYParser::InitValContext* init,
std::vector<ir::Value*>& values, size_t base,
size_t idx, size_t dim) {
if (!init) return idx;
if (init->exp()) {
if (base + idx < values.size()) {
values[base + idx] = EvalExp(init->exp());
}
return idx + 1;
}
size_t sub_size = ArrayStride(ty, dim);
if (init->initVal().empty()) {
idx = AlignIndex(idx, sub_size);
return idx + sub_size;
}
for (auto* child : init->initVal()) {
if (!child) continue;
if (child->exp()) {
idx = FillArrayValues(ty, child, values, base, idx, dim);
} else {
size_t aligned = AlignIndex(idx, sub_size);
idx = aligned;
idx = FillArrayValues(ty, child, values, base, idx, dim + 1);
idx = aligned + sub_size;
}
}
idx = AlignIndex(idx, sub_size);
return idx;
}
size_t IRGenImpl::FillConstArrayValues(
const TypeDesc& ty, SysYParser::ConstInitValContext* init,
std::vector<ir::Value*>& values, size_t base, size_t idx, size_t dim) {
if (!init) return idx;
if (init->constExp()) {
if (base + idx < values.size()) {
values[base + idx] = AnyToValue(init->constExp()->accept(this));
}
return idx + 1;
}
size_t sub_size = ArrayStride(ty, dim);
if (init->constInitVal().empty()) {
idx = AlignIndex(idx, sub_size);
return idx + sub_size;
}
for (auto* child : init->constInitVal()) {
if (!child) continue;
if (child->constExp()) {
idx = FillConstArrayValues(ty, child, values, base, idx, dim);
} else {
size_t aligned = AlignIndex(idx, sub_size);
idx = aligned;
idx = FillConstArrayValues(ty, child, values, base, idx, dim + 1);
idx = aligned + sub_size;
}
}
idx = AlignIndex(idx, sub_size);
return idx;
}
void IRGenImpl::InitArray(ir::Value* base_ptr, const TypeDesc& ty,
SysYParser::InitValContext* init) {
size_t total = ArrayTotalSize(ty);
std::vector<ir::Value*> values(total, DefaultValue(ty));
FillArrayValues(ty, init, values, 0, 0, 0);
for (size_t idx = 0; idx < total; ++idx) {
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
size_t remain = idx;
for (size_t dim = 0; dim < ty.dims.size(); ++dim) {
size_t stride = ArrayStride(ty, dim);
size_t cur = remain / stride;
remain %= stride;
indices.push_back(builder_.CreateConstInt(static_cast<int>(cur)));
}
ir::Value* addr = builder_.CreateGep(base_ptr, std::move(indices),
module_.GetContext().NextTemp());
ir::Value* value = values[idx];
if (ty.base == BaseTypeKind::Float) {
if (value->IsInt1() || value->IsInt32()) {
value = CastToFloat(value->IsInt1() ? CastToInt(value) : value);
}
} else if (ty.base == BaseTypeKind::Int && value->IsFloat()) {
value = CastToInt(value);
} else if (ty.base == BaseTypeKind::Int && value->IsInt1()) {
value = CastToInt(value);
}
builder_.CreateStore(value, addr);
}
}
void IRGenImpl::InitConstArray(ir::Value* base_ptr, const TypeDesc& ty,
SysYParser::ConstInitValContext* init) {
size_t total = ArrayTotalSize(ty);
std::vector<ir::Value*> values(total, DefaultValue(ty));
FillConstArrayValues(ty, init, values, 0, 0, 0);
for (size_t idx = 0; idx < total; ++idx) {
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
size_t remain = idx;
for (size_t dim = 0; dim < ty.dims.size(); ++dim) {
size_t stride = ArrayStride(ty, dim);
size_t cur = remain / stride;
remain %= stride;
indices.push_back(builder_.CreateConstInt(static_cast<int>(cur)));
}
ir::Value* addr = builder_.CreateGep(base_ptr, std::move(indices),
module_.GetContext().NextTemp());
ir::Value* value = values[idx];
if (ty.base == BaseTypeKind::Float) {
if (value->IsInt1() || value->IsInt32()) {
value = CastToFloat(value->IsInt1() ? CastToInt(value) : value);
}
} else if (ty.base == BaseTypeKind::Int && value->IsFloat()) {
value = CastToInt(value);
} else if (ty.base == BaseTypeKind::Int && value->IsInt1()) {
value = CastToInt(value);
}
builder_.CreateStore(value, addr);
}
}
void IRGenImpl::PushLoop(ir::BasicBlock* break_bb, ir::BasicBlock* cont_bb) {
loop_stack_.push_back({break_bb, cont_bb});
}
void IRGenImpl::PopLoop() {
if (!loop_stack_.empty()) loop_stack_.pop_back();
}
ir::BasicBlock* IRGenImpl::CurrentBreak() const {
if (loop_stack_.empty()) return nullptr;
return loop_stack_.back().first;
}
ir::BasicBlock* IRGenImpl::CurrentContinue() const {
if (loop_stack_.empty()) return nullptr;
return loop_stack_.back().second;
}
namespace {
struct ConstNumber {
bool is_float = false;
double f = 0.0;
long long i = 0;
};
ConstNumber ToConstNumber(ir::ConstantValue* v) {
ConstNumber num;
if (auto* ci = dynamic_cast<ir::ConstantInt*>(v)) {
num.is_float = false;
num.i = ci->GetValue();
return num;
}
if (auto* cf = dynamic_cast<ir::ConstantFloat*>(v)) {
num.is_float = true;
num.f = cf->GetValue();
return num;
}
return num;
}
ConstNumber PromoteToFloat(const ConstNumber& v) {
if (v.is_float) return v;
ConstNumber n;
n.is_float = true;
n.f = static_cast<double>(v.i);
return n;
}
} // namespace
ir::ConstantValue* IRGenImpl::EvalConstScalar(SysYParser::ExpContext* ctx) {
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("irgen", "非法常量表达式"));
}
return EvalConstAdd(ctx->addExp());
}
ir::ConstantValue* IRGenImpl::EvalConstScalar(SysYParser::ConstExpContext* ctx) {
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("irgen", "非法常量表达式"));
}
return EvalConstAdd(ctx->addExp());
}
ir::ConstantValue* IRGenImpl::EvalConstAdd(SysYParser::AddExpContext* ctx) {
auto muls = ctx->mulExp();
if (muls.empty()) return module_.GetContext().GetConstInt(0);
ConstNumber lhs = ToConstNumber(EvalConstMul(muls[0]));
for (size_t i = 1; i < muls.size(); ++i) {
ConstNumber rhs = ToConstNumber(EvalConstMul(muls[i]));
auto* node = ctx->children.at(2 * i - 1);
std::string text = node ? node->getText() : "+";
if (lhs.is_float || rhs.is_float) {
lhs = PromoteToFloat(lhs);
rhs = PromoteToFloat(rhs);
lhs.f = (text == "+") ? lhs.f + rhs.f : lhs.f - rhs.f;
lhs.is_float = true;
} else {
lhs.i = (text == "+") ? lhs.i + rhs.i : lhs.i - rhs.i;
}
}
if (lhs.is_float) return module_.GetContext().GetConstFloat(static_cast<float>(lhs.f));
return module_.GetContext().GetConstInt(static_cast<int>(lhs.i));
}
ir::ConstantValue* IRGenImpl::EvalConstMul(SysYParser::MulExpContext* ctx) {
auto unaries = ctx->unaryExp();
if (unaries.empty()) return module_.GetContext().GetConstInt(0);
ConstNumber lhs = ToConstNumber(EvalConstUnary(unaries[0]));
for (size_t i = 1; i < unaries.size(); ++i) {
ConstNumber rhs = ToConstNumber(EvalConstUnary(unaries[i]));
auto* node = ctx->children.at(2 * i - 1);
std::string text = node ? node->getText() : "*";
if (text == "%") {
if (lhs.is_float || rhs.is_float) {
throw std::runtime_error(FormatError("irgen", "const % 仅支持整数"));
}
lhs.i = lhs.i % rhs.i;
continue;
}
if (lhs.is_float || rhs.is_float) {
lhs = PromoteToFloat(lhs);
rhs = PromoteToFloat(rhs);
if (text == "*") lhs.f = lhs.f * rhs.f;
else lhs.f = lhs.f / rhs.f;
lhs.is_float = true;
} else {
if (text == "*") lhs.i = lhs.i * rhs.i;
else lhs.i = lhs.i / rhs.i;
}
}
if (lhs.is_float) return module_.GetContext().GetConstFloat(static_cast<float>(lhs.f));
return module_.GetContext().GetConstInt(static_cast<int>(lhs.i));
}
ir::ConstantValue* IRGenImpl::EvalConstUnary(SysYParser::UnaryExpContext* ctx) {
if (ctx->primaryExp()) return EvalConstPrimary(ctx->primaryExp());
if (ctx->unaryOp() && ctx->unaryExp()) {
ConstNumber val = ToConstNumber(EvalConstUnary(ctx->unaryExp()));
std::string op = ctx->unaryOp()->getText();
if (op == "+") {
if (val.is_float) {
return module_.GetContext().GetConstFloat(static_cast<float>(val.f));
}
return module_.GetContext().GetConstInt(static_cast<int>(val.i));
}
if (op == "-") {
if (val.is_float) {
return module_.GetContext().GetConstFloat(static_cast<float>(-val.f));
}
return module_.GetContext().GetConstInt(static_cast<int>(-val.i));
}
if (op == "!") {
bool is_zero = val.is_float ? (val.f == 0.0) : (val.i == 0);
return module_.GetContext().GetConstInt(is_zero ? 1 : 0);
}
}
throw std::runtime_error(FormatError("irgen", "const 不支持函数调用"));
}
ir::ConstantValue* IRGenImpl::EvalConstPrimary(SysYParser::PrimaryExpContext* ctx) {
if (ctx->exp()) return EvalConstScalar(ctx->exp());
if (ctx->lVal()) return EvalConstLVal(ctx->lVal());
if (ctx->number()) return EvalConstNumber(ctx->number());
return module_.GetContext().GetConstInt(0);
}
ir::ConstantValue* IRGenImpl::EvalConstNumber(SysYParser::NumberContext* ctx) {
if (ctx->INT_CONST()) {
const std::string text = ctx->getText();
size_t idx = 0;
long long val = std::stoll(text, &idx, 0);
if (idx != text.size()) {
throw std::runtime_error(FormatError("irgen", "非法整数常量"));
}
return module_.GetContext().GetConstInt(static_cast<int>(val));
}
if (ctx->FLOAT_CONST()) {
return module_.GetContext().GetConstFloat(std::stof(ctx->getText()));
}
throw std::runtime_error(FormatError("irgen", "非法常量"));
}
ir::ConstantValue* IRGenImpl::EvalConstLVal(SysYParser::LValContext* ctx) {
BoundDecl bound = sema_.ResolveVarUse(ctx);
if (bound.kind == BoundDecl::Kind::Const && bound.const_decl) {
auto it = global_const_storage_.find(bound.const_decl);
if (it != global_const_storage_.end()) {
auto* init = it->second->GetInitializer();
if (init) return init;
}
}
throw std::runtime_error(FormatError("irgen", "constExp 使用了非常量"));
}
size_t IRGenImpl::InitGlobalArray(const TypeDesc& ty,
SysYParser::InitValContext* init,
std::vector<ir::ConstantValue*>& values,
size_t base, size_t idx, size_t dim) {
if (!init) return idx;
if (init->exp()) {
if (base + idx < values.size()) {
auto* v = EvalConstScalar(init->exp());
if (ty.base == BaseTypeKind::Int && dynamic_cast<ir::ConstantFloat*>(v)) {
auto* cf = static_cast<ir::ConstantFloat*>(v);
v = module_.GetContext().GetConstInt(static_cast<int>(cf->GetValue()));
} else if (ty.base == BaseTypeKind::Float &&
dynamic_cast<ir::ConstantInt*>(v)) {
auto* ci = static_cast<ir::ConstantInt*>(v);
v = module_.GetContext().GetConstFloat(static_cast<float>(ci->GetValue()));
}
values[base + idx] = v;
}
return idx + 1;
}
size_t sub_size = ArrayStride(ty, dim);
if (init->initVal().empty()) {
idx = AlignIndex(idx, sub_size);
return idx + sub_size;
}
for (auto* child : init->initVal()) {
if (!child) continue;
if (child->exp()) {
idx = InitGlobalArray(ty, child, values, base, idx, dim);
} else {
size_t aligned = AlignIndex(idx, sub_size);
idx = aligned;
idx = InitGlobalArray(ty, child, values, base, idx, dim + 1);
idx = aligned + sub_size;
}
}
idx = AlignIndex(idx, sub_size);
return idx;
}
size_t IRGenImpl::InitGlobalConstArray(const TypeDesc& ty,
SysYParser::ConstInitValContext* init,
std::vector<ir::ConstantValue*>& values,
size_t base, size_t idx, size_t dim) {
if (!init) return idx;
if (init->constExp()) {
if (base + idx < values.size()) {
auto* v = EvalConstScalar(init->constExp());
if (ty.base == BaseTypeKind::Int && dynamic_cast<ir::ConstantFloat*>(v)) {
auto* cf = static_cast<ir::ConstantFloat*>(v);
v = module_.GetContext().GetConstInt(static_cast<int>(cf->GetValue()));
} else if (ty.base == BaseTypeKind::Float &&
dynamic_cast<ir::ConstantInt*>(v)) {
auto* ci = static_cast<ir::ConstantInt*>(v);
v = module_.GetContext().GetConstFloat(static_cast<float>(ci->GetValue()));
}
values[base + idx] = v;
}
return idx + 1;
}
size_t sub_size = ArrayStride(ty, dim);
if (init->constInitVal().empty()) {
idx = AlignIndex(idx, sub_size);
return idx + sub_size;
}
for (auto* child : init->constInitVal()) {
if (!child) continue;
if (child->constExp()) {
idx = InitGlobalConstArray(ty, child, values, base, idx, dim);
} else {
size_t aligned = AlignIndex(idx, sub_size);
idx = aligned;
idx = InitGlobalConstArray(ty, child, values, base, idx, dim + 1);
idx = aligned + sub_size;
}
}
idx = AlignIndex(idx, sub_size);
return idx;
}