forked from ppxf25tqu/nudt-compiler-cpp
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
391 lines
14 KiB
391 lines
14 KiB
#include "ir/IR.h"
|
|
|
|
#include <cstdint>
|
|
#include <cstring>
|
|
#include <ostream>
|
|
#include <sstream>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <algorithm>
|
|
#include <unordered_map>
|
|
|
|
#include "utils/Log.h"
|
|
|
|
namespace ir {
|
|
|
|
static const char* TypeToStr(const Type& ty) {
|
|
switch (ty.GetKind()) {
|
|
case Type::Kind::Void: return "void";
|
|
case Type::Kind::Int1: return "i1";
|
|
case Type::Kind::Int32: return "i32";
|
|
case Type::Kind::Float32: return "float";
|
|
case Type::Kind::PtrInt32: return "i32*";
|
|
case Type::Kind::PtrFloat32: return "float*";
|
|
}
|
|
throw std::runtime_error(FormatError("ir", "未知类型"));
|
|
}
|
|
|
|
static const char* PredToStr(ICmpPredicate pred) {
|
|
switch (pred) {
|
|
case ICmpPredicate::EQ: return "eq";
|
|
case ICmpPredicate::NE: return "ne";
|
|
case ICmpPredicate::SLT: return "slt";
|
|
case ICmpPredicate::SLE: return "sle";
|
|
case ICmpPredicate::SGT: return "sgt";
|
|
case ICmpPredicate::SGE: return "sge";
|
|
}
|
|
return "?";
|
|
}
|
|
|
|
static const char* FPredToStr(FCmpPredicate pred) {
|
|
switch (pred) {
|
|
case FCmpPredicate::OEQ: return "oeq";
|
|
case FCmpPredicate::ONE: return "one";
|
|
case FCmpPredicate::OLT: return "olt";
|
|
case FCmpPredicate::OLE: return "ole";
|
|
case FCmpPredicate::OGT: return "ogt";
|
|
case FCmpPredicate::OGE: return "oge";
|
|
}
|
|
return "?";
|
|
}
|
|
|
|
using RenameMap = std::unordered_map<const Value*, int>;
|
|
|
|
static std::string ValStr(const Value* v, const RenameMap& rm) {
|
|
if (!v) return "<null>";
|
|
if (dynamic_cast<const ConstantInt*>(v))
|
|
return std::to_string(static_cast<const ConstantInt*>(v)->GetValue());
|
|
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
|
|
double d = static_cast<double>(cf->GetValue());
|
|
uint64_t bits;
|
|
std::memcpy(&bits, &d, sizeof(bits));
|
|
std::ostringstream oss;
|
|
oss << "0x" << std::hex << std::uppercase << bits;
|
|
return oss.str();
|
|
}
|
|
if (dynamic_cast<const BasicBlock*>(v))
|
|
return "%" + v->GetName();
|
|
if (auto* gv = dynamic_cast<const GlobalVariable*>(v)) {
|
|
if (gv->IsArray()) {
|
|
const char* et = gv->IsFloat() ? "float" : "i32";
|
|
return std::string("getelementptr ([") + std::to_string(gv->GetNumElements()) +
|
|
" x " + et + "], [" + std::to_string(gv->GetNumElements()) +
|
|
" x " + et + "]* @" + gv->GetName() + ", i32 0, i32 0)";
|
|
}
|
|
return "@" + v->GetName();
|
|
}
|
|
auto it = rm.find(v);
|
|
if (it != rm.end()) return "%" + std::to_string(it->second);
|
|
return "%" + v->GetName();
|
|
}
|
|
|
|
static std::string TypeVal(const Value* v, const RenameMap& rm) {
|
|
if (!v) return "void";
|
|
if (dynamic_cast<const ConstantInt*>(v))
|
|
return std::string(TypeToStr(*v->GetType())) + " " +
|
|
std::to_string(static_cast<const ConstantInt*>(v)->GetValue());
|
|
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
|
|
double d = static_cast<double>(cf->GetValue());
|
|
uint64_t bits;
|
|
std::memcpy(&bits, &d, sizeof(bits));
|
|
std::ostringstream oss;
|
|
oss << "float 0x" << std::hex << std::uppercase << bits;
|
|
return oss.str();
|
|
}
|
|
return std::string(TypeToStr(*v->GetType())) + " " + ValStr(v, rm);
|
|
}
|
|
|
|
// Print one instruction (non-alloca) using rename map
|
|
static void PrintInst(const Instruction* inst, std::ostream& os,
|
|
const RenameMap& rm) {
|
|
auto N = [&](const Value* v) -> std::string {
|
|
auto it = rm.find(v);
|
|
if (it != rm.end()) return std::to_string(it->second);
|
|
return v->GetName();
|
|
};
|
|
auto VS = [&](const Value* v) { return ValStr(v, rm); };
|
|
auto TV = [&](const Value* v) { return TypeVal(v, rm); };
|
|
|
|
switch (inst->GetOpcode()) {
|
|
case Opcode::Add: case Opcode::Sub: case Opcode::Mul:
|
|
case Opcode::Div: case Opcode::Mod: {
|
|
auto* bin = static_cast<const BinaryInst*>(inst);
|
|
const char* op = nullptr;
|
|
switch (bin->GetOpcode()) {
|
|
case Opcode::Add: op = "add"; break;
|
|
case Opcode::Sub: op = "sub"; break;
|
|
case Opcode::Mul: op = "mul"; break;
|
|
case Opcode::Div: op = "sdiv"; break;
|
|
case Opcode::Mod: op = "srem"; break;
|
|
default: op = "?"; break;
|
|
}
|
|
os << " %" << N(bin) << " = " << op << " i32 "
|
|
<< VS(bin->GetLhs()) << ", " << VS(bin->GetRhs()) << "\n";
|
|
break;
|
|
}
|
|
case Opcode::FAdd: case Opcode::FSub:
|
|
case Opcode::FMul: case Opcode::FDiv: {
|
|
auto* bin = static_cast<const BinaryInst*>(inst);
|
|
const char* op = nullptr;
|
|
switch (bin->GetOpcode()) {
|
|
case Opcode::FAdd: op = "fadd"; break;
|
|
case Opcode::FSub: op = "fsub"; break;
|
|
case Opcode::FMul: op = "fmul"; break;
|
|
case Opcode::FDiv: op = "fdiv"; break;
|
|
default: op = "?"; break;
|
|
}
|
|
os << " %" << N(bin) << " = " << op << " float "
|
|
<< VS(bin->GetLhs()) << ", " << VS(bin->GetRhs()) << "\n";
|
|
break;
|
|
}
|
|
case Opcode::ICmp: {
|
|
auto* cmp = static_cast<const ICmpInst*>(inst);
|
|
os << " %" << N(cmp) << " = icmp " << PredToStr(cmp->GetPredicate())
|
|
<< " i32 " << VS(cmp->GetLhs()) << ", " << VS(cmp->GetRhs()) << "\n";
|
|
break;
|
|
}
|
|
case Opcode::FCmp: {
|
|
auto* cmp = static_cast<const FCmpInst*>(inst);
|
|
os << " %" << N(cmp) << " = fcmp " << FPredToStr(cmp->GetPredicate())
|
|
<< " float " << VS(cmp->GetLhs()) << ", " << VS(cmp->GetRhs()) << "\n";
|
|
break;
|
|
}
|
|
case Opcode::Alloca: {
|
|
auto* al = static_cast<const AllocaInst*>(inst);
|
|
const char* et = al->GetType()->IsPtrFloat32() ? "float" : "i32";
|
|
if (al->IsArray())
|
|
os << " %" << N(al) << " = alloca " << et << ", i32 " << al->GetNumElements() << "\n";
|
|
else
|
|
os << " %" << N(al) << " = alloca " << et << "\n";
|
|
break;
|
|
}
|
|
case Opcode::Gep: {
|
|
auto* gep = static_cast<const GepInst*>(inst);
|
|
bool fp = gep->GetBasePtr()->GetType()->IsPtrFloat32();
|
|
os << " %" << N(gep) << " = getelementptr " << (fp ? "float" : "i32")
|
|
<< ", " << (fp ? "float*" : "i32*") << " "
|
|
<< VS(gep->GetBasePtr()) << ", i32 " << VS(gep->GetIndex()) << "\n";
|
|
break;
|
|
}
|
|
case Opcode::Load: {
|
|
auto* ld = static_cast<const LoadInst*>(inst);
|
|
bool fp = ld->GetPtr()->GetType()->IsPtrFloat32();
|
|
os << " %" << N(ld) << " = load " << (fp ? "float" : "i32")
|
|
<< ", " << (fp ? "float*" : "i32*") << " " << VS(ld->GetPtr()) << "\n";
|
|
break;
|
|
}
|
|
case Opcode::Store: {
|
|
auto* st = static_cast<const StoreInst*>(inst);
|
|
os << " store " << TV(st->GetValue()) << ", "
|
|
<< TypeToStr(*st->GetPtr()->GetType()) << " " << VS(st->GetPtr()) << "\n";
|
|
break;
|
|
}
|
|
case Opcode::Ret: {
|
|
auto* ret = static_cast<const ReturnInst*>(inst);
|
|
if (!ret->HasValue()) os << " ret void\n";
|
|
else os << " ret " << TV(ret->GetValue()) << "\n";
|
|
break;
|
|
}
|
|
case Opcode::Br: {
|
|
auto* br = static_cast<const BrInst*>(inst);
|
|
os << " br label %" << br->GetTarget()->GetName() << "\n";
|
|
break;
|
|
}
|
|
case Opcode::CondBr: {
|
|
auto* cbr = static_cast<const CondBrInst*>(inst);
|
|
os << " br i1 " << VS(cbr->GetCond()) << ", label %"
|
|
<< cbr->GetTrueBB()->GetName() << ", label %"
|
|
<< cbr->GetFalseBB()->GetName() << "\n";
|
|
break;
|
|
}
|
|
case Opcode::Call: {
|
|
auto* call = static_cast<const CallInst*>(inst);
|
|
if (!call->IsVoid() && !call->GetName().empty())
|
|
os << " %" << N(call) << " = ";
|
|
else
|
|
os << " ";
|
|
os << "call " << (call->IsVoid() ? "void" : TypeToStr(*call->GetType()))
|
|
<< " @" << call->GetCalleeName() << "(";
|
|
for (size_t i = 0; i < call->GetNumArgs(); ++i) {
|
|
if (i > 0) os << ", ";
|
|
os << TV(call->GetArg(i));
|
|
}
|
|
os << ")\n";
|
|
break;
|
|
}
|
|
case Opcode::ZExt: {
|
|
auto* ze = static_cast<const ZExtInst*>(inst);
|
|
os << " %" << N(ze) << " = zext i1 " << VS(ze->GetSrc()) << " to i32\n";
|
|
break;
|
|
}
|
|
case Opcode::SIToFP: {
|
|
auto* si = static_cast<const SIToFPInst*>(inst);
|
|
os << " %" << N(si) << " = sitofp i32 " << VS(si->GetSrc()) << " to float\n";
|
|
break;
|
|
}
|
|
case Opcode::FPToSI: {
|
|
auto* fp = static_cast<const FPToSIInst*>(inst);
|
|
os << " %" << N(fp) << " = fptosi float " << VS(fp->GetSrc()) << " to i32\n";
|
|
break;
|
|
}
|
|
case Opcode::Phi: {
|
|
auto* phi = static_cast<const PhiInst*>(inst);
|
|
os << " %" << N(phi) << " = phi " << TypeToStr(*phi->GetType()) << " ";
|
|
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
|
|
if (i > 0) os << ", ";
|
|
os << "[ " << VS(phi->GetIncomingValue(i)) << ", %"
|
|
<< phi->GetIncomingBlock(i)->GetName() << " ]";
|
|
}
|
|
os << "\n";
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
void IRPrinter::Print(const Module& module, std::ostream& os) {
|
|
// 1. 全局变量/常量
|
|
for (const auto& gv : module.GetGlobalVariables()) {
|
|
if (!gv) continue;
|
|
if (gv->IsConstant()) {
|
|
os << "@" << gv->GetName() << " = constant i32 " << gv->GetInitVal() << "\n";
|
|
} else if (gv->IsArray()) {
|
|
const char* et = gv->IsFloat() ? "float" : "i32";
|
|
os << "@" << gv->GetName() << " = global [" << gv->GetNumElements()
|
|
<< " x " << et << "] ";
|
|
if (!gv->HasInitVals()) {
|
|
os << "zeroinitializer\n";
|
|
} else if (gv->IsFloat()) {
|
|
const auto& vals = gv->GetInitValsF();
|
|
bool all_zero = std::all_of(vals.begin(), vals.end(), [](float f){ return f == 0.0f; });
|
|
if (all_zero) {
|
|
os << "zeroinitializer\n";
|
|
} else {
|
|
os << "[";
|
|
for (int i = 0; i < gv->GetNumElements(); ++i) {
|
|
if (i > 0) os << ", ";
|
|
float fv = (i < (int)vals.size()) ? vals[i] : 0.0f;
|
|
double d = static_cast<double>(fv);
|
|
uint64_t bits;
|
|
std::memcpy(&bits, &d, sizeof(bits));
|
|
std::ostringstream oss;
|
|
oss << "float 0x" << std::hex << std::uppercase << bits;
|
|
os << oss.str();
|
|
}
|
|
os << "]\n";
|
|
}
|
|
} else {
|
|
const auto& vals = gv->GetInitVals();
|
|
bool all_zero = std::all_of(vals.begin(), vals.end(), [](int v){ return v == 0; });
|
|
if (all_zero) {
|
|
os << "zeroinitializer\n";
|
|
} else {
|
|
os << "[";
|
|
for (int i = 0; i < gv->GetNumElements(); ++i) {
|
|
if (i > 0) os << ", ";
|
|
os << "i32 " << (i < (int)vals.size() ? vals[i] : 0);
|
|
}
|
|
os << "]\n";
|
|
}
|
|
}
|
|
} else {
|
|
if(gv->IsFloat()) {
|
|
double d = static_cast<double>(gv->GetInitValF());
|
|
uint64_t bits;
|
|
std::memcpy(&bits, &d, sizeof(bits));
|
|
std::ostringstream oss;
|
|
oss << "0x" << std::hex << std::uppercase << bits;
|
|
os << "@" << gv->GetName() << " = global float " << oss.str() << "\n";
|
|
} else
|
|
{
|
|
os << "@" << gv->GetName() << " = global i32 " << gv->GetInitVal() << "\n";
|
|
}
|
|
}
|
|
}
|
|
if (!module.GetGlobalVariables().empty()) os << "\n";
|
|
|
|
// 2. 外部声明
|
|
for (const auto& decl : module.GetExternalDecls()) {
|
|
os << "declare " << TypeToStr(*decl.ret_type) << " @" << decl.name << "(";
|
|
for (size_t i = 0; i < decl.param_types.size(); ++i) {
|
|
if (i > 0) os << ", ";
|
|
os << TypeToStr(*decl.param_types[i]);
|
|
}
|
|
os << ")\n";
|
|
}
|
|
if (!module.GetExternalDecls().empty()) os << "\n";
|
|
|
|
// 3. 函数定义
|
|
for (const auto& func : module.GetFunctions()) {
|
|
os << "define " << TypeToStr(*func->GetType()) << " @" << func->GetName() << "(";
|
|
for (size_t i = 0; i < func->GetNumArgs(); ++i) {
|
|
if (i > 0) os << ", ";
|
|
auto* arg = func->GetArgument(i);
|
|
os << TypeToStr(*arg->GetType()) << " %" << arg->GetName();
|
|
}
|
|
os << ") {\n";
|
|
|
|
// Build rename map: alloca instructions first (in block order), then rest
|
|
RenameMap rm;
|
|
int next_id = 0;
|
|
auto assign = [&](const Value* v) {
|
|
if (!v) return;
|
|
if (dynamic_cast<const ConstantInt*>(v)) return;
|
|
if (dynamic_cast<const ConstantFloat*>(v)) return;
|
|
if (dynamic_cast<const BasicBlock*>(v)) return;
|
|
if (dynamic_cast<const GlobalVariable*>(v)) return;
|
|
if (dynamic_cast<const Argument*>(v)) return;
|
|
if (rm.count(v) == 0) rm[v] = next_id++;
|
|
};
|
|
// Pass 1: all allocas across all blocks
|
|
for (const auto& bb : func->GetBlocks()) {
|
|
if (!bb) continue;
|
|
for (const auto& ip : bb->GetInstructions())
|
|
if (ip->GetOpcode() == Opcode::Alloca) assign(ip.get());
|
|
}
|
|
// Pass 2: all non-alloca instructions in block order
|
|
for (const auto& bb : func->GetBlocks()) {
|
|
if (!bb) continue;
|
|
for (const auto& ip : bb->GetInstructions())
|
|
if (ip->GetOpcode() != Opcode::Alloca) assign(ip.get());
|
|
}
|
|
|
|
// Print: entry block first with all allocas hoisted, then rest
|
|
bool first_bb = true;
|
|
for (const auto& bb : func->GetBlocks()) {
|
|
if (!bb) continue;
|
|
os << bb->GetName() << ":\n";
|
|
if (first_bb) {
|
|
first_bb = false;
|
|
// Print all allocas from all blocks (only for entry block)
|
|
for (const auto& bb2 : func->GetBlocks()) {
|
|
if (!bb2) continue;
|
|
for (const auto& ip : bb2->GetInstructions())
|
|
if (ip->GetOpcode() == Opcode::Alloca)
|
|
PrintInst(ip.get(), os, rm);
|
|
}
|
|
// Print PHI nodes of entry block
|
|
for (const auto& ip : bb->GetInstructions())
|
|
if (ip->GetOpcode() == Opcode::Phi)
|
|
PrintInst(ip.get(), os, rm);
|
|
// Print non-alloca non-phi instructions of entry block
|
|
for (const auto& ip : bb->GetInstructions())
|
|
if (ip->GetOpcode() != Opcode::Alloca && ip->GetOpcode() != Opcode::Phi)
|
|
PrintInst(ip.get(), os, rm);
|
|
} else {
|
|
// Non-entry blocks: skip allocas (already printed)
|
|
// Print PHI nodes first
|
|
for (const auto& ip : bb->GetInstructions())
|
|
if (ip->GetOpcode() == Opcode::Phi)
|
|
PrintInst(ip.get(), os, rm);
|
|
// Print non-alloca non-phi instructions
|
|
for (const auto& ip : bb->GetInstructions())
|
|
if (ip->GetOpcode() != Opcode::Alloca && ip->GetOpcode() != Opcode::Phi)
|
|
PrintInst(ip.get(), os, rm);
|
|
}
|
|
}
|
|
os << "}\n\n";
|
|
}
|
|
}
|
|
|
|
} // namespace ir
|