#include "irgen/IRGen.h" #include #include #include "SysYParser.h" #include "ir/IR.h" #include "utils/Log.h" namespace { ir::Value* AnyToValue(const std::any& val) { if (val.type() == typeid(ir::Value*)) { return std::any_cast(val); } if (val.type() == typeid(ir::ConstantInt*)) { return static_cast(std::any_cast(val)); } if (val.type() == typeid(ir::ConstantFloat*)) { return static_cast(std::any_cast(val)); } if (val.type() == typeid(ir::Instruction*)) { return std::any_cast(val); } std::cerr << "Unknown type in AnyToValue: " << val.type().name() << std::endl; throw std::bad_any_cast(); } ir::Function* FindFunctionByName(ir::Module& module, const std::string& name) { for (const auto& fn : module.GetFunctions()) { if (fn && fn->GetName() == name) return fn.get(); } return nullptr; } } // namespace ir::Value* IRGenImpl::EvalExp(SysYParser::ExpContext* ctx) { if (!ctx) return nullptr; return AnyToValue(ctx->accept(this)); } std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) { if (!ctx) return static_cast(nullptr); if (ctx->addExp()) return ctx->addExp()->accept(this); throw std::runtime_error(FormatError("irgen", "不支持的表达式")); } std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { auto muls = ctx->mulExp(); if (muls.empty()) return static_cast(nullptr); ir::Value* lhs = AnyToValue(muls[0]->accept(this)); for (size_t i = 1; i < muls.size(); ++i) { ir::Value* rhs = AnyToValue(muls[i]->accept(this)); auto* node = ctx->children.at(2 * i - 1); std::string text = node ? node->getText() : "+"; bool use_float = lhs->IsFloat() || rhs->IsFloat(); if (use_float) { lhs = CastToFloat(lhs); rhs = CastToFloat(rhs); if (text == "+") { lhs = builder_.CreateFAdd(lhs, rhs, module_.GetContext().NextTemp()); } else { lhs = builder_.CreateFSub(lhs, rhs, module_.GetContext().NextTemp()); } } else { if (text == "+") { lhs = builder_.CreateAdd(lhs, rhs, module_.GetContext().NextTemp()); } else { lhs = builder_.CreateSub(lhs, rhs, module_.GetContext().NextTemp()); } } } return static_cast(lhs); } std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { auto unaries = ctx->unaryExp(); if (unaries.empty()) return static_cast(nullptr); ir::Value* lhs = AnyToValue(unaries[0]->accept(this)); for (size_t i = 1; i < unaries.size(); ++i) { ir::Value* rhs = AnyToValue(unaries[i]->accept(this)); auto* node = ctx->children.at(2 * i - 1); std::string text = node ? node->getText() : "*"; bool use_float = lhs->IsFloat() || rhs->IsFloat(); if (use_float) { lhs = CastToFloat(lhs); rhs = CastToFloat(rhs); if (text == "*") { lhs = builder_.CreateFMul(lhs, rhs, module_.GetContext().NextTemp()); } else if (text == "/") { lhs = builder_.CreateFDiv(lhs, rhs, module_.GetContext().NextTemp()); } else { throw std::runtime_error(FormatError("irgen", "float 不支持 %")); } } else { if (text == "*") { lhs = builder_.CreateMul(lhs, rhs, module_.GetContext().NextTemp()); } else if (text == "/") { lhs = builder_.CreateSDiv(lhs, rhs, module_.GetContext().NextTemp()); } else { lhs = builder_.CreateSRem(lhs, rhs, module_.GetContext().NextTemp()); } } } return static_cast(lhs); } std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { if (ctx->primaryExp()) return ctx->primaryExp()->accept(this); if (ctx->ID() && ctx->LPAREN()) { std::string name = ctx->ID()->getText(); ir::Function* callee = FindFunctionByName(module_, name); if (!callee) { throw std::runtime_error(FormatError("irgen", "未定义函数: " + name)); } std::vector args; if (ctx->funcRParams()) { for (auto* exp : ctx->funcRParams()->exp()) { args.push_back(EvalExp(exp)); } } const auto& param_tys = callee->GetFunctionType()->GetParamTypes(); for (size_t i = 0; i < args.size() && i < param_tys.size(); ++i) { auto* arg = args[i]; const auto& pty = param_tys[i]; if (pty->IsPointer()) { if (arg && arg->GetType() && arg->GetType()->IsPointer()) { bool param_elem_array = pty->GetElementType()->IsArray(); bool arg_elem_array = arg->GetType()->GetElementType()->IsArray(); if (!param_elem_array && arg_elem_array) { std::vector idx = {builder_.CreateConstInt(0), builder_.CreateConstInt(0)}; args[i] = builder_.CreateGep(arg, std::move(idx), module_.GetContext().NextTemp()); arg = args[i]; } else if (param_elem_array && arg_elem_array) { auto* param_elem = pty->GetElementType().get(); auto* arg_elem = arg->GetType()->GetElementType().get(); if (param_elem && arg_elem && arg_elem->IsArray() && arg_elem->GetElementType()->Equals(*param_elem)) { std::vector idx = {builder_.CreateConstInt(0), builder_.CreateConstInt(0)}; args[i] = builder_.CreateGep(arg, std::move(idx), module_.GetContext().NextTemp()); arg = args[i]; } } else if (param_elem_array && !arg_elem_array) { if (auto* gep = dynamic_cast(arg)) { const auto& idx = gep->GetIndices(); auto is_zero = [](ir::Value* v) { auto* ci = dynamic_cast(v); return ci && ci->GetValue() == 0; }; if (idx.size() == 2 && is_zero(idx[0]) && is_zero(idx[1])) { auto* base = gep->GetBasePtr(); if (base && base->GetType() && base->GetType()->IsPointer() && base->GetType()->GetElementType()->IsArray()) { args[i] = base; arg = base; } } } } } else if (arg && arg->GetType() && arg->GetType()->IsInt32()) { if (auto* load = dynamic_cast(arg)) { auto* base = load->GetPtr(); if (base && base->GetType() && base->GetType()->IsPointer()) { auto* elem = base->GetType()->GetElementType().get(); if (elem && elem->IsPointer() && elem->GetElementType()->Equals(*pty->GetElementType())) { args[i] = base; arg = base; } } } } } if (pty->IsFloat() && (arg->IsInt1() || arg->IsInt32())) { args[i] = CastToFloat(arg); } else if (pty->IsInt32() && (arg->IsInt1() || arg->IsFloat())) { args[i] = CastToInt(arg); } } std::string tmp = callee->GetReturnType()->IsVoid() ? std::string("") : module_.GetContext().NextTemp(); auto* call = builder_.CreateCall(callee, std::move(args), tmp); return static_cast(call); } if (ctx->unaryOp() && ctx->unaryExp()) { std::string op = ctx->unaryOp()->getText(); ir::Value* val = AnyToValue(ctx->unaryExp()->accept(this)); if (op == "+") return static_cast(val); if (op == "-") { if (val->IsFloat()) { auto* zero = builder_.CreateConstFloat(0.0f); return static_cast( builder_.CreateFSub(zero, val, module_.GetContext().NextTemp())); } auto* zero = builder_.CreateConstInt(0); return static_cast( builder_.CreateSub(zero, val, module_.GetContext().NextTemp())); } if (op == "!") { ir::Value* b = MakeBool(val); auto* zero = builder_.CreateConstBool(false); return static_cast( builder_.CreateICmp(ir::ICmpPredicate::Eq, b, zero, module_.GetContext().NextTemp())); } } return static_cast(nullptr); } std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { ir::Value* val = EmitRelEq(ctx); return static_cast(val); } std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { ir::Value* val = EmitEq(ctx); return static_cast(val); } std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { return ctx->eqExp(0)->accept(this); } std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { return ctx->lAndExp(0)->accept(this); } std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) { if (ctx->lOrExp()) return ctx->lOrExp()->accept(this); return static_cast(nullptr); } std::any IRGenImpl::visitFuncRParams(SysYParser::FuncRParamsContext* ctx) { return ctx ? static_cast(nullptr) : static_cast(nullptr); } std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { if (!ctx) return static_cast(nullptr); if (ctx->LPAREN() && ctx->exp()) return EvalExp(ctx->exp()); if (ctx->lVal()) return ctx->lVal()->accept(this); if (ctx->number()) return ctx->number()->accept(this); throw std::runtime_error(FormatError("irgen", "不支持的 PrimaryExp")); } std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "非法常量")); } if (ctx->INT_CONST()) { const std::string text = ctx->getText(); size_t idx = 0; long long val = std::stoll(text, &idx, 0); if (idx != text.size()) { throw std::runtime_error(FormatError("irgen", "非法整数常量")); } return static_cast(builder_.CreateConstInt(val)); } if (ctx->FLOAT_CONST()) { float val = std::stof(ctx->getText()); return static_cast(builder_.CreateConstFloat(val)); } throw std::runtime_error(FormatError("irgen", "不支持的常量类型")); } std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { if (!ctx || !ctx->ID()) { throw std::runtime_error(FormatError("irgen", "非法左值")); } ir::Value* addr = GetLValAddress(ctx); BoundDecl bound = sema_.ResolveVarUse(ctx); const TypeDesc* ty = nullptr; if (bound.kind == BoundDecl::Kind::Var && bound.var_decl) { ty = sema_.GetVarType(bound.var_decl); } else if (bound.kind == BoundDecl::Kind::Const && bound.const_decl) { ty = sema_.GetConstType(bound.const_decl); } else if (bound.kind == BoundDecl::Kind::Param && bound.param_decl) { ty = sema_.GetParamType(bound.param_decl); } if (!ty && ctx->ID()) { const auto name = ctx->ID()->getText(); for (const auto& kv : var_storage_) { auto* def = const_cast(kv.first); if (def && def->ID() && def->ID()->getText() == name) { ty = sema_.GetVarType(kv.first); break; } } if (!ty) { for (const auto& kv : const_storage_) { auto* def = const_cast(kv.first); if (def && def->ID() && def->ID()->getText() == name) { ty = sema_.GetConstType(kv.first); break; } } } if (!ty) { for (const auto& kv : param_storage_) { auto* def = const_cast(kv.first); if (def && def->ID() && def->ID()->getText() == name) { ty = sema_.GetParamType(kv.first); break; } } } if (!ty) { for (const auto& kv : global_var_storage_) { auto* def = const_cast(kv.first); if (def && def->ID() && def->ID()->getText() == name) { ty = sema_.GetVarType(kv.first); break; } } } if (!ty) { for (const auto& kv : global_const_storage_) { auto* def = const_cast(kv.first); if (def && def->ID() && def->ID()->getText() == name) { ty = sema_.GetConstType(kv.first); break; } } } } if (!ty) { throw std::runtime_error(FormatError("irgen", "无法解析左值类型")); } bool as_rvalue = true; if (!ty->dims.empty()) { const size_t index_count = ctx->exp().size(); if (index_count == 0 || index_count < ty->dims.size()) { as_rvalue = false; } } return static_cast(LoadIfNeeded(addr, *ty, as_rvalue)); } std::any IRGenImpl::visitConstExp(SysYParser::ConstExpContext* ctx) { if (!ctx) return static_cast(nullptr); return ctx->addExp()->accept(this); } std::any IRGenImpl::visitConstInitVal(SysYParser::ConstInitValContext* ctx) { return ctx ? static_cast(nullptr) : static_cast(nullptr); } std::any IRGenImpl::visitInitVal(SysYParser::InitValContext* ctx) { return ctx ? static_cast(nullptr) : static_cast(nullptr); } ir::Value* IRGenImpl::CastToFloat(ir::Value* v) { if (v->IsFloat()) return v; if (v->IsInt1() || v->IsInt32()) { return builder_.CreateSIToFP(v, ir::Type::GetFloatType(), module_.GetContext().NextTemp()); } throw std::runtime_error(FormatError("irgen", "无法转换为 float")); } ir::Value* IRGenImpl::CastToInt(ir::Value* v) { if (v->IsInt32()) return v; if (v->IsInt1()) { return builder_.CreateZExt(v, ir::Type::GetInt32Type(), module_.GetContext().NextTemp()); } if (v->IsFloat()) { return builder_.CreateFPToSI(v, ir::Type::GetInt32Type(), module_.GetContext().NextTemp()); } throw std::runtime_error(FormatError("irgen", "无法转换为 int")); } ir::Value* IRGenImpl::MakeBool(ir::Value* v) { if (v->IsInt1()) return v; if (v->IsFloat()) { auto* zero = builder_.CreateConstFloat(0.0f); return builder_.CreateFCmp(ir::FCmpPredicate::One, v, zero, module_.GetContext().NextTemp()); } auto* zero = builder_.CreateConstInt(0); return builder_.CreateICmp(ir::ICmpPredicate::Ne, v, zero, module_.GetContext().NextTemp()); } ir::Value* IRGenImpl::EmitRelEq(SysYParser::RelExpContext* ctx) { auto exps = ctx->addExp(); if (exps.empty()) return nullptr; ir::Value* lhs = AnyToValue(exps[0]->accept(this)); if (exps.size() == 1) return lhs; for (size_t i = 1; i < exps.size(); ++i) { ir::Value* rhs = AnyToValue(exps[i]->accept(this)); auto* node = ctx->children.at(2 * i - 1); std::string text = node ? node->getText() : "<"; bool use_float = lhs->IsFloat() || rhs->IsFloat(); if (use_float) { lhs = CastToFloat(lhs); rhs = CastToFloat(rhs); ir::FCmpPredicate pred = ir::FCmpPredicate::Olt; if (text == "<") pred = ir::FCmpPredicate::Olt; else if (text == "<=") pred = ir::FCmpPredicate::Ole; else if (text == ">") pred = ir::FCmpPredicate::Ogt; else pred = ir::FCmpPredicate::Oge; lhs = builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp()); } else { ir::ICmpPredicate pred = ir::ICmpPredicate::Slt; if (text == "<") pred = ir::ICmpPredicate::Slt; else if (text == "<=") pred = ir::ICmpPredicate::Sle; else if (text == ">") pred = ir::ICmpPredicate::Sgt; else pred = ir::ICmpPredicate::Sge; lhs = builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp()); } } return lhs; } ir::Value* IRGenImpl::EmitEq(SysYParser::EqExpContext* ctx) { auto rels = ctx->relExp(); if (rels.empty()) return nullptr; ir::Value* lhs = EmitRelEq(rels[0]); if (rels.size() == 1) return lhs; for (size_t i = 1; i < rels.size(); ++i) { ir::Value* rhs = EmitRelEq(rels[i]); auto* node = ctx->children.at(2 * i - 1); std::string text = node ? node->getText() : "=="; if (lhs->IsFloat() || rhs->IsFloat()) { lhs = CastToFloat(lhs); rhs = CastToFloat(rhs); ir::FCmpPredicate pred = text == "==" ? ir::FCmpPredicate::Oeq : ir::FCmpPredicate::One; lhs = builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp()); } else { ir::ICmpPredicate pred = text == "==" ? ir::ICmpPredicate::Eq : ir::ICmpPredicate::Ne; lhs = builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp()); } } return lhs; } ir::Value* IRGenImpl::EvalCondValue(SysYParser::CondContext* ctx) { if (!ctx) return nullptr; auto* tmp_true = func_->CreateBlock("cond.true"); auto* tmp_false = func_->CreateBlock("cond.false"); auto* merge = func_->CreateBlock("cond.merge"); EmitCondBr(ctx, tmp_true, tmp_false); builder_.SetInsertPoint(tmp_true); builder_.CreateBr(merge); builder_.SetInsertPoint(tmp_false); builder_.CreateBr(merge); builder_.SetInsertPoint(merge); auto* phi = builder_.CreatePhi(ir::Type::GetInt1Type(), module_.GetContext().NextTemp()); phi->AddIncoming(builder_.CreateConstBool(true), tmp_true); phi->AddIncoming(builder_.CreateConstBool(false), tmp_false); return phi; } void IRGenImpl::EmitCondBr(SysYParser::CondContext* ctx, ir::BasicBlock* true_bb, ir::BasicBlock* false_bb) { if (!ctx || !ctx->lOrExp()) { throw std::runtime_error(FormatError("irgen", "非法 cond")); } EmitLOrCond(ctx->lOrExp(), true_bb, false_bb); } void IRGenImpl::EmitLOrCond(SysYParser::LOrExpContext* ctx, ir::BasicBlock* true_bb, ir::BasicBlock* false_bb) { auto ands = ctx->lAndExp(); if (ands.empty()) { builder_.CreateBr(false_bb); return; } for (size_t i = 0; i < ands.size(); ++i) { if (i == ands.size() - 1) { EmitLAndCond(ands[i], true_bb, false_bb); } else { auto* next = func_->CreateBlock("lor.next"); EmitLAndCond(ands[i], true_bb, next); builder_.SetInsertPoint(next); } } } void IRGenImpl::EmitLAndCond(SysYParser::LAndExpContext* ctx, ir::BasicBlock* true_bb, ir::BasicBlock* false_bb) { auto eqs = ctx->eqExp(); if (eqs.empty()) { builder_.CreateBr(false_bb); return; } for (size_t i = 0; i < eqs.size(); ++i) { ir::Value* cond = EmitEq(eqs[i]); cond = MakeBool(cond); if (i == eqs.size() - 1) { builder_.CreateCondBr(cond, true_bb, false_bb); } else { auto* next = func_->CreateBlock("land.next"); builder_.CreateCondBr(cond, next, false_bb); builder_.SetInsertPoint(next); } } } ir::Value* IRGenImpl::GetLValAddress(SysYParser::LValContext* ctx) { BoundDecl bound = sema_.ResolveVarUse(ctx); ir::Value* base_ptr = nullptr; const TypeDesc* ty = nullptr; if (bound.kind == BoundDecl::Kind::Var && bound.var_decl) { ty = sema_.GetVarType(bound.var_decl); base_ptr = var_storage_[bound.var_decl]; } else if (bound.kind == BoundDecl::Kind::Const && bound.const_decl) { ty = sema_.GetConstType(bound.const_decl); base_ptr = const_storage_[bound.const_decl]; } else if (bound.kind == BoundDecl::Kind::Param && bound.param_decl) { ty = sema_.GetParamType(bound.param_decl); base_ptr = param_storage_[bound.param_decl]; } if (!base_ptr && bound.kind == BoundDecl::Kind::Var && bound.var_decl) { auto it = global_var_storage_.find(bound.var_decl); if (it != global_var_storage_.end()) { ty = sema_.GetVarType(bound.var_decl); base_ptr = it->second; } } if (!base_ptr && bound.kind == BoundDecl::Kind::Const && bound.const_decl) { auto it = global_const_storage_.find(bound.const_decl); if (it != global_const_storage_.end()) { ty = sema_.GetConstType(bound.const_decl); base_ptr = it->second; } } if (!base_ptr && ctx && ctx->ID()) { const auto name = ctx->ID()->getText(); for (const auto& kv : var_storage_) { auto* def = const_cast(kv.first); if (def && def->ID() && def->ID()->getText() == name) { ty = sema_.GetVarType(kv.first); base_ptr = kv.second; break; } } if (!base_ptr) { for (const auto& kv : const_storage_) { auto* def = const_cast(kv.first); if (def && def->ID() && def->ID()->getText() == name) { ty = sema_.GetConstType(kv.first); base_ptr = kv.second; break; } } } if (!base_ptr) { for (const auto& kv : param_storage_) { auto* def = const_cast(kv.first); if (def && def->ID() && def->ID()->getText() == name) { ty = sema_.GetParamType(kv.first); base_ptr = kv.second; break; } } } if (!base_ptr) { for (const auto& kv : global_var_storage_) { auto* def = const_cast(kv.first); if (def && def->ID() && def->ID()->getText() == name) { ty = sema_.GetVarType(kv.first); base_ptr = kv.second; break; } } } if (!base_ptr) { for (const auto& kv : global_const_storage_) { auto* def = const_cast(kv.first); if (def && def->ID() && def->ID()->getText() == name) { ty = sema_.GetConstType(kv.first); base_ptr = kv.second; break; } } } } if (!base_ptr || !ty) { throw std::runtime_error(FormatError("irgen", "左值未绑定")); } if (bound.kind == BoundDecl::Kind::Param && !ty->dims.empty()) { base_ptr = builder_.CreateLoad(base_ptr, module_.GetContext().NextTemp()); } std::vector indices; const auto exps = ctx->exp(); if (!ty->dims.empty() && !exps.empty()) { bool need_leading_zero = base_ptr->GetType()->GetElementType()->IsArray(); if (!ty->dims.empty() && ty->dims[0] < 0) { need_leading_zero = false; } if (need_leading_zero) { indices.push_back(builder_.CreateConstInt(0)); } } for (auto* exp : exps) { indices.push_back(CastToInt(EvalExp(exp))); } if (!indices.empty()) { if (base_ptr->GetType() && base_ptr->GetType()->IsPointer() && base_ptr->GetType()->GetElementType()->IsPointer()) { base_ptr = builder_.CreateLoad(base_ptr, module_.GetContext().NextTemp()); } return builder_.CreateGep(base_ptr, std::move(indices), module_.GetContext().NextTemp()); } return base_ptr; } ir::Value* IRGenImpl::LoadIfNeeded(ir::Value* addr_or_val, const TypeDesc& ty, bool as_rvalue) { if (!as_rvalue) { return addr_or_val; } return builder_.CreateLoad(addr_or_val, module_.GetContext().NextTemp()); } std::shared_ptr IRGenImpl::ToIRType(const TypeDesc& ty) { std::shared_ptr base; if (ty.base == BaseTypeKind::Int) base = ir::Type::GetInt32Type(); else if (ty.base == BaseTypeKind::Float) base = ir::Type::GetFloatType(); else base = ir::Type::GetVoidType(); for (auto it = ty.dims.rbegin(); it != ty.dims.rend(); ++it) { if (*it < 0) continue; base = ir::Type::GetArrayType(base, static_cast(*it)); } return base; } std::shared_ptr IRGenImpl::ToIRParamType(const TypeDesc& ty) { if (ty.dims.empty()) return ToIRType(ty); TypeDesc elem = ty; if (!elem.dims.empty() && elem.dims.front() < 0) { elem.dims.erase(elem.dims.begin()); } return ir::Type::GetPointerType(ToIRType(elem)); } ir::Value* IRGenImpl::DefaultValue(const TypeDesc& ty) { if (ty.base == BaseTypeKind::Float) return builder_.CreateConstFloat(0.0f); return builder_.CreateConstInt(0); } ir::AllocaInst* IRGenImpl::CreateEntryAlloca(std::shared_ptr ty, const std::string& name) { if (!func_) { throw std::runtime_error(FormatError("irgen", "CreateEntryAlloca 缺少函数")); } auto* entry = func_->GetEntry(); if (!entry) { throw std::runtime_error(FormatError("irgen", "CreateEntryAlloca 缺少入口块")); } return entry->Prepend(std::move(ty), name); } size_t IRGenImpl::ArrayStride(const TypeDesc& ty, size_t dim) const { size_t stride = 1; for (size_t i = dim + 1; i < ty.dims.size(); ++i) { stride *= static_cast(ty.dims[i]); } return stride; } size_t IRGenImpl::ArrayTotalSize(const TypeDesc& ty) const { size_t total = 1; for (int d : ty.dims) total *= static_cast(d); return total; } static size_t AlignIndex(size_t index, size_t align) { if (align == 0) return index; return (index + align - 1) / align * align; } size_t IRGenImpl::FillArrayValues(const TypeDesc& ty, SysYParser::InitValContext* init, std::vector& values, size_t base, size_t idx, size_t dim) { if (!init) return idx; if (init->exp()) { if (base + idx < values.size()) { values[base + idx] = EvalExp(init->exp()); } return idx + 1; } size_t sub_size = ArrayStride(ty, dim); if (init->initVal().empty()) { idx = AlignIndex(idx, sub_size); return idx + sub_size; } for (auto* child : init->initVal()) { if (!child) continue; if (child->exp()) { idx = FillArrayValues(ty, child, values, base, idx, dim); } else { size_t aligned = AlignIndex(idx, sub_size); idx = aligned; idx = FillArrayValues(ty, child, values, base, idx, dim + 1); idx = aligned + sub_size; } } idx = AlignIndex(idx, sub_size); return idx; } size_t IRGenImpl::FillConstArrayValues( const TypeDesc& ty, SysYParser::ConstInitValContext* init, std::vector& values, size_t base, size_t idx, size_t dim) { if (!init) return idx; if (init->constExp()) { if (base + idx < values.size()) { values[base + idx] = AnyToValue(init->constExp()->accept(this)); } return idx + 1; } size_t sub_size = ArrayStride(ty, dim); if (init->constInitVal().empty()) { idx = AlignIndex(idx, sub_size); return idx + sub_size; } for (auto* child : init->constInitVal()) { if (!child) continue; if (child->constExp()) { idx = FillConstArrayValues(ty, child, values, base, idx, dim); } else { size_t aligned = AlignIndex(idx, sub_size); idx = aligned; idx = FillConstArrayValues(ty, child, values, base, idx, dim + 1); idx = aligned + sub_size; } } idx = AlignIndex(idx, sub_size); return idx; } void IRGenImpl::InitArray(ir::Value* base_ptr, const TypeDesc& ty, SysYParser::InitValContext* init) { size_t total = ArrayTotalSize(ty); std::vector values(total, DefaultValue(ty)); FillArrayValues(ty, init, values, 0, 0, 0); for (size_t idx = 0; idx < total; ++idx) { std::vector indices; indices.push_back(builder_.CreateConstInt(0)); size_t remain = idx; for (size_t dim = 0; dim < ty.dims.size(); ++dim) { size_t stride = ArrayStride(ty, dim); size_t cur = remain / stride; remain %= stride; indices.push_back(builder_.CreateConstInt(static_cast(cur))); } ir::Value* addr = builder_.CreateGep(base_ptr, std::move(indices), module_.GetContext().NextTemp()); ir::Value* value = values[idx]; if (ty.base == BaseTypeKind::Float) { if (value->IsInt1() || value->IsInt32()) { value = CastToFloat(value->IsInt1() ? CastToInt(value) : value); } } else if (ty.base == BaseTypeKind::Int && value->IsFloat()) { value = CastToInt(value); } else if (ty.base == BaseTypeKind::Int && value->IsInt1()) { value = CastToInt(value); } builder_.CreateStore(value, addr); } } void IRGenImpl::InitConstArray(ir::Value* base_ptr, const TypeDesc& ty, SysYParser::ConstInitValContext* init) { size_t total = ArrayTotalSize(ty); std::vector values(total, DefaultValue(ty)); FillConstArrayValues(ty, init, values, 0, 0, 0); for (size_t idx = 0; idx < total; ++idx) { std::vector indices; indices.push_back(builder_.CreateConstInt(0)); size_t remain = idx; for (size_t dim = 0; dim < ty.dims.size(); ++dim) { size_t stride = ArrayStride(ty, dim); size_t cur = remain / stride; remain %= stride; indices.push_back(builder_.CreateConstInt(static_cast(cur))); } ir::Value* addr = builder_.CreateGep(base_ptr, std::move(indices), module_.GetContext().NextTemp()); ir::Value* value = values[idx]; if (ty.base == BaseTypeKind::Float) { if (value->IsInt1() || value->IsInt32()) { value = CastToFloat(value->IsInt1() ? CastToInt(value) : value); } } else if (ty.base == BaseTypeKind::Int && value->IsFloat()) { value = CastToInt(value); } else if (ty.base == BaseTypeKind::Int && value->IsInt1()) { value = CastToInt(value); } builder_.CreateStore(value, addr); } } void IRGenImpl::PushLoop(ir::BasicBlock* break_bb, ir::BasicBlock* cont_bb) { loop_stack_.push_back({break_bb, cont_bb}); } void IRGenImpl::PopLoop() { if (!loop_stack_.empty()) loop_stack_.pop_back(); } ir::BasicBlock* IRGenImpl::CurrentBreak() const { if (loop_stack_.empty()) return nullptr; return loop_stack_.back().first; } ir::BasicBlock* IRGenImpl::CurrentContinue() const { if (loop_stack_.empty()) return nullptr; return loop_stack_.back().second; } namespace { struct ConstNumber { bool is_float = false; double f = 0.0; long long i = 0; }; ConstNumber ToConstNumber(ir::ConstantValue* v) { ConstNumber num; if (auto* ci = dynamic_cast(v)) { num.is_float = false; num.i = ci->GetValue(); return num; } if (auto* cf = dynamic_cast(v)) { num.is_float = true; num.f = cf->GetValue(); return num; } return num; } ConstNumber PromoteToFloat(const ConstNumber& v) { if (v.is_float) return v; ConstNumber n; n.is_float = true; n.f = static_cast(v.i); return n; } } // namespace ir::ConstantValue* IRGenImpl::EvalConstScalar(SysYParser::ExpContext* ctx) { if (!ctx || !ctx->addExp()) { throw std::runtime_error(FormatError("irgen", "非法常量表达式")); } return EvalConstAdd(ctx->addExp()); } ir::ConstantValue* IRGenImpl::EvalConstScalar(SysYParser::ConstExpContext* ctx) { if (!ctx || !ctx->addExp()) { throw std::runtime_error(FormatError("irgen", "非法常量表达式")); } return EvalConstAdd(ctx->addExp()); } ir::ConstantValue* IRGenImpl::EvalConstAdd(SysYParser::AddExpContext* ctx) { auto muls = ctx->mulExp(); if (muls.empty()) return module_.GetContext().GetConstInt(0); ConstNumber lhs = ToConstNumber(EvalConstMul(muls[0])); for (size_t i = 1; i < muls.size(); ++i) { ConstNumber rhs = ToConstNumber(EvalConstMul(muls[i])); auto* node = ctx->children.at(2 * i - 1); std::string text = node ? node->getText() : "+"; if (lhs.is_float || rhs.is_float) { lhs = PromoteToFloat(lhs); rhs = PromoteToFloat(rhs); lhs.f = (text == "+") ? lhs.f + rhs.f : lhs.f - rhs.f; lhs.is_float = true; } else { lhs.i = (text == "+") ? lhs.i + rhs.i : lhs.i - rhs.i; } } if (lhs.is_float) return module_.GetContext().GetConstFloat(static_cast(lhs.f)); return module_.GetContext().GetConstInt(static_cast(lhs.i)); } ir::ConstantValue* IRGenImpl::EvalConstMul(SysYParser::MulExpContext* ctx) { auto unaries = ctx->unaryExp(); if (unaries.empty()) return module_.GetContext().GetConstInt(0); ConstNumber lhs = ToConstNumber(EvalConstUnary(unaries[0])); for (size_t i = 1; i < unaries.size(); ++i) { ConstNumber rhs = ToConstNumber(EvalConstUnary(unaries[i])); auto* node = ctx->children.at(2 * i - 1); std::string text = node ? node->getText() : "*"; if (text == "%") { if (lhs.is_float || rhs.is_float) { throw std::runtime_error(FormatError("irgen", "const % 仅支持整数")); } lhs.i = lhs.i % rhs.i; continue; } if (lhs.is_float || rhs.is_float) { lhs = PromoteToFloat(lhs); rhs = PromoteToFloat(rhs); if (text == "*") lhs.f = lhs.f * rhs.f; else lhs.f = lhs.f / rhs.f; lhs.is_float = true; } else { if (text == "*") lhs.i = lhs.i * rhs.i; else lhs.i = lhs.i / rhs.i; } } if (lhs.is_float) return module_.GetContext().GetConstFloat(static_cast(lhs.f)); return module_.GetContext().GetConstInt(static_cast(lhs.i)); } ir::ConstantValue* IRGenImpl::EvalConstUnary(SysYParser::UnaryExpContext* ctx) { if (ctx->primaryExp()) return EvalConstPrimary(ctx->primaryExp()); if (ctx->unaryOp() && ctx->unaryExp()) { ConstNumber val = ToConstNumber(EvalConstUnary(ctx->unaryExp())); std::string op = ctx->unaryOp()->getText(); if (op == "+") { if (val.is_float) { return module_.GetContext().GetConstFloat(static_cast(val.f)); } return module_.GetContext().GetConstInt(static_cast(val.i)); } if (op == "-") { if (val.is_float) { return module_.GetContext().GetConstFloat(static_cast(-val.f)); } return module_.GetContext().GetConstInt(static_cast(-val.i)); } if (op == "!") { bool is_zero = val.is_float ? (val.f == 0.0) : (val.i == 0); return module_.GetContext().GetConstInt(is_zero ? 1 : 0); } } throw std::runtime_error(FormatError("irgen", "const 不支持函数调用")); } ir::ConstantValue* IRGenImpl::EvalConstPrimary(SysYParser::PrimaryExpContext* ctx) { if (ctx->exp()) return EvalConstScalar(ctx->exp()); if (ctx->lVal()) return EvalConstLVal(ctx->lVal()); if (ctx->number()) return EvalConstNumber(ctx->number()); return module_.GetContext().GetConstInt(0); } ir::ConstantValue* IRGenImpl::EvalConstNumber(SysYParser::NumberContext* ctx) { if (ctx->INT_CONST()) { const std::string text = ctx->getText(); size_t idx = 0; long long val = std::stoll(text, &idx, 0); if (idx != text.size()) { throw std::runtime_error(FormatError("irgen", "非法整数常量")); } return module_.GetContext().GetConstInt(static_cast(val)); } if (ctx->FLOAT_CONST()) { return module_.GetContext().GetConstFloat(std::stof(ctx->getText())); } throw std::runtime_error(FormatError("irgen", "非法常量")); } ir::ConstantValue* IRGenImpl::EvalConstLVal(SysYParser::LValContext* ctx) { BoundDecl bound = sema_.ResolveVarUse(ctx); if (bound.kind == BoundDecl::Kind::Const && bound.const_decl) { auto it = global_const_storage_.find(bound.const_decl); if (it != global_const_storage_.end()) { auto* init = it->second->GetInitializer(); if (init) return init; } } throw std::runtime_error(FormatError("irgen", "constExp 使用了非常量")); } size_t IRGenImpl::InitGlobalArray(const TypeDesc& ty, SysYParser::InitValContext* init, std::vector& values, size_t base, size_t idx, size_t dim) { if (!init) return idx; if (init->exp()) { if (base + idx < values.size()) { auto* v = EvalConstScalar(init->exp()); if (ty.base == BaseTypeKind::Int && dynamic_cast(v)) { auto* cf = static_cast(v); v = module_.GetContext().GetConstInt(static_cast(cf->GetValue())); } else if (ty.base == BaseTypeKind::Float && dynamic_cast(v)) { auto* ci = static_cast(v); v = module_.GetContext().GetConstFloat(static_cast(ci->GetValue())); } values[base + idx] = v; } return idx + 1; } size_t sub_size = ArrayStride(ty, dim); if (init->initVal().empty()) { idx = AlignIndex(idx, sub_size); return idx + sub_size; } for (auto* child : init->initVal()) { if (!child) continue; if (child->exp()) { idx = InitGlobalArray(ty, child, values, base, idx, dim); } else { size_t aligned = AlignIndex(idx, sub_size); idx = aligned; idx = InitGlobalArray(ty, child, values, base, idx, dim + 1); idx = aligned + sub_size; } } idx = AlignIndex(idx, sub_size); return idx; } size_t IRGenImpl::InitGlobalConstArray(const TypeDesc& ty, SysYParser::ConstInitValContext* init, std::vector& values, size_t base, size_t idx, size_t dim) { if (!init) return idx; if (init->constExp()) { if (base + idx < values.size()) { auto* v = EvalConstScalar(init->constExp()); if (ty.base == BaseTypeKind::Int && dynamic_cast(v)) { auto* cf = static_cast(v); v = module_.GetContext().GetConstInt(static_cast(cf->GetValue())); } else if (ty.base == BaseTypeKind::Float && dynamic_cast(v)) { auto* ci = static_cast(v); v = module_.GetContext().GetConstFloat(static_cast(ci->GetValue())); } values[base + idx] = v; } return idx + 1; } size_t sub_size = ArrayStride(ty, dim); if (init->constInitVal().empty()) { idx = AlignIndex(idx, sub_size); return idx + sub_size; } for (auto* child : init->constInitVal()) { if (!child) continue; if (child->constExp()) { idx = InitGlobalConstArray(ty, child, values, base, idx, dim); } else { size_t aligned = AlignIndex(idx, sub_size); idx = aligned; idx = InitGlobalConstArray(ty, child, values, base, idx, dim + 1); idx = aligned + sub_size; } } idx = AlignIndex(idx, sub_size); return idx; }