#include "irgen/IRGen.h" #include #include #include "SysYParser.h" #include "ir/IR.h" #include "sem/func.h" #include "utils/Log.h" // ─── 辅助 ───────────────────────────────────────────────────────────────────── // 把 i32 值转成 i1(icmp ne i32 v, 0) ir::Value* IRGenImpl::ToI1(ir::Value* v) { if (!v) throw std::runtime_error(FormatError("irgen", "ToI1: null value")); if (v->IsInt1()) return v; return builder_.CreateICmp(ir::ICmpPredicate::NE, v, builder_.CreateConstInt(0), module_.GetContext().NextTemp()); } // 把 i1 值零扩展为 i32 ir::Value* IRGenImpl::ToI32(ir::Value* v) { if (!v) throw std::runtime_error(FormatError("irgen", "ToI32: null value")); if (v->IsInt32()) return v; return builder_.CreateZExt(v, module_.GetContext().NextTemp()); } // 转换为 float(如果是 int) ir::Value* IRGenImpl::ToFloat(ir::Value* v) { if (!v) throw std::runtime_error(FormatError("irgen", "ToFloat: null value")); if (v->IsFloat32()) return v; if (v->IsInt32()) return builder_.CreateSIToFP(v, module_.GetContext().NextTemp()); if (v->IsInt1()) { auto* i32 = ToI32(v); return builder_.CreateSIToFP(i32, module_.GetContext().NextTemp()); } throw std::runtime_error(FormatError("irgen", "ToFloat: 不支持的类型")); } // 转换为 int(如果是 float) ir::Value* IRGenImpl::ToInt(ir::Value* v) { if (!v) throw std::runtime_error(FormatError("irgen", "ToInt: null value")); if (v->IsInt32()) return v; if (v->IsFloat32()) return builder_.CreateFPToSI(v, module_.GetContext().NextTemp()); if (v->IsInt1()) return ToI32(v); throw std::runtime_error(FormatError("irgen", "ToInt: 不支持的类型")); } // 隐式类型转换:确保两个操作数类型一致(int 转 float) void IRGenImpl::ImplicitConvert(ir::Value*& lhs, ir::Value*& rhs) { if (!lhs || !rhs) return; bool lhs_float = lhs->IsFloat32(); bool rhs_float = rhs->IsFloat32(); if (lhs_float && !rhs_float) { rhs = ToFloat(rhs); } else if (!lhs_float && rhs_float) { lhs = ToFloat(lhs); } } // 求值 exp(i32) ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { auto result = expr.accept(this); return std::any_cast(result); } // 求值 addExp(i32) ir::Value* IRGenImpl::EvalExprAdd(SysYParser::AddExpContext& expr) { auto result = expr.accept(this); return std::any_cast(result); } // 求值 cond(i1) ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) { auto result = cond.accept(this); auto* v = std::any_cast(result); return ToI1(v); } // 注册外部函数声明(幂等) void IRGenImpl::EnsureExternalDecl(const std::string& name) { if (module_.HasExternalDecl(name) || module_.GetFunction(name)) return; // SysY 标准运行库函数签名 if (name == "getint") { module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), {}); } else if (name == "getch") { module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), {}); } else if (name == "getfloat") { module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {}); // 近似 } else if (name == "putint") { module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {ir::Type::GetInt32Type()}); } else if (name == "putch") { module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {ir::Type::GetInt32Type()}); } else if (name == "putfloat") { module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {}); } else if (name == "putarray") { module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {ir::Type::GetInt32Type()}); } else if (name == "starttime" || name == "stoptime") { module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {ir::Type::GetInt32Type()}); } else { // 未知外部函数,按 i32 返回声明 module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), {}); } } // ─── 表达式访问器(返回 ir::Value*,i32) ───────────────────────────────────── std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) { if (!ctx || !ctx->addExp()) { throw std::runtime_error(FormatError("irgen", "非法表达式")); } return ctx->addExp()->accept(this); } // addExp : mulExp (AddOp mulExp)* std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法加减表达式")); } auto muls = ctx->mulExp(); if (muls.empty()) { throw std::runtime_error(FormatError("irgen", "addExp 缺少操作数")); } ir::Value* result = std::any_cast(muls[0]->accept(this)); auto ops = ctx->AddOp(); for (size_t i = 0; i < ops.size(); ++i) { ir::Value* rhs = std::any_cast(muls[i + 1]->accept(this)); ImplicitConvert(result, rhs); std::string tmp = module_.GetContext().NextTemp(); std::string op = ops[i]->getText(); if (result->IsFloat32()) { result = (op == "+") ? builder_.CreateFAdd(result, rhs, tmp) : builder_.CreateFSub(result, rhs, tmp); } else { result = (op == "+") ? builder_.CreateAdd(result, rhs, tmp) : builder_.CreateSub(result, rhs, tmp); } } return static_cast(result); } // mulExp : unaryExp (MulOp unaryExp)* std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法乘除表达式")); } auto unarys = ctx->unaryExp(); if (unarys.empty()) { throw std::runtime_error(FormatError("irgen", "mulExp 缺少操作数")); } ir::Value* result = std::any_cast(unarys[0]->accept(this)); auto ops = ctx->MulOp(); for (size_t i = 0; i < ops.size(); ++i) { ir::Value* rhs = std::any_cast(unarys[i + 1]->accept(this)); ImplicitConvert(result, rhs); std::string tmp = module_.GetContext().NextTemp(); std::string op = ops[i]->getText(); if (result->IsFloat32()) { if (op == "*") result = builder_.CreateFMul(result, rhs, tmp); else if (op == "/") result = builder_.CreateFDiv(result, rhs, tmp); else throw std::runtime_error(FormatError("irgen", "float 不支持取模")); } else { if (op == "*") result = builder_.CreateMul(result, rhs, tmp); else if (op == "/") result = builder_.CreateDiv(result, rhs, tmp); else result = builder_.CreateMod(result, rhs, tmp); } } return static_cast(result); } // unaryExp : primaryExp // | Ident L_PAREN (funcRParams)? R_PAREN // | unaryOp unaryExp std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法一元表达式")); } // ── 一元运算符 ───────────────────────────────────────────────────────────── if (ctx->unaryOp() && ctx->unaryExp()) { ir::Value* operand = std::any_cast(ctx->unaryExp()->accept(this)); std::string op = ctx->unaryOp()->getText(); if (op == "-") { if (operand->IsFloat32()) { return static_cast( builder_.CreateFSub(builder_.CreateConstFloat(0.0f), operand, module_.GetContext().NextTemp())); } else { return static_cast( builder_.CreateSub(builder_.CreateConstInt(0), operand, module_.GetContext().NextTemp())); } } else if (op == "+") { return static_cast(operand); } else if (op == "!") { ir::Value* cmp; if (operand->IsFloat32()) { cmp = builder_.CreateFCmp(ir::FCmpPredicate::OEQ, operand, builder_.CreateConstFloat(0.0f), module_.GetContext().NextTemp()); } else { operand = ToI32(operand); cmp = builder_.CreateICmp(ir::ICmpPredicate::EQ, operand, builder_.CreateConstInt(0), module_.GetContext().NextTemp()); } return static_cast(ToI32(cmp)); } throw std::runtime_error(FormatError("irgen", "不支持的一元运算符: " + op)); } // ── 函数调用 ────────────────────────────────────────────────────────────── if (ctx->Ident() && ctx->L_PAREN()) { std::string callee_name = ctx->Ident()->getText(); // 收集实参 std::vector args; if (ctx->funcRParams()) { for (auto* exp : ctx->funcRParams()->exp()) { args.push_back(EvalExpr(*exp)); } } // 模块内已知函数? ir::Function* callee = module_.GetFunction(callee_name); if (callee) { std::string ret_name = callee->IsVoidReturn() ? "" : module_.GetContext().NextTemp(); auto* call = builder_.CreateCall(callee, std::move(args), ret_name); return static_cast( callee->IsVoidReturn() ? static_cast( builder_.CreateConstInt(0)) : call); } // 外部函数 EnsureExternalDecl(callee_name); // 获取返回类型 std::shared_ptr ret_type = ir::Type::GetInt32Type(); for (const auto& decl : module_.GetExternalDecls()) { if (decl.name == callee_name) { ret_type = decl.ret_type; break; } } bool is_void = ret_type->IsVoid(); std::string ret_name = is_void ? "" : module_.GetContext().NextTemp(); auto* call = builder_.CreateCallExternal(callee_name, ret_type, std::move(args), ret_name); // void 调用返回 0 占位 return static_cast( is_void ? static_cast(builder_.CreateConstInt(0)) : call); } // ── primaryExp ──────────────────────────────────────────────────────────── if (ctx->primaryExp()) { return ctx->primaryExp()->accept(this); } throw std::runtime_error(FormatError("irgen", "非法一元表达式结构")); } // primaryExp : L_PAREN exp R_PAREN | lVar | number std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法基本表达式")); } if (ctx->exp()) { return ctx->exp()->accept(this); } if (ctx->lVar()) { return ctx->lVar()->accept(this); } if (ctx->number()) { return ctx->number()->accept(this); } throw std::runtime_error(FormatError("irgen", "primaryExp 结构非法")); } // EvalLVarAddr:计算 lVar 的地址(支持数组索引) ir::Value* IRGenImpl::EvalLVarAddr(SysYParser::LVarContext* ctx) { if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("irgen", "非法变量引用")); } auto* decl = sema_.ResolveVarUse(ctx->Ident()); if (!decl) { throw std::runtime_error( FormatError("irgen", "变量未绑定: " + ctx->Ident()->getText())); } // 查找存储槽位 ir::Value* base = nullptr; std::vector dims; auto it = storage_map_.find(decl); if (it != storage_map_.end()) { base = it->second; auto dit = array_dims_.find(decl); if (dit != array_dims_.end()) dims = dit->second; } else { auto git = global_storage_map_.find(decl); if (git == global_storage_map_.end()) { throw std::runtime_error( FormatError("irgen", "变量无存储槽位: " + ctx->Ident()->getText())); } base = git->second; auto gdit = global_array_dims_.find(decl); if (gdit != global_array_dims_.end()) dims = gdit->second; } // 无索引 → 返回基地址 if (ctx->exp().empty()) return base; // 有索引 → 计算扁平化偏移 auto indices = ctx->exp(); // 对于数组参数(第一维为-1),允许索引数等于维度数 bool is_array_param = !dims.empty() && dims[0] == -1; if (!is_array_param && indices.size() > dims.size()) { throw std::runtime_error(FormatError("irgen", "数组索引维度过多")); } ir::Value* offset = builder_.CreateConstInt(0); if (is_array_param) { // 数组参数:dims[0]=-1, dims[1..n]是已知维度 // 索引:indices[0]对应第一维,indices[1]对应第二维... for (size_t i = 0; i < indices.size(); ++i) { ir::Value* idx = EvalExpr(*indices[i]); if (i == 0) { // 第一维:stride = dims[1] * dims[2] * ... (如果有的话) int stride = 1; for (size_t j = 1; j < dims.size(); ++j) { stride *= dims[j]; } if (stride > 1) { ir::Value* scaled = builder_.CreateMul( idx, builder_.CreateConstInt(stride), module_.GetContext().NextTemp()); offset = builder_.CreateAdd(offset, scaled, module_.GetContext().NextTemp()); } else { offset = builder_.CreateAdd(offset, idx, module_.GetContext().NextTemp()); } } else { // 后续维度 int stride = 1; for (size_t j = i + 1; j < dims.size(); ++j) { stride *= dims[j]; } ir::Value* scaled = builder_.CreateMul( idx, builder_.CreateConstInt(stride), module_.GetContext().NextTemp()); offset = builder_.CreateAdd(offset, scaled, module_.GetContext().NextTemp()); } } } else { // 普通数组:从最后一维开始计算 int stride = 1; for (int i = (int)dims.size() - 1; i >= 0; --i) { stride = (i == (int)dims.size() - 1) ? 1 : stride * dims[i + 1]; if (i < (int)indices.size()) { ir::Value* idx = EvalExpr(*indices[i]); ir::Value* scaled = builder_.CreateMul( idx, builder_.CreateConstInt(stride), module_.GetContext().NextTemp()); offset = builder_.CreateAdd(offset, scaled, module_.GetContext().NextTemp()); } } } return builder_.CreateGep(base, offset, module_.GetContext().NextTemp()); } // lVar : Ident (L_BRAKT exp R_BRAKT)* // 在表达式语境下:load 变量值(返回 i32) std::any IRGenImpl::visitLVar(SysYParser::LVarContext* ctx) { if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("irgen", "非法变量引用")); } auto* decl = sema_.ResolveVarUse(ctx->Ident()); if (!decl) { throw std::runtime_error( FormatError("irgen", "变量未绑定: " + ctx->Ident()->getText())); } // 标量常量(ConstDefContext 且无索引) if (auto* const_def = dynamic_cast(decl)) { if (ctx->exp().empty()) { // 先查局部 auto it = storage_map_.find(const_def); if (it != storage_map_.end()) { if (auto* ci = dynamic_cast(it->second)) { return static_cast(ci); } if (auto* cf = dynamic_cast(it->second)) { return static_cast(cf); } } // 再查全局 auto git = global_storage_map_.find(const_def); if (git != global_storage_map_.end()) { if (auto* cf = dynamic_cast(git->second)) { return static_cast(cf); } } } } // 通用路径:计算地址并 load ir::Value* addr = EvalLVarAddr(ctx); return static_cast( builder_.CreateLoad(addr, module_.GetContext().NextTemp())); } // number : IntConst | FloatConst std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法数字")); } if (ctx->IntConst()) { std::string text = ctx->IntConst()->getText(); int val = 0; try { val = std::stoi(text, nullptr, 0); } catch (...) { throw std::runtime_error( FormatError("irgen", "整数字面量解析失败: " + text)); } return static_cast(builder_.CreateConstInt(val)); } if (ctx->FloatConst()) { std::string text = ctx->FloatConst()->getText(); float val = 0.0f; try { val = std::stof(text); } catch (...) { throw std::runtime_error( FormatError("irgen", "浮点字面量解析失败: " + text)); } return static_cast(builder_.CreateConstFloat(val)); } throw std::runtime_error(FormatError("irgen", "非法数字节点")); } // ─── 条件表达式访问器(返回 ir::Value*,i1) ────────────────────────────────── std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) { if (!ctx || !ctx->lOrExp()) { throw std::runtime_error(FormatError("irgen", "非法条件")); } return ctx->lOrExp()->accept(this); } // lOrExp : lAndExp ('||' lAndExp)* std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { if (!ctx) throw std::runtime_error(FormatError("irgen", "非法 lOrExp")); auto ands = ctx->lAndExp(); if (ands.empty()) throw std::runtime_error(FormatError("irgen", "lOrExp 空")); ir::Value* result = std::any_cast(ands[0]->accept(this)); result = ToI1(result); for (size_t i = 1; i < ands.size(); ++i) { // 检查当前块是否已终结 if (builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()) { break; } // 短路:result || rhs auto* res_slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); ir::Value* res_ext = ToI32(result); builder_.CreateStore(res_ext, res_slot); ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.rhs"); ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.end"); builder_.CreateCondBr(result, end_bb, rhs_bb); builder_.SetInsertPoint(rhs_bb); ir::Value* rhs = std::any_cast(ands[i]->accept(this)); rhs = ToI32(ToI1(rhs)); if (!builder_.GetInsertBlock()->HasTerminator()) { builder_.CreateStore(rhs, res_slot); builder_.CreateBr(end_bb); } builder_.SetInsertPoint(end_bb); result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp())); } return static_cast(result); } // lAndExp : eqExp ('&&' eqExp)* std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { if (!ctx) throw std::runtime_error(FormatError("irgen", "非法 lAndExp")); auto eqs = ctx->eqExp(); if (eqs.empty()) throw std::runtime_error(FormatError("irgen", "lAndExp 空")); ir::Value* result = std::any_cast(eqs[0]->accept(this)); result = ToI1(result); for (size_t i = 1; i < eqs.size(); ++i) { // 检查当前块是否已终结 if (builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()) { break; } auto* res_slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); ir::Value* res_ext = ToI32(result); builder_.CreateStore(res_ext, res_slot); ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.rhs"); ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.end"); builder_.CreateCondBr(result, rhs_bb, end_bb); builder_.SetInsertPoint(rhs_bb); ir::Value* rhs = std::any_cast(eqs[i]->accept(this)); rhs = ToI32(ToI1(rhs)); if (!builder_.GetInsertBlock()->HasTerminator()) { builder_.CreateStore(rhs, res_slot); builder_.CreateBr(end_bb); } builder_.SetInsertPoint(end_bb); result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp())); } return static_cast(result); } // eqExp : relExp (EqOp relExp)* std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { if (!ctx) throw std::runtime_error(FormatError("irgen", "非法 eqExp")); auto rels = ctx->relExp(); if (rels.empty()) throw std::runtime_error(FormatError("irgen", "eqExp 空")); ir::Value* result = std::any_cast(rels[0]->accept(this)); auto ops = ctx->EqOp(); for (size_t i = 0; i < ops.size(); ++i) { ir::Value* rhs = std::any_cast(rels[i + 1]->accept(this)); ir::Value* lhs = result; ImplicitConvert(lhs, rhs); std::string op = ops[i]->getText(); if (lhs->IsFloat32()) { ir::FCmpPredicate pred = (op == "==") ? ir::FCmpPredicate::OEQ : ir::FCmpPredicate::ONE; result = builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp()); } else { lhs = ToI32(lhs); rhs = ToI32(rhs); ir::ICmpPredicate pred = (op == "==") ? ir::ICmpPredicate::EQ : ir::ICmpPredicate::NE; result = builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp()); } } return static_cast(result); } // relExp : addExp (RelOp addExp)* std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { if (!ctx) throw std::runtime_error(FormatError("irgen", "非法 relExp")); auto adds = ctx->addExp(); if (adds.empty()) throw std::runtime_error(FormatError("irgen", "relExp 空")); ir::Value* result = std::any_cast(adds[0]->accept(this)); auto ops = ctx->RelOp(); for (size_t i = 0; i < ops.size(); ++i) { ir::Value* rhs = std::any_cast(adds[i + 1]->accept(this)); ir::Value* lhs = result; ImplicitConvert(lhs, rhs); std::string op = ops[i]->getText(); if (lhs->IsFloat32()) { ir::FCmpPredicate pred; if (op == "<") pred = ir::FCmpPredicate::OLT; else if (op == ">") pred = ir::FCmpPredicate::OGT; else if (op == "<=") pred = ir::FCmpPredicate::OLE; else pred = ir::FCmpPredicate::OGE; result = builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp()); } else { lhs = ToI32(lhs); rhs = ToI32(rhs); ir::ICmpPredicate pred; if (op == "<") pred = ir::ICmpPredicate::SLT; else if (op == ">") pred = ir::ICmpPredicate::SGT; else if (op == "<=") pred = ir::ICmpPredicate::SLE; else pred = ir::ICmpPredicate::SGE; result = builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp()); } } return static_cast(result); }