// IR 文本输出: // - 将 IR 打印为 .ll 风格的文本 // - 支撑调试与测试对比(diff) #include "ir/IR.h" #include #include #include #include #include #include #include "utils/Log.h" namespace ir { static std::string TypeToString(const Type& ty); static std::string ArrayTypeToStringFrom(const Type& base_ty, const std::vector& dims, size_t begin) { std::string s = TypeToString(base_ty); for (size_t i = dims.size(); i-- > begin;) { s = "[" + std::to_string(dims[i]) + " x " + s + "]"; } return s; } static bool IsZeroConstant(const ConstantValue* value) { if (!value) { return true; } if (auto* ci = dynamic_cast(value)) { return ci->GetValue() == 0; } if (auto* cf = dynamic_cast(value)) { return cf->GetValue() == 0.0f; } if (dynamic_cast(value) || dynamic_cast(value)) { return true; } if (auto* arr = dynamic_cast(value)) { for (auto* elem : arr->GetElements()) { if (!IsZeroConstant(elem)) { return false; } } return true; } return false; } static size_t AggregateSpan(const std::vector& dims, size_t level) { size_t span = 1; for (size_t i = level; i < dims.size(); ++i) { span *= static_cast(dims[i]); } return span; } static bool IsZeroRange(const std::vector& init, size_t begin, size_t count) { for (size_t i = 0; i < count; ++i) { const size_t index = begin + i; if (index >= init.size()) { continue; } if (!IsZeroConstant(init[index])) { return false; } } return true; } static void PrintFlatArrayBody(std::ostream& os, const Type& base_ty, const std::vector& dims, size_t level, const std::vector& init, size_t& flat_index) { const size_t span = AggregateSpan(dims, level); if (IsZeroRange(init, flat_index, span)) { os << "zeroinitializer"; flat_index += span; return; } os << "["; for (int i = 0; i < dims[level]; ++i) { if (i > 0) os << ", "; if (level + 1 < dims.size()) { os << ArrayTypeToStringFrom(base_ty, dims, level + 1) << " "; PrintFlatArrayBody(os, base_ty, dims, level + 1, init, flat_index); continue; } os << TypeToString(base_ty) << " "; if (flat_index < init.size() && init[flat_index]) { if (auto* ci = dynamic_cast(init[flat_index])) { os << ci->GetValue(); } else if (auto* cf = dynamic_cast(init[flat_index])) { os << cf->GetValue(); } else if (IsZeroConstant(init[flat_index])) { os << "0"; } else { os << "0"; } } else { os << "0"; } ++flat_index; } os << "]"; } static std::string TypeToString(const Type& ty) { switch (ty.GetKind()) { case Type::Kind::Void: return "void"; case Type::Kind::Int32: return "i32"; case Type::Kind::Float: return "float"; case Type::Kind::PtrInt32: return "i32*"; case Type::Kind::PtrFloat: return "float*"; case Type::Kind::Label: return "label"; case Type::Kind::Function: return "function"; case Type::Kind::Int1: return "i1"; case Type::Kind::PtrInt1: return "i1*"; case Type::Kind::Array: { // 打印数组类型为 LLVM 风格,如 [4 x [2 x i32]] auto* at = dynamic_cast(&ty); if (!at) return "array"; // 递归构建类型字符串 std::string elem = TypeToString(*at->GetElementType()); const auto& dims = at->GetDimensions(); // 从外到内构建 std::string s = elem; for (auto it = dims.rbegin(); it != dims.rend(); ++it) { s = "[" + std::to_string(*it) + " x " + s + "]"; } return s; } default: return "unknown"; } throw std::runtime_error(FormatError("ir", "未知类型")); } static const char* OpcodeToString(Opcode op) { switch (op) { case Opcode::Add: return "add"; case Opcode::Sub: return "sub"; case Opcode::Mul: return "mul"; case Opcode::Alloca: return "alloca"; case Opcode::Load: return "load"; case Opcode::Store: return "store"; case Opcode::Ret: return "ret"; case Opcode::Call: return "call"; case Opcode::Br: return "br"; case Opcode::CondBr: return "condbr"; case Opcode::Icmp: return "icmp"; case Opcode::Div: return "sdiv"; case Opcode::Mod: return "srem"; case Opcode::ZExt: return "zext"; case Opcode::Trunc: return "trunc"; case Opcode::And: return "and"; case Opcode::Or: return "or"; case Opcode::Not: return "not"; case Opcode::GEP: return "getelementptr"; case Opcode::FAdd: return "fadd"; case Opcode::FSub: return "fsub"; case Opcode::FMul: return "fmul"; case Opcode::FDiv: return "fdiv"; case Opcode::FCmp: return "fcmp"; case Opcode::SIToFP: return "sitofp"; case Opcode::FPToSI: return "fptosi"; case Opcode::FPExt: return "fpext"; case Opcode::FPTrunc: return "fptrunc"; } return "?"; } // 将 float 值转为 LLVM IR 接受的 64-bit 十六进制浮点格式 static std::string FloatToLLVMHex(float f) { double d = static_cast(f); uint64_t bits; memcpy(&bits, &d, sizeof(bits)); char buf[20]; snprintf(buf, sizeof(buf), "0x%016llX", (unsigned long long)bits); return buf; } static std::string ValueToString(const Value* v) { if (!v) { return ""; } if (dynamic_cast(v) || dynamic_cast(v)) { return "zeroinitializer"; } if (auto* ci = dynamic_cast(v)) { return std::to_string(ci->GetValue()); } if (auto* cf = dynamic_cast(v)) { return FloatToLLVMHex(cf->GetValue()); } const auto& name = v->GetName(); if (name.empty()) { return ""; } if (name[0] == '%' || name[0] == '@') { return name; } if (dynamic_cast(v)) { return "@" + name; } return "%" + name; } static std::string MemoryTypeToString(const Type& ty) { std::string text = TypeToString(ty); if (ty.IsArray()) { text += "*"; } return text; } void IRPrinter::Print(const Module& module, std::ostream& os) { for (const auto& global : module.GetGlobals()) { if (!global) continue; os << "@" << global->GetName() << " = " << (global->IsConstant() ? "constant " : "global "); if (global->GetType()->IsPtrInt32()) { os << "i32 "; if (global->HasInitializer()) { auto* ci = dynamic_cast(global->GetInitializer().front()); os << (ci ? ci->GetValue() : 0); } else { os << "0"; } os << "\n"; continue; } if (global->GetType()->IsPtrFloat()) { os << "float "; if (global->HasInitializer()) { auto* cf = dynamic_cast(global->GetInitializer().front()); os << (cf ? ValueToString(cf) : FloatToLLVMHex(0.0f)); } else { os << FloatToLLVMHex(0.0f); } os << "\n"; continue; } if (global->GetType()->IsArray()) { auto* at = dynamic_cast(global->GetType().get()); os << TypeToString(*global->GetType()) << " "; if (!at || !global->HasInitializer() || IsZeroRange(global->GetInitializer(), 0, AggregateSpan(at->GetDimensions(), 0))) { os << "zeroinitializer\n"; continue; } size_t flat_index = 0; PrintFlatArrayBody(os, *at->GetElementType(), at->GetDimensions(), 0, global->GetInitializer(), flat_index); os << "\n"; continue; } os << TypeToString(*global->GetType()) << " zeroinitializer\n"; } auto print_func_params = [&](const Function* func, const FunctionType* func_ty) { bool first = true; if (!func->GetArguments().empty()) { for (const auto& arg : func->GetArguments()) { if (!first) os << ", "; first = false; os << TypeToString(*arg->GetType()) << " %" << arg->GetName(); } return; } for (const auto& pty : func_ty->GetParamTypes()) { if (!first) os << ", "; first = false; os << TypeToString(*pty); } }; auto is_declaration_only = [](const Function* func) { const auto& blocks = func->GetBlocks(); if (blocks.size() != 1) return false; const auto& only = blocks.front(); if (!only) return false; return only->GetInstructions().empty(); }; for (const auto& func : module.GetFunctions()) { auto* func_ty = static_cast(func->GetType().get()); if (is_declaration_only(func.get())) { os << "declare " << TypeToString(*func_ty->GetReturnType()) << " @" << func->GetName() << "("; print_func_params(func.get(), func_ty); os << ")\n"; continue; } os << "define " << TypeToString(*func_ty->GetReturnType()) << " @" << func->GetName() << "("; print_func_params(func.get(), func_ty); os << ") {\n"; for (const auto& bb : func->GetBlocks()) { if (!bb) { continue; } os << bb->GetName() << ":\n"; for (const auto& instPtr : bb->GetInstructions()) { const auto* inst = instPtr.get(); switch (inst->GetOpcode()) { case Opcode::Add: case Opcode::Sub: case Opcode::Mul: case Opcode::Div: case Opcode::Mod: case Opcode::And: case Opcode::Not: case Opcode::Or: case Opcode::FAdd: case Opcode::FSub: case Opcode::FMul: case Opcode::FDiv: { auto* bin = static_cast(inst); os << " " << bin->GetName() << " = " << OpcodeToString(bin->GetOpcode()) << " " << TypeToString(*bin->GetLhs()->GetType()) << " " << ValueToString(bin->GetLhs()) << ", " << ValueToString(bin->GetRhs()) << "\n"; break; } case Opcode::Alloca: { auto* alloca = static_cast(inst); std::string elem_ty_str; if (alloca->GetType()->IsPtrInt32()) { elem_ty_str = "i32"; } else if (alloca->GetType()->IsPtrFloat()) { elem_ty_str = "float"; } else { elem_ty_str = TypeToString(*alloca->GetType()); } os << " " << alloca->GetName() << " = alloca " << elem_ty_str << "\n"; break; } case Opcode::Load: { auto* load = static_cast(inst); os << " " << load->GetName() << " = load " << TypeToString(*load->GetType()) << ", " << MemoryTypeToString(*load->GetPtr()->GetType()) << " " << ValueToString(load->GetPtr()) << "\n"; break; } case Opcode::Store: { auto* store = static_cast(inst); os << " store " << TypeToString(*store->GetValue()->GetType()) << " " << ValueToString(store->GetValue()) << ", " << MemoryTypeToString(*store->GetPtr()->GetType()) << " " << ValueToString(store->GetPtr()) << "\n"; break; } case Opcode::Ret: { auto* ret = static_cast(inst); if (!ret->GetValue()) { os << " ret void\n"; } else { os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " " << ValueToString(ret->GetValue()) << "\n"; } break; } // CallInst类在 include/ir/IR.h 中定义 case Opcode::Call: { auto* call = static_cast(inst); os << " "; if (!call->GetType()->IsVoid()) { os << call->GetName() << " = "; } os << "call " << TypeToString(*call->GetType()) << " @" << call->GetCallee()->GetName() << "("; bool first = true; for (auto* arg : call->GetArgs()) { if (!first) os << ", "; first = false; os << TypeToString(*arg->GetType()) << " " << ValueToString(arg); } os << ")\n"; break; } // 在 IRPrinter.cpp 的 switch 语句中添加 case Opcode::Br: case Opcode::CondBr: { auto* br = static_cast(inst); if (!br->IsConditional()) { os << " br label %" << br->GetTarget()->GetName() << "\n"; } else { os << " br i1 " << ValueToString(br->GetCondition()) << ", label %" << br->GetTrueTarget()->GetName() << ", label %" << br->GetFalseTarget()->GetName() << "\n"; } break; } case Opcode::Icmp: { auto* icmp = static_cast(inst); os << " " << icmp->GetName() << " = icmp "; switch (icmp->GetPredicate()) { case IcmpInst::Predicate::EQ: os << "eq"; break; case IcmpInst::Predicate::NE: os << "ne"; break; case IcmpInst::Predicate::LT: os << "slt"; break; case IcmpInst::Predicate::LE: os << "sle"; break; case IcmpInst::Predicate::GT: os << "sgt"; break; case IcmpInst::Predicate::GE: os << "sge"; break; } os << " " << TypeToString(*icmp->GetLhs()->GetType()) << " " << ValueToString(icmp->GetLhs()) << ", " << ValueToString(icmp->GetRhs()) << "\n"; break; } case Opcode::ZExt: { auto* zext = static_cast(inst); os << " " << zext->GetName() << " = zext " << TypeToString(*zext->GetSourceType()) << " " << ValueToString(zext->GetValue()) << " to " << TypeToString(*zext->GetTargetType()) << "\n"; break; } case Opcode::Trunc: { auto* trunc = static_cast(inst); os << " " << trunc->GetName() << " = trunc " << TypeToString(*trunc->GetSourceType()) << " " << ValueToString(trunc->GetValue()) << " to " << TypeToString(*trunc->GetTargetType()) << "\n"; break; } case Opcode::GEP:{ // 打印为类似 LLVM 的 getelementptr 形式: // getelementptr , , i32 , i32 , ... os << " " << inst->GetName() << " = getelementptr "; // 基地址类型使用第一个操作数的类型 Value* base = inst->GetOperand(0); // GEP 的第一个类型参数应是基址指向的元素类型(pointee)。 std::string elem_ty; if (base->GetType()->IsPtrInt32()) elem_ty = "i32"; else if (base->GetType()->IsPtrFloat()) elem_ty = "float"; else if (base->GetType()->IsArray()) elem_ty = TypeToString(*base->GetType()); else elem_ty = TypeToString(*inst->GetType()); std::string base_ty = TypeToString(*base->GetType()); if (base->GetType()->IsArray()) { base_ty += "*"; } os << elem_ty << ", " << base_ty << " " << ValueToString(base); // 后续操作数为索引,按照 i32 打印 // 特殊处理:如果 base 是标量指针(i32*/float*)且第一个索引是常量 0 // 且后续还有索引,则丢弃第一个 0(对 T* 来说多余且会导致无效 IR)。 size_t start_idx = 1; if ((base->GetType()->IsPtrInt32() || base->GetType()->IsPtrFloat()) && inst->GetNumOperands() >= 3) { // 检查第一个索引是否为常量 0 auto* first_idx = inst->GetOperand(1); if (auto* ci = dynamic_cast(first_idx)) { if (ci->GetValue() == 0) { start_idx = 2; // 跳过第一个 0 } } } for (size_t i = start_idx; i < inst->GetNumOperands(); ++i) { os << ", i32 " << ValueToString(inst->GetOperand(i)); } os << "\n"; break; } case Opcode::FCmp: { auto* fcmp = static_cast(inst); os << " " << fcmp->GetName() << " = fcmp "; switch (fcmp->GetPredicate()) { case FcmpInst::Predicate::OEQ: os << "oeq"; break; case FcmpInst::Predicate::ONE: os << "one"; break; case FcmpInst::Predicate::OLT: os << "olt"; break; case FcmpInst::Predicate::OLE: os << "ole"; break; case FcmpInst::Predicate::OGT: os << "ogt"; break; case FcmpInst::Predicate::OGE: os << "oge"; break; default: os << "oeq"; break; } os << " " << TypeToString(*fcmp->GetLhs()->GetType()) << " " << ValueToString(fcmp->GetLhs()) << ", " << ValueToString(fcmp->GetRhs()) << "\n"; break; } case Opcode::SIToFP: { auto* sitofp = static_cast(inst); os << " " << sitofp->GetName() << " = sitofp " << TypeToString(*sitofp->GetValue()->GetType()) << " " << ValueToString(sitofp->GetValue()) << " to " << TypeToString(*sitofp->GetType()) << "\n"; break; } case Opcode::FPToSI: { auto* fptosi = static_cast(inst); os << " " << fptosi->GetName() << " = fptosi " << TypeToString(*fptosi->GetValue()->GetType()) << " " << ValueToString(fptosi->GetValue()) << " to " << TypeToString(*fptosi->GetType()) << "\n"; break; } default: { // 处理未知操作码 os << " ; 未知指令: " << OpcodeToString(inst->GetOpcode()) << "\n"; break; } } } } os << "}\n"; } } void IRPrinter::PrintConstant(const ConstantValue* constant, std::ostream& os) { if (auto* const_int = dynamic_cast(constant)) { os << const_int->GetValue(); } else if (auto* const_float = dynamic_cast(constant)) { os << const_float->GetValue(); } else if (auto* const_array = dynamic_cast(constant)) { os << "["; auto& elements = const_array->GetElements(); for (size_t i = 0; i < elements.size(); ++i) { if (i > 0) os << ", "; PrintConstant(elements[i], os); } os << "]"; } else if (dynamic_cast(constant)) { os << "zero"; } else if (dynamic_cast(constant)) { os << "zeroinitializer"; } } } // namespace ir