You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

309 lines
14 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "irgen/IRGen.h"
#include <stdexcept>
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
// 表达式生成当前也只实现了很小的一个子集。
// 目前支持:
// - 整数字面量
// - 普通局部变量读取
// - 括号表达式
// - 二元加法
//
// 还未支持:
// - 减乘除与一元运算
// - 赋值表达式
// - 函数调用
// - 数组、指针、下标访问
// - 条件与比较表达式
// - ...
ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
return std::any_cast<ir::Value*>(expr.accept(this));
}
std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法 primary 表达式"));
if (ctx->exp()) return EvalExpr(*ctx->exp());
if (ctx->lVal()) return ctx->lVal()->accept(this);
if (ctx->number()) return ctx->number()->accept(this);
throw std::runtime_error(FormatError("irgen", "不支持的 primary 表达式"));
}
std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法数字字面量"));
if (ctx->IntConst()) {
return static_cast<ir::Value*>(builder_.CreateConstInt(std::stoi(ctx->getText())));
}
if (ctx->FloatConst()) {
return static_cast<ir::Value*>(builder_.CreateConstFloat(std::stof(ctx->getText())));
}
throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量或浮点字面量"));
}
// 变量使用的处理流程:
// 1. 先通过语义分析结果把变量使用绑定回声明;
// 2. 再通过 storage_map_ 找到该声明对应的栈槽位;
// 3. 最后生成 load把内存中的值读出来。
//
// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。
std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量"));
}
// find storage by matching declaration node stored in Sema context
// Sema stores types/decl contexts in IRGenContext maps; here we search storage_map_ by name
std::string name = ctx->Ident()->getText();
// 优先使用按名称的快速映射
auto nit = name_map_.find(name);
if (nit != name_map_.end()) {
// 支持下标访问:若有索引表达式列表,则生成 GEP + load
if (ctx->exp().size() > 0) {
std::vector<ir::Value*> indices;
// 首个索引用于穿过数组对象
indices.push_back(builder_.CreateConstInt(0));
for (auto* e : ctx->exp()) {
indices.push_back(EvalExpr(*e));
}
auto* gep = builder_.CreateGEP(nit->second, indices, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(builder_.CreateLoad(gep, module_.GetContext().NextTemp()));
}
// 如果映射到的是常量,直接返回常量值;否则按原来行为从槽位 load
if (nit->second->IsConstant()) return nit->second;
return static_cast<ir::Value*>(builder_.CreateLoad(nit->second, module_.GetContext().NextTemp()));
}
for (auto& kv : storage_map_) {
if (!kv.first) continue;
if (auto* vdef = dynamic_cast<SysYParser::VarDefContext*>(kv.first)) {
if (vdef->Ident() && vdef->Ident()->getText() == name) {
if (ctx->exp().size() > 0) {
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e));
auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(builder_.CreateLoad(gep, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp()));
}
} else if (auto* fparam = dynamic_cast<SysYParser::FuncFParamContext*>(kv.first)) {
if (fparam->Ident() && fparam->Ident()->getText() == name) {
if (ctx->exp().size() > 0) {
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e));
auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(builder_.CreateLoad(gep, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp()));
}
} else if (auto* cdef = dynamic_cast<SysYParser::ConstDefContext*>(kv.first)) {
if (cdef->Ident() && cdef->Ident()->getText() == name) {
if (ctx->exp().size() > 0) {
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e));
auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(builder_.CreateLoad(gep, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp()));
}
}
}
throw std::runtime_error(FormatError("irgen", "变量声明缺少存储槽位: " + name));
}
std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
try {
// left-associative: fold across all mulExp operands
if (ctx->mulExp().size() == 1) return ctx->mulExp(0)->accept(this);
ir::Value* cur = std::any_cast<ir::Value*>(ctx->mulExp(0)->accept(this));
// extract operator sequence from text (in-order '+' or '-')
std::string text = ctx->getText();
std::vector<char> ops;
for (char c : text) if (c == '+' || c == '-') ops.push_back(c);
for (size_t i = 1; i < ctx->mulExp().size(); ++i) {
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->mulExp(i)->accept(this));
char opch = (i - 1 < ops.size()) ? ops[i - 1] : '+';
ir::Opcode op = (opch == '-') ? ir::Opcode::Sub : ir::Opcode::Add;
cur = builder_.CreateBinary(op, cur, rhs, module_.GetContext().NextTemp());
}
return static_cast<ir::Value*>(cur);
} catch (const std::exception& e) {
LogInfo(std::string("[irgen] exception in visitAddExp text=") + ctx->getText() + ", err=" + e.what(), std::cerr);
throw;
}
}
std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法乘法表达式"));
if (ctx->unaryExp().size() == 1) return ctx->unaryExp(0)->accept(this);
ir::Value* cur = std::any_cast<ir::Value*>(ctx->unaryExp(0)->accept(this));
// extract operator sequence for '*', '/', '%'
std::string text = ctx->getText();
std::vector<char> ops;
for (char c : text) if (c == '*' || c == '/' || c == '%') ops.push_back(c);
for (size_t i = 1; i < ctx->unaryExp().size(); ++i) {
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->unaryExp(i)->accept(this));
char opch = (i - 1 < ops.size()) ? ops[i - 1] : '*';
ir::Opcode op = ir::Opcode::Mul;
if (opch == '/') op = ir::Opcode::Div;
else if (opch == '%') op = ir::Opcode::Mod;
cur = builder_.CreateBinary(op, cur, rhs, module_.GetContext().NextTemp());
}
return static_cast<ir::Value*>(cur);
}
std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法一元表达式"));
if (ctx->primaryExp()) return ctx->primaryExp()->accept(this);
// function call: Ident '(' funcRParams? ')'
if (ctx->Ident() && ctx->getText().find("(") != std::string::npos) {
std::string fname = ctx->Ident()->getText();
std::vector<ir::Value*> args;
if (ctx->funcRParams()) {
for (auto* e : ctx->funcRParams()->exp()) {
args.push_back(EvalExpr(*e));
}
}
// find existing function or create an external declaration (assume int return)
ir::Function* callee = nullptr;
for (auto &fup : module_.GetFunctions()) {
if (fup && fup->GetName() == fname) { callee = fup.get(); break; }
}
if (!callee) {
std::vector<std::shared_ptr<ir::Type>> param_types;
for (auto* a : args) {
if (a && a->IsFloat32()) param_types.push_back(ir::Type::GetFloat32Type());
else param_types.push_back(ir::Type::GetInt32Type());
}
callee = module_.CreateFunction(fname, ir::Type::GetInt32Type(), param_types);
}
return static_cast<ir::Value*>(builder_.CreateCall(callee, args, module_.GetContext().NextTemp()));
}
if (ctx->unaryExp()) {
ir::Value* val = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
if (ctx->unaryOp() && ctx->unaryOp()->getText() == "+") return static_cast<ir::Value*>(val);
else if (ctx->unaryOp() && ctx->unaryOp()->getText() == "-") {
// 负号0 - val区分整型/浮点
if (val->IsFloat32()) {
ir::Value* zero = builder_.CreateConstFloat(0.0f);
return static_cast<ir::Value*>(builder_.CreateSub(zero, val, module_.GetContext().NextTemp()));
} else {
ir::Value* zero = builder_.CreateConstInt(0);
return static_cast<ir::Value*>(builder_.CreateSub(zero, val, module_.GetContext().NextTemp()));
}
}
if (ctx->unaryOp() && ctx->unaryOp()->getText() == "!") {
// logical not: produce int 1 if val == 0, else 0
if (val->IsFloat32()) {
ir::Value* zerof = builder_.CreateConstFloat(0.0f);
return static_cast<ir::Value*>(builder_.CreateFCmp(ir::CmpInst::EQ, val, zerof, module_.GetContext().NextTemp()));
} else {
ir::Value* zero = builder_.CreateConstInt(0);
return static_cast<ir::Value*>(builder_.CreateICmp(ir::CmpInst::EQ, val, zero, module_.GetContext().NextTemp()));
}
}
}
throw std::runtime_error(FormatError("irgen", "不支持的一元运算"));
}
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
if (ctx->addExp().size() == 1) return ctx->addExp(0)->accept(this);
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->addExp(0)->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->addExp(1)->accept(this));
// 类型提升
if (lhs->IsFloat32() && rhs->IsInt32()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(rhs)) {
rhs = builder_.CreateConstFloat(static_cast<float>(ci->GetValue()));
} else {
throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换"));
}
} else if (rhs->IsFloat32() && lhs->IsInt32()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(lhs)) {
lhs = builder_.CreateConstFloat(static_cast<float>(ci->GetValue()));
} else {
throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换"));
}
}
ir::CmpInst::Predicate pred = ir::CmpInst::EQ;
std::string text = ctx->getText();
if (text.find("<=") != std::string::npos) pred = ir::CmpInst::LE;
else if (text.find(">=") != std::string::npos) pred = ir::CmpInst::GE;
else if (text.find("<") != std::string::npos) pred = ir::CmpInst::LT;
else if (text.find(">") != std::string::npos) pred = ir::CmpInst::GT;
if (lhs->IsFloat32() || rhs->IsFloat32()) {
return static_cast<ir::Value*>(
builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(
builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
if (ctx->relExp().size() == 1) return ctx->relExp(0)->accept(this);
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->relExp(0)->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->relExp(1)->accept(this));
// 类型提升
if (lhs->IsFloat32() && rhs->IsInt32()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(rhs)) {
rhs = builder_.CreateConstFloat(static_cast<float>(ci->GetValue()));
} else {
throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换"));
}
} else if (rhs->IsFloat32() && lhs->IsInt32()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(lhs)) {
lhs = builder_.CreateConstFloat(static_cast<float>(ci->GetValue()));
} else {
throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换"));
}
}
ir::CmpInst::Predicate pred = ir::CmpInst::EQ;
std::string text = ctx->getText();
if (text.find("==") != std::string::npos) pred = ir::CmpInst::EQ;
else if (text.find("!=") != std::string::npos) pred = ir::CmpInst::NE;
if (lhs->IsFloat32() || rhs->IsFloat32()) {
return static_cast<ir::Value*>(
builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(
builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
if (ctx->eqExp().size() == 1) return ctx->eqExp(0)->accept(this);
// For simplicity, treat as int (0 or 1)
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->eqExp(0)->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->eqExp(1)->accept(this));
// lhs && rhs : (lhs != 0) && (rhs != 0)
ir::Value* zero = builder_.CreateConstInt(0);
ir::Value* lhs_ne = builder_.CreateICmp(ir::CmpInst::NE, lhs, zero, module_.GetContext().NextTemp());
ir::Value* rhs_ne = builder_.CreateICmp(ir::CmpInst::NE, rhs, zero, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(
builder_.CreateMul(lhs_ne, rhs_ne, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
if (ctx->lAndExp().size() == 1) return ctx->lAndExp(0)->accept(this);
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->lAndExp(0)->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->lAndExp(1)->accept(this));
// lhs || rhs : (lhs != 0) || (rhs != 0)
ir::Value* zero = builder_.CreateConstInt(0);
ir::Value* lhs_ne = builder_.CreateICmp(ir::CmpInst::NE, lhs, zero, module_.GetContext().NextTemp());
ir::Value* rhs_ne = builder_.CreateICmp(ir::CmpInst::NE, rhs, zero, module_.GetContext().NextTemp());
ir::Value* or_val = builder_.CreateAdd(lhs_ne, rhs_ne, module_.GetContext().NextTemp());
ir::Value* one = builder_.CreateConstInt(1);
return static_cast<ir::Value*>(
builder_.CreateICmp(ir::CmpInst::GE, or_val, one, module_.GetContext().NextTemp()));
}