diff --git a/include/ir/IR.h b/include/ir/IR.h index a1a3329..ed38827 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -301,6 +301,8 @@ class IRBuilder { BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs, const std::string& name); BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name); StoreInst* CreateStore(Value* val, Value* ptr); diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index 231ba90..24d0b3b 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -26,16 +26,16 @@ class IRGenImpl final : public SysYBaseVisitor { std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; + std::any visitBlock(SysYParser::BlockContext* ctx) override; std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; std::any visitDecl(SysYParser::DeclContext* ctx) override; std::any visitStmt(SysYParser::StmtContext* ctx) override; std::any visitVarDef(SysYParser::VarDefContext* ctx) override; - std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; - std::any visitParenExp(SysYParser::ParenExpContext* ctx) override; - std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override; - std::any visitVarExp(SysYParser::VarExpContext* ctx) override; - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override; + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override; + std::any visitNumber(SysYParser::NumberContext* ctx) override; + std::any visitLVal(SysYParser::LValContext* ctx) override; + std::any visitAddExp(SysYParser::AddExpContext* ctx) override; + std::any visitMulExp(SysYParser::MulExpContext* ctx) override; private: enum class BlockFlow { @@ -50,8 +50,8 @@ class IRGenImpl final : public SysYBaseVisitor { const SemanticContext& sema_; ir::Function* func_; ir::IRBuilder builder_; - // 名称绑定由 Sema 负责;IRGen 只维护“声明 -> 存储槽位”的代码生成状态。 - std::unordered_map storage_map_; + // 名称绑定由 Sema 负责;IRGen 只维护"变量名 -> 存储槽位"的代码生成状态。 + std::unordered_map storage_map_; }; std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index 30efbb6..6d7f486 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -32,6 +32,10 @@ static const char* OpcodeToString(Opcode op) { return "sub"; case Opcode::Mul: return "mul"; + case Opcode::Div: + return "sdiv"; + case Opcode::Mod: + return "srem"; case Opcode::Alloca: return "alloca"; case Opcode::Load: @@ -65,7 +69,9 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { switch (inst->GetOpcode()) { case Opcode::Add: case Opcode::Sub: - case Opcode::Mul: { + case Opcode::Mul: + case Opcode::Div: + case Opcode::Mod: { auto* bin = static_cast(inst); os << " " << bin->GetName() << " = " << OpcodeToString(bin->GetOpcode()) << " " diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index 7928716..5abab19 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -61,8 +61,9 @@ void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; } BinaryInst::BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, Value* rhs, std::string name) : Instruction(op, std::move(ty), std::move(name)) { - if (op != Opcode::Add) { - throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add")); + if (op != Opcode::Add && op != Opcode::Sub && op != Opcode::Mul && + op != Opcode::Div && op != Opcode::Mod) { + throw std::runtime_error(FormatError("ir", "BinaryInst 不支持的操作码")); } if (!lhs || !rhs) { throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数")); diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index 0eb62ae..df7ccf9 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -6,18 +6,7 @@ #include "ir/IR.h" #include "utils/Log.h" -namespace { - -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { - throw std::runtime_error(FormatError("irgen", "非法左值")); - } - return lvalue.ID()->getText(); -} - -} // namespace - -std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { +std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句块")); } @@ -63,14 +52,20 @@ std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少变量声明")); } - if (!ctx->btype() || !ctx->btype()->INT()) { + // 当前语法中 decl 包含 constDecl 或 varDecl,这里只支持 varDecl + auto* var_decl = ctx->varDecl(); + if (!var_decl) { + throw std::runtime_error(FormatError("irgen", "当前仅支持变量声明")); + } + if (!var_decl->bType() || !var_decl->bType()->INT()) { throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明")); } - auto* var_def = ctx->varDef(); - if (!var_def) { - throw std::runtime_error(FormatError("irgen", "非法变量声明")); + // 遍历所有 varDef + for (auto* var_def : var_decl->varDef()) { + if (var_def) { + var_def->accept(this); + } } - var_def->accept(this); return {}; } @@ -83,22 +78,26 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少变量定义")); } - if (!ctx->lValue()) { + if (!ctx->ID()) { throw std::runtime_error(FormatError("irgen", "变量声明缺少名称")); } - GetLValueName(*ctx->lValue()); - if (storage_map_.find(ctx) != storage_map_.end()) { + // 暂不支持数组声明(constIndex) + if (!ctx->constIndex().empty()) { + throw std::runtime_error(FormatError("irgen", "暂不支持数组声明")); + } + std::string var_name = ctx->ID()->getText(); + if (storage_map_.find(var_name) != storage_map_.end()) { throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); } auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); - storage_map_[ctx] = slot; + storage_map_[var_name] = slot; ir::Value* init = nullptr; - if (auto* init_value = ctx->initValue()) { - if (!init_value->exp()) { + if (auto* init_val = ctx->initVal()) { + if (!init_val->exp()) { throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化")); } - init = EvalExpr(*init_value->exp()); + init = EvalExpr(*init_val->exp()); } else { init = builder_.CreateConstInt(0); } diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index f20609e..7b25d4a 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -25,20 +25,51 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { } -std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "非法括号表达式")); +std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { + if (!ctx) { + throw std::runtime_error(FormatError("irgen", "非法基本表达式")); + } + // 处理括号表达式:LPAREN exp RPAREN + if (ctx->exp()) { + return EvalExpr(*ctx->exp()); + } + // 处理 lVal(变量使用)- 交给 visitLVal 处理 + if (ctx->lVal()) { + // 直接在这里处理变量读取,避免 accept 调用可能导致的问题 + auto* lval_ctx = ctx->lVal(); + if (!lval_ctx || !lval_ctx->ID()) { + throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); + } + const auto* decl = sema_.ResolveObjectUse(lval_ctx); + if (!decl) { + throw std::runtime_error( + FormatError("irgen", + "变量使用缺少语义绑定:" + lval_ctx->ID()->getText())); + } + std::string var_name = lval_ctx->ID()->getText(); + auto it = storage_map_.find(var_name); + if (it == storage_map_.end()) { + throw std::runtime_error( + FormatError("irgen", + "变量声明缺少存储槽位:" + var_name)); + } + return static_cast( + builder_.CreateLoad(it->second, module_.GetContext().NextTemp())); } - return EvalExpr(*ctx->exp()); + // 处理 number + if (ctx->number()) { + return ctx->number()->accept(this); + } + throw std::runtime_error(FormatError("irgen", "不支持的基本表达式类型")); } -std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { +std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) { + if (!ctx || !ctx->intConst()) { throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量")); } return static_cast( - builder_.CreateConstInt(std::stoi(ctx->number()->getText()))); + builder_.CreateConstInt(std::stoi(ctx->intConst()->getText()))); } // 变量使用的处理流程: @@ -47,33 +78,46 @@ std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) { // 3. 最后生成 load,把内存中的值读出来。 // // 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。 -std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) { - if (!ctx || !ctx->var() || !ctx->var()->ID()) { +std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { + if (!ctx || !ctx->ID()) { throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); } - auto* decl = sema_.ResolveVarUse(ctx->var()); + const auto* decl = sema_.ResolveObjectUse(ctx); if (!decl) { throw std::runtime_error( FormatError("irgen", - "变量使用缺少语义绑定: " + ctx->var()->ID()->getText())); + "变量使用缺少语义绑定:" + ctx->ID()->getText())); } - auto it = storage_map_.find(decl); + // 使用变量名查找存储槽位 + std::string var_name = ctx->ID()->getText(); + auto it = storage_map_.find(var_name); if (it == storage_map_.end()) { throw std::runtime_error( FormatError("irgen", - "变量声明缺少存储槽位: " + ctx->var()->ID()->getText())); + "变量声明缺少存储槽位:" + var_name)); } return static_cast( builder_.CreateLoad(it->second, module_.GetContext().NextTemp())); } -std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { +std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { + if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法加减法表达式")); } - ir::Value* lhs = EvalExpr(*ctx->exp(0)); - ir::Value* rhs = EvalExpr(*ctx->exp(1)); + + // 如果是 mulExp 直接返回(addExp : mulExp) + if (ctx->mulExp() && ctx->addExp() == nullptr) { + return ctx->mulExp()->accept(this); + } + + // 处理 addExp op mulExp 的递归形式 + if (!ctx->addExp() || !ctx->mulExp()) { + throw std::runtime_error(FormatError("irgen", "非法加减法表达式结构")); + } + + ir::Value* lhs = std::any_cast(ctx->addExp()->accept(this)); + ir::Value* rhs = std::any_cast(ctx->mulExp()->accept(this)); ir::Opcode op = ir::Opcode::Add; if (ctx->ADD()) { @@ -93,18 +137,18 @@ std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { throw std::runtime_error(FormatError("irgen", "非法乘除法表达式")); } - // 如果是 unaryExp 直接返回 - if (ctx->unaryExp()) { + // 如果是 unaryExp 直接返回(mulExp : unaryExp) + if (ctx->unaryExp() && ctx->mulExp() == nullptr) { return ctx->unaryExp()->accept(this); } - // 处理 MulExp op unaryExp 的递归形式 - if (!ctx->exp(0) || !ctx->unaryExp(0)) { + // 处理 mulExp op unaryExp 的递归形式 + if (!ctx->mulExp() || !ctx->unaryExp()) { throw std::runtime_error(FormatError("irgen", "非法乘除法表达式结构")); } - ir::Value* lhs = std::any_cast(ctx->exp(0)->accept(this)); - ir::Value* rhs = std::any_cast(ctx->unaryExp(0)->accept(this)); + ir::Value* lhs = std::any_cast(ctx->mulExp()->accept(this)); + ir::Value* rhs = std::any_cast(ctx->unaryExp()->accept(this)); ir::Opcode op = ir::Opcode::Mul; if (ctx->MUL()) { diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index 4912d03..4ee5b3e 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -29,7 +29,7 @@ IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema) // 编译单元的 IR 生成当前只实现了最小功能: // - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容; -// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR; +// - 当前会读取编译单元中的 topLevelItem,找到 funcDef 后生成函数 IR; // // 当前还没有实现: // - 多个函数定义的遍历与生成; @@ -38,12 +38,15 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少编译单元")); } - auto* func = ctx->funcDef(); - if (!func) { - throw std::runtime_error(FormatError("irgen", "缺少函数定义")); + // 遍历所有 topLevelItem,找到 funcDef + for (auto* item : ctx->topLevelItem()) { + if (item && item->funcDef()) { + item->funcDef()->accept(this); + // 当前只支持单个函数,找到第一个后就返回 + return {}; + } } - func->accept(this); - return {}; + throw std::runtime_error(FormatError("irgen", "缺少函数定义")); } // 函数 IR 生成当前实现了: @@ -61,12 +64,11 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { // - 入口块中的参数初始化逻辑。 // ... -// 因此这里目前只支持最小的“无参 int 函数”生成。 std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少函数定义")); } - if (!ctx->blockStmt()) { + if (!ctx->block()) { throw std::runtime_error(FormatError("irgen", "函数体为空")); } if (!ctx->ID()) { @@ -80,7 +82,7 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { builder_.SetInsertPoint(func_->GetEntry()); storage_map_.clear(); - ctx->blockStmt()->accept(this); + ctx->block()->accept(this); // 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。 VerifyFunctionStructure(*func_); return {}; diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index 751550c..8f7a2a5 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -19,21 +19,14 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句")); } - if (ctx->returnStmt()) { - return ctx->returnStmt()->accept(this); + // 检查是否是 return 语句:RETURN exp? SEMICOLON + if (ctx->RETURN()) { + if (!ctx->exp()) { + throw std::runtime_error(FormatError("irgen", "return 缺少表达式")); + } + ir::Value* v = EvalExpr(*ctx->exp()); + builder_.CreateRet(v); + return BlockFlow::Terminated; } throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); } - - -std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少 return 语句")); - } - if (!ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "return 缺少表达式")); - } - ir::Value* v = EvalExpr(*ctx->exp()); - builder_.CreateRet(v); - return BlockFlow::Terminated; -} diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index f0b49e5..fc73f9d 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -434,9 +434,11 @@ class SemaVisitor final : public SysYBaseVisitor { if (!ctx) { ThrowSemaError(ctx, "非法乘法表达式"); } - if (ctx->unaryExp()) { + // 如果是 mulExp : unaryExp 形式(没有 MUL/DIV/MOD token),直接处理 unaryExp + if (!ctx->MUL() && !ctx->DIV() && !ctx->MOD()) { return EvalExpr(*ctx->unaryExp()); } + // 否则是 mulExp MUL/DIV/MOD unaryExp 形式 ExprInfo lhs = EvalExpr(*ctx->mulExp()); ExprInfo rhs = EvalExpr(*ctx->unaryExp()); return EvalArithmetic(*ctx, lhs, rhs, ctx->MUL() ? '*' : (ctx->DIV() ? '/' : '%')); @@ -446,9 +448,11 @@ class SemaVisitor final : public SysYBaseVisitor { if (!ctx) { ThrowSemaError(ctx, "非法加法表达式"); } - if (ctx->mulExp()) { + // 如果是 addExp : mulExp 形式(没有 ADD/SUB token),直接处理 mulExp + if (!ctx->ADD() && !ctx->SUB()) { return EvalExpr(*ctx->mulExp()); } + // 否则是 addExp ADD/SUB mulExp 形式 ExprInfo lhs = EvalExpr(*ctx->addExp()); ExprInfo rhs = EvalExpr(*ctx->mulExp()); return EvalArithmetic(*ctx, lhs, rhs, ctx->ADD() ? '+' : '-'); @@ -544,7 +548,7 @@ class SemaVisitor final : public SysYBaseVisitor { const std::string name = ctx.ID()->getText(); const ObjectBinding* symbol = symbols_.Lookup(name); if (!symbol) { - ThrowSemaError(&ctx, "使用了未声明的标识符: " + name); + ThrowSemaError(&ctx, "使用了未声明的标识符:" + name); } sema_.BindObjectUse(&ctx, *symbol);