#include "include/irgen/IRGen.h" #include #include "SysYParser.h" #include "include/ir/IR.h" #include "include/utils/Log.h" std::string IRGenImpl::NextBlockName(const std::string& prefix) { return prefix + std::to_string(++block_index_); } void IRGenImpl::EmitCondBranch(SysYParser::CondContext& cond, ir::BasicBlock* true_bb, ir::BasicBlock* false_bb) { if (!cond.lOrExp()) { throw std::runtime_error(FormatError("irgen", "非法条件表达式")); } EmitLOrBranch(*cond.lOrExp(), true_bb, false_bb); } void IRGenImpl::EmitLOrBranch(SysYParser::LOrExpContext& expr, ir::BasicBlock* true_bb, ir::BasicBlock* false_bb) { if (!expr.lOrExp()) { if (!expr.lAndExp()) { throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); } EmitLAndBranch(*expr.lAndExp(), true_bb, false_bb); return; } if (!expr.lAndExp()) { throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); } auto* rhs_bb = func_->CreateBlock(NextBlockName("lor.rhs.")); EmitLOrBranch(*expr.lOrExp(), true_bb, rhs_bb); builder_.SetInsertPoint(rhs_bb); EmitLAndBranch(*expr.lAndExp(), true_bb, false_bb); } void IRGenImpl::EmitLAndBranch(SysYParser::LAndExpContext& expr, ir::BasicBlock* true_bb, ir::BasicBlock* false_bb) { if (!expr.lAndExp()) { if (!expr.eqExp()) { throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); } EmitEqBranch(*expr.eqExp(), true_bb, false_bb); return; } if (!expr.eqExp()) { throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); } auto* rhs_bb = func_->CreateBlock(NextBlockName("land.rhs.")); EmitLAndBranch(*expr.lAndExp(), rhs_bb, false_bb); builder_.SetInsertPoint(rhs_bb); EmitEqBranch(*expr.eqExp(), true_bb, false_bb); } ir::Value* IRGenImpl::EvalRelValue(SysYParser::RelExpContext& expr) { if (!expr.relExp()) { if (!expr.addExp()) { throw std::runtime_error(FormatError("irgen", "非法关系表达式")); } return std::any_cast(expr.addExp()->accept(this)); } if (!expr.addExp()) { throw std::runtime_error(FormatError("irgen", "非法关系表达式")); } auto* lhs = EvalRelValue(*expr.relExp()); auto* rhs = std::any_cast(expr.addExp()->accept(this)); ir::Opcode cmp = ir::Opcode::Lt; if (expr.LtOp()) { cmp = ir::Opcode::Lt; } else if (expr.GtOp()) { cmp = ir::Opcode::Gt; } else if (expr.LeOp()) { cmp = ir::Opcode::Le; } else if (expr.GeOp()) { cmp = ir::Opcode::Ge; } else { throw std::runtime_error(FormatError("irgen", "非法关系表达式")); } return EvalBinaryOrFold(cmp, lhs, rhs); } ir::Value* IRGenImpl::EvalEqValue(SysYParser::EqExpContext& expr) { if (!expr.eqExp()) { if (!expr.relExp()) { throw std::runtime_error(FormatError("irgen", "非法相等表达式")); } return EvalRelValue(*expr.relExp()); } if (!expr.relExp()) { throw std::runtime_error(FormatError("irgen", "非法相等表达式")); } auto* lhs = EvalEqValue(*expr.eqExp()); auto* rhs = EvalRelValue(*expr.relExp()); ir::Opcode cmp = ir::Opcode::Eq; if (expr.EqOp()) { cmp = ir::Opcode::Eq; } else if (expr.NeOp()) { cmp = ir::Opcode::Ne; } else { throw std::runtime_error(FormatError("irgen", "非法相等表达式")); } return EvalBinaryOrFold(cmp, lhs, rhs); } void IRGenImpl::EmitEqBranch(SysYParser::EqExpContext& expr, ir::BasicBlock* true_bb, ir::BasicBlock* false_bb) { auto* cond_value = EvalEqValue(expr); ir::Value* zero = cond_value->GetType()->IsFloat32() ? static_cast(builder_.CreateConstFloat(0.0)) : static_cast(builder_.CreateConstInt(0)); auto* cond = builder_.CreateBinary(ir::Opcode::Ne, cond_value, zero, module_.GetContext().NextTemp()); builder_.CreateCondBr(cond, true_bb, false_bb); } void IRGenImpl::EmitRelBranch(SysYParser::RelExpContext& expr, ir::BasicBlock* true_bb, ir::BasicBlock* false_bb) { if (!expr.relExp()) { if (!expr.addExp()) { throw std::runtime_error(FormatError("irgen", "非法关系表达式")); } auto* value = std::any_cast(expr.addExp()->accept(this)); auto* zero = builder_.CreateConstInt(0); auto* cond = builder_.CreateBinary(ir::Opcode::Ne, value, zero, module_.GetContext().NextTemp()); builder_.CreateCondBr(cond, true_bb, false_bb); return; } if (!expr.addExp() || !expr.relExp()->addExp() || expr.relExp()->relExp()) { throw std::runtime_error( FormatError("irgen", "当前不支持链式关系比较表达式")); } auto* lhs = std::any_cast(expr.relExp()->addExp()->accept(this)); auto* rhs = std::any_cast(expr.addExp()->accept(this)); ir::Opcode cmp = ir::Opcode::Lt; if (expr.LtOp()) { cmp = ir::Opcode::Lt; } else if (expr.GtOp()) { cmp = ir::Opcode::Gt; } else if (expr.LeOp()) { cmp = ir::Opcode::Le; } else if (expr.GeOp()) { cmp = ir::Opcode::Ge; } else { throw std::runtime_error(FormatError("irgen", "非法关系表达式")); } auto* cond = builder_.CreateBinary(cmp, lhs, rhs, module_.GetContext().NextTemp()); builder_.CreateCondBr(cond, true_bb, false_bb); } std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句")); } if (ctx->block()) { return ctx->block()->accept(this); } if (ctx->lVal() && ctx->Assign()) { if (!ctx->exp()) { throw std::runtime_error(FormatError("irgen", "赋值语句缺少右值表达式")); } bool is_array = false; const auto extents = GetArrayExtentsForLVal(*ctx->lVal(), is_array); if (is_array && ctx->lVal()->exp().size() < extents.size()) { throw std::runtime_error(FormatError("irgen", "不能给数组对象赋值")); } auto* slot = GetLValAddress(*ctx->lVal()); ir::Value* rhs = EvalExpr(*ctx->exp()); auto target_ty = slot->GetType()->IsPtrFloat32() ? ir::Type::GetFloat32Type() : ir::Type::GetInt32Type(); rhs = CastValueTo(rhs, target_ty); builder_.CreateStore(rhs, slot); return BlockFlow::Continue; } if (ctx->If()) { if (!ctx->cond() || ctx->stmt().empty()) { throw std::runtime_error(FormatError("irgen", "if 语句缺少必要组成部分")); } auto* then_bb = func_->CreateBlock(NextBlockName("if.then.")); ir::BasicBlock* else_bb = nullptr; const bool has_else = ctx->Else() != nullptr; if (has_else) { else_bb = func_->CreateBlock(NextBlockName("if.else.")); } auto* merge_bb = func_->CreateBlock(NextBlockName("if.end.")); EmitCondBranch(*ctx->cond(), then_bb, has_else ? else_bb : merge_bb); builder_.SetInsertPoint(then_bb); auto then_flow = std::any_cast(ctx->stmt(0)->accept(this)); if (then_flow == BlockFlow::Continue && !builder_.GetInsertBlock()->HasTerminator()) { builder_.CreateBr(merge_bb); } BlockFlow else_flow = BlockFlow::Continue; if (has_else) { builder_.SetInsertPoint(else_bb); else_flow = std::any_cast(ctx->stmt(1)->accept(this)); if (else_flow == BlockFlow::Continue && !builder_.GetInsertBlock()->HasTerminator()) { builder_.CreateBr(merge_bb); } } builder_.SetInsertPoint(merge_bb); if (has_else && then_flow == BlockFlow::Terminated && else_flow == BlockFlow::Terminated) { if (func_->GetType()->IsInt32()) { builder_.CreateRet(builder_.CreateConstInt(0)); } else if (func_->GetType()->IsFloat32()) { builder_.CreateRet(builder_.CreateConstFloat(0.0)); } else { builder_.CreateRetVoid(); } return BlockFlow::Terminated; } return BlockFlow::Continue; } if (ctx->While()) { if (!ctx->cond() || ctx->stmt().empty()) { throw std::runtime_error(FormatError("irgen", "while 语句缺少必要组成部分")); } auto* cond_bb = func_->CreateBlock(NextBlockName("while.cond.")); auto* body_bb = func_->CreateBlock(NextBlockName("while.body.")); auto* exit_bb = func_->CreateBlock(NextBlockName("while.end.")); builder_.CreateBr(cond_bb); builder_.SetInsertPoint(cond_bb); EmitCondBranch(*ctx->cond(), body_bb, exit_bb); loop_stack_.push_back({cond_bb, exit_bb}); builder_.SetInsertPoint(body_bb); auto body_flow = std::any_cast(ctx->stmt(0)->accept(this)); if (body_flow == BlockFlow::Continue && !builder_.GetInsertBlock()->HasTerminator()) { builder_.CreateBr(cond_bb); } loop_stack_.pop_back(); builder_.SetInsertPoint(exit_bb); return BlockFlow::Continue; } if (ctx->Break()) { if (loop_stack_.empty()) { throw std::runtime_error(FormatError("irgen", "break 不在循环体内")); } builder_.CreateBr(loop_stack_.back().second); return BlockFlow::Terminated; } if (ctx->Continue()) { if (loop_stack_.empty()) { throw std::runtime_error(FormatError("irgen", "continue 不在循环体内")); } builder_.CreateBr(loop_stack_.back().first); return BlockFlow::Terminated; } if (ctx->Return()) { if (func_->GetType()->IsVoid()) { if (ctx->exp()) { throw std::runtime_error(FormatError("irgen", "void 函数不应返回值")); } builder_.CreateRetVoid(); return BlockFlow::Terminated; } if (!ctx->exp()) { throw std::runtime_error(FormatError("irgen", "return 缺少表达式")); } ir::Value* v = CastValueTo(EvalExpr(*ctx->exp()), func_->GetType()); builder_.CreateRet(v); return BlockFlow::Terminated; } if (ctx->Semi() && !ctx->exp()) { return BlockFlow::Continue; } if (ctx->exp() && ctx->Semi()) { (void)EvalExpr(*ctx->exp()); return BlockFlow::Continue; } throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); }