|
|
#include "irgen/IRGen.h"
|
|
|
|
|
|
#include <stdexcept>
|
|
|
|
|
|
#include "SysYParser.h"
|
|
|
#include "ir/IR.h"
|
|
|
#include "utils/Log.h"
|
|
|
|
|
|
// 表达式生成当前也只实现了很小的一个子集。
|
|
|
// 目前支持:
|
|
|
// - 整数字面量
|
|
|
// - 普通局部变量读取
|
|
|
// - 括号表达式
|
|
|
// - 二元加法
|
|
|
//
|
|
|
// 还未支持:
|
|
|
// - 减乘除与一元运算
|
|
|
// - 赋值表达式
|
|
|
// - 函数调用
|
|
|
// - 数组、指针、下标访问
|
|
|
// - 条件与比较表达式
|
|
|
// - ...
|
|
|
ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
|
|
|
return std::any_cast<ir::Value*>(expr.accept(this));
|
|
|
}
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
|
|
|
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法 primary 表达式"));
|
|
|
if (ctx->exp()) return EvalExpr(*ctx->exp());
|
|
|
if (ctx->lVal()) return ctx->lVal()->accept(this);
|
|
|
if (ctx->number()) return ctx->number()->accept(this);
|
|
|
throw std::runtime_error(FormatError("irgen", "不支持的 primary 表达式"));
|
|
|
}
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) {
|
|
|
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法数字字面量"));
|
|
|
if (ctx->IntConst()) {
|
|
|
return static_cast<ir::Value*>(builder_.CreateConstInt(std::stoi(ctx->getText())));
|
|
|
}
|
|
|
if (ctx->FloatConst()) {
|
|
|
return static_cast<ir::Value*>(builder_.CreateConstFloat(std::stof(ctx->getText())));
|
|
|
}
|
|
|
throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量或浮点字面量"));
|
|
|
}
|
|
|
|
|
|
// 变量使用的处理流程:
|
|
|
// 1. 先通过语义分析结果把变量使用绑定回声明;
|
|
|
// 2. 再通过 storage_map_ 找到该声明对应的栈槽位;
|
|
|
// 3. 最后生成 load,把内存中的值读出来。
|
|
|
//
|
|
|
// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。
|
|
|
std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
|
|
|
if (!ctx || !ctx->Ident()) {
|
|
|
throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量"));
|
|
|
}
|
|
|
// find storage by matching declaration node stored in Sema context
|
|
|
// Sema stores types/decl contexts in IRGenContext maps; here we search storage_map_ by name
|
|
|
std::string name = ctx->Ident()->getText();
|
|
|
// 优先使用按名称的快速映射
|
|
|
auto nit = name_map_.find(name);
|
|
|
if (nit != name_map_.end()) {
|
|
|
// 支持下标访问:若有索引表达式列表,则生成 GEP + load
|
|
|
if (ctx->exp().size() > 0) {
|
|
|
std::vector<ir::Value*> indices;
|
|
|
// 首个索引用于穿过数组对象
|
|
|
indices.push_back(builder_.CreateConstInt(0));
|
|
|
for (auto* e : ctx->exp()) {
|
|
|
indices.push_back(EvalExpr(*e));
|
|
|
}
|
|
|
auto* gep = builder_.CreateGEP(nit->second, indices, module_.GetContext().NextTemp());
|
|
|
return static_cast<ir::Value*>(builder_.CreateLoad(gep, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
// 如果映射到的是常量,直接返回常量值;否则按原来行为从槽位 load
|
|
|
if (nit->second->IsConstant()) return nit->second;
|
|
|
return static_cast<ir::Value*>(builder_.CreateLoad(nit->second, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
for (auto& kv : storage_map_) {
|
|
|
if (!kv.first) continue;
|
|
|
if (auto* vdef = dynamic_cast<SysYParser::VarDefContext*>(kv.first)) {
|
|
|
if (vdef->Ident() && vdef->Ident()->getText() == name) {
|
|
|
if (ctx->exp().size() > 0) {
|
|
|
std::vector<ir::Value*> indices;
|
|
|
indices.push_back(builder_.CreateConstInt(0));
|
|
|
for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e));
|
|
|
auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp());
|
|
|
return static_cast<ir::Value*>(builder_.CreateLoad(gep, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
return static_cast<ir::Value*>(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
} else if (auto* fparam = dynamic_cast<SysYParser::FuncFParamContext*>(kv.first)) {
|
|
|
if (fparam->Ident() && fparam->Ident()->getText() == name) {
|
|
|
if (ctx->exp().size() > 0) {
|
|
|
std::vector<ir::Value*> indices;
|
|
|
indices.push_back(builder_.CreateConstInt(0));
|
|
|
for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e));
|
|
|
auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp());
|
|
|
return static_cast<ir::Value*>(builder_.CreateLoad(gep, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
return static_cast<ir::Value*>(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
} else if (auto* cdef = dynamic_cast<SysYParser::ConstDefContext*>(kv.first)) {
|
|
|
if (cdef->Ident() && cdef->Ident()->getText() == name) {
|
|
|
if (ctx->exp().size() > 0) {
|
|
|
std::vector<ir::Value*> indices;
|
|
|
indices.push_back(builder_.CreateConstInt(0));
|
|
|
for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e));
|
|
|
auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp());
|
|
|
return static_cast<ir::Value*>(builder_.CreateLoad(gep, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
return static_cast<ir::Value*>(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
throw std::runtime_error(FormatError("irgen", "变量声明缺少存储槽位: " + name));
|
|
|
}
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
|
|
|
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
|
|
|
try {
|
|
|
// left-associative: fold across all mulExp operands
|
|
|
if (ctx->mulExp().size() == 1) return ctx->mulExp(0)->accept(this);
|
|
|
ir::Value* cur = std::any_cast<ir::Value*>(ctx->mulExp(0)->accept(this));
|
|
|
// extract operator sequence from text (in-order '+' or '-')
|
|
|
std::string text = ctx->getText();
|
|
|
std::vector<char> ops;
|
|
|
for (char c : text) if (c == '+' || c == '-') ops.push_back(c);
|
|
|
for (size_t i = 1; i < ctx->mulExp().size(); ++i) {
|
|
|
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->mulExp(i)->accept(this));
|
|
|
char opch = (i - 1 < ops.size()) ? ops[i - 1] : '+';
|
|
|
ir::Opcode op = (opch == '-') ? ir::Opcode::Sub : ir::Opcode::Add;
|
|
|
cur = builder_.CreateBinary(op, cur, rhs, module_.GetContext().NextTemp());
|
|
|
}
|
|
|
return static_cast<ir::Value*>(cur);
|
|
|
} catch (const std::exception& e) {
|
|
|
LogInfo(std::string("[irgen] exception in visitAddExp text=") + ctx->getText() + ", err=" + e.what(), std::cerr);
|
|
|
throw;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
|
|
|
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法乘法表达式"));
|
|
|
if (ctx->unaryExp().size() == 1) return ctx->unaryExp(0)->accept(this);
|
|
|
ir::Value* cur = std::any_cast<ir::Value*>(ctx->unaryExp(0)->accept(this));
|
|
|
// extract operator sequence for '*', '/', '%'
|
|
|
std::string text = ctx->getText();
|
|
|
std::vector<char> ops;
|
|
|
for (char c : text) if (c == '*' || c == '/' || c == '%') ops.push_back(c);
|
|
|
for (size_t i = 1; i < ctx->unaryExp().size(); ++i) {
|
|
|
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->unaryExp(i)->accept(this));
|
|
|
char opch = (i - 1 < ops.size()) ? ops[i - 1] : '*';
|
|
|
ir::Opcode op = ir::Opcode::Mul;
|
|
|
if (opch == '/') op = ir::Opcode::Div;
|
|
|
else if (opch == '%') op = ir::Opcode::Mod;
|
|
|
cur = builder_.CreateBinary(op, cur, rhs, module_.GetContext().NextTemp());
|
|
|
}
|
|
|
return static_cast<ir::Value*>(cur);
|
|
|
}
|
|
|
|
|
|
std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
|
|
|
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法一元表达式"));
|
|
|
if (ctx->primaryExp()) return ctx->primaryExp()->accept(this);
|
|
|
// function call: Ident '(' funcRParams? ')'
|
|
|
if (ctx->Ident() && ctx->getText().find("(") != std::string::npos) {
|
|
|
std::string fname = ctx->Ident()->getText();
|
|
|
std::vector<ir::Value*> args;
|
|
|
if (ctx->funcRParams()) {
|
|
|
for (auto* e : ctx->funcRParams()->exp()) {
|
|
|
args.push_back(EvalExpr(*e));
|
|
|
}
|
|
|
}
|
|
|
// find existing function or create an external declaration (assume int return)
|
|
|
ir::Function* callee = nullptr;
|
|
|
for (auto &fup : module_.GetFunctions()) {
|
|
|
if (fup && fup->GetName() == fname) { callee = fup.get(); break; }
|
|
|
}
|
|
|
if (!callee) {
|
|
|
std::vector<std::shared_ptr<ir::Type>> param_types;
|
|
|
for (auto* a : args) {
|
|
|
if (a && a->IsFloat32()) param_types.push_back(ir::Type::GetFloat32Type());
|
|
|
else param_types.push_back(ir::Type::GetInt32Type());
|
|
|
}
|
|
|
callee = module_.CreateFunction(fname, ir::Type::GetInt32Type(), param_types);
|
|
|
}
|
|
|
return static_cast<ir::Value*>(builder_.CreateCall(callee, args, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
if (ctx->unaryExp()) {
|
|
|
ir::Value* val = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
|
|
|
if (ctx->unaryOp() && ctx->unaryOp()->getText() == "+") return static_cast<ir::Value*>(val);
|
|
|
else if (ctx->unaryOp() && ctx->unaryOp()->getText() == "-") {
|
|
|
// 负号:0 - val,区分整型/浮点
|
|
|
if (val->IsFloat32()) {
|
|
|
ir::Value* zero = builder_.CreateConstFloat(0.0f);
|
|
|
return static_cast<ir::Value*>(builder_.CreateSub(zero, val, module_.GetContext().NextTemp()));
|
|
|
} else {
|
|
|
ir::Value* zero = builder_.CreateConstInt(0);
|
|
|
return static_cast<ir::Value*>(builder_.CreateSub(zero, val, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
}
|
|
|
if (ctx->unaryOp() && ctx->unaryOp()->getText() == "!") {
|
|
|
// logical not: produce int 1 if val == 0, else 0
|
|
|
if (val->IsFloat32()) {
|
|
|
ir::Value* zerof = builder_.CreateConstFloat(0.0f);
|
|
|
return static_cast<ir::Value*>(builder_.CreateFCmp(ir::CmpInst::EQ, val, zerof, module_.GetContext().NextTemp()));
|
|
|
} else {
|
|
|
ir::Value* zero = builder_.CreateConstInt(0);
|
|
|
return static_cast<ir::Value*>(builder_.CreateICmp(ir::CmpInst::EQ, val, zero, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
throw std::runtime_error(FormatError("irgen", "不支持的一元运算"));
|
|
|
}
|
|
|
|
|
|
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
|
|
|
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
|
|
|
if (ctx->addExp().size() == 1) return ctx->addExp(0)->accept(this);
|
|
|
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->addExp(0)->accept(this));
|
|
|
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->addExp(1)->accept(this));
|
|
|
// 类型提升
|
|
|
if (lhs->IsFloat32() && rhs->IsInt32()) {
|
|
|
if (auto* ci = dynamic_cast<ir::ConstantInt*>(rhs)) {
|
|
|
rhs = builder_.CreateConstFloat(static_cast<float>(ci->GetValue()));
|
|
|
} else {
|
|
|
throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换"));
|
|
|
}
|
|
|
} else if (rhs->IsFloat32() && lhs->IsInt32()) {
|
|
|
if (auto* ci = dynamic_cast<ir::ConstantInt*>(lhs)) {
|
|
|
lhs = builder_.CreateConstFloat(static_cast<float>(ci->GetValue()));
|
|
|
} else {
|
|
|
throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换"));
|
|
|
}
|
|
|
}
|
|
|
ir::CmpInst::Predicate pred = ir::CmpInst::EQ;
|
|
|
std::string text = ctx->getText();
|
|
|
if (text.find("<=") != std::string::npos) pred = ir::CmpInst::LE;
|
|
|
else if (text.find(">=") != std::string::npos) pred = ir::CmpInst::GE;
|
|
|
else if (text.find("<") != std::string::npos) pred = ir::CmpInst::LT;
|
|
|
else if (text.find(">") != std::string::npos) pred = ir::CmpInst::GT;
|
|
|
if (lhs->IsFloat32() || rhs->IsFloat32()) {
|
|
|
return static_cast<ir::Value*>(
|
|
|
builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
return static_cast<ir::Value*>(
|
|
|
builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
|
|
|
std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
|
|
|
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
|
|
|
if (ctx->relExp().size() == 1) return ctx->relExp(0)->accept(this);
|
|
|
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->relExp(0)->accept(this));
|
|
|
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->relExp(1)->accept(this));
|
|
|
// 类型提升
|
|
|
if (lhs->IsFloat32() && rhs->IsInt32()) {
|
|
|
if (auto* ci = dynamic_cast<ir::ConstantInt*>(rhs)) {
|
|
|
rhs = builder_.CreateConstFloat(static_cast<float>(ci->GetValue()));
|
|
|
} else {
|
|
|
throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换"));
|
|
|
}
|
|
|
} else if (rhs->IsFloat32() && lhs->IsInt32()) {
|
|
|
if (auto* ci = dynamic_cast<ir::ConstantInt*>(lhs)) {
|
|
|
lhs = builder_.CreateConstFloat(static_cast<float>(ci->GetValue()));
|
|
|
} else {
|
|
|
throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换"));
|
|
|
}
|
|
|
}
|
|
|
ir::CmpInst::Predicate pred = ir::CmpInst::EQ;
|
|
|
std::string text = ctx->getText();
|
|
|
if (text.find("==") != std::string::npos) pred = ir::CmpInst::EQ;
|
|
|
else if (text.find("!=") != std::string::npos) pred = ir::CmpInst::NE;
|
|
|
if (lhs->IsFloat32() || rhs->IsFloat32()) {
|
|
|
return static_cast<ir::Value*>(
|
|
|
builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
return static_cast<ir::Value*>(
|
|
|
builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
|
|
|
std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
|
|
|
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
|
|
|
if (ctx->eqExp().size() == 1) return ctx->eqExp(0)->accept(this);
|
|
|
// For simplicity, treat as int (0 or 1)
|
|
|
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->eqExp(0)->accept(this));
|
|
|
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->eqExp(1)->accept(this));
|
|
|
// lhs && rhs : (lhs != 0) && (rhs != 0)
|
|
|
ir::Value* zero = builder_.CreateConstInt(0);
|
|
|
ir::Value* lhs_ne = builder_.CreateICmp(ir::CmpInst::NE, lhs, zero, module_.GetContext().NextTemp());
|
|
|
ir::Value* rhs_ne = builder_.CreateICmp(ir::CmpInst::NE, rhs, zero, module_.GetContext().NextTemp());
|
|
|
return static_cast<ir::Value*>(
|
|
|
builder_.CreateMul(lhs_ne, rhs_ne, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
|
|
|
std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
|
|
|
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
|
|
|
if (ctx->lAndExp().size() == 1) return ctx->lAndExp(0)->accept(this);
|
|
|
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->lAndExp(0)->accept(this));
|
|
|
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->lAndExp(1)->accept(this));
|
|
|
// lhs || rhs : (lhs != 0) || (rhs != 0)
|
|
|
ir::Value* zero = builder_.CreateConstInt(0);
|
|
|
ir::Value* lhs_ne = builder_.CreateICmp(ir::CmpInst::NE, lhs, zero, module_.GetContext().NextTemp());
|
|
|
ir::Value* rhs_ne = builder_.CreateICmp(ir::CmpInst::NE, rhs, zero, module_.GetContext().NextTemp());
|
|
|
ir::Value* or_val = builder_.CreateAdd(lhs_ne, rhs_ne, module_.GetContext().NextTemp());
|
|
|
ir::Value* one = builder_.CreateConstInt(1);
|
|
|
return static_cast<ir::Value*>(
|
|
|
builder_.CreateICmp(ir::CmpInst::GE, or_val, one, module_.GetContext().NextTemp()));
|
|
|
}
|
|
|
|