|
|
|
|
@ -24,21 +24,75 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
|
|
|
|
|
return std::any_cast<ir::Value*>(expr.accept(this));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ir::ConstantValue* IRGenImpl::EvaluateConst(antlr4::tree::ParseTree* tree) {
|
|
|
|
|
auto val = std::any_cast<ir::Value*>(tree->accept(this));
|
|
|
|
|
auto* cval = dynamic_cast<ir::ConstantValue*>(val);
|
|
|
|
|
if (!cval) throw std::runtime_error("Not a constant expression");
|
|
|
|
|
return cval;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int IRGenImpl::EvaluateConstInt(SysYParser::ConstExpContext* ctx) {
|
|
|
|
|
if (!ctx) return 0;
|
|
|
|
|
auto* val = EvaluateConst(ctx->addExp());
|
|
|
|
|
if (auto* ci = dynamic_cast<ir::ConstantInt*>(val)) return ci->GetValue();
|
|
|
|
|
if (auto* cf = dynamic_cast<ir::ConstantFloat*>(val)) return (int)cf->GetValue();
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int IRGenImpl::EvaluateConstInt(SysYParser::ExpContext* ctx) {
|
|
|
|
|
if (!ctx) return 0;
|
|
|
|
|
auto* val = EvaluateConst(ctx);
|
|
|
|
|
if (auto* ci = dynamic_cast<ir::ConstantInt*>(val)) return ci->GetValue();
|
|
|
|
|
if (auto* cf = dynamic_cast<ir::ConstantFloat*>(val)) return (int)cf->GetValue();
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) {
|
|
|
|
|
if (!ctx || !ctx->exp()) {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "非法括号表达式"));
|
|
|
|
|
std::shared_ptr<ir::Type> IRGenImpl::GetGEPResultType(ir::Value* ptr, const std::vector<ir::Value*>& indices) {
|
|
|
|
|
auto cur_ty = ptr->GetType()->GetPointedType();
|
|
|
|
|
for (size_t i = 1; i < indices.size(); ++i) {
|
|
|
|
|
if (cur_ty->IsArray()) {
|
|
|
|
|
cur_ty = cur_ty->GetElementType();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return EvalExpr(*ctx->exp());
|
|
|
|
|
return ir::Type::GetPointerType(cur_ty);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
|
|
|
|
|
if (!ctx) {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "非法基本表达式"));
|
|
|
|
|
}
|
|
|
|
|
// 处理括号表达式:LPAREN exp RPAREN
|
|
|
|
|
if (ctx->exp()) {
|
|
|
|
|
return EvalExpr(*ctx->exp());
|
|
|
|
|
}
|
|
|
|
|
// 处理 lVal(变量使用)
|
|
|
|
|
if (ctx->lVal()) {
|
|
|
|
|
return ctx->lVal()->accept(this);
|
|
|
|
|
}
|
|
|
|
|
// 处理 number
|
|
|
|
|
if (ctx->number()) {
|
|
|
|
|
return ctx->number()->accept(this);
|
|
|
|
|
}
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "不支持的基本表达式类型"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) {
|
|
|
|
|
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量"));
|
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) {
|
|
|
|
|
if (!ctx) {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "缺少字面量节点"));
|
|
|
|
|
}
|
|
|
|
|
return static_cast<ir::Value*>(
|
|
|
|
|
builder_.CreateConstInt(std::stoi(ctx->number()->getText())));
|
|
|
|
|
if (ctx->intConst()) {
|
|
|
|
|
// 可能是 0x, 0X, 0 开头的八进制等,目前 std::stoi 会处理十进制,
|
|
|
|
|
// 为了支持 16 进制/8 进制建议使用 std::stoi(str, nullptr, 0)
|
|
|
|
|
std::string text = ctx->intConst()->getText();
|
|
|
|
|
return static_cast<ir::Value*>(
|
|
|
|
|
builder_.CreateConstInt(std::stoi(text, nullptr, 0)));
|
|
|
|
|
} else if (ctx->floatConst()) {
|
|
|
|
|
std::string text = ctx->floatConst()->getText();
|
|
|
|
|
return static_cast<ir::Value*>(
|
|
|
|
|
module_.GetContext().GetConstFloat(std::stof(text)));
|
|
|
|
|
}
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "不支持的字面量"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 变量使用的处理流程:
|
|
|
|
|
@ -47,34 +101,482 @@ std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) {
|
|
|
|
|
// 3. 最后生成 load,把内存中的值读出来。
|
|
|
|
|
//
|
|
|
|
|
// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。
|
|
|
|
|
std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) {
|
|
|
|
|
if (!ctx || !ctx->var() || !ctx->var()->ID()) {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量"));
|
|
|
|
|
std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
|
|
|
|
|
if (!ctx || !ctx->ID()) {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "非法左值"));
|
|
|
|
|
}
|
|
|
|
|
auto* decl = sema_.ResolveVarUse(ctx->var());
|
|
|
|
|
if (!decl) {
|
|
|
|
|
|
|
|
|
|
std::string var_name = ctx->ID()->getText();
|
|
|
|
|
|
|
|
|
|
// 优先检查是否为已记录的常量
|
|
|
|
|
ir::ConstantValue* const_val = FindConst(var_name);
|
|
|
|
|
if (const_val && ctx->exp().empty()) {
|
|
|
|
|
return static_cast<ir::Value*>(const_val);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const auto* binding = sema_.ResolveObjectUse(ctx);
|
|
|
|
|
if (!binding) {
|
|
|
|
|
throw std::runtime_error(
|
|
|
|
|
FormatError("irgen",
|
|
|
|
|
"变量使用缺少语义绑定: " + ctx->var()->ID()->getText()));
|
|
|
|
|
FormatError("irgen", "变量使用缺少语义绑定:" + var_name));
|
|
|
|
|
}
|
|
|
|
|
auto it = storage_map_.find(decl);
|
|
|
|
|
if (it == storage_map_.end()) {
|
|
|
|
|
|
|
|
|
|
ir::Value* slot = FindStorage(var_name);
|
|
|
|
|
if (!slot) {
|
|
|
|
|
throw std::runtime_error(
|
|
|
|
|
FormatError("irgen",
|
|
|
|
|
"变量声明缺少存储槽位: " + ctx->var()->ID()->getText()));
|
|
|
|
|
FormatError("irgen", "变量声明缺少存储槽位:" + var_name));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ir::Value* ptr = slot;
|
|
|
|
|
auto ptr_ty = ptr->GetType();
|
|
|
|
|
bool is_param = false;
|
|
|
|
|
// If it's a pointer to a pointer (function parameter case), load the pointer value first
|
|
|
|
|
if (ptr_ty->IsPointer() && ptr_ty->GetPointedType()->IsPointer()) {
|
|
|
|
|
ptr = builder_.CreateLoad(ptr, module_.GetContext().NextTemp());
|
|
|
|
|
is_param = true;
|
|
|
|
|
} else if (ptr->IsArgument()) {
|
|
|
|
|
is_param = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Determine if the result of this LVal is a scalar or an array
|
|
|
|
|
bool result_is_scalar = (ctx->exp().size() == binding->dimensions.size());
|
|
|
|
|
|
|
|
|
|
if (!ctx->exp().empty()) {
|
|
|
|
|
std::vector<ir::Value*> indices;
|
|
|
|
|
// If it's a local array, we need leading 0
|
|
|
|
|
if (ptr->GetType()->IsPointer() && ptr->GetType()->GetPointedType()->IsArray()) {
|
|
|
|
|
if (!is_param) {
|
|
|
|
|
indices.push_back(builder_.CreateConstInt(0));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto* exp_ctx : ctx->exp()) {
|
|
|
|
|
indices.push_back(EvalExpr(*exp_ctx));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto res_ptr_ty = GetGEPResultType(ptr, indices);
|
|
|
|
|
ptr = builder_.CreateGEP(res_ptr_ty, ptr, indices, module_.GetContext().NextTemp());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (result_is_scalar) {
|
|
|
|
|
return static_cast<ir::Value*>(builder_.CreateLoad(ptr, module_.GetContext().NextTemp()));
|
|
|
|
|
} else {
|
|
|
|
|
// Decay ptr to the first element of the sub-array
|
|
|
|
|
while (ptr->GetType()->GetPointedType()->IsArray()) {
|
|
|
|
|
std::vector<ir::Value*> d_indices;
|
|
|
|
|
d_indices.push_back(builder_.CreateConstInt(0));
|
|
|
|
|
d_indices.push_back(builder_.CreateConstInt(0));
|
|
|
|
|
auto d_res_ty = GetGEPResultType(ptr, d_indices);
|
|
|
|
|
ptr = builder_.CreateGEP(d_res_ty, ptr, d_indices, module_.GetContext().NextTemp());
|
|
|
|
|
}
|
|
|
|
|
return ptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
|
|
|
|
|
if (!ctx) {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "非法加减法表达式"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 如果是 mulExp 直接返回(addExp : mulExp)
|
|
|
|
|
if (ctx->mulExp() && ctx->addExp() == nullptr) {
|
|
|
|
|
return ctx->mulExp()->accept(this);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 处理 addExp op mulExp 的递归形式
|
|
|
|
|
if (!ctx->addExp() || !ctx->mulExp()) {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "非法加减法表达式结构"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->addExp()->accept(this));
|
|
|
|
|
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->mulExp()->accept(this));
|
|
|
|
|
|
|
|
|
|
if (lhs->IsConstant() && rhs->IsConstant()) {
|
|
|
|
|
auto* cl = static_cast<ir::ConstantValue*>(lhs);
|
|
|
|
|
auto* cr = static_cast<ir::ConstantValue*>(rhs);
|
|
|
|
|
if (auto* cil = dynamic_cast<ir::ConstantInt*>(cl)) {
|
|
|
|
|
if (auto* cir = dynamic_cast<ir::ConstantInt*>(cr)) {
|
|
|
|
|
if (ctx->ADD()) return static_cast<ir::Value*>(module_.GetContext().GetConstInt(cil->GetValue() + cir->GetValue()));
|
|
|
|
|
if (ctx->SUB()) return static_cast<ir::Value*>(module_.GetContext().GetConstInt(cil->GetValue() - cir->GetValue()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Implicit conversion
|
|
|
|
|
if (lhs->GetType()->IsFloat() && !rhs->GetType()->IsFloat()) {
|
|
|
|
|
if (rhs->IsConstant()) {
|
|
|
|
|
rhs = module_.GetContext().GetConstFloat((float)static_cast<ir::ConstantInt*>(rhs)->GetValue());
|
|
|
|
|
} else {
|
|
|
|
|
rhs = builder_.CreateSIToFP(rhs, module_.GetContext().NextTemp());
|
|
|
|
|
}
|
|
|
|
|
} else if (!lhs->GetType()->IsFloat() && rhs->GetType()->IsFloat()) {
|
|
|
|
|
if (lhs->IsConstant()) {
|
|
|
|
|
lhs = module_.GetContext().GetConstFloat((float)static_cast<ir::ConstantInt*>(lhs)->GetValue());
|
|
|
|
|
} else {
|
|
|
|
|
lhs = builder_.CreateSIToFP(lhs, module_.GetContext().NextTemp());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (lhs->IsConstant() && rhs->IsConstant()) {
|
|
|
|
|
auto* cl = static_cast<ir::ConstantValue*>(lhs);
|
|
|
|
|
auto* cr = static_cast<ir::ConstantValue*>(rhs);
|
|
|
|
|
if (auto* cfl = dynamic_cast<ir::ConstantFloat*>(cl)) {
|
|
|
|
|
if (auto* cfr = dynamic_cast<ir::ConstantFloat*>(cr)) {
|
|
|
|
|
if (ctx->ADD()) return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(cfl->GetValue() + cfr->GetValue()));
|
|
|
|
|
if (ctx->SUB()) return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(cfl->GetValue() - cfr->GetValue()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ir::Opcode op = ir::Opcode::Add;
|
|
|
|
|
if (ctx->ADD()) {
|
|
|
|
|
op = ir::Opcode::Add;
|
|
|
|
|
} else if (ctx->SUB()) {
|
|
|
|
|
op = ir::Opcode::Sub;
|
|
|
|
|
} else {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "未知的加减运算符"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return static_cast<ir::Value*>(
|
|
|
|
|
builder_.CreateLoad(it->second, module_.GetContext().NextTemp()));
|
|
|
|
|
builder_.CreateBinary(op, lhs, rhs, module_.GetContext().NextTemp()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
|
|
|
|
|
if (!ctx) {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "非法一元表达式"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 如果是 primaryExp 直接返回(unaryExp : primaryExp)
|
|
|
|
|
if (ctx->primaryExp()) {
|
|
|
|
|
return ctx->primaryExp()->accept(this);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 处理函数调用(unaryExp : ID LPAREN funcRParams? RPAREN)
|
|
|
|
|
if (ctx->ID()) {
|
|
|
|
|
std::string func_name = ctx->ID()->getText();
|
|
|
|
|
|
|
|
|
|
// 从 Sema 或 Module 中查找函数
|
|
|
|
|
// 目前简化处理,直接从 Module 中查找(如果是当前文件内定义的)
|
|
|
|
|
// 或者依赖 Sema 给出解析结果
|
|
|
|
|
const FunctionBinding* func_binding = sema_.ResolveFunctionCall(ctx);
|
|
|
|
|
if (!func_binding) {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "未找到函数声明:" + func_name));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 假设 func_binding 能够找到对应的 ir::Function*
|
|
|
|
|
// 这里如果 sema 不提供直接拿 ir::Function 的方式,需要遍历 module_.GetFunctions() 查找
|
|
|
|
|
ir::Function* target_func = nullptr;
|
|
|
|
|
for (const auto& f : module_.GetFunctions()) {
|
|
|
|
|
if (f->GetName() == func_name) {
|
|
|
|
|
target_func = f.get();
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!target_func) {
|
|
|
|
|
// 可能是外部函数如 putint, getint 等
|
|
|
|
|
// 如果没有在 module_ 中,则需要创建一个只有声明的 Function
|
|
|
|
|
std::shared_ptr<ir::Type> ret_ty;
|
|
|
|
|
if (func_binding->return_type == SemanticType::Int) {
|
|
|
|
|
ret_ty = ir::Type::GetInt32Type();
|
|
|
|
|
} else if (func_binding->return_type == SemanticType::Float) {
|
|
|
|
|
ret_ty = ir::Type::GetFloatType();
|
|
|
|
|
} else {
|
|
|
|
|
ret_ty = ir::Type::GetVoidType();
|
|
|
|
|
}
|
|
|
|
|
target_func = module_.CreateFunction(func_name, ret_ty);
|
|
|
|
|
// 对于外部函数,需要传递参数,可能还需要在 target_func 中 AddArgument
|
|
|
|
|
for (const auto& param : func_binding->params) {
|
|
|
|
|
std::shared_ptr<ir::Type> p_ty;
|
|
|
|
|
if (param.type == SemanticType::Int) {
|
|
|
|
|
p_ty = param.dimensions.empty() && !param.is_array_param ? ir::Type::GetInt32Type() : ir::Type::GetPtrInt32Type();
|
|
|
|
|
} else {
|
|
|
|
|
p_ty = param.dimensions.empty() && !param.is_array_param ? ir::Type::GetFloatType() : ir::Type::GetPtrFloatType();
|
|
|
|
|
}
|
|
|
|
|
target_func->AddArgument(p_ty, param.name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<ir::Value*> args;
|
|
|
|
|
if (ctx->funcRParams()) {
|
|
|
|
|
args = std::any_cast<std::vector<ir::Value*>>(ctx->funcRParams()->accept(this));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Implicit conversion for function arguments
|
|
|
|
|
const auto& formal_args = target_func->GetArgs();
|
|
|
|
|
for (size_t i = 0; i < std::min(args.size(), formal_args.size()); ++i) {
|
|
|
|
|
if (formal_args[i]->GetType()->IsFloat() && !args[i]->GetType()->IsFloat()) {
|
|
|
|
|
args[i] = builder_.CreateSIToFP(args[i], module_.GetContext().NextTemp());
|
|
|
|
|
} else if (formal_args[i]->GetType()->IsInt32() && args[i]->GetType()->IsFloat()) {
|
|
|
|
|
args[i] = builder_.CreateFPToSI(args[i], module_.GetContext().NextTemp());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return static_cast<ir::Value*>(builder_.CreateCall(target_func, args, module_.GetContext().NextTemp()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 处理一元运算符(unaryExp : addUnaryOp unaryExp)
|
|
|
|
|
if (ctx->addUnaryOp() && ctx->unaryExp()) {
|
|
|
|
|
ir::Value* operand = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
|
|
|
|
|
|
|
|
|
|
// Constant folding for unary op
|
|
|
|
|
if (operand->IsConstant()) {
|
|
|
|
|
if (ctx->addUnaryOp()->SUB()) {
|
|
|
|
|
if (auto* ci = dynamic_cast<ir::ConstantInt*>(operand)) {
|
|
|
|
|
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(-ci->GetValue()));
|
|
|
|
|
} else if (auto* cf = dynamic_cast<ir::ConstantFloat*>(operand)) {
|
|
|
|
|
return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(-cf->GetValue()));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
return operand;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 判断是正号还是负号
|
|
|
|
|
if (ctx->addUnaryOp()->SUB()) {
|
|
|
|
|
// 负号:如果是整数生成 sub 0, operand,浮点数生成 fsub 0.0, operand
|
|
|
|
|
if (operand->GetType()->IsFloat()) {
|
|
|
|
|
ir::Value* zero = module_.GetContext().GetConstFloat(0.0f);
|
|
|
|
|
// 此处暂且假设 CreateSub 可以处理浮点数(如果底层有 fsub 则更好)
|
|
|
|
|
return static_cast<ir::Value*>(
|
|
|
|
|
builder_.CreateSub(zero, operand, module_.GetContext().NextTemp()));
|
|
|
|
|
} else {
|
|
|
|
|
ir::Value* zero = builder_.CreateConstInt(0);
|
|
|
|
|
return static_cast<ir::Value*>(
|
|
|
|
|
builder_.CreateSub(zero, operand, module_.GetContext().NextTemp()));
|
|
|
|
|
}
|
|
|
|
|
} else if (ctx->addUnaryOp()->ADD()) {
|
|
|
|
|
// 正号:直接返回操作数(+x 等价于 x)
|
|
|
|
|
return operand;
|
|
|
|
|
} else {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "未知的一元运算符"));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "不支持的一元表达式类型"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
|
|
|
|
|
if (!ctx) {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "非法乘除法表达式"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 如果是 unaryExp 直接返回(mulExp : unaryExp)
|
|
|
|
|
if (ctx->unaryExp() && ctx->mulExp() == nullptr) {
|
|
|
|
|
return ctx->unaryExp()->accept(this);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 处理 mulExp op unaryExp 的递归形式
|
|
|
|
|
if (!ctx->mulExp() || !ctx->unaryExp()) {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "非法乘除法表达式结构"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->mulExp()->accept(this));
|
|
|
|
|
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
|
|
|
|
|
|
|
|
|
|
// Constant folding
|
|
|
|
|
if (lhs->IsConstant() && rhs->IsConstant()) {
|
|
|
|
|
auto* cl = static_cast<ir::ConstantValue*>(lhs);
|
|
|
|
|
auto* cr = static_cast<ir::ConstantValue*>(rhs);
|
|
|
|
|
if (auto* cil = dynamic_cast<ir::ConstantInt*>(cl)) {
|
|
|
|
|
if (auto* cir = dynamic_cast<ir::ConstantInt*>(cr)) {
|
|
|
|
|
if (ctx->MUL()) return static_cast<ir::Value*>(module_.GetContext().GetConstInt(cil->GetValue() * cir->GetValue()));
|
|
|
|
|
if (ctx->DIV()) return static_cast<ir::Value*>(module_.GetContext().GetConstInt(cil->GetValue() / cir->GetValue()));
|
|
|
|
|
if (ctx->MOD()) return static_cast<ir::Value*>(module_.GetContext().GetConstInt(cil->GetValue() % cir->GetValue()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Implicit conversion
|
|
|
|
|
if (lhs->GetType()->IsFloat() && !rhs->GetType()->IsFloat()) {
|
|
|
|
|
if (rhs->IsConstant()) {
|
|
|
|
|
rhs = module_.GetContext().GetConstFloat((float)static_cast<ir::ConstantInt*>(rhs)->GetValue());
|
|
|
|
|
} else {
|
|
|
|
|
rhs = builder_.CreateSIToFP(rhs, module_.GetContext().NextTemp());
|
|
|
|
|
}
|
|
|
|
|
} else if (!lhs->GetType()->IsFloat() && rhs->GetType()->IsFloat()) {
|
|
|
|
|
if (lhs->IsConstant()) {
|
|
|
|
|
lhs = module_.GetContext().GetConstFloat((float)static_cast<ir::ConstantInt*>(lhs)->GetValue());
|
|
|
|
|
} else {
|
|
|
|
|
lhs = builder_.CreateSIToFP(lhs, module_.GetContext().NextTemp());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) {
|
|
|
|
|
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
|
|
|
|
|
if (lhs->IsConstant() && rhs->IsConstant()) {
|
|
|
|
|
auto* cl = static_cast<ir::ConstantValue*>(lhs);
|
|
|
|
|
auto* cr = static_cast<ir::ConstantValue*>(rhs);
|
|
|
|
|
if (auto* cfl = dynamic_cast<ir::ConstantFloat*>(cl)) {
|
|
|
|
|
if (auto* cfr = dynamic_cast<ir::ConstantFloat*>(cr)) {
|
|
|
|
|
if (ctx->MUL()) return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(cfl->GetValue() * cfr->GetValue()));
|
|
|
|
|
if (ctx->DIV()) return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(cfl->GetValue() / cfr->GetValue()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ir::Opcode op = ir::Opcode::Mul;
|
|
|
|
|
if (ctx->MUL()) {
|
|
|
|
|
op = ir::Opcode::Mul;
|
|
|
|
|
} else if (ctx->DIV()) {
|
|
|
|
|
op = ir::Opcode::Div;
|
|
|
|
|
} else if (ctx->MOD()) {
|
|
|
|
|
op = ir::Opcode::Mod;
|
|
|
|
|
} else {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "未知的乘除运算符"));
|
|
|
|
|
}
|
|
|
|
|
ir::Value* lhs = EvalExpr(*ctx->exp(0));
|
|
|
|
|
ir::Value* rhs = EvalExpr(*ctx->exp(1));
|
|
|
|
|
|
|
|
|
|
return static_cast<ir::Value*>(
|
|
|
|
|
builder_.CreateBinary(ir::Opcode::Add, lhs, rhs,
|
|
|
|
|
module_.GetContext().NextTemp()));
|
|
|
|
|
builder_.CreateBinary(op, lhs, rhs, module_.GetContext().NextTemp()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
|
|
|
|
|
if (ctx->addExp() && ctx->relExp() == nullptr) {
|
|
|
|
|
return ctx->addExp()->accept(this);
|
|
|
|
|
}
|
|
|
|
|
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->relExp()->accept(this));
|
|
|
|
|
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->addExp()->accept(this));
|
|
|
|
|
if (lhs->GetType()->IsInt1()) lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp());
|
|
|
|
|
if (rhs->GetType()->IsInt1()) rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp());
|
|
|
|
|
|
|
|
|
|
// Implicit conversion
|
|
|
|
|
if (lhs->GetType()->IsFloat() && !rhs->GetType()->IsFloat()) {
|
|
|
|
|
rhs = builder_.CreateSIToFP(rhs, module_.GetContext().NextTemp());
|
|
|
|
|
} else if (!lhs->GetType()->IsFloat() && rhs->GetType()->IsFloat()) {
|
|
|
|
|
lhs = builder_.CreateSIToFP(lhs, module_.GetContext().NextTemp());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ir::CmpOp op;
|
|
|
|
|
if (ctx->LT()) op = ir::CmpOp::Lt;
|
|
|
|
|
else if (ctx->GT()) op = ir::CmpOp::Gt;
|
|
|
|
|
else if (ctx->LE()) op = ir::CmpOp::Le;
|
|
|
|
|
else if (ctx->GE()) op = ir::CmpOp::Ge;
|
|
|
|
|
else throw std::runtime_error(FormatError("irgen", "未知的关系运算符"));
|
|
|
|
|
|
|
|
|
|
return static_cast<ir::Value*>(builder_.CreateCmp(op, lhs, rhs, module_.GetContext().NextTemp()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
|
|
|
|
|
if (ctx->relExp() && ctx->eqExp() == nullptr) {
|
|
|
|
|
return ctx->relExp()->accept(this);
|
|
|
|
|
}
|
|
|
|
|
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->eqExp()->accept(this));
|
|
|
|
|
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->relExp()->accept(this));
|
|
|
|
|
if (lhs->GetType()->IsInt1()) lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp());
|
|
|
|
|
if (rhs->GetType()->IsInt1()) rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp());
|
|
|
|
|
|
|
|
|
|
// Implicit conversion
|
|
|
|
|
if (lhs->GetType()->IsFloat() && !rhs->GetType()->IsFloat()) {
|
|
|
|
|
rhs = builder_.CreateSIToFP(rhs, module_.GetContext().NextTemp());
|
|
|
|
|
} else if (!lhs->GetType()->IsFloat() && rhs->GetType()->IsFloat()) {
|
|
|
|
|
lhs = builder_.CreateSIToFP(lhs, module_.GetContext().NextTemp());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ir::CmpOp op;
|
|
|
|
|
if (ctx->EQ()) op = ir::CmpOp::Eq;
|
|
|
|
|
else if (ctx->NE()) op = ir::CmpOp::Ne;
|
|
|
|
|
else throw std::runtime_error(FormatError("irgen", "未知的相等运算符"));
|
|
|
|
|
|
|
|
|
|
return static_cast<ir::Value*>(builder_.CreateCmp(op, lhs, rhs, module_.GetContext().NextTemp()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitCondUnaryExp(SysYParser::CondUnaryExpContext* ctx) {
|
|
|
|
|
if (ctx->eqExp()) {
|
|
|
|
|
return ctx->eqExp()->accept(this);
|
|
|
|
|
}
|
|
|
|
|
if (ctx->NOT()) {
|
|
|
|
|
ir::Value* operand = std::any_cast<ir::Value*>(ctx->condUnaryExp()->accept(this));
|
|
|
|
|
if (operand->GetType()->IsInt1()) {
|
|
|
|
|
operand = builder_.CreateZext(operand, module_.GetContext().NextTemp());
|
|
|
|
|
}
|
|
|
|
|
if (operand->GetType()->IsFloat()) {
|
|
|
|
|
ir::Value* zero = module_.GetContext().GetConstFloat(0.0f);
|
|
|
|
|
return static_cast<ir::Value*>(builder_.CreateCmp(ir::CmpOp::Eq, operand, zero, module_.GetContext().NextTemp()));
|
|
|
|
|
} else {
|
|
|
|
|
ir::Value* zero = builder_.CreateConstInt(0);
|
|
|
|
|
return static_cast<ir::Value*>(builder_.CreateCmp(ir::CmpOp::Eq, operand, zero, module_.GetContext().NextTemp()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "非法条件一元表达式"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
|
|
|
|
|
if (ctx->condUnaryExp() && ctx->lAndExp() == nullptr) {
|
|
|
|
|
return ctx->condUnaryExp()->accept(this);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ir::AllocaInst* res_ptr = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
|
|
|
|
|
ir::Value* zero = builder_.CreateConstInt(0);
|
|
|
|
|
builder_.CreateStore(zero, res_ptr);
|
|
|
|
|
|
|
|
|
|
ir::BasicBlock* rhs_bb = func_->CreateBlock(NextBlockName("land_rhs"));
|
|
|
|
|
ir::BasicBlock* end_bb = func_->CreateBlock(NextBlockName("land_end"));
|
|
|
|
|
|
|
|
|
|
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->lAndExp()->accept(this));
|
|
|
|
|
if (lhs->GetType()->IsInt1()) {
|
|
|
|
|
lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp());
|
|
|
|
|
}
|
|
|
|
|
ir::Value* lhs_cond = builder_.CreateCmp(ir::CmpOp::Ne, lhs, zero, module_.GetContext().NextTemp());
|
|
|
|
|
builder_.CreateCondBr(lhs_cond, rhs_bb, end_bb);
|
|
|
|
|
|
|
|
|
|
builder_.SetInsertPoint(rhs_bb);
|
|
|
|
|
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->condUnaryExp()->accept(this));
|
|
|
|
|
if (rhs->GetType()->IsInt1()) {
|
|
|
|
|
rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp());
|
|
|
|
|
}
|
|
|
|
|
ir::Value* rhs_cond = builder_.CreateCmp(ir::CmpOp::Ne, rhs, zero, module_.GetContext().NextTemp());
|
|
|
|
|
ir::Value* rhs_res = builder_.CreateZext(rhs_cond, module_.GetContext().NextTemp());
|
|
|
|
|
builder_.CreateStore(rhs_res, res_ptr);
|
|
|
|
|
builder_.CreateBr(end_bb);
|
|
|
|
|
|
|
|
|
|
builder_.SetInsertPoint(end_bb);
|
|
|
|
|
return static_cast<ir::Value*>(builder_.CreateLoad(res_ptr, module_.GetContext().NextTemp()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
|
|
|
|
|
if (ctx->lAndExp() && ctx->lOrExp() == nullptr) {
|
|
|
|
|
return ctx->lAndExp()->accept(this);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ir::AllocaInst* res_ptr = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
|
|
|
|
|
ir::Value* one = builder_.CreateConstInt(1);
|
|
|
|
|
builder_.CreateStore(one, res_ptr);
|
|
|
|
|
|
|
|
|
|
ir::BasicBlock* rhs_bb = func_->CreateBlock(NextBlockName("lor_rhs"));
|
|
|
|
|
ir::BasicBlock* end_bb = func_->CreateBlock(NextBlockName("lor_end"));
|
|
|
|
|
|
|
|
|
|
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->lOrExp()->accept(this));
|
|
|
|
|
ir::Value* zero = builder_.CreateConstInt(0);
|
|
|
|
|
if (lhs->GetType()->IsInt1()) {
|
|
|
|
|
lhs = builder_.CreateZext(lhs, module_.GetContext().NextTemp());
|
|
|
|
|
}
|
|
|
|
|
ir::Value* lhs_cond = builder_.CreateCmp(ir::CmpOp::Eq, lhs, zero, module_.GetContext().NextTemp());
|
|
|
|
|
builder_.CreateCondBr(lhs_cond, rhs_bb, end_bb);
|
|
|
|
|
|
|
|
|
|
builder_.SetInsertPoint(rhs_bb);
|
|
|
|
|
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->lAndExp()->accept(this));
|
|
|
|
|
if (rhs->GetType()->IsInt1()) {
|
|
|
|
|
rhs = builder_.CreateZext(rhs, module_.GetContext().NextTemp());
|
|
|
|
|
}
|
|
|
|
|
ir::Value* rhs_cond = builder_.CreateCmp(ir::CmpOp::Ne, rhs, zero, module_.GetContext().NextTemp());
|
|
|
|
|
ir::Value* rhs_res = builder_.CreateZext(rhs_cond, module_.GetContext().NextTemp());
|
|
|
|
|
builder_.CreateStore(rhs_res, res_ptr);
|
|
|
|
|
builder_.CreateBr(end_bb);
|
|
|
|
|
|
|
|
|
|
builder_.SetInsertPoint(end_bb);
|
|
|
|
|
return static_cast<ir::Value*>(builder_.CreateLoad(res_ptr, module_.GetContext().NextTemp()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) {
|
|
|
|
|
if (!ctx || !ctx->lOrExp()) {
|
|
|
|
|
throw std::runtime_error(FormatError("irgen", "非法条件表达式"));
|
|
|
|
|
}
|
|
|
|
|
return ctx->lOrExp()->accept(this);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) {
|
|
|
|
|
std::vector<ir::Value*> args;
|
|
|
|
|
for (auto* exp : ctx->exp()) {
|
|
|
|
|
args.push_back(EvalExpr(*exp));
|
|
|
|
|
}
|
|
|
|
|
return args;
|
|
|
|
|
}
|