diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index cf4797c..f20609e 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -70,11 +70,53 @@ std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) { std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) { if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { - throw std::runtime_error(FormatError("irgen", "非法加法表达式")); + throw std::runtime_error(FormatError("irgen", "非法加减法表达式")); } ir::Value* lhs = EvalExpr(*ctx->exp(0)); ir::Value* rhs = EvalExpr(*ctx->exp(1)); + + ir::Opcode op = ir::Opcode::Add; + if (ctx->ADD()) { + op = ir::Opcode::Add; + } else if (ctx->SUB()) { + op = ir::Opcode::Sub; + } else { + throw std::runtime_error(FormatError("irgen", "未知的加减运算符")); + } + return static_cast( - builder_.CreateBinary(ir::Opcode::Add, lhs, rhs, - module_.GetContext().NextTemp())); + builder_.CreateBinary(op, lhs, rhs, module_.GetContext().NextTemp())); } + +std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法乘除法表达式")); + } + + // 如果是 unaryExp 直接返回 + if (ctx->unaryExp()) { + return ctx->unaryExp()->accept(this); + } + + // 处理 MulExp op unaryExp 的递归形式 + if (!ctx->exp(0) || !ctx->unaryExp(0)) { + throw std::runtime_error(FormatError("irgen", "非法乘除法表达式结构")); + } + + ir::Value* lhs = std::any_cast(ctx->exp(0)->accept(this)); + ir::Value* rhs = std::any_cast(ctx->unaryExp(0)->accept(this)); + + ir::Opcode op = ir::Opcode::Mul; + if (ctx->MUL()) { + op = ir::Opcode::Mul; + } else if (ctx->DIV()) { + op = ir::Opcode::Div; + } else if (ctx->MOD()) { + op = ir::Opcode::Mod; + } else { + throw std::runtime_error(FormatError("irgen", "未知的乘除运算符")); + } + + return static_cast( + builder_.CreateBinary(op, lhs, rhs, module_.GetContext().NextTemp())); +} \ No newline at end of file