#include "irgen/IRGen.h" #include #include "SysYParser.h" #include "ir/IR.h" #include "utils/Log.h" ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { return std::any_cast(expr.accept(this)); } ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) { return std::any_cast(cond.accept(this)); } ir::Value* IRGenImpl::CastToFloat(ir::Value* v) { if (!v || !v->GetType()) { throw std::runtime_error(FormatError("irgen", "CastToFloat 输入为空")); } if (v->GetType()->IsFloat32()) return v; if (v->GetType()->IsInt32()) { return builder_.CreateSIToFP(v, module_.GetContext().NextTemp()); } throw std::runtime_error(FormatError("irgen", "不支持转换到 float 的类型")); } ir::Value* IRGenImpl::CastToInt(ir::Value* v) { if (!v || !v->GetType()) { throw std::runtime_error(FormatError("irgen", "CastToInt 输入为空")); } if (v->GetType()->IsInt32()) return v; if (v->GetType()->IsFloat32()) { return builder_.CreateFPToSI(v, module_.GetContext().NextTemp()); } throw std::runtime_error(FormatError("irgen", "不支持转换到 i32 的类型")); } ir::Value* IRGenImpl::ToBoolValue(ir::Value* v) { if (!v) { throw std::runtime_error(FormatError("irgen", "条件值为空")); } if (v->GetType() && (v->GetType()->IsPtrInt32() || v->GetType()->IsPtrFloat32())) { // SysY 中数组名退化得到的指针在当前实现里总是非空。 return builder_.CreateConstInt(1); } if (dynamic_cast(v) != nullptr) { return v; } ir::Value* zero = v->GetType()->IsFloat32() ? static_cast(module_.GetContext().GetConstFloat(0.0f)) : static_cast(builder_.CreateConstInt(0)); return builder_.CreateCmp(ir::CmpOp::Ne, v, zero, module_.GetContext().NextTemp()); } std::string IRGenImpl::NextBlockName() { std::string temp = module_.GetContext().NextTemp(); if (!temp.empty() && temp.front() == '%') { return "bb" + temp.substr(1); } return "bb" + temp; } // ─── 数组维度查找 ──────────────────────────────────────────────────────── const std::vector* IRGenImpl::FindArrayDims(const std::string& name) const { auto it = local_array_dims_.find(name); if (it != local_array_dims_.end()) return &it->second; // 局部同名标量(含形参/局部变量)应屏蔽全局数组维度信息。 if (named_storage_.find(name) != named_storage_.end()) return nullptr; auto git = global_array_dims_.find(name); if (git != global_array_dims_.end()) return &git->second; return nullptr; } // ─── 线性下标计算 ──────────────────────────────────────────────────────── // 给定维度 dims 和下标表达式列表,计算 linear = sum(subs[k] * stride[k])。 ir::Value* IRGenImpl::ComputeLinearIndex( const std::vector& dims, const std::vector& subs) { // 对于 dims=[d0,d1,...,dn-1],stride[k] = d_{k+1} * ... * d_{n-1} // 允许 dims[0] == -1(数组参数首维未知) ir::Value* linear = builder_.CreateConstInt(0); for (int k = 0; k < (int)subs.size() && k < (int)dims.size(); k++) { int stride = 1; for (int j = k + 1; j < (int)dims.size(); j++) stride *= dims[j]; ir::Value* idx = CastToInt(EvalExpr(*subs[k])); if (stride != 1) { auto* sv = builder_.CreateConstInt(stride); idx = builder_.CreateMul(idx, sv, module_.GetContext().NextTemp()); } linear = (stride == 1 && k == (int)subs.size() - 1 && dynamic_cast(linear) && static_cast(linear)->GetValue() == 0) ? idx : builder_.CreateAdd(linear, idx, module_.GetContext().NextTemp()); } return linear; } std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) { if (!ctx || !ctx->addExp()) { throw std::runtime_error(FormatError("irgen", "非法表达式")); } return ctx->addExp()->accept(this); } std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) { if (!ctx || !ctx->lOrExp()) { throw std::runtime_error(FormatError("irgen", "非法条件表达式")); } return ctx->lOrExp()->accept(this); } std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法基本表达式")); } if (ctx->exp()) { return EvalExpr(*ctx->exp()); } if (ctx->number()) { return ctx->number()->accept(this); } if (ctx->lValue()) { return ctx->lValue()->accept(this); } throw std::runtime_error(FormatError("irgen", "不支持的基本表达式")); } std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少数字字面量")); } // 浮点字面量 if (ctx->FLITERAL()) { const std::string text = ctx->getText(); float val = std::stof(text); return static_cast(module_.GetContext().GetConstFloat(val)); } // 整数字面量 if (!ctx->ILITERAL()) { throw std::runtime_error(FormatError("irgen", "当前仅支持整数和浮点字面量")); } // 支持十六进制和八进制字面量 const std::string text = ctx->getText(); int val = 0; if (text.size() >= 2 && text[0] == '0' && (text[1] == 'x' || text[1] == 'X')) { val = std::stoi(text, nullptr, 16); } else if (text.size() > 1 && text[0] == '0') { val = std::stoi(text, nullptr, 8); } else { val = std::stoi(text); } return static_cast(builder_.CreateConstInt(val)); } // ─── 变量存储槽位查找(含下标 GEP)──────────────────────────────────────── // 返回 i32* 指针: // - 无下标:直接返回 alloca/arg/globalvar 槽位 // - 有下标:计算线性偏移并生成 GEP 指令,返回元素指针 ir::Value* IRGenImpl::ResolveStorage(SysYParser::LValueContext* lvalue) { if (!lvalue || !lvalue->ID()) return nullptr; const std::string name = lvalue->ID()->getText(); // 获取基础槽位(三级查找) ir::Value* base = nullptr; // 1. sema binding(处理同名变量遮蔽) auto* decl = sema_.ResolveVarUse(lvalue); if (decl) { auto it = storage_map_.find(decl); if (it != storage_map_.end()) base = it->second; } if (!base) { auto it = named_storage_.find(name); if (it != named_storage_.end()) base = it->second; } if (!base) { auto git = global_storage_.find(name); if (git != global_storage_.end()) base = git->second; } if (!base) return nullptr; // 无下标:直接返回槽位 if (lvalue->exp().empty()) return base; // 有下标:计算线性 GEP const std::vector* dims = FindArrayDims(name); if (!dims) { throw std::runtime_error( FormatError("irgen", "未找到数组维度信息: " + name)); } ir::Value* linear = ComputeLinearIndex(*dims, lvalue->exp()); return builder_.CreateGep(base, linear, module_.GetContext().NextTemp()); } // ─── lValue 访问 ───────────────────────────────────────────────────────── std::any IRGenImpl::visitLValue(SysYParser::LValueContext* ctx) { if (!ctx || !ctx->ID()) { throw std::runtime_error(FormatError("irgen", "非法左值")); } const std::string name = ctx->ID()->getText(); if (ctx->exp().empty()) { auto itf = const_float_env_.find(name); if (itf != const_float_env_.end()) { return static_cast(module_.GetContext().GetConstFloat(itf->second)); } auto iti = const_env_.find(name); if (iti != const_env_.end()) { return static_cast(builder_.CreateConstInt(iti->second)); } // 无下标:标量读取 或 数组基址引用 ir::Value* slot = ResolveStorage(ctx); if (!slot) { throw std::runtime_error( FormatError("irgen", "变量未找到存储槽位: " + name)); } // 如果是数组名,返回基址指针(用于传参)。 // 全局数组需要先退化为首元素指针,避免直接把 [N x i32]* 传给 i32* 形参。 if (FindArrayDims(name) != nullptr) { if (auto* gv = dynamic_cast(slot); gv && gv->IsArray()) { return static_cast( builder_.CreateGep(slot, builder_.CreateConstInt(0), module_.GetContext().NextTemp())); } return static_cast(slot); } // 标量:加载值 return static_cast( builder_.CreateLoad(slot, module_.GetContext().NextTemp())); } // 有下标:GEP + load ir::Value* elem_ptr = ResolveStorage(ctx); if (!elem_ptr) { throw std::runtime_error( FormatError("irgen", "数组元素指针解析失败: " + name)); } const auto* dims = FindArrayDims(name); if (dims && ctx->exp().size() < dims->size()) { // 如 A[i](A 为二维数组)应退化为指针,用于实参传递。 return static_cast(elem_ptr); } return static_cast( builder_.CreateLoad(elem_ptr, module_.GetContext().NextTemp())); } std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法一元表达式")); } if (ctx->primaryExp()) { return ctx->primaryExp()->accept(this); } if (ctx->unaryOp() && ctx->unaryExp()) { ir::Value* v = std::any_cast(ctx->unaryExp()->accept(this)); if (ctx->unaryOp()->SUB()) { ir::Value* zero = v->GetType()->IsFloat32() ? static_cast(module_.GetContext().GetConstFloat(0.0f)) : static_cast(builder_.CreateConstInt(0)); return static_cast(builder_.CreateSub( zero, v, module_.GetContext().NextTemp())); } if (ctx->unaryOp()->ADD()) { return v; } if (ctx->unaryOp()->NOT()) { // !v ≡ (v == 0) ir::Value* zero = v->GetType()->IsFloat32() ? static_cast(module_.GetContext().GetConstFloat(0.0f)) : static_cast(builder_.CreateConstInt(0)); return static_cast(builder_.CreateCmp( ir::CmpOp::Eq, v, zero, module_.GetContext().NextTemp())); } throw std::runtime_error(FormatError("irgen", "未知一元运算符")); } if (ctx->ID()) { // 函数调用:ID '(' funcRParams? ')' const std::string callee_name = ctx->ID()->getText(); ir::Function* callee = module_.FindFunction(callee_name); if (!callee) { throw std::runtime_error( FormatError("irgen", "未定义的函数: " + callee_name)); } std::vector args; if (auto* rparams = ctx->funcRParams()) { const auto& param_types = callee->GetParamTypes(); size_t i = 0; for (auto* ep : rparams->exp()) { ir::Value* arg = EvalExpr(*ep); if (i < param_types.size()) { if (param_types[i]->IsFloat32()) { arg = CastToFloat(arg); } else if (param_types[i]->IsInt32()) { arg = CastToInt(arg); } } args.push_back(arg); ++i; } } const std::string name = callee->GetType()->IsVoid() ? "" : module_.GetContext().NextTemp(); return static_cast( builder_.CreateCall(callee, args, name)); } throw std::runtime_error(FormatError("irgen", "非法一元表达式")); } std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); } if (ctx->mulExp()) { if (!ctx->unaryExp()) { throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); } ir::Value* lhs = std::any_cast(ctx->mulExp()->accept(this)); ir::Value* rhs = std::any_cast(ctx->unaryExp()->accept(this)); const bool has_float = lhs->GetType()->IsFloat32() || rhs->GetType()->IsFloat32(); if (has_float) { lhs = CastToFloat(lhs); rhs = CastToFloat(rhs); } if (ctx->MUL()) { return static_cast( builder_.CreateMul(lhs, rhs, module_.GetContext().NextTemp())); } if (ctx->DIV()) { return static_cast( builder_.CreateDiv(lhs, rhs, module_.GetContext().NextTemp())); } if (ctx->MOD()) { lhs = CastToInt(lhs); rhs = CastToInt(rhs); return static_cast( builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp())); } throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); } if (ctx->unaryExp()) { return ctx->unaryExp()->accept(this); } throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); } std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法加法表达式")); } if (ctx->addExp()) { if (!ctx->mulExp()) { throw std::runtime_error(FormatError("irgen", "非法加法表达式")); } ir::Value* lhs = std::any_cast(ctx->addExp()->accept(this)); ir::Value* rhs = std::any_cast(ctx->mulExp()->accept(this)); if (lhs->GetType()->IsFloat32() || rhs->GetType()->IsFloat32()) { lhs = CastToFloat(lhs); rhs = CastToFloat(rhs); } if (ctx->ADD()) { return static_cast( builder_.CreateAdd(lhs, rhs, module_.GetContext().NextTemp())); } if (ctx->SUB()) { return static_cast( builder_.CreateSub(lhs, rhs, module_.GetContext().NextTemp())); } throw std::runtime_error(FormatError("irgen", "非法加法表达式")); } if (ctx->mulExp()) { return ctx->mulExp()->accept(this); } throw std::runtime_error(FormatError("irgen", "非法加法表达式")); } std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法关系表达式")); } if (ctx->relExp()) { if (!ctx->addExp()) { throw std::runtime_error(FormatError("irgen", "非法关系表达式")); } ir::Value* lhs = std::any_cast(ctx->relExp()->accept(this)); ir::Value* rhs = std::any_cast(ctx->addExp()->accept(this)); if (lhs->GetType()->IsFloat32() || rhs->GetType()->IsFloat32()) { lhs = CastToFloat(lhs); rhs = CastToFloat(rhs); } if (ctx->LT()) { return static_cast(builder_.CreateCmp( ir::CmpOp::Lt, lhs, rhs, module_.GetContext().NextTemp())); } if (ctx->LE()) { return static_cast(builder_.CreateCmp( ir::CmpOp::Le, lhs, rhs, module_.GetContext().NextTemp())); } if (ctx->GT()) { return static_cast(builder_.CreateCmp( ir::CmpOp::Gt, lhs, rhs, module_.GetContext().NextTemp())); } if (ctx->GE()) { return static_cast(builder_.CreateCmp( ir::CmpOp::Ge, lhs, rhs, module_.GetContext().NextTemp())); } throw std::runtime_error(FormatError("irgen", "非法关系表达式")); } if (ctx->addExp()) { return ctx->addExp()->accept(this); } throw std::runtime_error(FormatError("irgen", "非法关系表达式")); } std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法相等表达式")); } if (ctx->eqExp()) { if (!ctx->relExp()) { throw std::runtime_error(FormatError("irgen", "非法相等表达式")); } ir::Value* lhs = std::any_cast(ctx->eqExp()->accept(this)); ir::Value* rhs = std::any_cast(ctx->relExp()->accept(this)); if (lhs->GetType()->IsFloat32() || rhs->GetType()->IsFloat32()) { lhs = CastToFloat(lhs); rhs = CastToFloat(rhs); } if (ctx->EQ()) { return static_cast(builder_.CreateCmp( ir::CmpOp::Eq, lhs, rhs, module_.GetContext().NextTemp())); } if (ctx->NE()) { return static_cast(builder_.CreateCmp( ir::CmpOp::Ne, lhs, rhs, module_.GetContext().NextTemp())); } throw std::runtime_error(FormatError("irgen", "非法相等表达式")); } if (ctx->relExp()) { return ctx->relExp()->accept(this); } throw std::runtime_error(FormatError("irgen", "非法相等表达式")); } std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); } if (ctx->lAndExp()) { if (!ctx->eqExp()) { throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); } // 短路求值:a && b // 使用函数级临时槽位(0=false,1=true),避免 phi 依赖和循环内动态 alloca。 if (!short_circuit_slot_) { throw std::runtime_error(FormatError("irgen", "短路求值槽位未初始化")); } auto* slot = short_circuit_slot_; builder_.CreateStore(builder_.CreateConstInt(0), slot); auto* lhs = std::any_cast(ctx->lAndExp()->accept(this)); auto* lhs_bool = ToBoolValue(lhs); auto* rhs_bb = func_->CreateBlock(NextBlockName()); auto* true_bb = func_->CreateBlock(NextBlockName()); auto* merge_bb = func_->CreateBlock(NextBlockName()); builder_.CreateCondBr(lhs_bool, rhs_bb, merge_bb); builder_.SetInsertPoint(rhs_bb); auto* rhs = std::any_cast(ctx->eqExp()->accept(this)); auto* rhs_bool = ToBoolValue(rhs); builder_.CreateCondBr(rhs_bool, true_bb, merge_bb); builder_.SetInsertPoint(true_bb); builder_.CreateStore(builder_.CreateConstInt(1), slot); builder_.CreateBr(merge_bb); builder_.SetInsertPoint(merge_bb); return static_cast( builder_.CreateLoad(slot, module_.GetContext().NextTemp())); } if (ctx->eqExp()) { return ToBoolValue(std::any_cast(ctx->eqExp()->accept(this))); } throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); } std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); } if (ctx->lOrExp()) { if (!ctx->lAndExp()) { throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); } // 短路求值:a || b if (!short_circuit_slot_) { throw std::runtime_error(FormatError("irgen", "短路求值槽位未初始化")); } auto* slot = short_circuit_slot_; builder_.CreateStore(builder_.CreateConstInt(0), slot); auto* lhs = std::any_cast(ctx->lOrExp()->accept(this)); auto* lhs_bool = ToBoolValue(lhs); auto* true_bb = func_->CreateBlock(NextBlockName()); auto* rhs_bb = func_->CreateBlock(NextBlockName()); auto* merge_bb = func_->CreateBlock(NextBlockName()); builder_.CreateCondBr(lhs_bool, true_bb, rhs_bb); builder_.SetInsertPoint(true_bb); builder_.CreateStore(builder_.CreateConstInt(1), slot); builder_.CreateBr(merge_bb); builder_.SetInsertPoint(rhs_bb); auto* rhs = std::any_cast(ctx->lAndExp()->accept(this)); auto* rhs_bool = ToBoolValue(rhs); auto* true2_bb = func_->CreateBlock(NextBlockName()); builder_.CreateCondBr(rhs_bool, true2_bb, merge_bb); builder_.SetInsertPoint(true2_bb); builder_.CreateStore(builder_.CreateConstInt(1), slot); builder_.CreateBr(merge_bb); builder_.SetInsertPoint(merge_bb); return static_cast( builder_.CreateLoad(slot, module_.GetContext().NextTemp())); } if (ctx->lAndExp()) { return ToBoolValue(std::any_cast(ctx->lAndExp()->accept(this))); } throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); }