You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/ir/IRPrinter.cpp

590 lines
20 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// IR 文本输出:
// - 将 IR 打印为 .ll 风格的文本
// - 支撑调试与测试对比diff
#include "ir/IR.h"
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <ostream>
#include <stdexcept>
#include <string>
#include "utils/Log.h"
namespace ir {
static std::string TypeToString(const Type& ty);
static std::string ArrayTypeToStringFrom(const Type& base_ty,
const std::vector<int>& dims,
size_t begin) {
std::string s = TypeToString(base_ty);
for (size_t i = dims.size(); i-- > begin;) {
s = "[" + std::to_string(dims[i]) + " x " + s + "]";
}
return s;
}
static bool IsZeroConstant(const ConstantValue* value) {
if (!value) {
return true;
}
if (auto* ci = dynamic_cast<const ConstantInt*>(value)) {
return ci->GetValue() == 0;
}
if (auto* cf = dynamic_cast<const ConstantFloat*>(value)) {
return cf->GetValue() == 0.0f;
}
if (dynamic_cast<const ConstantZero*>(value) ||
dynamic_cast<const ConstantAggregateZero*>(value)) {
return true;
}
if (auto* arr = dynamic_cast<const ConstantArray*>(value)) {
for (auto* elem : arr->GetElements()) {
if (!IsZeroConstant(elem)) {
return false;
}
}
return true;
}
return false;
}
static size_t AggregateSpan(const std::vector<int>& dims, size_t level) {
size_t span = 1;
for (size_t i = level; i < dims.size(); ++i) {
span *= static_cast<size_t>(dims[i]);
}
return span;
}
static bool IsZeroRange(const std::vector<ConstantValue*>& init,
size_t begin,
size_t count) {
for (size_t i = 0; i < count; ++i) {
const size_t index = begin + i;
if (index >= init.size()) {
continue;
}
if (!IsZeroConstant(init[index])) {
return false;
}
}
return true;
}
static void PrintFlatArrayBody(std::ostream& os,
const Type& base_ty,
const std::vector<int>& dims,
size_t level,
const std::vector<ConstantValue*>& init,
size_t& flat_index) {
const size_t span = AggregateSpan(dims, level);
if (IsZeroRange(init, flat_index, span)) {
os << "zeroinitializer";
flat_index += span;
return;
}
os << "[";
for (int i = 0; i < dims[level]; ++i) {
if (i > 0) os << ", ";
if (level + 1 < dims.size()) {
os << ArrayTypeToStringFrom(base_ty, dims, level + 1) << " ";
PrintFlatArrayBody(os, base_ty, dims, level + 1, init, flat_index);
continue;
}
os << TypeToString(base_ty) << " ";
if (flat_index < init.size() && init[flat_index]) {
if (auto* ci = dynamic_cast<const ConstantInt*>(init[flat_index])) {
os << ci->GetValue();
} else if (auto* cf = dynamic_cast<const ConstantFloat*>(init[flat_index])) {
os << cf->GetValue();
} else if (IsZeroConstant(init[flat_index])) {
os << "0";
} else {
os << "0";
}
} else {
os << "0";
}
++flat_index;
}
os << "]";
}
static std::string TypeToString(const Type& ty) {
switch (ty.GetKind()) {
case Type::Kind::Void: return "void";
case Type::Kind::Int32: return "i32";
case Type::Kind::Float: return "float";
case Type::Kind::PtrInt32: return "i32*";
case Type::Kind::PtrFloat: return "float*";
case Type::Kind::Label: return "label";
case Type::Kind::Function: return "function";
case Type::Kind::Int1: return "i1";
case Type::Kind::PtrInt1: return "i1*";
case Type::Kind::Array: {
// 打印数组类型为 LLVM 风格,如 [4 x [2 x i32]]
auto* at = dynamic_cast<const ArrayType*>(&ty);
if (!at) return "array";
// 递归构建类型字符串
std::string elem = TypeToString(*at->GetElementType());
const auto& dims = at->GetDimensions();
// 从外到内构建
std::string s = elem;
for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
s = "[" + std::to_string(*it) + " x " + s + "]";
}
return s;
}
default: return "unknown";
}
throw std::runtime_error(FormatError("ir", "未知类型"));
}
static const char* OpcodeToString(Opcode op) {
switch (op) {
case Opcode::Add:
return "add";
case Opcode::Sub:
return "sub";
case Opcode::Mul:
return "mul";
case Opcode::Alloca:
return "alloca";
case Opcode::Load:
return "load";
case Opcode::Store:
return "store";
case Opcode::Ret:
return "ret";
case Opcode::Call:
return "call";
case Opcode::Br:
return "br";
case Opcode::CondBr:
return "condbr";
case Opcode::Icmp:
return "icmp";
case Opcode::Div:
return "sdiv";
case Opcode::Mod:
return "srem";
case Opcode::ZExt:
return "zext";
case Opcode::Trunc:
return "trunc";
case Opcode::And:
return "and";
case Opcode::Or:
return "or";
case Opcode::Not:
return "not";
case Opcode::GEP:
return "getelementptr";
case Opcode::FAdd: return "fadd";
case Opcode::FSub: return "fsub";
case Opcode::FMul: return "fmul";
case Opcode::FDiv: return "fdiv";
case Opcode::FCmp: return "fcmp";
case Opcode::SIToFP: return "sitofp";
case Opcode::FPToSI: return "fptosi";
case Opcode::FPExt: return "fpext";
case Opcode::FPTrunc: return "fptrunc";
}
return "?";
}
// 将 float 值转为 LLVM IR 接受的 64-bit 十六进制浮点格式
static std::string FloatToLLVMHex(float f) {
double d = static_cast<double>(f);
uint64_t bits;
memcpy(&bits, &d, sizeof(bits));
char buf[20];
snprintf(buf, sizeof(buf), "0x%016llX", (unsigned long long)bits);
return buf;
}
static std::string ValueToString(const Value* v) {
if (!v) {
return "<null>";
}
if (dynamic_cast<const ConstantZero*>(v) ||
dynamic_cast<const ConstantAggregateZero*>(v)) {
return "zeroinitializer";
}
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
return std::to_string(ci->GetValue());
}
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
return FloatToLLVMHex(cf->GetValue());
}
const auto& name = v->GetName();
if (name.empty()) {
return "<unnamed>";
}
if (name[0] == '%' || name[0] == '@') {
return name;
}
if (dynamic_cast<const GlobalValue*>(v)) {
return "@" + name;
}
return "%" + name;
}
static std::string MemoryTypeToString(const Type& ty) {
std::string text = TypeToString(ty);
if (ty.IsArray()) {
text += "*";
}
return text;
}
void IRPrinter::Print(const Module& module, std::ostream& os) {
for (const auto& global : module.GetGlobals()) {
if (!global) continue;
os << "@" << global->GetName() << " = "
<< (global->IsConstant() ? "constant " : "global ");
if (global->GetType()->IsPtrInt32()) {
os << "i32 ";
if (global->HasInitializer()) {
auto* ci = dynamic_cast<const ConstantInt*>(global->GetInitializer().front());
os << (ci ? ci->GetValue() : 0);
} else {
os << "0";
}
os << "\n";
continue;
}
if (global->GetType()->IsPtrFloat()) {
os << "float ";
if (global->HasInitializer()) {
auto* cf = dynamic_cast<const ConstantFloat*>(global->GetInitializer().front());
os << (cf ? ValueToString(cf) : FloatToLLVMHex(0.0f));
} else {
os << FloatToLLVMHex(0.0f);
}
os << "\n";
continue;
}
if (global->GetType()->IsArray()) {
auto* at = dynamic_cast<const ArrayType*>(global->GetType().get());
os << TypeToString(*global->GetType()) << " ";
if (!at || !global->HasInitializer() ||
IsZeroRange(global->GetInitializer(), 0, AggregateSpan(at->GetDimensions(), 0))) {
os << "zeroinitializer\n";
continue;
}
size_t flat_index = 0;
PrintFlatArrayBody(os,
*at->GetElementType(),
at->GetDimensions(),
0,
global->GetInitializer(),
flat_index);
os << "\n";
continue;
}
os << TypeToString(*global->GetType()) << " zeroinitializer\n";
}
auto print_func_params = [&](const Function* func,
const FunctionType* func_ty) {
bool first = true;
if (!func->GetArguments().empty()) {
for (const auto& arg : func->GetArguments()) {
if (!first) os << ", ";
first = false;
os << TypeToString(*arg->GetType()) << " %" << arg->GetName();
}
return;
}
for (const auto& pty : func_ty->GetParamTypes()) {
if (!first) os << ", ";
first = false;
os << TypeToString(*pty);
}
};
auto is_declaration_only = [](const Function* func) {
const auto& blocks = func->GetBlocks();
if (blocks.size() != 1) return false;
const auto& only = blocks.front();
if (!only) return false;
return only->GetInstructions().empty();
};
for (const auto& func : module.GetFunctions()) {
auto* func_ty = static_cast<const FunctionType*>(func->GetType().get());
if (is_declaration_only(func.get())) {
os << "declare " << TypeToString(*func_ty->GetReturnType()) << " @"
<< func->GetName() << "(";
print_func_params(func.get(), func_ty);
os << ")\n";
continue;
}
os << "define " << TypeToString(*func_ty->GetReturnType()) << " @"
<< func->GetName() << "(";
print_func_params(func.get(), func_ty);
os << ") {\n";
for (const auto& bb : func->GetBlocks()) {
if (!bb) {
continue;
}
os << bb->GetName() << ":\n";
for (const auto& instPtr : bb->GetInstructions()) {
const auto* inst = instPtr.get();
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Mod:
case Opcode::And:
case Opcode::Not:
case Opcode::Or:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
{
auto* bin = static_cast<const BinaryInst*>(inst);
os << " " << bin->GetName() << " = "
<< OpcodeToString(bin->GetOpcode()) << " "
<< TypeToString(*bin->GetLhs()->GetType()) << " "
<< ValueToString(bin->GetLhs()) << ", "
<< ValueToString(bin->GetRhs()) << "\n";
break;
}
case Opcode::Alloca: {
auto* alloca = static_cast<const AllocaInst*>(inst);
std::string elem_ty_str;
if (alloca->GetType()->IsPtrInt32()) {
elem_ty_str = "i32";
} else if (alloca->GetType()->IsPtrFloat()) {
elem_ty_str = "float";
} else {
elem_ty_str = TypeToString(*alloca->GetType());
}
os << " " << alloca->GetName() << " = alloca " << elem_ty_str << "\n";
break;
}
case Opcode::Load: {
auto* load = static_cast<const LoadInst*>(inst);
os << " " << load->GetName() << " = load "
<< TypeToString(*load->GetType()) << ", "
<< MemoryTypeToString(*load->GetPtr()->GetType()) << " "
<< ValueToString(load->GetPtr()) << "\n";
break;
}
case Opcode::Store: {
auto* store = static_cast<const StoreInst*>(inst);
os << " store " << TypeToString(*store->GetValue()->GetType()) << " "
<< ValueToString(store->GetValue())
<< ", " << MemoryTypeToString(*store->GetPtr()->GetType()) << " "
<< ValueToString(store->GetPtr()) << "\n";
break;
}
case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst);
if (!ret->GetValue()) {
os << " ret void\n";
} else {
os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " "
<< ValueToString(ret->GetValue()) << "\n";
}
break;
}
// CallInst类在 include/ir/IR.h 中定义
case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst);
os << " ";
if (!call->GetType()->IsVoid()) {
os << call->GetName() << " = ";
}
os << "call " << TypeToString(*call->GetType()) << " @"
<< call->GetCallee()->GetName() << "(";
bool first = true;
for (auto* arg : call->GetArgs()) {
if (!first) os << ", ";
first = false;
os << TypeToString(*arg->GetType()) << " " << ValueToString(arg);
}
os << ")\n";
break;
}
// 在 IRPrinter.cpp 的 switch 语句中添加
case Opcode::Br:
case Opcode::CondBr: {
auto* br = static_cast<const BranchInst*>(inst);
if (!br->IsConditional()) {
os << " br label %" << br->GetTarget()->GetName() << "\n";
} else {
os << " br i1 " << ValueToString(br->GetCondition())
<< ", label %" << br->GetTrueTarget()->GetName()
<< ", label %" << br->GetFalseTarget()->GetName() << "\n";
}
break;
}
case Opcode::Icmp: {
auto* icmp = static_cast<const IcmpInst*>(inst);
os << " " << icmp->GetName() << " = icmp ";
switch (icmp->GetPredicate()) {
case IcmpInst::Predicate::EQ: os << "eq"; break;
case IcmpInst::Predicate::NE: os << "ne"; break;
case IcmpInst::Predicate::LT: os << "slt"; break;
case IcmpInst::Predicate::LE: os << "sle"; break;
case IcmpInst::Predicate::GT: os << "sgt"; break;
case IcmpInst::Predicate::GE: os << "sge"; break;
}
os << " " << TypeToString(*icmp->GetLhs()->GetType())
<< " " << ValueToString(icmp->GetLhs())
<< ", " << ValueToString(icmp->GetRhs()) << "\n";
break;
}
case Opcode::ZExt: {
auto* zext = static_cast<const ZExtInst*>(inst);
os << " " << zext->GetName() << " = zext "
<< TypeToString(*zext->GetSourceType()) << " "
<< ValueToString(zext->GetValue()) << " to "
<< TypeToString(*zext->GetTargetType()) << "\n";
break;
}
case Opcode::Trunc: {
auto* trunc = static_cast<const TruncInst*>(inst);
os << " " << trunc->GetName() << " = trunc "
<< TypeToString(*trunc->GetSourceType()) << " "
<< ValueToString(trunc->GetValue()) << " to "
<< TypeToString(*trunc->GetTargetType()) << "\n";
break;
}
case Opcode::GEP:{
// 打印为类似 LLVM 的 getelementptr 形式:
// getelementptr <elem_ty>, <base_ty> <base>, i32 <idx0>, i32 <idx1>, ...
os << " " << inst->GetName() << " = getelementptr ";
// 基地址类型使用第一个操作数的类型
Value* base = inst->GetOperand(0);
// GEP 的第一个类型参数应是基址指向的元素类型pointee
std::string elem_ty;
if (base->GetType()->IsPtrInt32()) elem_ty = "i32";
else if (base->GetType()->IsPtrFloat()) elem_ty = "float";
else if (base->GetType()->IsArray()) elem_ty = TypeToString(*base->GetType());
else elem_ty = TypeToString(*inst->GetType());
std::string base_ty = TypeToString(*base->GetType());
if (base->GetType()->IsArray()) {
base_ty += "*";
}
os << elem_ty << ", " << base_ty << " " << ValueToString(base);
// 后续操作数为索引,按照 i32 打印
// 特殊处理:如果 base 是标量指针i32*/float*)且第一个索引是常量 0
// 且后续还有索引,则丢弃第一个 0对 T* 来说多余且会导致无效 IR
size_t start_idx = 1;
if ((base->GetType()->IsPtrInt32() || base->GetType()->IsPtrFloat()) &&
inst->GetNumOperands() >= 3) {
// 检查第一个索引是否为常量 0
auto* first_idx = inst->GetOperand(1);
if (auto* ci = dynamic_cast<const ConstantInt*>(first_idx)) {
if (ci->GetValue() == 0) {
start_idx = 2; // 跳过第一个 0
}
}
}
for (size_t i = start_idx; i < inst->GetNumOperands(); ++i) {
os << ", i32 " << ValueToString(inst->GetOperand(i));
}
os << "\n";
break;
}
case Opcode::FCmp: {
auto* fcmp = static_cast<const FcmpInst*>(inst);
os << " " << fcmp->GetName() << " = fcmp ";
switch (fcmp->GetPredicate()) {
case FcmpInst::Predicate::OEQ: os << "oeq"; break;
case FcmpInst::Predicate::ONE: os << "one"; break;
case FcmpInst::Predicate::OLT: os << "olt"; break;
case FcmpInst::Predicate::OLE: os << "ole"; break;
case FcmpInst::Predicate::OGT: os << "ogt"; break;
case FcmpInst::Predicate::OGE: os << "oge"; break;
default: os << "oeq"; break;
}
os << " " << TypeToString(*fcmp->GetLhs()->GetType())
<< " " << ValueToString(fcmp->GetLhs())
<< ", " << ValueToString(fcmp->GetRhs()) << "\n";
break;
}
case Opcode::SIToFP: {
auto* sitofp = static_cast<const SIToFPInst*>(inst);
os << " " << sitofp->GetName() << " = sitofp "
<< TypeToString(*sitofp->GetValue()->GetType()) << " "
<< ValueToString(sitofp->GetValue()) << " to "
<< TypeToString(*sitofp->GetType()) << "\n";
break;
}
case Opcode::FPToSI: {
auto* fptosi = static_cast<const FPToSIInst*>(inst);
os << " " << fptosi->GetName() << " = fptosi "
<< TypeToString(*fptosi->GetValue()->GetType()) << " "
<< ValueToString(fptosi->GetValue()) << " to "
<< TypeToString(*fptosi->GetType()) << "\n";
break;
}
default: {
// 处理未知操作码
os << " ; 未知指令: " << OpcodeToString(inst->GetOpcode()) << "\n";
break;
}
}
}
}
os << "}\n";
}
}
void IRPrinter::PrintConstant(const ConstantValue* constant, std::ostream& os) {
if (auto* const_int = dynamic_cast<const ConstantInt*>(constant)) {
os << const_int->GetValue();
}
else if (auto* const_float = dynamic_cast<const ConstantFloat*>(constant)) {
os << const_float->GetValue();
}
else if (auto* const_array = dynamic_cast<const ConstantArray*>(constant)) {
os << "[";
auto& elements = const_array->GetElements();
for (size_t i = 0; i < elements.size(); ++i) {
if (i > 0) os << ", ";
PrintConstant(elements[i], os);
}
os << "]";
}
else if (dynamic_cast<const ConstantZero*>(constant)) {
os << "zero";
}
else if (dynamic_cast<const ConstantAggregateZero*>(constant)) {
os << "zeroinitializer";
}
}
} // namespace ir