You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

401 lines
15 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "irgen/IRGen.h"
#include <stdexcept>
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
// 表达式生成当前也只实现了很小的一个子集。
// 目前支持:
// - 整数字面量
// - 普通局部变量读取
// - 括号表达式
// - 二元加法
//
// 还未支持:
// - 减乘除与一元运算
// - 赋值表达式
// - 函数调用
// - 数组、指针、下标访问
// - 条件与比较表达式
// - ...
ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
return std::any_cast<ir::Value*>(expr.accept(this));
}
std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法基本表达式"));
}
// 处理括号表达式LPAREN exp RPAREN
if (ctx->exp()) {
return EvalExpr(*ctx->exp());
}
// 处理 lVal变量使用- 交给 visitLVal 处理
if (ctx->lVal()) {
// 直接在这里处理变量读取,避免 accept 调用可能导致的问题
auto* lval_ctx = ctx->lVal();
if (!lval_ctx || !lval_ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量"));
}
const auto* decl = sema_.ResolveObjectUse(lval_ctx);
if (!decl) {
throw std::runtime_error(
FormatError("irgen",
"变量使用缺少语义绑定:" + lval_ctx->ID()->getText()));
}
std::string var_name = lval_ctx->ID()->getText();
ir::Value* slot = FindStorage(var_name);
if (!slot) {
throw std::runtime_error(
FormatError("irgen",
"变量声明缺少存储槽位:" + var_name));
}
return static_cast<ir::Value*>(
builder_.CreateLoad(slot, module_.GetContext().NextTemp()));
}
// 处理 number
if (ctx->number()) {
return ctx->number()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "不支持的基本表达式类型"));
}
std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少字面量节点"));
}
if (ctx->intConst()) {
// 可能是 0x, 0X, 0 开头的八进制等,目前 std::stoi 会处理十进制,
// 为了支持 16 进制/8 进制建议使用 std::stoi(str, nullptr, 0)
std::string text = ctx->intConst()->getText();
return static_cast<ir::Value*>(
builder_.CreateConstInt(std::stoi(text, nullptr, 0)));
} else if (ctx->floatConst()) {
std::string text = ctx->floatConst()->getText();
return static_cast<ir::Value*>(
module_.GetContext().GetConstFloat(std::stof(text)));
}
throw std::runtime_error(FormatError("irgen", "不支持的字面量"));
}
// 变量使用的处理流程:
// 1. 先通过语义分析结果把变量使用绑定回声明;
// 2. 再通过 storage_map_ 找到该声明对应的栈槽位;
// 3. 最后生成 load把内存中的值读出来。
//
// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。
std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量"));
}
std::string var_name = ctx->ID()->getText();
// 优先检查是否为已记录的常量,如果是则直接返回常量值,不再生成 Load 指令
ir::ConstantValue* const_val = FindConst(var_name);
if (const_val) {
return static_cast<ir::Value*>(const_val);
}
const auto* decl = sema_.ResolveObjectUse(ctx);
if (!decl) {
throw std::runtime_error(
FormatError("irgen",
"变量使用缺少语义绑定:" + ctx->ID()->getText()));
}
// 使用变量名查找存储槽位
ir::Value* slot = FindStorage(var_name);
if (!slot) {
throw std::runtime_error(
FormatError("irgen",
"变量声明缺少存储槽位:" + var_name));
}
return static_cast<ir::Value*>(
builder_.CreateLoad(slot, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法加减法表达式"));
}
// 如果是 mulExp 直接返回addExp : mulExp
if (ctx->mulExp() && ctx->addExp() == nullptr) {
return ctx->mulExp()->accept(this);
}
// 处理 addExp op mulExp 的递归形式
if (!ctx->addExp() || !ctx->mulExp()) {
throw std::runtime_error(FormatError("irgen", "非法加减法表达式结构"));
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->addExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->mulExp()->accept(this));
ir::Opcode op = ir::Opcode::Add;
if (ctx->ADD()) {
op = ir::Opcode::Add;
} else if (ctx->SUB()) {
op = ir::Opcode::Sub;
} else {
throw std::runtime_error(FormatError("irgen", "未知的加减运算符"));
}
return static_cast<ir::Value*>(
builder_.CreateBinary(op, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法一元表达式"));
}
// 如果是 primaryExp 直接返回unaryExp : primaryExp
if (ctx->primaryExp()) {
return ctx->primaryExp()->accept(this);
}
// 处理函数调用unaryExp : ID LPAREN funcRParams? RPAREN
if (ctx->ID()) {
std::string func_name = ctx->ID()->getText();
// 从 Sema 或 Module 中查找函数
// 目前简化处理,直接从 Module 中查找(如果是当前文件内定义的)
// 或者依赖 Sema 给出解析结果
const FunctionBinding* func_binding = sema_.ResolveFunctionCall(ctx);
if (!func_binding) {
throw std::runtime_error(FormatError("irgen", "未找到函数声明:" + func_name));
}
// 假设 func_binding 能够找到对应的 ir::Function*
// 这里如果 sema 不提供直接拿 ir::Function 的方式,需要遍历 module_.GetFunctions() 查找
ir::Function* target_func = nullptr;
for (const auto& f : module_.GetFunctions()) {
if (f->GetName() == func_name) {
target_func = f.get();
break;
}
}
if (!target_func) {
// 可能是外部函数如 putint, getint 等
// 如果没有在 module_ 中,则需要创建一个只有声明的 Function
std::shared_ptr<ir::Type> ret_ty;
if (func_binding->return_type == SemanticType::Int) {
ret_ty = ir::Type::GetInt32Type();
} else if (func_binding->return_type == SemanticType::Float) {
ret_ty = ir::Type::GetFloatType();
} else {
ret_ty = ir::Type::GetVoidType();
}
target_func = module_.CreateFunction(func_name, ret_ty);
// 对于外部函数,如果需要传递参数,可能还需要在 target_func 中 AddArgument不过外部声明通常不需要形参块。
}
std::vector<ir::Value*> args;
if (ctx->funcRParams()) {
args = std::any_cast<std::vector<ir::Value*>>(ctx->funcRParams()->accept(this));
}
return static_cast<ir::Value*>(builder_.CreateCall(target_func, args, module_.GetContext().NextTemp()));
}
// 处理一元运算符unaryExp : addUnaryOp unaryExp
if (ctx->addUnaryOp() && ctx->unaryExp()) {
ir::Value* operand = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
// 判断是正号还是负号
if (ctx->addUnaryOp()->SUB()) {
// 负号:如果是整数生成 sub 0, operand浮点数生成 fsub 0.0, operand
if (operand->GetType()->IsFloat()) {
ir::Value* zero = module_.GetContext().GetConstFloat(0.0f);
// 此处暂且假设 CreateSub 可以处理浮点数(如果底层有 fsub 则更好)
return static_cast<ir::Value*>(
builder_.CreateSub(zero, operand, module_.GetContext().NextTemp()));
} else {
ir::Value* zero = builder_.CreateConstInt(0);
return static_cast<ir::Value*>(
builder_.CreateSub(zero, operand, module_.GetContext().NextTemp()));
}
} else if (ctx->addUnaryOp()->ADD()) {
// 正号:直接返回操作数(+x 等价于 x
return operand;
} else {
throw std::runtime_error(FormatError("irgen", "未知的一元运算符"));
}
}
throw std::runtime_error(FormatError("irgen", "不支持的一元表达式类型"));
}
std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法乘除法表达式"));
}
// 如果是 unaryExp 直接返回mulExp : unaryExp
if (ctx->unaryExp() && ctx->mulExp() == nullptr) {
return ctx->unaryExp()->accept(this);
}
// 处理 mulExp op unaryExp 的递归形式
if (!ctx->mulExp() || !ctx->unaryExp()) {
throw std::runtime_error(FormatError("irgen", "非法乘除法表达式结构"));
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->mulExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
ir::Opcode op = ir::Opcode::Mul;
if (ctx->MUL()) {
op = ir::Opcode::Mul;
} else if (ctx->DIV()) {
op = ir::Opcode::Div;
} else if (ctx->MOD()) {
op = ir::Opcode::Mod;
} else {
throw std::runtime_error(FormatError("irgen", "未知的乘除运算符"));
}
return static_cast<ir::Value*>(
builder_.CreateBinary(op, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
if (ctx->addExp() && ctx->relExp() == nullptr) {
return ctx->addExp()->accept(this);
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->relExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->addExp()->accept(this));
if (lhs->GetType()->IsInt1()) lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp());
if (rhs->GetType()->IsInt1()) rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp());
ir::CmpOp op;
if (ctx->LT()) op = ir::CmpOp::Lt;
else if (ctx->GT()) op = ir::CmpOp::Gt;
else if (ctx->LE()) op = ir::CmpOp::Le;
else if (ctx->GE()) op = ir::CmpOp::Ge;
else throw std::runtime_error(FormatError("irgen", "未知的关系运算符"));
return static_cast<ir::Value*>(builder_.CreateCmp(op, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
if (ctx->relExp() && ctx->eqExp() == nullptr) {
return ctx->relExp()->accept(this);
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->eqExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->relExp()->accept(this));
if (lhs->GetType()->IsInt1()) lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp());
if (rhs->GetType()->IsInt1()) rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp());
ir::CmpOp op;
if (ctx->EQ()) op = ir::CmpOp::Eq;
else if (ctx->NE()) op = ir::CmpOp::Ne;
else throw std::runtime_error(FormatError("irgen", "未知的相等运算符"));
return static_cast<ir::Value*>(builder_.CreateCmp(op, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitCondUnaryExp(SysYParser::CondUnaryExpContext* ctx) {
if (ctx->eqExp()) {
return ctx->eqExp()->accept(this);
}
if (ctx->NOT()) {
ir::Value* operand = std::any_cast<ir::Value*>(ctx->condUnaryExp()->accept(this));
if (operand->GetType()->IsInt1()) {
operand = builder_.CreateZext(operand, module_.GetContext().NextTemp());
}
ir::Value* zero = builder_.CreateConstInt(0);
return static_cast<ir::Value*>(builder_.CreateCmp(ir::CmpOp::Eq, operand, zero, module_.GetContext().NextTemp()));
}
throw std::runtime_error(FormatError("irgen", "非法条件一元表达式"));
}
std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
if (ctx->condUnaryExp() && ctx->lAndExp() == nullptr) {
return ctx->condUnaryExp()->accept(this);
}
ir::AllocaInst* res_ptr = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
ir::Value* zero = builder_.CreateConstInt(0);
builder_.CreateStore(zero, res_ptr);
ir::BasicBlock* rhs_bb = func_->CreateBlock(NextBlockName("land_rhs"));
ir::BasicBlock* end_bb = func_->CreateBlock(NextBlockName("land_end"));
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->lAndExp()->accept(this));
if (lhs->GetType()->IsInt1()) {
lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp());
}
ir::Value* lhs_cond = builder_.CreateCmp(ir::CmpOp::Ne, lhs, zero, module_.GetContext().NextTemp());
builder_.CreateCondBr(lhs_cond, rhs_bb, end_bb);
builder_.SetInsertPoint(rhs_bb);
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->condUnaryExp()->accept(this));
if (rhs->GetType()->IsInt1()) {
rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp());
}
ir::Value* rhs_cond = builder_.CreateCmp(ir::CmpOp::Ne, rhs, zero, module_.GetContext().NextTemp());
ir::Value* rhs_res = builder_.CreateZext(rhs_cond, module_.GetContext().NextTemp());
builder_.CreateStore(rhs_res, res_ptr);
builder_.CreateBr(end_bb);
builder_.SetInsertPoint(end_bb);
return static_cast<ir::Value*>(builder_.CreateLoad(res_ptr, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
if (ctx->lAndExp() && ctx->lOrExp() == nullptr) {
return ctx->lAndExp()->accept(this);
}
ir::AllocaInst* res_ptr = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
ir::Value* one = builder_.CreateConstInt(1);
builder_.CreateStore(one, res_ptr);
ir::BasicBlock* rhs_bb = func_->CreateBlock(NextBlockName("lor_rhs"));
ir::BasicBlock* end_bb = func_->CreateBlock(NextBlockName("lor_end"));
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->lOrExp()->accept(this));
ir::Value* zero = builder_.CreateConstInt(0);
if (lhs->GetType()->IsInt1()) {
lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp());
}
ir::Value* lhs_cond = builder_.CreateCmp(ir::CmpOp::Eq, lhs, zero, module_.GetContext().NextTemp());
builder_.CreateCondBr(lhs_cond, rhs_bb, end_bb);
builder_.SetInsertPoint(rhs_bb);
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->lAndExp()->accept(this));
if (rhs->GetType()->IsInt1()) {
rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp());
}
ir::Value* rhs_cond = builder_.CreateCmp(ir::CmpOp::Ne, rhs, zero, module_.GetContext().NextTemp());
ir::Value* rhs_res = builder_.CreateZext(rhs_cond, module_.GetContext().NextTemp());
builder_.CreateStore(rhs_res, res_ptr);
builder_.CreateBr(end_bb);
builder_.SetInsertPoint(end_bb);
return static_cast<ir::Value*>(builder_.CreateLoad(res_ptr, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) {
if (!ctx || !ctx->lOrExp()) {
throw std::runtime_error(FormatError("irgen", "非法条件表达式"));
}
return ctx->lOrExp()->accept(this);
}
std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) {
std::vector<ir::Value*> args;
for (auto* exp : ctx->exp()) {
args.push_back(EvalExpr(*exp));
}
return args;
}