#include "irgen/IRGen.h" #include #include "SysYParser.h" #include "ir/IR.h" #include "utils/Log.h" // 表达式生成当前也只实现了很小的一个子集。 // 目前支持: // - 整数字面量 // - 普通局部变量读取 // - 括号表达式 // - 二元加法 // // 还未支持: // - 减乘除与一元运算 // - 赋值表达式 // - 函数调用 // - 数组、指针、下标访问 // - 条件与比较表达式 // - ... ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { return std::any_cast(expr.accept(this)); } std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { if (!ctx) throw std::runtime_error(FormatError("irgen", "非法 primary 表达式")); if (ctx->exp()) return EvalExpr(*ctx->exp()); if (ctx->lVal()) return ctx->lVal()->accept(this); if (ctx->number()) return ctx->number()->accept(this); throw std::runtime_error(FormatError("irgen", "不支持的 primary 表达式")); } std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) { if (!ctx) throw std::runtime_error(FormatError("irgen", "非法数字字面量")); if (ctx->IntConst()) { return static_cast(builder_.CreateConstInt(std::stoi(ctx->getText()))); } if (ctx->FloatConst()) { return static_cast(builder_.CreateConstFloat(std::stof(ctx->getText()))); } throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量或浮点字面量")); } // 变量使用的处理流程: // 1. 先通过语义分析结果把变量使用绑定回声明; // 2. 再通过 storage_map_ 找到该声明对应的栈槽位; // 3. 最后生成 load,把内存中的值读出来。 // // 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。 std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { if (!ctx || !ctx->Ident()) { throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); } // find storage by matching declaration node stored in Sema context // Sema stores types/decl contexts in IRGenContext maps; here we search storage_map_ by name std::string name = ctx->Ident()->getText(); // 优先使用按名称的快速映射 auto nit = name_map_.find(name); if (nit != name_map_.end()) { // 支持下标访问:若有索引表达式列表,则生成 GEP + load if (ctx->exp().size() > 0) { std::vector indices; // 首个索引用于穿过数组对象 indices.push_back(builder_.CreateConstInt(0)); for (auto* e : ctx->exp()) { indices.push_back(EvalExpr(*e)); } auto* gep = builder_.CreateGEP(nit->second, indices, module_.GetContext().NextTemp()); return static_cast(builder_.CreateLoad(gep, module_.GetContext().NextTemp())); } // 如果映射到的是常量,直接返回常量值;否则按原来行为从槽位 load if (nit->second->IsConstant()) return nit->second; return static_cast(builder_.CreateLoad(nit->second, module_.GetContext().NextTemp())); } for (auto& kv : storage_map_) { if (!kv.first) continue; if (auto* vdef = dynamic_cast(kv.first)) { if (vdef->Ident() && vdef->Ident()->getText() == name) { if (ctx->exp().size() > 0) { std::vector indices; indices.push_back(builder_.CreateConstInt(0)); for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e)); auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp()); return static_cast(builder_.CreateLoad(gep, module_.GetContext().NextTemp())); } return static_cast(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp())); } } else if (auto* fparam = dynamic_cast(kv.first)) { if (fparam->Ident() && fparam->Ident()->getText() == name) { if (ctx->exp().size() > 0) { std::vector indices; indices.push_back(builder_.CreateConstInt(0)); for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e)); auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp()); return static_cast(builder_.CreateLoad(gep, module_.GetContext().NextTemp())); } return static_cast(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp())); } } else if (auto* cdef = dynamic_cast(kv.first)) { if (cdef->Ident() && cdef->Ident()->getText() == name) { if (ctx->exp().size() > 0) { std::vector indices; indices.push_back(builder_.CreateConstInt(0)); for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e)); auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp()); return static_cast(builder_.CreateLoad(gep, module_.GetContext().NextTemp())); } return static_cast(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp())); } } } throw std::runtime_error(FormatError("irgen", "变量声明缺少存储槽位: " + name)); } std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { if (!ctx) throw std::runtime_error(FormatError("irgen", "非法加法表达式")); try { // left-associative: fold across all mulExp operands if (ctx->mulExp().size() == 1) return ctx->mulExp(0)->accept(this); ir::Value* cur = std::any_cast(ctx->mulExp(0)->accept(this)); // extract operator sequence from text (in-order '+' or '-') std::string text = ctx->getText(); std::vector ops; for (char c : text) if (c == '+' || c == '-') ops.push_back(c); for (size_t i = 1; i < ctx->mulExp().size(); ++i) { ir::Value* rhs = std::any_cast(ctx->mulExp(i)->accept(this)); char opch = (i - 1 < ops.size()) ? ops[i - 1] : '+'; ir::Opcode op = (opch == '-') ? ir::Opcode::Sub : ir::Opcode::Add; cur = builder_.CreateBinary(op, cur, rhs, module_.GetContext().NextTemp()); } return static_cast(cur); } catch (const std::exception& e) { LogInfo(std::string("[irgen] exception in visitAddExp text=") + ctx->getText() + ", err=" + e.what(), std::cerr); throw; } } std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { if (!ctx) throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); if (ctx->unaryExp().size() == 1) return ctx->unaryExp(0)->accept(this); ir::Value* cur = std::any_cast(ctx->unaryExp(0)->accept(this)); // extract operator sequence for '*', '/', '%' std::string text = ctx->getText(); std::vector ops; for (char c : text) if (c == '*' || c == '/' || c == '%') ops.push_back(c); for (size_t i = 1; i < ctx->unaryExp().size(); ++i) { ir::Value* rhs = std::any_cast(ctx->unaryExp(i)->accept(this)); char opch = (i - 1 < ops.size()) ? ops[i - 1] : '*'; ir::Opcode op = ir::Opcode::Mul; if (opch == '/') op = ir::Opcode::Div; else if (opch == '%') op = ir::Opcode::Mod; cur = builder_.CreateBinary(op, cur, rhs, module_.GetContext().NextTemp()); } return static_cast(cur); } std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { if (!ctx) throw std::runtime_error(FormatError("irgen", "非法一元表达式")); if (ctx->primaryExp()) return ctx->primaryExp()->accept(this); // function call: Ident '(' funcRParams? ')' if (ctx->Ident() && ctx->getText().find("(") != std::string::npos) { std::string fname = ctx->Ident()->getText(); std::vector args; if (ctx->funcRParams()) { for (auto* e : ctx->funcRParams()->exp()) { args.push_back(EvalExpr(*e)); } } // find existing function or create an external declaration (assume int return) ir::Function* callee = nullptr; for (auto &fup : module_.GetFunctions()) { if (fup && fup->GetName() == fname) { callee = fup.get(); break; } } if (!callee) { std::vector> param_types; for (auto* a : args) { if (a && a->IsFloat32()) param_types.push_back(ir::Type::GetFloat32Type()); else param_types.push_back(ir::Type::GetInt32Type()); } callee = module_.CreateFunction(fname, ir::Type::GetInt32Type(), param_types); } return static_cast(builder_.CreateCall(callee, args, module_.GetContext().NextTemp())); } if (ctx->unaryExp()) { ir::Value* val = std::any_cast(ctx->unaryExp()->accept(this)); if (ctx->unaryOp() && ctx->unaryOp()->getText() == "+") return static_cast(val); else if (ctx->unaryOp() && ctx->unaryOp()->getText() == "-") { // 负号:0 - val,区分整型/浮点 if (val->IsFloat32()) { ir::Value* zero = builder_.CreateConstFloat(0.0f); return static_cast(builder_.CreateSub(zero, val, module_.GetContext().NextTemp())); } else { ir::Value* zero = builder_.CreateConstInt(0); return static_cast(builder_.CreateSub(zero, val, module_.GetContext().NextTemp())); } } if (ctx->unaryOp() && ctx->unaryOp()->getText() == "!") { // logical not: produce int 1 if val == 0, else 0 if (val->IsFloat32()) { ir::Value* zerof = builder_.CreateConstFloat(0.0f); return static_cast(builder_.CreateFCmp(ir::CmpInst::EQ, val, zerof, module_.GetContext().NextTemp())); } else { ir::Value* zero = builder_.CreateConstInt(0); return static_cast(builder_.CreateICmp(ir::CmpInst::EQ, val, zero, module_.GetContext().NextTemp())); } } } throw std::runtime_error(FormatError("irgen", "不支持的一元运算")); } std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { if (!ctx) throw std::runtime_error(FormatError("irgen", "非法关系表达式")); if (ctx->addExp().size() == 1) return ctx->addExp(0)->accept(this); ir::Value* lhs = std::any_cast(ctx->addExp(0)->accept(this)); ir::Value* rhs = std::any_cast(ctx->addExp(1)->accept(this)); // 类型提升 if (lhs->IsFloat32() && rhs->IsInt32()) { if (auto* ci = dynamic_cast(rhs)) { rhs = builder_.CreateConstFloat(static_cast(ci->GetValue())); } else { throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换")); } } else if (rhs->IsFloat32() && lhs->IsInt32()) { if (auto* ci = dynamic_cast(lhs)) { lhs = builder_.CreateConstFloat(static_cast(ci->GetValue())); } else { throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换")); } } ir::CmpInst::Predicate pred = ir::CmpInst::EQ; std::string text = ctx->getText(); if (text.find("<=") != std::string::npos) pred = ir::CmpInst::LE; else if (text.find(">=") != std::string::npos) pred = ir::CmpInst::GE; else if (text.find("<") != std::string::npos) pred = ir::CmpInst::LT; else if (text.find(">") != std::string::npos) pred = ir::CmpInst::GT; if (lhs->IsFloat32() || rhs->IsFloat32()) { return static_cast( builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp())); } return static_cast( builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp())); } std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { if (!ctx) throw std::runtime_error(FormatError("irgen", "非法相等表达式")); if (ctx->relExp().size() == 1) return ctx->relExp(0)->accept(this); ir::Value* lhs = std::any_cast(ctx->relExp(0)->accept(this)); ir::Value* rhs = std::any_cast(ctx->relExp(1)->accept(this)); // 类型提升 if (lhs->IsFloat32() && rhs->IsInt32()) { if (auto* ci = dynamic_cast(rhs)) { rhs = builder_.CreateConstFloat(static_cast(ci->GetValue())); } else { throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换")); } } else if (rhs->IsFloat32() && lhs->IsInt32()) { if (auto* ci = dynamic_cast(lhs)) { lhs = builder_.CreateConstFloat(static_cast(ci->GetValue())); } else { throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换")); } } ir::CmpInst::Predicate pred = ir::CmpInst::EQ; std::string text = ctx->getText(); if (text.find("==") != std::string::npos) pred = ir::CmpInst::EQ; else if (text.find("!=") != std::string::npos) pred = ir::CmpInst::NE; if (lhs->IsFloat32() || rhs->IsFloat32()) { return static_cast( builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp())); } return static_cast( builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp())); } std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); if (ctx->eqExp().size() == 1) return ctx->eqExp(0)->accept(this); // For simplicity, treat as int (0 or 1) ir::Value* lhs = std::any_cast(ctx->eqExp(0)->accept(this)); ir::Value* rhs = std::any_cast(ctx->eqExp(1)->accept(this)); // lhs && rhs : (lhs != 0) && (rhs != 0) ir::Value* zero = builder_.CreateConstInt(0); ir::Value* lhs_ne = builder_.CreateICmp(ir::CmpInst::NE, lhs, zero, module_.GetContext().NextTemp()); ir::Value* rhs_ne = builder_.CreateICmp(ir::CmpInst::NE, rhs, zero, module_.GetContext().NextTemp()); return static_cast( builder_.CreateMul(lhs_ne, rhs_ne, module_.GetContext().NextTemp())); } std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); if (ctx->lAndExp().size() == 1) return ctx->lAndExp(0)->accept(this); ir::Value* lhs = std::any_cast(ctx->lAndExp(0)->accept(this)); ir::Value* rhs = std::any_cast(ctx->lAndExp(1)->accept(this)); // lhs || rhs : (lhs != 0) || (rhs != 0) ir::Value* zero = builder_.CreateConstInt(0); ir::Value* lhs_ne = builder_.CreateICmp(ir::CmpInst::NE, lhs, zero, module_.GetContext().NextTemp()); ir::Value* rhs_ne = builder_.CreateICmp(ir::CmpInst::NE, rhs, zero, module_.GetContext().NextTemp()); ir::Value* or_val = builder_.CreateAdd(lhs_ne, rhs_ne, module_.GetContext().NextTemp()); ir::Value* one = builder_.CreateConstInt(1); return static_cast( builder_.CreateICmp(ir::CmpInst::GE, or_val, one, module_.GetContext().NextTemp())); }