#include "irgen/IRGen.h" #include #include "SysYParser.h" #include "ir/IR.h" #include "utils/Log.h" // 表达式生成当前也只实现了很小的一个子集。 // 目前支持: // - 整数字面量 // - 普通局部变量读取 // - 括号表达式 // - 二元加法 // // 还未支持: // - 减乘除与一元运算 // - 赋值表达式 // - 函数调用 // - 数组、指针、下标访问 // - 条件与比较表达式 // - ... ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { return std::any_cast(expr.accept(this)); } std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法基本表达式")); } // 处理括号表达式:LPAREN exp RPAREN if (ctx->exp()) { return EvalExpr(*ctx->exp()); } // 处理 lVal(变量使用)- 交给 visitLVal 处理 if (ctx->lVal()) { // 直接在这里处理变量读取,避免 accept 调用可能导致的问题 auto* lval_ctx = ctx->lVal(); if (!lval_ctx || !lval_ctx->ID()) { throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); } const auto* decl = sema_.ResolveObjectUse(lval_ctx); if (!decl) { throw std::runtime_error( FormatError("irgen", "变量使用缺少语义绑定:" + lval_ctx->ID()->getText())); } std::string var_name = lval_ctx->ID()->getText(); ir::Value* slot = FindStorage(var_name); if (!slot) { throw std::runtime_error( FormatError("irgen", "变量声明缺少存储槽位:" + var_name)); } return static_cast( builder_.CreateLoad(slot, module_.GetContext().NextTemp())); } // 处理 number if (ctx->number()) { return ctx->number()->accept(this); } throw std::runtime_error(FormatError("irgen", "不支持的基本表达式类型")); } std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少字面量节点")); } if (ctx->intConst()) { // 可能是 0x, 0X, 0 开头的八进制等,目前 std::stoi 会处理十进制, // 为了支持 16 进制/8 进制建议使用 std::stoi(str, nullptr, 0) std::string text = ctx->intConst()->getText(); return static_cast( builder_.CreateConstInt(std::stoi(text, nullptr, 0))); } else if (ctx->floatConst()) { std::string text = ctx->floatConst()->getText(); return static_cast( module_.GetContext().GetConstFloat(std::stof(text))); } throw std::runtime_error(FormatError("irgen", "不支持的字面量")); } // 变量使用的处理流程: // 1. 先通过语义分析结果把变量使用绑定回声明; // 2. 再通过 storage_map_ 找到该声明对应的栈槽位; // 3. 最后生成 load,把内存中的值读出来。 // // 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。 std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { if (!ctx || !ctx->ID()) { throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); } std::string var_name = ctx->ID()->getText(); // 优先检查是否为已记录的常量,如果是则直接返回常量值,不再生成 Load 指令 ir::ConstantValue* const_val = FindConst(var_name); if (const_val) { return static_cast(const_val); } const auto* decl = sema_.ResolveObjectUse(ctx); if (!decl) { throw std::runtime_error( FormatError("irgen", "变量使用缺少语义绑定:" + ctx->ID()->getText())); } // 使用变量名查找存储槽位 ir::Value* slot = FindStorage(var_name); if (!slot) { throw std::runtime_error( FormatError("irgen", "变量声明缺少存储槽位:" + var_name)); } return static_cast( builder_.CreateLoad(slot, module_.GetContext().NextTemp())); } std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法加减法表达式")); } // 如果是 mulExp 直接返回(addExp : mulExp) if (ctx->mulExp() && ctx->addExp() == nullptr) { return ctx->mulExp()->accept(this); } // 处理 addExp op mulExp 的递归形式 if (!ctx->addExp() || !ctx->mulExp()) { throw std::runtime_error(FormatError("irgen", "非法加减法表达式结构")); } ir::Value* lhs = std::any_cast(ctx->addExp()->accept(this)); ir::Value* rhs = std::any_cast(ctx->mulExp()->accept(this)); ir::Opcode op = ir::Opcode::Add; if (ctx->ADD()) { op = ir::Opcode::Add; } else if (ctx->SUB()) { op = ir::Opcode::Sub; } else { throw std::runtime_error(FormatError("irgen", "未知的加减运算符")); } return static_cast( builder_.CreateBinary(op, lhs, rhs, module_.GetContext().NextTemp())); } std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法一元表达式")); } // 如果是 primaryExp 直接返回(unaryExp : primaryExp) if (ctx->primaryExp()) { return ctx->primaryExp()->accept(this); } // 处理函数调用(unaryExp : ID LPAREN funcRParams? RPAREN) if (ctx->ID()) { std::string func_name = ctx->ID()->getText(); // 从 Sema 或 Module 中查找函数 // 目前简化处理,直接从 Module 中查找(如果是当前文件内定义的) // 或者依赖 Sema 给出解析结果 const FunctionBinding* func_binding = sema_.ResolveFunctionCall(ctx); if (!func_binding) { throw std::runtime_error(FormatError("irgen", "未找到函数声明:" + func_name)); } // 假设 func_binding 能够找到对应的 ir::Function* // 这里如果 sema 不提供直接拿 ir::Function 的方式,需要遍历 module_.GetFunctions() 查找 ir::Function* target_func = nullptr; for (const auto& f : module_.GetFunctions()) { if (f->GetName() == func_name) { target_func = f.get(); break; } } if (!target_func) { // 可能是外部函数如 putint, getint 等 // 如果没有在 module_ 中,则需要创建一个只有声明的 Function std::shared_ptr ret_ty; if (func_binding->return_type == SemanticType::Int) { ret_ty = ir::Type::GetInt32Type(); } else if (func_binding->return_type == SemanticType::Float) { ret_ty = ir::Type::GetFloatType(); } else { ret_ty = ir::Type::GetVoidType(); } target_func = module_.CreateFunction(func_name, ret_ty); // 对于外部函数,如果需要传递参数,可能还需要在 target_func 中 AddArgument,不过外部声明通常不需要形参块。 } std::vector args; if (ctx->funcRParams()) { args = std::any_cast>(ctx->funcRParams()->accept(this)); } return static_cast(builder_.CreateCall(target_func, args, module_.GetContext().NextTemp())); } // 处理一元运算符(unaryExp : addUnaryOp unaryExp) if (ctx->addUnaryOp() && ctx->unaryExp()) { ir::Value* operand = std::any_cast(ctx->unaryExp()->accept(this)); // 判断是正号还是负号 if (ctx->addUnaryOp()->SUB()) { // 负号:如果是整数生成 sub 0, operand,浮点数生成 fsub 0.0, operand if (operand->GetType()->IsFloat()) { ir::Value* zero = module_.GetContext().GetConstFloat(0.0f); // 此处暂且假设 CreateSub 可以处理浮点数(如果底层有 fsub 则更好) return static_cast( builder_.CreateSub(zero, operand, module_.GetContext().NextTemp())); } else { ir::Value* zero = builder_.CreateConstInt(0); return static_cast( builder_.CreateSub(zero, operand, module_.GetContext().NextTemp())); } } else if (ctx->addUnaryOp()->ADD()) { // 正号:直接返回操作数(+x 等价于 x) return operand; } else { throw std::runtime_error(FormatError("irgen", "未知的一元运算符")); } } throw std::runtime_error(FormatError("irgen", "不支持的一元表达式类型")); } std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法乘除法表达式")); } // 如果是 unaryExp 直接返回(mulExp : unaryExp) if (ctx->unaryExp() && ctx->mulExp() == nullptr) { return ctx->unaryExp()->accept(this); } // 处理 mulExp op unaryExp 的递归形式 if (!ctx->mulExp() || !ctx->unaryExp()) { throw std::runtime_error(FormatError("irgen", "非法乘除法表达式结构")); } ir::Value* lhs = std::any_cast(ctx->mulExp()->accept(this)); ir::Value* rhs = std::any_cast(ctx->unaryExp()->accept(this)); ir::Opcode op = ir::Opcode::Mul; if (ctx->MUL()) { op = ir::Opcode::Mul; } else if (ctx->DIV()) { op = ir::Opcode::Div; } else if (ctx->MOD()) { op = ir::Opcode::Mod; } else { throw std::runtime_error(FormatError("irgen", "未知的乘除运算符")); } return static_cast( builder_.CreateBinary(op, lhs, rhs, module_.GetContext().NextTemp())); } std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { if (ctx->addExp() && ctx->relExp() == nullptr) { return ctx->addExp()->accept(this); } ir::Value* lhs = std::any_cast(ctx->relExp()->accept(this)); ir::Value* rhs = std::any_cast(ctx->addExp()->accept(this)); if (lhs->GetType()->IsInt1()) lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp()); if (rhs->GetType()->IsInt1()) rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp()); ir::CmpOp op; if (ctx->LT()) op = ir::CmpOp::Lt; else if (ctx->GT()) op = ir::CmpOp::Gt; else if (ctx->LE()) op = ir::CmpOp::Le; else if (ctx->GE()) op = ir::CmpOp::Ge; else throw std::runtime_error(FormatError("irgen", "未知的关系运算符")); return static_cast(builder_.CreateCmp(op, lhs, rhs, module_.GetContext().NextTemp())); } std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { if (ctx->relExp() && ctx->eqExp() == nullptr) { return ctx->relExp()->accept(this); } ir::Value* lhs = std::any_cast(ctx->eqExp()->accept(this)); ir::Value* rhs = std::any_cast(ctx->relExp()->accept(this)); if (lhs->GetType()->IsInt1()) lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp()); if (rhs->GetType()->IsInt1()) rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp()); ir::CmpOp op; if (ctx->EQ()) op = ir::CmpOp::Eq; else if (ctx->NE()) op = ir::CmpOp::Ne; else throw std::runtime_error(FormatError("irgen", "未知的相等运算符")); return static_cast(builder_.CreateCmp(op, lhs, rhs, module_.GetContext().NextTemp())); } std::any IRGenImpl::visitCondUnaryExp(SysYParser::CondUnaryExpContext* ctx) { if (ctx->eqExp()) { return ctx->eqExp()->accept(this); } if (ctx->NOT()) { ir::Value* operand = std::any_cast(ctx->condUnaryExp()->accept(this)); if (operand->GetType()->IsInt1()) { operand = builder_.CreateZext(operand, module_.GetContext().NextTemp()); } ir::Value* zero = builder_.CreateConstInt(0); return static_cast(builder_.CreateCmp(ir::CmpOp::Eq, operand, zero, module_.GetContext().NextTemp())); } throw std::runtime_error(FormatError("irgen", "非法条件一元表达式")); } std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { if (ctx->condUnaryExp() && ctx->lAndExp() == nullptr) { return ctx->condUnaryExp()->accept(this); } ir::AllocaInst* res_ptr = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); ir::Value* zero = builder_.CreateConstInt(0); builder_.CreateStore(zero, res_ptr); ir::BasicBlock* rhs_bb = func_->CreateBlock(NextBlockName("land_rhs")); ir::BasicBlock* end_bb = func_->CreateBlock(NextBlockName("land_end")); ir::Value* lhs = std::any_cast(ctx->lAndExp()->accept(this)); if (lhs->GetType()->IsInt1()) { lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp()); } ir::Value* lhs_cond = builder_.CreateCmp(ir::CmpOp::Ne, lhs, zero, module_.GetContext().NextTemp()); builder_.CreateCondBr(lhs_cond, rhs_bb, end_bb); builder_.SetInsertPoint(rhs_bb); ir::Value* rhs = std::any_cast(ctx->condUnaryExp()->accept(this)); if (rhs->GetType()->IsInt1()) { rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp()); } ir::Value* rhs_cond = builder_.CreateCmp(ir::CmpOp::Ne, rhs, zero, module_.GetContext().NextTemp()); ir::Value* rhs_res = builder_.CreateZext(rhs_cond, module_.GetContext().NextTemp()); builder_.CreateStore(rhs_res, res_ptr); builder_.CreateBr(end_bb); builder_.SetInsertPoint(end_bb); return static_cast(builder_.CreateLoad(res_ptr, module_.GetContext().NextTemp())); } std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { if (ctx->lAndExp() && ctx->lOrExp() == nullptr) { return ctx->lAndExp()->accept(this); } ir::AllocaInst* res_ptr = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); ir::Value* one = builder_.CreateConstInt(1); builder_.CreateStore(one, res_ptr); ir::BasicBlock* rhs_bb = func_->CreateBlock(NextBlockName("lor_rhs")); ir::BasicBlock* end_bb = func_->CreateBlock(NextBlockName("lor_end")); ir::Value* lhs = std::any_cast(ctx->lOrExp()->accept(this)); ir::Value* zero = builder_.CreateConstInt(0); if (lhs->GetType()->IsInt1()) { lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp()); } ir::Value* lhs_cond = builder_.CreateCmp(ir::CmpOp::Eq, lhs, zero, module_.GetContext().NextTemp()); builder_.CreateCondBr(lhs_cond, rhs_bb, end_bb); builder_.SetInsertPoint(rhs_bb); ir::Value* rhs = std::any_cast(ctx->lAndExp()->accept(this)); if (rhs->GetType()->IsInt1()) { rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp()); } ir::Value* rhs_cond = builder_.CreateCmp(ir::CmpOp::Ne, rhs, zero, module_.GetContext().NextTemp()); ir::Value* rhs_res = builder_.CreateZext(rhs_cond, module_.GetContext().NextTemp()); builder_.CreateStore(rhs_res, res_ptr); builder_.CreateBr(end_bb); builder_.SetInsertPoint(end_bb); return static_cast(builder_.CreateLoad(res_ptr, module_.GetContext().NextTemp())); } std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) { if (!ctx || !ctx->lOrExp()) { throw std::runtime_error(FormatError("irgen", "非法条件表达式")); } return ctx->lOrExp()->accept(this); } std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) { std::vector args; for (auto* exp : ctx->exp()) { args.push_back(EvalExpr(*exp)); } return args; }