From c33d36e0403ab05a8db0dfb8d3f7712fea28822a Mon Sep 17 00:00:00 2001 From: Shrink <1569629152@qq.com> Date: Tue, 24 Mar 2026 23:24:25 +0800 Subject: [PATCH] =?UTF-8?q?Shrink:=20Compile=20pass=20with=20IRGen=20fixed?= =?UTF-8?q?=20=E5=AE=9E=E7=8E=B0=E5=90=88=E5=B9=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/irgen/IRGen.h | 27 ++++- src/irgen/IRGenDecl.cpp | 25 +++- src/irgen/IRGenExp.cpp | 251 +++++++++++++++++++++++++++++++++++----- src/irgen/IRGenFunc.cpp | 10 +- src/irgen/IRGenStmt.cpp | 92 +++++++++++++++ 5 files changed, 366 insertions(+), 39 deletions(-) diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index 231ba90..53eb24d 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -29,13 +29,22 @@ class IRGenImpl final : public SysYBaseVisitor { std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; std::any visitDecl(SysYParser::DeclContext* ctx) override; + std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override; std::any visitStmt(SysYParser::StmtContext* ctx) override; std::any visitVarDef(SysYParser::VarDefContext* ctx) override; std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; - std::any visitParenExp(SysYParser::ParenExpContext* ctx) override; - std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override; - std::any visitVarExp(SysYParser::VarExpContext* ctx) override; - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override; + std::any visitExp(SysYParser::ExpContext* ctx) override; + std::any visitCond(SysYParser::CondContext* ctx) override; + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override; + std::any visitNumber(SysYParser::NumberContext* ctx) override; + std::any visitLValue(SysYParser::LValueContext* ctx) override; + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override; + std::any visitMulExp(SysYParser::MulExpContext* ctx) override; + std::any visitAddExp(SysYParser::AddExpContext* ctx) override; + std::any visitRelExp(SysYParser::RelExpContext* ctx) override; + std::any visitEqExp(SysYParser::EqExpContext* ctx) override; + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override; + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override; private: enum class BlockFlow { @@ -43,8 +52,16 @@ class IRGenImpl final : public SysYBaseVisitor { Terminated, }; + struct LoopTargets { + ir::BasicBlock* continue_target; + ir::BasicBlock* break_target; + }; + BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item); ir::Value* EvalExpr(SysYParser::ExpContext& expr); + ir::Value* EvalCond(SysYParser::CondContext& cond); + ir::Value* ToBoolValue(ir::Value* v); + std::string NextBlockName(); ir::Module& module_; const SemanticContext& sema_; @@ -52,6 +69,8 @@ class IRGenImpl final : public SysYBaseVisitor { ir::IRBuilder builder_; // 名称绑定由 Sema 负责;IRGen 只维护“声明 -> 存储槽位”的代码生成状态。 std::unordered_map storage_map_; + std::unordered_map named_storage_; + std::vector loop_stack_; }; std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index 0eb62ae..75cfdf0 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -60,17 +60,29 @@ std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { // - const、数组、全局变量等不同声明形态; // - 更丰富的类型系统。 std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "缺少变量声明")); + } + if (!ctx->varDecl()) { + // 当前先忽略 constDecl 与其它声明形态。 + return {}; + } + return ctx->varDecl()->accept(this); +} + +std::any IRGenImpl::visitVarDecl(SysYParser::VarDeclContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少变量声明")); } if (!ctx->btype() || !ctx->btype()->INT()) { throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明")); } - auto* var_def = ctx->varDef(); - if (!var_def) { - throw std::runtime_error(FormatError("irgen", "非法变量声明")); + for (auto* var_def : ctx->varDef()) { + if (!var_def) { + throw std::runtime_error(FormatError("irgen", "非法变量声明")); + } + var_def->accept(this); } - var_def->accept(this); return {}; } @@ -83,15 +95,16 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少变量定义")); } - if (!ctx->lValue()) { + if (!ctx->ID()) { throw std::runtime_error(FormatError("irgen", "变量声明缺少名称")); } - GetLValueName(*ctx->lValue()); + const std::string name = ctx->ID()->getText(); if (storage_map_.find(ctx) != storage_map_.end()) { throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); } auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); storage_map_[ctx] = slot; + named_storage_[name] = slot; ir::Value* init = nullptr; if (auto* init_value = ctx->initValue()) { diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index cf4797c..7e22485 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -24,21 +24,62 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { return std::any_cast(expr.accept(this)); } +ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) { + return std::any_cast(cond.accept(this)); +} -std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "非法括号表达式")); +ir::Value* IRGenImpl::ToBoolValue(ir::Value* v) { + if (!v) { + throw std::runtime_error(FormatError("irgen", "条件值为空")); } - return EvalExpr(*ctx->exp()); + auto* zero = builder_.CreateConstInt(0); + return builder_.CreateCmp(ir::CmpOp::Ne, v, zero, module_.GetContext().NextTemp()); } +std::string IRGenImpl::NextBlockName() { + std::string temp = module_.GetContext().NextTemp(); + if (!temp.empty() && temp.front() == '%') { + return "bb" + temp.substr(1); + } + return "bb" + temp; +} -std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { +std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) { + if (!ctx || !ctx->addExp()) { + throw std::runtime_error(FormatError("irgen", "非法表达式")); + } + return ctx->addExp()->accept(this); +} + +std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) { + if (!ctx || !ctx->lOrExp()) { + throw std::runtime_error(FormatError("irgen", "非法条件表达式")); + } + return ctx->lOrExp()->accept(this); +} + +std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法基本表达式")); + } + if (ctx->exp()) { + return EvalExpr(*ctx->exp()); + } + if (ctx->number()) { + return ctx->number()->accept(this); + } + if (ctx->lValue()) { + return ctx->lValue()->accept(this); + } + throw std::runtime_error(FormatError("irgen", "不支持的基本表达式")); +} + +std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) { + if (!ctx || !ctx->ILITERAL()) { throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量")); } return static_cast( - builder_.CreateConstInt(std::stoi(ctx->number()->getText()))); + builder_.CreateConstInt(std::stoi(ctx->getText()))); } // 变量使用的处理流程: @@ -47,34 +88,192 @@ std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) { // 3. 最后生成 load,把内存中的值读出来。 // // 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。 -std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) { - if (!ctx || !ctx->var() || !ctx->var()->ID()) { +std::any IRGenImpl::visitLValue(SysYParser::LValueContext* ctx) { + if (!ctx || !ctx->ID()) { throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); } - auto* decl = sema_.ResolveVarUse(ctx->var()); - if (!decl) { - throw std::runtime_error( - FormatError("irgen", - "变量使用缺少语义绑定: " + ctx->var()->ID()->getText())); - } - auto it = storage_map_.find(decl); - if (it == storage_map_.end()) { + const std::string name = ctx->ID()->getText(); + auto it = named_storage_.find(name); + if (it == named_storage_.end()) { throw std::runtime_error( - FormatError("irgen", - "变量声明缺少存储槽位: " + ctx->var()->ID()->getText())); + FormatError("irgen", "变量声明缺少存储槽位: " + name)); } return static_cast( builder_.CreateLoad(it->second, module_.GetContext().NextTemp())); } +std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法一元表达式")); + } + if (ctx->primaryExp()) { + return ctx->primaryExp()->accept(this); + } + if (ctx->unaryOp() && ctx->unaryExp()) { + ir::Value* v = std::any_cast(ctx->unaryExp()->accept(this)); + if (ctx->unaryOp()->SUB()) { + auto* zero = builder_.CreateConstInt(0); + return static_cast(builder_.CreateSub( + zero, v, module_.GetContext().NextTemp())); + } + if (ctx->unaryOp()->ADD()) { + return v; + } + throw std::runtime_error(FormatError("irgen", "当前不支持逻辑非运算")); + } + throw std::runtime_error(FormatError("irgen", "当前不支持函数调用表达式")); +} + +std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); + } + if (ctx->mulExp()) { + if (!ctx->unaryExp()) { + throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); + } + ir::Value* lhs = std::any_cast(ctx->mulExp()->accept(this)); + ir::Value* rhs = std::any_cast(ctx->unaryExp()->accept(this)); + if (ctx->MUL()) { + return static_cast( + builder_.CreateMul(lhs, rhs, module_.GetContext().NextTemp())); + } + if (ctx->DIV()) { + return static_cast( + builder_.CreateDiv(lhs, rhs, module_.GetContext().NextTemp())); + } + if (ctx->MOD()) { + return static_cast( + builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp())); + } + throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); + } + if (ctx->unaryExp()) { + return ctx->unaryExp()->accept(this); + } + throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); +} -std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { +std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { + if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法加法表达式")); } - ir::Value* lhs = EvalExpr(*ctx->exp(0)); - ir::Value* rhs = EvalExpr(*ctx->exp(1)); - return static_cast( - builder_.CreateBinary(ir::Opcode::Add, lhs, rhs, - module_.GetContext().NextTemp())); + if (ctx->addExp()) { + if (!ctx->mulExp()) { + throw std::runtime_error(FormatError("irgen", "非法加法表达式")); + } + ir::Value* lhs = std::any_cast(ctx->addExp()->accept(this)); + ir::Value* rhs = std::any_cast(ctx->mulExp()->accept(this)); + if (ctx->ADD()) { + return static_cast( + builder_.CreateAdd(lhs, rhs, module_.GetContext().NextTemp())); + } + if (ctx->SUB()) { + return static_cast( + builder_.CreateSub(lhs, rhs, module_.GetContext().NextTemp())); + } + throw std::runtime_error(FormatError("irgen", "非法加法表达式")); + } + if (ctx->mulExp()) { + return ctx->mulExp()->accept(this); + } + throw std::runtime_error(FormatError("irgen", "非法加法表达式")); +} + +std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法关系表达式")); + } + if (ctx->relExp()) { + if (!ctx->addExp()) { + throw std::runtime_error(FormatError("irgen", "非法关系表达式")); + } + ir::Value* lhs = std::any_cast(ctx->relExp()->accept(this)); + ir::Value* rhs = std::any_cast(ctx->addExp()->accept(this)); + if (ctx->LT()) { + return static_cast(builder_.CreateCmp( + ir::CmpOp::Lt, lhs, rhs, module_.GetContext().NextTemp())); + } + if (ctx->LE()) { + return static_cast(builder_.CreateCmp( + ir::CmpOp::Le, lhs, rhs, module_.GetContext().NextTemp())); + } + if (ctx->GT()) { + return static_cast(builder_.CreateCmp( + ir::CmpOp::Gt, lhs, rhs, module_.GetContext().NextTemp())); + } + if (ctx->GE()) { + return static_cast(builder_.CreateCmp( + ir::CmpOp::Ge, lhs, rhs, module_.GetContext().NextTemp())); + } + throw std::runtime_error(FormatError("irgen", "非法关系表达式")); + } + if (ctx->addExp()) { + return ctx->addExp()->accept(this); + } + throw std::runtime_error(FormatError("irgen", "非法关系表达式")); +} + +std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法相等表达式")); + } + if (ctx->eqExp()) { + if (!ctx->relExp()) { + throw std::runtime_error(FormatError("irgen", "非法相等表达式")); + } + ir::Value* lhs = std::any_cast(ctx->eqExp()->accept(this)); + ir::Value* rhs = std::any_cast(ctx->relExp()->accept(this)); + if (ctx->EQ()) { + return static_cast(builder_.CreateCmp( + ir::CmpOp::Eq, lhs, rhs, module_.GetContext().NextTemp())); + } + if (ctx->NE()) { + return static_cast(builder_.CreateCmp( + ir::CmpOp::Ne, lhs, rhs, module_.GetContext().NextTemp())); + } + throw std::runtime_error(FormatError("irgen", "非法相等表达式")); + } + if (ctx->relExp()) { + return ctx->relExp()->accept(this); + } + throw std::runtime_error(FormatError("irgen", "非法相等表达式")); +} + +std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); + } + if (ctx->lAndExp()) { + if (!ctx->eqExp()) { + throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); + } + auto* lhs = ToBoolValue(std::any_cast(ctx->lAndExp()->accept(this))); + auto* rhs = ToBoolValue(std::any_cast(ctx->eqExp()->accept(this))); + return static_cast( + builder_.CreateMul(lhs, rhs, module_.GetContext().NextTemp())); + } + if (ctx->eqExp()) { + return ToBoolValue(std::any_cast(ctx->eqExp()->accept(this))); + } + throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); +} + +std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); + } + if (ctx->lOrExp()) { + if (!ctx->lAndExp()) { + throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); + } + auto* lhs = ToBoolValue(std::any_cast(ctx->lOrExp()->accept(this))); + auto* rhs = ToBoolValue(std::any_cast(ctx->lAndExp()->accept(this))); + auto* sum = builder_.CreateAdd(lhs, rhs, module_.GetContext().NextTemp()); + return static_cast(ToBoolValue(sum)); + } + if (ctx->lAndExp()) { + return ToBoolValue(std::any_cast(ctx->lAndExp()->accept(this))); + } + throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); } diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index 4912d03..737563d 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -38,11 +38,14 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少编译单元")); } - auto* func = ctx->funcDef(); - if (!func) { + if (ctx->funcDef().empty()) { throw std::runtime_error(FormatError("irgen", "缺少函数定义")); } - func->accept(this); + for (auto* func : ctx->funcDef()) { + if (func) { + func->accept(this); + } + } return {}; } @@ -79,6 +82,7 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type()); builder_.SetInsertPoint(func_->GetEntry()); storage_map_.clear(); + named_storage_.clear(); ctx->blockStmt()->accept(this); // 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。 diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index 751550c..61ad87e 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -19,9 +19,101 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句")); } + if (ctx->lValue() && ctx->ASSIGN() && ctx->exp()) { + if (!ctx->lValue()->ID()) { + throw std::runtime_error(FormatError("irgen", "赋值语句左值非法")); + } + const std::string name = ctx->lValue()->ID()->getText(); + auto slot_it = named_storage_.find(name); + if (slot_it == named_storage_.end()) { + throw std::runtime_error(FormatError("irgen", "赋值目标未声明: " + name)); + } + ir::Value* rhs = EvalExpr(*ctx->exp()); + builder_.CreateStore(rhs, slot_it->second); + return BlockFlow::Continue; + } + if (ctx->blockStmt()) { + ctx->blockStmt()->accept(this); + return builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator() + ? BlockFlow::Terminated + : BlockFlow::Continue; + } + if (ctx->IF()) { + if (!ctx->cond() || ctx->stmt().empty()) { + throw std::runtime_error(FormatError("irgen", "if 语句不完整")); + } + auto* then_bb = func_->CreateBlock(NextBlockName()); + auto* merge_bb = func_->CreateBlock(NextBlockName()); + auto* else_bb = ctx->ELSE() ? func_->CreateBlock(NextBlockName()) : merge_bb; + + ir::Value* cond = ToBoolValue(EvalCond(*ctx->cond())); + builder_.CreateCondBr(cond, then_bb, else_bb); + + builder_.SetInsertPoint(then_bb); + auto then_flow = std::any_cast(ctx->stmt(0)->accept(this)); + if (then_flow != BlockFlow::Terminated) { + builder_.CreateBr(merge_bb); + } + + if (ctx->ELSE()) { + builder_.SetInsertPoint(else_bb); + auto else_flow = std::any_cast(ctx->stmt(1)->accept(this)); + if (else_flow != BlockFlow::Terminated) { + builder_.CreateBr(merge_bb); + } + } + + builder_.SetInsertPoint(merge_bb); + return BlockFlow::Continue; + } + if (ctx->WHILE()) { + if (!ctx->cond() || ctx->stmt().empty()) { + throw std::runtime_error(FormatError("irgen", "while 语句不完整")); + } + auto* cond_bb = func_->CreateBlock(NextBlockName()); + auto* body_bb = func_->CreateBlock(NextBlockName()); + auto* exit_bb = func_->CreateBlock(NextBlockName()); + + builder_.CreateBr(cond_bb); + builder_.SetInsertPoint(cond_bb); + ir::Value* cond = ToBoolValue(EvalCond(*ctx->cond())); + builder_.CreateCondBr(cond, body_bb, exit_bb); + + loop_stack_.push_back({cond_bb, exit_bb}); + builder_.SetInsertPoint(body_bb); + auto body_flow = std::any_cast(ctx->stmt(0)->accept(this)); + if (body_flow != BlockFlow::Terminated) { + builder_.CreateBr(cond_bb); + } + loop_stack_.pop_back(); + + builder_.SetInsertPoint(exit_bb); + return BlockFlow::Continue; + } + if (ctx->BREAK()) { + if (loop_stack_.empty()) { + throw std::runtime_error(FormatError("irgen", "break 不在循环中")); + } + builder_.CreateBr(loop_stack_.back().break_target); + return BlockFlow::Terminated; + } + if (ctx->CONTINUE()) { + if (loop_stack_.empty()) { + throw std::runtime_error(FormatError("irgen", "continue 不在循环中")); + } + builder_.CreateBr(loop_stack_.back().continue_target); + return BlockFlow::Terminated; + } if (ctx->returnStmt()) { return ctx->returnStmt()->accept(this); } + if (ctx->exp()) { + EvalExpr(*ctx->exp()); + return BlockFlow::Continue; + } + if (ctx->SEMICOLON()) { + return BlockFlow::Continue; + } throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); }