You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/irgen/IRGenExp.cpp

605 lines
22 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "irgen/IRGen.h"
#include <stdexcept>
#include <string>
#include "SysYParser.h"
#include "ir/IR.h"
#include "sem/func.h"
#include "utils/Log.h"
// ─── 辅助 ─────────────────────────────────────────────────────────────────────
// 把 i32 值转成 i1icmp ne i32 v, 0
ir::Value* IRGenImpl::ToI1(ir::Value* v) {
if (!v) throw std::runtime_error(FormatError("irgen", "ToI1: null value"));
if (v->IsInt1()) return v;
return builder_.CreateICmp(ir::ICmpPredicate::NE, v,
builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
}
// 把 i1 值零扩展为 i32
ir::Value* IRGenImpl::ToI32(ir::Value* v) {
if (!v) throw std::runtime_error(FormatError("irgen", "ToI32: null value"));
if (v->IsInt32()) return v;
return builder_.CreateZExt(v, module_.GetContext().NextTemp());
}
// 转换为 float如果是 int
ir::Value* IRGenImpl::ToFloat(ir::Value* v) {
if (!v) throw std::runtime_error(FormatError("irgen", "ToFloat: null value"));
if (v->IsFloat32()) return v;
if (v->IsInt32()) return builder_.CreateSIToFP(v, module_.GetContext().NextTemp());
if (v->IsInt1()) {
auto* i32 = ToI32(v);
return builder_.CreateSIToFP(i32, module_.GetContext().NextTemp());
}
throw std::runtime_error(FormatError("irgen", "ToFloat: 不支持的类型"));
}
// 转换为 int如果是 float
ir::Value* IRGenImpl::ToInt(ir::Value* v) {
if (!v) throw std::runtime_error(FormatError("irgen", "ToInt: null value"));
if (v->IsInt32()) return v;
if (v->IsFloat32()) return builder_.CreateFPToSI(v, module_.GetContext().NextTemp());
if (v->IsInt1()) return ToI32(v);
throw std::runtime_error(FormatError("irgen", "ToInt: 不支持的类型"));
}
// 隐式类型转换确保两个操作数类型一致int 转 float
void IRGenImpl::ImplicitConvert(ir::Value*& lhs, ir::Value*& rhs) {
if (!lhs || !rhs) return;
bool lhs_float = lhs->IsFloat32();
bool rhs_float = rhs->IsFloat32();
if (lhs_float && !rhs_float) {
rhs = ToFloat(rhs);
} else if (!lhs_float && rhs_float) {
lhs = ToFloat(lhs);
}
}
// 求值 expi32
ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
auto result = expr.accept(this);
return std::any_cast<ir::Value*>(result);
}
// 求值 addExpi32
ir::Value* IRGenImpl::EvalExprAdd(SysYParser::AddExpContext& expr) {
auto result = expr.accept(this);
return std::any_cast<ir::Value*>(result);
}
// 求值 condi1
ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) {
auto result = cond.accept(this);
auto* v = std::any_cast<ir::Value*>(result);
return ToI1(v);
}
// 注册外部函数声明(幂等)
void IRGenImpl::EnsureExternalDecl(const std::string& name) {
if (module_.HasExternalDecl(name) || module_.GetFunction(name)) return;
// SysY 标准运行库函数签名
if (name == "getint") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), {});
} else if (name == "getch") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), {});
} else if (name == "getfloat") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {}); // 近似
} else if (name == "putint") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
} else if (name == "putch") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
} else if (name == "putfloat") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {});
} else if (name == "putarray") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
} else if (name == "starttime" || name == "stoptime") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
} else {
// 未知外部函数,按 i32 返回声明
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), {});
}
}
// ─── 表达式访问器(返回 ir::Value*i32 ─────────────────────────────────────
std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) {
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("irgen", "非法表达式"));
}
return ctx->addExp()->accept(this);
}
// addExp : mulExp (AddOp mulExp)*
std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法加减表达式"));
}
auto muls = ctx->mulExp();
if (muls.empty()) {
throw std::runtime_error(FormatError("irgen", "addExp 缺少操作数"));
}
ir::Value* result = std::any_cast<ir::Value*>(muls[0]->accept(this));
auto ops = ctx->AddOp();
for (size_t i = 0; i < ops.size(); ++i) {
ir::Value* rhs = std::any_cast<ir::Value*>(muls[i + 1]->accept(this));
ImplicitConvert(result, rhs);
std::string tmp = module_.GetContext().NextTemp();
std::string op = ops[i]->getText();
if (result->IsFloat32()) {
result = (op == "+") ? builder_.CreateFAdd(result, rhs, tmp)
: builder_.CreateFSub(result, rhs, tmp);
} else {
result = (op == "+") ? builder_.CreateAdd(result, rhs, tmp)
: builder_.CreateSub(result, rhs, tmp);
}
}
return static_cast<ir::Value*>(result);
}
// mulExp : unaryExp (MulOp unaryExp)*
std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法乘除表达式"));
}
auto unarys = ctx->unaryExp();
if (unarys.empty()) {
throw std::runtime_error(FormatError("irgen", "mulExp 缺少操作数"));
}
ir::Value* result = std::any_cast<ir::Value*>(unarys[0]->accept(this));
auto ops = ctx->MulOp();
for (size_t i = 0; i < ops.size(); ++i) {
ir::Value* rhs = std::any_cast<ir::Value*>(unarys[i + 1]->accept(this));
ImplicitConvert(result, rhs);
std::string tmp = module_.GetContext().NextTemp();
std::string op = ops[i]->getText();
if (result->IsFloat32()) {
if (op == "*") result = builder_.CreateFMul(result, rhs, tmp);
else if (op == "/") result = builder_.CreateFDiv(result, rhs, tmp);
else throw std::runtime_error(FormatError("irgen", "float 不支持取模"));
} else {
if (op == "*") result = builder_.CreateMul(result, rhs, tmp);
else if (op == "/") result = builder_.CreateDiv(result, rhs, tmp);
else result = builder_.CreateMod(result, rhs, tmp);
}
}
return static_cast<ir::Value*>(result);
}
// unaryExp : primaryExp
// | Ident L_PAREN (funcRParams)? R_PAREN
// | unaryOp unaryExp
std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法一元表达式"));
}
// ── 一元运算符 ─────────────────────────────────────────────────────────────
if (ctx->unaryOp() && ctx->unaryExp()) {
ir::Value* operand =
std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
std::string op = ctx->unaryOp()->getText();
if (op == "-") {
if (operand->IsFloat32()) {
return static_cast<ir::Value*>(
builder_.CreateFSub(builder_.CreateConstFloat(0.0f), operand,
module_.GetContext().NextTemp()));
} else {
return static_cast<ir::Value*>(
builder_.CreateSub(builder_.CreateConstInt(0), operand,
module_.GetContext().NextTemp()));
}
} else if (op == "+") {
return static_cast<ir::Value*>(operand);
} else if (op == "!") {
ir::Value* cmp;
if (operand->IsFloat32()) {
cmp = builder_.CreateFCmp(ir::FCmpPredicate::OEQ, operand,
builder_.CreateConstFloat(0.0f),
module_.GetContext().NextTemp());
} else {
operand = ToI32(operand);
cmp = builder_.CreateICmp(ir::ICmpPredicate::EQ, operand,
builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
}
return static_cast<ir::Value*>(ToI32(cmp));
}
throw std::runtime_error(FormatError("irgen", "不支持的一元运算符: " + op));
}
// ── 函数调用 ──────────────────────────────────────────────────────────────
if (ctx->Ident() && ctx->L_PAREN()) {
std::string callee_name = ctx->Ident()->getText();
// 收集实参
std::vector<ir::Value*> args;
if (ctx->funcRParams()) {
for (auto* exp : ctx->funcRParams()->exp()) {
args.push_back(EvalExpr(*exp));
}
}
// 模块内已知函数?
ir::Function* callee = module_.GetFunction(callee_name);
if (callee) {
std::string ret_name =
callee->IsVoidReturn() ? "" : module_.GetContext().NextTemp();
auto* call =
builder_.CreateCall(callee, std::move(args), ret_name);
return static_cast<ir::Value*>(
callee->IsVoidReturn() ? static_cast<ir::Value*>(
builder_.CreateConstInt(0))
: call);
}
// 外部函数
EnsureExternalDecl(callee_name);
// 获取返回类型
std::shared_ptr<ir::Type> ret_type = ir::Type::GetInt32Type();
for (const auto& decl : module_.GetExternalDecls()) {
if (decl.name == callee_name) {
ret_type = decl.ret_type;
break;
}
}
bool is_void = ret_type->IsVoid();
std::string ret_name = is_void ? "" : module_.GetContext().NextTemp();
auto* call = builder_.CreateCallExternal(callee_name, ret_type,
std::move(args), ret_name);
// void 调用返回 0 占位
return static_cast<ir::Value*>(
is_void ? static_cast<ir::Value*>(builder_.CreateConstInt(0)) : call);
}
// ── primaryExp ────────────────────────────────────────────────────────────
if (ctx->primaryExp()) {
return ctx->primaryExp()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "非法一元表达式结构"));
}
// primaryExp : L_PAREN exp R_PAREN | lVar | number
std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法基本表达式"));
}
if (ctx->exp()) {
return ctx->exp()->accept(this);
}
if (ctx->lVar()) {
return ctx->lVar()->accept(this);
}
if (ctx->number()) {
return ctx->number()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "primaryExp 结构非法"));
}
// EvalLVarAddr计算 lVar 的地址(支持数组索引)
ir::Value* IRGenImpl::EvalLVarAddr(SysYParser::LVarContext* ctx) {
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("irgen", "非法变量引用"));
}
auto* decl = sema_.ResolveVarUse(ctx->Ident());
if (!decl) {
throw std::runtime_error(
FormatError("irgen", "变量未绑定: " + ctx->Ident()->getText()));
}
// 查找存储槽位
ir::Value* base = nullptr;
std::vector<int> dims;
auto it = storage_map_.find(decl);
if (it != storage_map_.end()) {
base = it->second;
auto dit = array_dims_.find(decl);
if (dit != array_dims_.end()) dims = dit->second;
} else {
auto git = global_storage_map_.find(decl);
if (git == global_storage_map_.end()) {
throw std::runtime_error(
FormatError("irgen", "变量无存储槽位: " + ctx->Ident()->getText()));
}
base = git->second;
auto gdit = global_array_dims_.find(decl);
if (gdit != global_array_dims_.end()) dims = gdit->second;
}
// 无索引 → 返回基地址
if (ctx->exp().empty()) return base;
// 有索引 → 计算扁平化偏移
auto indices = ctx->exp();
// 对于数组参数(第一维为-1允许索引数等于维度数
bool is_array_param = !dims.empty() && dims[0] == -1;
if (!is_array_param && indices.size() > dims.size()) {
throw std::runtime_error(FormatError("irgen", "数组索引维度过多"));
}
ir::Value* offset = builder_.CreateConstInt(0);
if (is_array_param) {
// 数组参数dims[0]=-1, dims[1..n]是已知维度
// 索引indices[0]对应第一维indices[1]对应第二维...
for (size_t i = 0; i < indices.size(); ++i) {
ir::Value* idx = EvalExpr(*indices[i]);
if (i == 0) {
// 第一维stride = dims[1] * dims[2] * ... (如果有的话)
int stride = 1;
for (size_t j = 1; j < dims.size(); ++j) {
stride *= dims[j];
}
if (stride > 1) {
ir::Value* scaled = builder_.CreateMul(
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
module_.GetContext().NextTemp());
} else {
offset = builder_.CreateAdd(offset, idx,
module_.GetContext().NextTemp());
}
} else {
// 后续维度
int stride = 1;
for (size_t j = i + 1; j < dims.size(); ++j) {
stride *= dims[j];
}
ir::Value* scaled = builder_.CreateMul(
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
module_.GetContext().NextTemp());
}
}
} else {
// 普通数组:从最后一维开始计算
int stride = 1;
for (int i = (int)dims.size() - 1; i >= 0; --i) {
stride = (i == (int)dims.size() - 1) ? 1 : stride * dims[i + 1];
if (i < (int)indices.size()) {
ir::Value* idx = EvalExpr(*indices[i]);
ir::Value* scaled = builder_.CreateMul(
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
module_.GetContext().NextTemp());
}
}
}
return builder_.CreateGep(base, offset, module_.GetContext().NextTemp());
}
// lVar : Ident (L_BRAKT exp R_BRAKT)*
// 在表达式语境下load 变量值(返回 i32
std::any IRGenImpl::visitLVar(SysYParser::LVarContext* ctx) {
if (!ctx || !ctx->Ident()) {
throw std::runtime_error(FormatError("irgen", "非法变量引用"));
}
auto* decl = sema_.ResolveVarUse(ctx->Ident());
if (!decl) {
throw std::runtime_error(
FormatError("irgen", "变量未绑定: " + ctx->Ident()->getText()));
}
// 标量常量ConstDefContext 且无索引)
if (auto* const_def = dynamic_cast<SysYParser::ConstDefContext*>(decl)) {
if (ctx->exp().empty()) {
// 先查局部
auto it = storage_map_.find(const_def);
if (it != storage_map_.end()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(it->second)) {
return static_cast<ir::Value*>(ci);
}
if (auto* cf = dynamic_cast<ir::ConstantFloat*>(it->second)) {
return static_cast<ir::Value*>(cf);
}
}
// 再查全局
auto git = global_storage_map_.find(const_def);
if (git != global_storage_map_.end()) {
if (auto* cf = dynamic_cast<ir::ConstantFloat*>(git->second)) {
return static_cast<ir::Value*>(cf);
}
}
}
}
// 通用路径:计算地址并 load
ir::Value* addr = EvalLVarAddr(ctx);
return static_cast<ir::Value*>(
builder_.CreateLoad(addr, module_.GetContext().NextTemp()));
}
// number : IntConst | FloatConst
std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法数字"));
}
if (ctx->IntConst()) {
std::string text = ctx->IntConst()->getText();
int val = 0;
try {
val = std::stoi(text, nullptr, 0);
} catch (...) {
throw std::runtime_error(
FormatError("irgen", "整数字面量解析失败: " + text));
}
return static_cast<ir::Value*>(builder_.CreateConstInt(val));
}
if (ctx->FloatConst()) {
std::string text = ctx->FloatConst()->getText();
float val = 0.0f;
try {
val = std::stof(text);
} catch (...) {
throw std::runtime_error(
FormatError("irgen", "浮点字面量解析失败: " + text));
}
return static_cast<ir::Value*>(builder_.CreateConstFloat(val));
}
throw std::runtime_error(FormatError("irgen", "非法数字节点"));
}
// ─── 条件表达式访问器(返回 ir::Value*i1 ──────────────────────────────────
std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) {
if (!ctx || !ctx->lOrExp()) {
throw std::runtime_error(FormatError("irgen", "非法条件"));
}
return ctx->lOrExp()->accept(this);
}
// lOrExp : lAndExp ('||' lAndExp)*
std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法 lOrExp"));
auto ands = ctx->lAndExp();
if (ands.empty()) throw std::runtime_error(FormatError("irgen", "lOrExp 空"));
ir::Value* result = std::any_cast<ir::Value*>(ands[0]->accept(this));
result = ToI1(result);
for (size_t i = 1; i < ands.size(); ++i) {
// 检查当前块是否已终结
if (builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()) {
break;
}
// 短路result || rhs
auto* res_slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
ir::Value* res_ext = ToI32(result);
builder_.CreateStore(res_ext, res_slot);
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.end");
builder_.CreateCondBr(result, end_bb, rhs_bb);
builder_.SetInsertPoint(rhs_bb);
ir::Value* rhs = std::any_cast<ir::Value*>(ands[i]->accept(this));
rhs = ToI32(ToI1(rhs));
if (!builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateStore(rhs, res_slot);
builder_.CreateBr(end_bb);
}
builder_.SetInsertPoint(end_bb);
result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(result);
}
// lAndExp : eqExp ('&&' eqExp)*
std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法 lAndExp"));
auto eqs = ctx->eqExp();
if (eqs.empty()) throw std::runtime_error(FormatError("irgen", "lAndExp 空"));
ir::Value* result = std::any_cast<ir::Value*>(eqs[0]->accept(this));
result = ToI1(result);
for (size_t i = 1; i < eqs.size(); ++i) {
// 检查当前块是否已终结
if (builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()) {
break;
}
auto* res_slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
ir::Value* res_ext = ToI32(result);
builder_.CreateStore(res_ext, res_slot);
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.end");
builder_.CreateCondBr(result, rhs_bb, end_bb);
builder_.SetInsertPoint(rhs_bb);
ir::Value* rhs = std::any_cast<ir::Value*>(eqs[i]->accept(this));
rhs = ToI32(ToI1(rhs));
if (!builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateStore(rhs, res_slot);
builder_.CreateBr(end_bb);
}
builder_.SetInsertPoint(end_bb);
result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(result);
}
// eqExp : relExp (EqOp relExp)*
std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法 eqExp"));
auto rels = ctx->relExp();
if (rels.empty()) throw std::runtime_error(FormatError("irgen", "eqExp 空"));
ir::Value* result = std::any_cast<ir::Value*>(rels[0]->accept(this));
auto ops = ctx->EqOp();
for (size_t i = 0; i < ops.size(); ++i) {
ir::Value* rhs = std::any_cast<ir::Value*>(rels[i + 1]->accept(this));
ir::Value* lhs = result;
ImplicitConvert(lhs, rhs);
std::string op = ops[i]->getText();
if (lhs->IsFloat32()) {
ir::FCmpPredicate pred = (op == "==") ? ir::FCmpPredicate::OEQ : ir::FCmpPredicate::ONE;
result = builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp());
} else {
lhs = ToI32(lhs);
rhs = ToI32(rhs);
ir::ICmpPredicate pred = (op == "==") ? ir::ICmpPredicate::EQ : ir::ICmpPredicate::NE;
result = builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp());
}
}
return static_cast<ir::Value*>(result);
}
// relExp : addExp (RelOp addExp)*
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法 relExp"));
auto adds = ctx->addExp();
if (adds.empty()) throw std::runtime_error(FormatError("irgen", "relExp 空"));
ir::Value* result = std::any_cast<ir::Value*>(adds[0]->accept(this));
auto ops = ctx->RelOp();
for (size_t i = 0; i < ops.size(); ++i) {
ir::Value* rhs = std::any_cast<ir::Value*>(adds[i + 1]->accept(this));
ir::Value* lhs = result;
ImplicitConvert(lhs, rhs);
std::string op = ops[i]->getText();
if (lhs->IsFloat32()) {
ir::FCmpPredicate pred;
if (op == "<") pred = ir::FCmpPredicate::OLT;
else if (op == ">") pred = ir::FCmpPredicate::OGT;
else if (op == "<=") pred = ir::FCmpPredicate::OLE;
else pred = ir::FCmpPredicate::OGE;
result = builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp());
} else {
lhs = ToI32(lhs);
rhs = ToI32(rhs);
ir::ICmpPredicate pred;
if (op == "<") pred = ir::ICmpPredicate::SLT;
else if (op == ">") pred = ir::ICmpPredicate::SGT;
else if (op == "<=") pred = ir::ICmpPredicate::SLE;
else pred = ir::ICmpPredicate::SGE;
result = builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp());
}
}
return static_cast<ir::Value*>(result);
}