You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/ir/IRPrinter.cpp

391 lines
14 KiB

#include "ir/IR.h"
#include <cstdint>
#include <cstring>
#include <ostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <algorithm>
#include <unordered_map>
#include "utils/Log.h"
namespace ir {
static const char* TypeToStr(const Type& ty) {
switch (ty.GetKind()) {
case Type::Kind::Void: return "void";
case Type::Kind::Int1: return "i1";
case Type::Kind::Int32: return "i32";
case Type::Kind::Float32: return "float";
case Type::Kind::PtrInt32: return "i32*";
case Type::Kind::PtrFloat32: return "float*";
}
throw std::runtime_error(FormatError("ir", "未知类型"));
}
static const char* PredToStr(ICmpPredicate pred) {
switch (pred) {
case ICmpPredicate::EQ: return "eq";
case ICmpPredicate::NE: return "ne";
case ICmpPredicate::SLT: return "slt";
case ICmpPredicate::SLE: return "sle";
case ICmpPredicate::SGT: return "sgt";
case ICmpPredicate::SGE: return "sge";
}
return "?";
}
static const char* FPredToStr(FCmpPredicate pred) {
switch (pred) {
case FCmpPredicate::OEQ: return "oeq";
case FCmpPredicate::ONE: return "one";
case FCmpPredicate::OLT: return "olt";
case FCmpPredicate::OLE: return "ole";
case FCmpPredicate::OGT: return "ogt";
case FCmpPredicate::OGE: return "oge";
}
return "?";
}
using RenameMap = std::unordered_map<const Value*, int>;
static std::string ValStr(const Value* v, const RenameMap& rm) {
if (!v) return "<null>";
if (dynamic_cast<const ConstantInt*>(v))
return std::to_string(static_cast<const ConstantInt*>(v)->GetValue());
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
double d = static_cast<double>(cf->GetValue());
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "0x" << std::hex << std::uppercase << bits;
return oss.str();
}
if (dynamic_cast<const BasicBlock*>(v))
return "%" + v->GetName();
if (auto* gv = dynamic_cast<const GlobalVariable*>(v)) {
if (gv->IsArray()) {
const char* et = gv->IsFloat() ? "float" : "i32";
return std::string("getelementptr ([") + std::to_string(gv->GetNumElements()) +
" x " + et + "], [" + std::to_string(gv->GetNumElements()) +
" x " + et + "]* @" + gv->GetName() + ", i32 0, i32 0)";
}
return "@" + v->GetName();
}
auto it = rm.find(v);
if (it != rm.end()) return "%" + std::to_string(it->second);
return "%" + v->GetName();
}
static std::string TypeVal(const Value* v, const RenameMap& rm) {
if (!v) return "void";
if (dynamic_cast<const ConstantInt*>(v))
return std::string(TypeToStr(*v->GetType())) + " " +
std::to_string(static_cast<const ConstantInt*>(v)->GetValue());
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
double d = static_cast<double>(cf->GetValue());
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "float 0x" << std::hex << std::uppercase << bits;
return oss.str();
}
return std::string(TypeToStr(*v->GetType())) + " " + ValStr(v, rm);
}
// Print one instruction (non-alloca) using rename map
static void PrintInst(const Instruction* inst, std::ostream& os,
const RenameMap& rm) {
auto N = [&](const Value* v) -> std::string {
auto it = rm.find(v);
if (it != rm.end()) return std::to_string(it->second);
return v->GetName();
};
auto VS = [&](const Value* v) { return ValStr(v, rm); };
auto TV = [&](const Value* v) { return TypeVal(v, rm); };
switch (inst->GetOpcode()) {
case Opcode::Add: case Opcode::Sub: case Opcode::Mul:
case Opcode::Div: case Opcode::Mod: {
auto* bin = static_cast<const BinaryInst*>(inst);
const char* op = nullptr;
switch (bin->GetOpcode()) {
case Opcode::Add: op = "add"; break;
case Opcode::Sub: op = "sub"; break;
case Opcode::Mul: op = "mul"; break;
case Opcode::Div: op = "sdiv"; break;
case Opcode::Mod: op = "srem"; break;
default: op = "?"; break;
}
os << " %" << N(bin) << " = " << op << " i32 "
<< VS(bin->GetLhs()) << ", " << VS(bin->GetRhs()) << "\n";
break;
}
case Opcode::FAdd: case Opcode::FSub:
case Opcode::FMul: case Opcode::FDiv: {
auto* bin = static_cast<const BinaryInst*>(inst);
const char* op = nullptr;
switch (bin->GetOpcode()) {
case Opcode::FAdd: op = "fadd"; break;
case Opcode::FSub: op = "fsub"; break;
case Opcode::FMul: op = "fmul"; break;
case Opcode::FDiv: op = "fdiv"; break;
default: op = "?"; break;
}
os << " %" << N(bin) << " = " << op << " float "
<< VS(bin->GetLhs()) << ", " << VS(bin->GetRhs()) << "\n";
break;
}
case Opcode::ICmp: {
auto* cmp = static_cast<const ICmpInst*>(inst);
os << " %" << N(cmp) << " = icmp " << PredToStr(cmp->GetPredicate())
<< " i32 " << VS(cmp->GetLhs()) << ", " << VS(cmp->GetRhs()) << "\n";
break;
}
case Opcode::FCmp: {
auto* cmp = static_cast<const FCmpInst*>(inst);
os << " %" << N(cmp) << " = fcmp " << FPredToStr(cmp->GetPredicate())
<< " float " << VS(cmp->GetLhs()) << ", " << VS(cmp->GetRhs()) << "\n";
break;
}
case Opcode::Alloca: {
auto* al = static_cast<const AllocaInst*>(inst);
const char* et = al->GetType()->IsPtrFloat32() ? "float" : "i32";
if (al->IsArray())
os << " %" << N(al) << " = alloca " << et << ", i32 " << al->GetNumElements() << "\n";
else
os << " %" << N(al) << " = alloca " << et << "\n";
break;
}
case Opcode::Gep: {
auto* gep = static_cast<const GepInst*>(inst);
bool fp = gep->GetBasePtr()->GetType()->IsPtrFloat32();
os << " %" << N(gep) << " = getelementptr " << (fp ? "float" : "i32")
<< ", " << (fp ? "float*" : "i32*") << " "
<< VS(gep->GetBasePtr()) << ", i32 " << VS(gep->GetIndex()) << "\n";
break;
}
case Opcode::Load: {
auto* ld = static_cast<const LoadInst*>(inst);
bool fp = ld->GetPtr()->GetType()->IsPtrFloat32();
os << " %" << N(ld) << " = load " << (fp ? "float" : "i32")
<< ", " << (fp ? "float*" : "i32*") << " " << VS(ld->GetPtr()) << "\n";
break;
}
case Opcode::Store: {
auto* st = static_cast<const StoreInst*>(inst);
os << " store " << TV(st->GetValue()) << ", "
<< TypeToStr(*st->GetPtr()->GetType()) << " " << VS(st->GetPtr()) << "\n";
break;
}
case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst);
if (!ret->HasValue()) os << " ret void\n";
else os << " ret " << TV(ret->GetValue()) << "\n";
break;
}
case Opcode::Br: {
auto* br = static_cast<const BrInst*>(inst);
os << " br label %" << br->GetTarget()->GetName() << "\n";
break;
}
case Opcode::CondBr: {
auto* cbr = static_cast<const CondBrInst*>(inst);
os << " br i1 " << VS(cbr->GetCond()) << ", label %"
<< cbr->GetTrueBB()->GetName() << ", label %"
<< cbr->GetFalseBB()->GetName() << "\n";
break;
}
case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst);
if (!call->IsVoid() && !call->GetName().empty())
os << " %" << N(call) << " = ";
else
os << " ";
os << "call " << (call->IsVoid() ? "void" : TypeToStr(*call->GetType()))
<< " @" << call->GetCalleeName() << "(";
for (size_t i = 0; i < call->GetNumArgs(); ++i) {
if (i > 0) os << ", ";
os << TV(call->GetArg(i));
}
os << ")\n";
break;
}
case Opcode::ZExt: {
auto* ze = static_cast<const ZExtInst*>(inst);
os << " %" << N(ze) << " = zext i1 " << VS(ze->GetSrc()) << " to i32\n";
break;
}
case Opcode::SIToFP: {
auto* si = static_cast<const SIToFPInst*>(inst);
os << " %" << N(si) << " = sitofp i32 " << VS(si->GetSrc()) << " to float\n";
break;
}
case Opcode::FPToSI: {
auto* fp = static_cast<const FPToSIInst*>(inst);
os << " %" << N(fp) << " = fptosi float " << VS(fp->GetSrc()) << " to i32\n";
break;
}
case Opcode::Phi: {
auto* phi = static_cast<const PhiInst*>(inst);
os << " %" << N(phi) << " = phi " << TypeToStr(*phi->GetType()) << " ";
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
if (i > 0) os << ", ";
os << "[ " << VS(phi->GetIncomingValue(i)) << ", %"
<< phi->GetIncomingBlock(i)->GetName() << " ]";
}
os << "\n";
break;
}
}
}
void IRPrinter::Print(const Module& module, std::ostream& os) {
// 1. 全局变量/常量
for (const auto& gv : module.GetGlobalVariables()) {
if (!gv) continue;
if (gv->IsConstant()) {
os << "@" << gv->GetName() << " = constant i32 " << gv->GetInitVal() << "\n";
} else if (gv->IsArray()) {
const char* et = gv->IsFloat() ? "float" : "i32";
os << "@" << gv->GetName() << " = global [" << gv->GetNumElements()
<< " x " << et << "] ";
if (!gv->HasInitVals()) {
os << "zeroinitializer\n";
} else if (gv->IsFloat()) {
const auto& vals = gv->GetInitValsF();
bool all_zero = std::all_of(vals.begin(), vals.end(), [](float f){ return f == 0.0f; });
if (all_zero) {
os << "zeroinitializer\n";
} else {
os << "[";
for (int i = 0; i < gv->GetNumElements(); ++i) {
if (i > 0) os << ", ";
float fv = (i < (int)vals.size()) ? vals[i] : 0.0f;
double d = static_cast<double>(fv);
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "float 0x" << std::hex << std::uppercase << bits;
os << oss.str();
}
os << "]\n";
}
} else {
const auto& vals = gv->GetInitVals();
bool all_zero = std::all_of(vals.begin(), vals.end(), [](int v){ return v == 0; });
if (all_zero) {
os << "zeroinitializer\n";
} else {
os << "[";
for (int i = 0; i < gv->GetNumElements(); ++i) {
if (i > 0) os << ", ";
os << "i32 " << (i < (int)vals.size() ? vals[i] : 0);
}
os << "]\n";
}
}
} else {
if(gv->IsFloat()) {
double d = static_cast<double>(gv->GetInitValF());
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "0x" << std::hex << std::uppercase << bits;
os << "@" << gv->GetName() << " = global float " << oss.str() << "\n";
} else
{
os << "@" << gv->GetName() << " = global i32 " << gv->GetInitVal() << "\n";
}
}
}
if (!module.GetGlobalVariables().empty()) os << "\n";
// 2. 外部声明
for (const auto& decl : module.GetExternalDecls()) {
os << "declare " << TypeToStr(*decl.ret_type) << " @" << decl.name << "(";
for (size_t i = 0; i < decl.param_types.size(); ++i) {
if (i > 0) os << ", ";
os << TypeToStr(*decl.param_types[i]);
}
os << ")\n";
}
if (!module.GetExternalDecls().empty()) os << "\n";
// 3. 函数定义
for (const auto& func : module.GetFunctions()) {
os << "define " << TypeToStr(*func->GetType()) << " @" << func->GetName() << "(";
for (size_t i = 0; i < func->GetNumArgs(); ++i) {
if (i > 0) os << ", ";
auto* arg = func->GetArgument(i);
os << TypeToStr(*arg->GetType()) << " %" << arg->GetName();
}
os << ") {\n";
// Build rename map: alloca instructions first (in block order), then rest
RenameMap rm;
int next_id = 0;
auto assign = [&](const Value* v) {
if (!v) return;
if (dynamic_cast<const ConstantInt*>(v)) return;
if (dynamic_cast<const ConstantFloat*>(v)) return;
if (dynamic_cast<const BasicBlock*>(v)) return;
if (dynamic_cast<const GlobalVariable*>(v)) return;
if (dynamic_cast<const Argument*>(v)) return;
if (rm.count(v) == 0) rm[v] = next_id++;
};
// Pass 1: all allocas across all blocks
for (const auto& bb : func->GetBlocks()) {
if (!bb) continue;
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() == Opcode::Alloca) assign(ip.get());
}
// Pass 2: all non-alloca instructions in block order
for (const auto& bb : func->GetBlocks()) {
if (!bb) continue;
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() != Opcode::Alloca) assign(ip.get());
}
// Print: entry block first with all allocas hoisted, then rest
bool first_bb = true;
for (const auto& bb : func->GetBlocks()) {
if (!bb) continue;
os << bb->GetName() << ":\n";
if (first_bb) {
first_bb = false;
// Print all allocas from all blocks (only for entry block)
for (const auto& bb2 : func->GetBlocks()) {
if (!bb2) continue;
for (const auto& ip : bb2->GetInstructions())
if (ip->GetOpcode() == Opcode::Alloca)
PrintInst(ip.get(), os, rm);
}
// Print PHI nodes of entry block
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() == Opcode::Phi)
PrintInst(ip.get(), os, rm);
// Print non-alloca non-phi instructions of entry block
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() != Opcode::Alloca && ip->GetOpcode() != Opcode::Phi)
PrintInst(ip.get(), os, rm);
} else {
// Non-entry blocks: skip allocas (already printed)
// Print PHI nodes first
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() == Opcode::Phi)
PrintInst(ip.get(), os, rm);
// Print non-alloca non-phi instructions
for (const auto& ip : bb->GetInstructions())
if (ip->GetOpcode() != Opcode::Alloca && ip->GetOpcode() != Opcode::Phi)
PrintInst(ip.get(), os, rm);
}
}
os << "}\n\n";
}
}
} // namespace ir