forked from NUDT-compiler/nudt-compiler-cpp
Compare commits
2 Commits
master
...
parallel-o
| Author | SHA1 | Date |
|---|---|---|
|
|
93ff6fad02 | 1 week ago |
|
|
cbf1e6ba83 | 1 week ago |
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <iosfwd>
|
||||
#include <memory>
|
||||
|
||||
#include "mir/MIR.h"
|
||||
|
||||
namespace ir {
|
||||
class Module;
|
||||
}
|
||||
|
||||
namespace mir {
|
||||
|
||||
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module);
|
||||
void RunRegAlloc(MachineModule& module);
|
||||
void RunFrameLowering(MachineModule& module);
|
||||
void PrintAsm(const MachineModule& module, std::ostream& os);
|
||||
|
||||
} // namespace mir
|
||||
@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "mir/MIR.h"
|
||||
|
||||
namespace mir {
|
||||
|
||||
bool RunPeephole(MachineModule& module);
|
||||
bool RunSpillReduction(MachineModule& module);
|
||||
bool RunCFGCleanup(MachineModule& module);
|
||||
void RunAddressHoisting(MachineModule& module);
|
||||
void VerifyMIR(const MachineModule& module);
|
||||
void RunMIRPreRegAllocPassPipeline(MachineModule& module);
|
||||
void RunMIRPostRegAllocPassPipeline(MachineModule& module);
|
||||
|
||||
} // namespace mir
|
||||
@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
namespace utils {
|
||||
|
||||
inline bool IsEnvFlagSet(const char* name) {
|
||||
const char* value = std::getenv(name);
|
||||
return value != nullptr && value[0] != '\0' && value[0] != '0';
|
||||
}
|
||||
|
||||
inline bool IsEnabledUnlessEnvFlag(const char* disable_flag_name) {
|
||||
return !IsEnvFlagSet(disable_flag_name);
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
@ -0,0 +1,214 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
[[noreturn]] void Fail(const Function* function, const BasicBlock* block,
|
||||
const std::string& message) {
|
||||
std::string where = "[ir-verify]";
|
||||
if (function != nullptr) {
|
||||
where += " function " + function->GetName();
|
||||
}
|
||||
if (block != nullptr) {
|
||||
where += " block " + block->GetName();
|
||||
}
|
||||
throw std::runtime_error(where + ": " + message);
|
||||
}
|
||||
|
||||
bool Contains(const std::vector<BasicBlock*>& blocks, const BasicBlock* needle) {
|
||||
return std::find(blocks.begin(), blocks.end(), needle) != blocks.end();
|
||||
}
|
||||
|
||||
bool SameType(const std::shared_ptr<Type>& lhs, const std::shared_ptr<Type>& rhs) {
|
||||
if (lhs == rhs) {
|
||||
return true;
|
||||
}
|
||||
if (!lhs || !rhs || lhs->GetKind() != rhs->GetKind()) {
|
||||
return false;
|
||||
}
|
||||
if (lhs->IsPointer()) {
|
||||
return true;
|
||||
}
|
||||
if (lhs->IsArray()) {
|
||||
return lhs->GetNumElements() == rhs->GetNumElements() &&
|
||||
SameType(lhs->GetElementType(), rhs->GetElementType());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void CheckValueUse(const Function& function, const BasicBlock& block,
|
||||
const Instruction& inst, std::size_t operand_index) {
|
||||
auto* value = inst.GetOperand(operand_index);
|
||||
if (value == nullptr) {
|
||||
Fail(&function, &block, "null operand");
|
||||
}
|
||||
const auto& uses = value->GetUses();
|
||||
const bool found = std::any_of(uses.begin(), uses.end(), [&](const Use& use) {
|
||||
return use.GetUser() == &inst && use.GetOperandIndex() == operand_index;
|
||||
});
|
||||
if (!found) {
|
||||
Fail(&function, &block, "operand use-list is inconsistent");
|
||||
}
|
||||
}
|
||||
|
||||
void CheckTerminatorTargets(const Function& function, const BasicBlock& block,
|
||||
const std::unordered_set<const BasicBlock*>& blocks) {
|
||||
if (block.GetInstructions().empty()) {
|
||||
Fail(&function, &block, "empty block has no terminator");
|
||||
}
|
||||
const auto* terminator = block.GetInstructions().back().get();
|
||||
if (!terminator->IsTerminator()) {
|
||||
Fail(&function, &block, "block has no terminator");
|
||||
}
|
||||
|
||||
std::vector<BasicBlock*> expected;
|
||||
if (auto* br = dyncast<UncondBrInst>(terminator)) {
|
||||
expected.push_back(br->GetDest());
|
||||
} else if (auto* br = dyncast<CondBrInst>(terminator)) {
|
||||
if (!br->GetCondition() || !br->GetCondition()->IsBool()) {
|
||||
Fail(&function, &block, "conditional branch condition must be i1");
|
||||
}
|
||||
expected.push_back(br->GetThenBlock());
|
||||
expected.push_back(br->GetElseBlock());
|
||||
}
|
||||
|
||||
for (auto* succ : expected) {
|
||||
if (succ == nullptr || blocks.count(succ) == 0) {
|
||||
Fail(&function, &block, "terminator targets a block outside the function");
|
||||
}
|
||||
if (!Contains(block.GetSuccessors(), succ)) {
|
||||
Fail(&function, &block, "terminator target is missing from successor list");
|
||||
}
|
||||
}
|
||||
for (auto* succ : block.GetSuccessors()) {
|
||||
if (succ == nullptr || blocks.count(succ) == 0) {
|
||||
Fail(&function, &block, "successor list contains a block outside the function");
|
||||
}
|
||||
if (!Contains(succ->GetPredecessors(), const_cast<BasicBlock*>(&block))) {
|
||||
Fail(&function, &block, "successor/predecessor lists are inconsistent");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckInstructionTypes(const Function& function, const BasicBlock& block,
|
||||
const Instruction& inst) {
|
||||
for (std::size_t i = 0; i < inst.GetNumOperands(); ++i) {
|
||||
CheckValueUse(function, block, inst, i);
|
||||
}
|
||||
|
||||
if (auto* ret = dyncast<ReturnInst>(&inst)) {
|
||||
if (function.GetReturnType()->IsVoid()) {
|
||||
if (ret->HasReturnValue()) {
|
||||
Fail(&function, &block, "void function returns a value");
|
||||
}
|
||||
} else if (!ret->HasReturnValue() ||
|
||||
!SameType(function.GetReturnType(), ret->GetReturnValue()->GetType())) {
|
||||
Fail(&function, &block, "return value type does not match function type");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto* call = dyncast<CallInst>(&inst)) {
|
||||
auto* callee = call->GetCallee();
|
||||
if (callee == nullptr) {
|
||||
Fail(&function, &block, "call has no callee");
|
||||
}
|
||||
const auto args = call->GetArguments();
|
||||
if (args.size() != callee->GetParamTypes().size()) {
|
||||
Fail(&function, &block, "call argument count mismatch");
|
||||
}
|
||||
for (std::size_t i = 0; i < args.size(); ++i) {
|
||||
if (!SameType(args[i]->GetType(), callee->GetParamTypes()[i])) {
|
||||
Fail(&function, &block, "call argument type mismatch");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckFunction(const Function& function) {
|
||||
if (function.IsExternal()) {
|
||||
if (!function.GetBlocks().empty()) {
|
||||
Fail(&function, nullptr, "external function must not have blocks");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
std::unordered_set<const BasicBlock*> blocks;
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
if (!block_ptr) {
|
||||
Fail(&function, nullptr, "null block");
|
||||
}
|
||||
auto* block = block_ptr.get();
|
||||
if (block->GetParent() != &function) {
|
||||
Fail(&function, block, "block parent is inconsistent");
|
||||
}
|
||||
blocks.insert(block);
|
||||
}
|
||||
if (function.GetEntryBlock() == nullptr || blocks.count(function.GetEntryBlock()) == 0) {
|
||||
Fail(&function, nullptr, "entry block is missing or outside the function");
|
||||
}
|
||||
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
const auto& block = *block_ptr;
|
||||
const auto& instructions = block.GetInstructions();
|
||||
for (std::size_t i = 0; i < instructions.size(); ++i) {
|
||||
auto* inst = instructions[i].get();
|
||||
if (inst == nullptr) {
|
||||
Fail(&function, &block, "null instruction");
|
||||
}
|
||||
if (inst->GetParent() != &block) {
|
||||
Fail(&function, &block, "instruction parent is inconsistent");
|
||||
}
|
||||
if (inst->IsTerminator() && i + 1 != instructions.size()) {
|
||||
Fail(&function, &block, "terminator is not the last instruction");
|
||||
}
|
||||
CheckInstructionTypes(function, block, *inst);
|
||||
}
|
||||
CheckTerminatorTargets(function, block, blocks);
|
||||
|
||||
for (auto* pred : block.GetPredecessors()) {
|
||||
if (pred == nullptr || blocks.count(pred) == 0) {
|
||||
Fail(&function, &block, "predecessor list contains a block outside the function");
|
||||
}
|
||||
if (!Contains(pred->GetSuccessors(), const_cast<BasicBlock*>(&block))) {
|
||||
Fail(&function, &block, "predecessor/successor lists are inconsistent");
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& inst_ptr : instructions) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
|
||||
auto* incoming_block = phi->GetIncomingBlock(i);
|
||||
if (!Contains(block.GetPredecessors(), incoming_block)) {
|
||||
Fail(&function, &block, "phi incoming block is not a predecessor");
|
||||
}
|
||||
if (!SameType(phi->GetType(), phi->GetIncomingValue(i)->GetType())) {
|
||||
Fail(&function, &block, "phi incoming value type mismatch");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void VerifyIR(const Module& module) {
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
CheckFunction(*function);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,456 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "LoopMemoryUtils.h"
|
||||
#include "LoopPassUtils.h"
|
||||
#include "utils/OptConfig.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
constexpr const char* kRuntimeParallelFor = "__nudtc_parallel_for_i32";
|
||||
constexpr const char* kWorkerPrefix = "__nudtc_par_worker_";
|
||||
constexpr const char* kCapturePrefix = "__nudtc_par_cap_";
|
||||
constexpr int kDefaultParallelMinTrip = 8192;
|
||||
|
||||
struct CaptureInfo {
|
||||
Value* value = nullptr;
|
||||
GlobalValue* slot = nullptr;
|
||||
};
|
||||
|
||||
struct ParallelLoopInfo {
|
||||
Loop* loop = nullptr;
|
||||
BasicBlock* preheader = nullptr;
|
||||
BasicBlock* header = nullptr;
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
Value* begin = nullptr;
|
||||
Value* end = nullptr;
|
||||
loopmem::SimpleInductionVar iv;
|
||||
BinaryInst* compare = nullptr;
|
||||
BinaryInst* step = nullptr;
|
||||
std::vector<CaptureInfo> captures;
|
||||
};
|
||||
|
||||
bool HasPrefix(const std::string& value, const char* prefix) {
|
||||
return value.rfind(prefix, 0) == 0;
|
||||
}
|
||||
|
||||
void ReplaceTerminatorWithCondBr(BasicBlock* block, Value* cond,
|
||||
BasicBlock* then_block,
|
||||
BasicBlock* else_block) {
|
||||
auto& instructions = block->GetInstructions();
|
||||
if (instructions.empty() || !instructions.back()->IsTerminator()) {
|
||||
return;
|
||||
}
|
||||
instructions.back()->ClearAllOperands();
|
||||
auto branch =
|
||||
std::make_unique<CondBrInst>(cond, then_block, else_block, nullptr);
|
||||
branch->SetParent(block);
|
||||
instructions.back() = std::move(branch);
|
||||
}
|
||||
|
||||
int ReadPositiveEnvInt(const char* name, int fallback, int min_value, int max_value) {
|
||||
const char* raw = std::getenv(name);
|
||||
if (raw == nullptr || *raw == '\0') {
|
||||
return fallback;
|
||||
}
|
||||
char* end = nullptr;
|
||||
const long parsed = std::strtol(raw, &end, 10);
|
||||
if (end == raw || parsed < min_value || parsed > max_value) {
|
||||
return fallback;
|
||||
}
|
||||
return static_cast<int>(parsed);
|
||||
}
|
||||
|
||||
bool IsAlreadyParallelGuarded(BasicBlock* preheader) {
|
||||
auto* branch = dyncast<CondBrInst>(looputils::GetTerminator(preheader));
|
||||
if (!branch) {
|
||||
return false;
|
||||
}
|
||||
auto* cond = branch->GetCondition();
|
||||
return cond != nullptr && cond->GetName().find("%par.guard.") == 0;
|
||||
}
|
||||
|
||||
int NextWorkerId(const Module& module) {
|
||||
int next_id = 0;
|
||||
const std::string prefix = kWorkerPrefix;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (!function || !HasPrefix(function->GetName(), kWorkerPrefix)) {
|
||||
continue;
|
||||
}
|
||||
const auto suffix = function->GetName().substr(prefix.size());
|
||||
try {
|
||||
next_id = std::max(next_id, std::stoi(suffix) + 1);
|
||||
} catch (...) {
|
||||
}
|
||||
}
|
||||
return next_id;
|
||||
}
|
||||
|
||||
bool IsSupportedExternalValue(const Loop& loop, Value* value) {
|
||||
if (value == nullptr || value->IsConstant() || dyncast<GlobalValue>(value) ||
|
||||
dyncast<Function>(value) || dyncast<BasicBlock>(value)) {
|
||||
return true;
|
||||
}
|
||||
auto* inst = dyncast<Instruction>(value);
|
||||
return inst != nullptr && loop.Contains(inst->GetParent());
|
||||
}
|
||||
|
||||
bool IsCapturableValue(const Loop& loop, Value* value) {
|
||||
if (IsSupportedExternalValue(loop, value)) {
|
||||
return false;
|
||||
}
|
||||
if (!value || !value->GetType() || value->IsVoid() || value->IsLabel() ||
|
||||
value->GetType()->IsFunction() || value->IsArray()) {
|
||||
return false;
|
||||
}
|
||||
if (!value->GetType()->IsInt32() && !value->GetType()->IsFloat() &&
|
||||
!value->GetType()->IsPointer()) {
|
||||
return false;
|
||||
}
|
||||
return value->IsArgument() || dyncast<Instruction>(value) != nullptr;
|
||||
}
|
||||
|
||||
bool CollectCaptures(const Loop& loop, const DominatorTree& dom_tree,
|
||||
const ParallelLoopInfo& info,
|
||||
std::vector<Value*>& captures) {
|
||||
std::unordered_set<Value*> seen;
|
||||
auto observe = [&](Value* value) {
|
||||
if (IsSupportedExternalValue(loop, value)) {
|
||||
return true;
|
||||
}
|
||||
if (!IsCapturableValue(loop, value)) {
|
||||
return false;
|
||||
}
|
||||
if (auto* inst = dyncast<Instruction>(value)) {
|
||||
if (!dom_tree.Dominates(inst->GetParent(), info.preheader)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (seen.insert(value).second) {
|
||||
captures.push_back(value);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
for (auto* block : loop.block_list) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (dyncast<PhiInst>(inst) || inst->IsTerminator() || inst == info.step) {
|
||||
continue;
|
||||
}
|
||||
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
|
||||
if (!observe(inst->GetOperand(i))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StepOnlyFeedsLoopPhi(BinaryInst* step, PhiInst* iv) {
|
||||
if (!step || !iv) {
|
||||
return false;
|
||||
}
|
||||
for (const auto& use : step->GetUses()) {
|
||||
auto* user_inst = dyncast<Instruction>(use.GetUser());
|
||||
if (user_inst != iv) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ExitHasHeaderPhiUse(BasicBlock* exit, BasicBlock* header) {
|
||||
if (!exit || !header) {
|
||||
return true;
|
||||
}
|
||||
for (const auto& inst_ptr : exit->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
if (looputils::GetPhiIncomingIndex(phi, header) >= 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool HasWriteSelfDependence(const std::vector<loopmem::MemoryAccessInfo>& accesses,
|
||||
int iv_stride) {
|
||||
for (const auto& access : accesses) {
|
||||
if (access.is_write &&
|
||||
loopmem::HasCrossIterationDependence(access.ptr, access.ptr, iv_stride)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool BodyUsesOnlyCloneableState(const Loop& loop, const DominatorTree& dom_tree,
|
||||
const ParallelLoopInfo& info,
|
||||
std::vector<Value*>& captures) {
|
||||
for (const auto& inst_ptr : info.body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator() || inst == info.step) {
|
||||
continue;
|
||||
}
|
||||
if (!looputils::IsCloneableInstruction(inst) || dyncast<MemsetInst>(inst) ||
|
||||
dyncast<AllocaInst>(inst)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return CollectCaptures(loop, dom_tree, info, captures);
|
||||
}
|
||||
|
||||
bool MatchParallelLoop(Loop& loop, const DominatorTree& dom_tree,
|
||||
ParallelLoopInfo& info) {
|
||||
if (!loop.preheader || !loop.header || !loop.IsInnermost() ||
|
||||
IsAlreadyParallelGuarded(loop.preheader)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<PhiInst*> phis;
|
||||
loopmem::SimpleInductionVar iv;
|
||||
bool found_iv = false;
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
phis.push_back(phi);
|
||||
if (!found_iv && loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, iv)) {
|
||||
found_iv = true;
|
||||
}
|
||||
}
|
||||
if (!found_iv || phis.size() != 1 || iv.stride != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* branch = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
|
||||
auto* compare = branch ? dyncast<BinaryInst>(branch->GetCondition()) : nullptr;
|
||||
if (!branch || branch->GetThenBlock() != body || !compare ||
|
||||
compare->GetOpcode() != Opcode::ICmpLT || compare->GetLhs() != iv.phi ||
|
||||
!looputils::IsLoopInvariantValue(loop, compare->GetRhs())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* step = dyncast<BinaryInst>(iv.latch_value);
|
||||
if (!step || step->GetParent() != body || !StepOnlyFeedsLoopPhi(step, iv.phi) ||
|
||||
ExitHasHeaderPhiUse(exit, loop.header)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
info.loop = &loop;
|
||||
info.preheader = loop.preheader;
|
||||
info.header = loop.header;
|
||||
info.body = body;
|
||||
info.exit = exit;
|
||||
const int begin_index = looputils::GetPhiIncomingIndex(iv.phi, loop.preheader);
|
||||
if (begin_index < 0) {
|
||||
return false;
|
||||
}
|
||||
info.begin = iv.phi->GetIncomingValue(begin_index);
|
||||
info.end = compare->GetRhs();
|
||||
info.iv = iv;
|
||||
info.compare = compare;
|
||||
info.step = step;
|
||||
if (!info.begin || !info.end) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto accesses = loopmem::CollectMemoryAccesses(loop, iv.phi);
|
||||
if (accesses.empty() || HasWriteSelfDependence(accesses, iv.stride) ||
|
||||
!loopmem::IsLoopParallelizable(loop, iv.phi, iv.stride, accesses)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<Value*> captures;
|
||||
if (!BodyUsesOnlyCloneableState(loop, dom_tree, info, captures)) {
|
||||
return false;
|
||||
}
|
||||
info.captures.clear();
|
||||
info.captures.reserve(captures.size());
|
||||
for (auto* value : captures) {
|
||||
info.captures.push_back({value, nullptr});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Function* EnsureRuntime(Module& module) {
|
||||
auto* runtime = module.CreateFunction(
|
||||
kRuntimeParallelFor, Type::GetVoidType(),
|
||||
{Type::GetInt32Type(), Type::GetInt32Type(), Type::GetInt32Type(),
|
||||
Type::GetPointerType()},
|
||||
{"%begin", "%end", "%step", "%body"}, true);
|
||||
runtime->SetEffectInfo(true, true, true, true, false, true, false);
|
||||
return runtime;
|
||||
}
|
||||
|
||||
Function* BuildWorker(Module& module, const ParallelLoopInfo& info, int worker_id) {
|
||||
auto* worker = module.CreateFunction(
|
||||
std::string(kWorkerPrefix) + std::to_string(worker_id), Type::GetVoidType(),
|
||||
{Type::GetInt32Type(), Type::GetInt32Type()}, {"%begin", "%end"}, false);
|
||||
worker->SetEffectInfo(true, true, true, true, false, false, false);
|
||||
|
||||
auto* entry = worker->CreateBlock("entry");
|
||||
auto* header = worker->CreateBlock("par.header");
|
||||
auto* body = worker->CreateBlock("par.body");
|
||||
auto* exit = worker->CreateBlock("par.exit");
|
||||
|
||||
std::unordered_map<Value*, Value*> remap;
|
||||
for (const auto& capture : info.captures) {
|
||||
if (!capture.value || !capture.slot) {
|
||||
return nullptr;
|
||||
}
|
||||
auto* loaded = entry->Append<LoadInst>(
|
||||
capture.value->GetType(), capture.slot, nullptr,
|
||||
looputils::NextSyntheticName(*worker, "par.cap."));
|
||||
remap[capture.value] = loaded;
|
||||
}
|
||||
|
||||
entry->Append<UncondBrInst>(header, nullptr);
|
||||
entry->AddSuccessor(header);
|
||||
header->AddPredecessor(entry);
|
||||
|
||||
auto* worker_iv = header->Append<PhiInst>(
|
||||
Type::GetInt32Type(), nullptr,
|
||||
looputils::NextSyntheticName(*worker, "par.iv."));
|
||||
worker_iv->AddIncoming(worker->GetArgument(0), entry);
|
||||
auto* worker_cmp = header->Append<BinaryInst>(
|
||||
Opcode::ICmpLT, Type::GetBoolType(), worker_iv, worker->GetArgument(1), nullptr,
|
||||
looputils::NextSyntheticName(*worker, "par.cmp."));
|
||||
header->Append<CondBrInst>(worker_cmp, body, exit, nullptr);
|
||||
header->AddSuccessor(body);
|
||||
header->AddSuccessor(exit);
|
||||
body->AddPredecessor(header);
|
||||
exit->AddPredecessor(header);
|
||||
|
||||
remap[info.iv.phi] = worker_iv;
|
||||
for (const auto& inst_ptr : info.body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator() || inst == info.step) {
|
||||
continue;
|
||||
}
|
||||
looputils::CloneInstruction(*worker, inst, body, remap, "par.");
|
||||
}
|
||||
auto* next_iv = body->Append<BinaryInst>(
|
||||
Opcode::Add, Type::GetInt32Type(), worker_iv, looputils::ConstInt(1), nullptr,
|
||||
looputils::NextSyntheticName(*worker, "par.next."));
|
||||
worker_iv->AddIncoming(next_iv, body);
|
||||
body->Append<UncondBrInst>(header, nullptr);
|
||||
body->AddSuccessor(header);
|
||||
header->AddPredecessor(body);
|
||||
|
||||
exit->Append<ReturnInst>(nullptr, nullptr);
|
||||
return worker;
|
||||
}
|
||||
|
||||
void CreateCaptureSlots(Module& module, ParallelLoopInfo& info, int worker_id) {
|
||||
for (std::size_t i = 0; i < info.captures.size(); ++i) {
|
||||
auto& capture = info.captures[i];
|
||||
capture.slot = module.CreateGlobalValue(
|
||||
std::string(kCapturePrefix) + std::to_string(worker_id) + "_" +
|
||||
std::to_string(i),
|
||||
capture.value->GetType(), false, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
bool ParallelizeFirstLoopInFunction(Module& module, Function& function, int* worker_id) {
|
||||
if (function.IsExternal() || !function.GetEntryBlock() ||
|
||||
HasPrefix(function.GetName(), kWorkerPrefix) ||
|
||||
function.GetName() == kRuntimeParallelFor) {
|
||||
return false;
|
||||
}
|
||||
|
||||
DominatorTree dom_tree(function);
|
||||
LoopInfo loop_info(function, dom_tree);
|
||||
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
|
||||
ParallelLoopInfo info;
|
||||
if (!MatchParallelLoop(*loop, dom_tree, info)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* runtime = EnsureRuntime(module);
|
||||
const int current_worker_id = (*worker_id)++;
|
||||
CreateCaptureSlots(module, info, current_worker_id);
|
||||
auto* worker = BuildWorker(module, info, current_worker_id);
|
||||
if (worker == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto* parallel_block = function.CreateBlock(
|
||||
looputils::NextSyntheticBlockName(function, "par.dispatch"));
|
||||
|
||||
const int min_trip =
|
||||
ReadPositiveEnvInt("NUDTC_PARALLEL_MIN_TRIP", kDefaultParallelMinTrip, 1,
|
||||
1 << 30);
|
||||
auto* trip_count = info.preheader->Insert<BinaryInst>(
|
||||
looputils::GetTerminatorIndex(info.preheader), Opcode::Sub,
|
||||
Type::GetInt32Type(), info.end, info.begin, nullptr,
|
||||
looputils::NextSyntheticName(function, "par.trip."));
|
||||
auto* large_enough = info.preheader->Insert<BinaryInst>(
|
||||
looputils::GetTerminatorIndex(info.preheader), Opcode::ICmpGE,
|
||||
Type::GetBoolType(), trip_count, looputils::ConstInt(min_trip), nullptr,
|
||||
looputils::NextSyntheticName(function, "par.guard."));
|
||||
|
||||
ReplaceTerminatorWithCondBr(info.preheader, large_enough, parallel_block,
|
||||
info.header);
|
||||
info.preheader->AddSuccessor(parallel_block);
|
||||
parallel_block->AddPredecessor(info.preheader);
|
||||
|
||||
for (const auto& capture : info.captures) {
|
||||
parallel_block->Append<StoreInst>(capture.value, capture.slot, nullptr);
|
||||
}
|
||||
parallel_block->Append<CallInst>(
|
||||
runtime,
|
||||
std::vector<Value*>{info.begin, info.end, looputils::ConstInt(1), worker},
|
||||
nullptr, "");
|
||||
parallel_block->Append<UncondBrInst>(info.exit, nullptr);
|
||||
parallel_block->AddSuccessor(info.exit);
|
||||
info.exit->AddPredecessor(parallel_block);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoopParallelize(Module& module) {
|
||||
if (utils::IsEnvFlagSet("NUDTC_DISABLE_LOOP_PARALLELIZE")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<Function*> functions;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function && !function->IsExternal()) {
|
||||
functions.push_back(function.get());
|
||||
}
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
int worker_id = NextWorkerId(module);
|
||||
for (auto* function : functions) {
|
||||
changed |= ParallelizeFirstLoopInFunction(module, *function, &worker_id);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,607 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "LoopMemoryUtils.h"
|
||||
#include "LoopPassUtils.h"
|
||||
#include "utils/OptConfig.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
constexpr int kMinConstantTripToVectorize = 16;
|
||||
|
||||
enum class VectorKind { I32, F32 };
|
||||
enum class VectorOp { Add, Sub, Mul };
|
||||
|
||||
struct VectorLoopInfo {
|
||||
Loop* loop = nullptr;
|
||||
BasicBlock* preheader = nullptr;
|
||||
BasicBlock* header = nullptr;
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
loopmem::SimpleInductionVar iv;
|
||||
Value* begin = nullptr;
|
||||
Value* end = nullptr;
|
||||
StoreInst* store = nullptr;
|
||||
LoadInst* lhs_load = nullptr;
|
||||
LoadInst* rhs_load = nullptr;
|
||||
BinaryInst* binary = nullptr;
|
||||
VectorKind kind = VectorKind::I32;
|
||||
VectorOp op = VectorOp::Add;
|
||||
};
|
||||
|
||||
struct FillLoopInfo {
|
||||
Loop* loop = nullptr;
|
||||
BasicBlock* preheader = nullptr;
|
||||
BasicBlock* header = nullptr;
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
loopmem::SimpleInductionVar iv;
|
||||
Value* begin = nullptr;
|
||||
Value* end = nullptr;
|
||||
StoreInst* store = nullptr;
|
||||
Value* fill_value = nullptr;
|
||||
VectorKind kind = VectorKind::I32;
|
||||
};
|
||||
|
||||
bool HasExitPhiUse(BasicBlock* exit, BasicBlock* header) {
|
||||
if (!exit || !header) {
|
||||
return true;
|
||||
}
|
||||
for (const auto& inst_ptr : exit->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
if (looputils::GetPhiIncomingIndex(phi, header) >= 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool StepOnlyFeedsLoopPhi(BinaryInst* step, PhiInst* iv) {
|
||||
if (!step || !iv) {
|
||||
return false;
|
||||
}
|
||||
for (const auto& use : step->GetUses()) {
|
||||
if (dyncast<Instruction>(use.GetUser()) != iv) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MatchVectorOpcode(Opcode opcode, VectorKind* kind, VectorOp* op) {
|
||||
switch (opcode) {
|
||||
case Opcode::Add:
|
||||
*kind = VectorKind::I32;
|
||||
*op = VectorOp::Add;
|
||||
return true;
|
||||
case Opcode::Sub:
|
||||
*kind = VectorKind::I32;
|
||||
*op = VectorOp::Sub;
|
||||
return true;
|
||||
case Opcode::Mul:
|
||||
*kind = VectorKind::I32;
|
||||
*op = VectorOp::Mul;
|
||||
return true;
|
||||
case Opcode::FAdd:
|
||||
*kind = VectorKind::F32;
|
||||
*op = VectorOp::Add;
|
||||
return true;
|
||||
case Opcode::FSub:
|
||||
*kind = VectorKind::F32;
|
||||
*op = VectorOp::Sub;
|
||||
return true;
|
||||
case Opcode::FMul:
|
||||
*kind = VectorKind::F32;
|
||||
*op = VectorOp::Mul;
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsExpectedScalarType(Value* value, VectorKind kind) {
|
||||
return value && value->GetType() &&
|
||||
((kind == VectorKind::I32 && value->GetType()->IsInt32()) ||
|
||||
(kind == VectorKind::F32 && value->GetType()->IsFloat()));
|
||||
}
|
||||
|
||||
bool IsDirectVectorIndex(Value* value, PhiInst* iv, const Loop& loop) {
|
||||
return value == iv || looputils::IsLoopInvariantValue(loop, value);
|
||||
}
|
||||
|
||||
Value* RemapIndexAtBegin(Value* value, PhiInst* iv, Value* begin, const Loop& loop) {
|
||||
if (value == iv) {
|
||||
return begin;
|
||||
}
|
||||
if (looputils::IsLoopInvariantValue(loop, value)) {
|
||||
return value;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Value* MaterializePointerAtBegin(Function& function, BasicBlock* insert_block,
|
||||
Value* pointer, PhiInst* iv, Value* begin,
|
||||
const Loop& loop) {
|
||||
if (looputils::IsLoopInvariantValue(loop, pointer)) {
|
||||
return pointer;
|
||||
}
|
||||
auto* gep = dyncast<GetElementPtrInst>(pointer);
|
||||
if (!gep || !gep->GetParent() || !loop.Contains(gep->GetParent())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* base =
|
||||
MaterializePointerAtBegin(function, insert_block, gep->GetPointer(), iv, begin, loop);
|
||||
if (!base) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<Value*> indices;
|
||||
indices.reserve(gep->GetNumIndices());
|
||||
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
|
||||
auto* index = gep->GetIndex(i);
|
||||
if (!IsDirectVectorIndex(index, iv, loop)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto* mapped = RemapIndexAtBegin(index, iv, begin, loop);
|
||||
if (!mapped) {
|
||||
return nullptr;
|
||||
}
|
||||
indices.push_back(mapped);
|
||||
}
|
||||
|
||||
return insert_block->Insert<GetElementPtrInst>(
|
||||
looputils::GetTerminatorIndex(insert_block), gep->GetSourceType(), base,
|
||||
indices, nullptr, looputils::NextSyntheticName(function, "vec.ptr."));
|
||||
}
|
||||
|
||||
bool IsUnitStrideAccess(const loopmem::PointerInfo& ptr, PhiInst* iv, int access_size) {
|
||||
return ptr.byte_offset.valid && ptr.byte_offset.var == iv &&
|
||||
ptr.byte_offset.coeff == access_size;
|
||||
}
|
||||
|
||||
bool MatchVectorLoop(Loop& loop, VectorLoopInfo& info) {
|
||||
if (!loop.preheader || !loop.header || !loop.IsInnermost()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<PhiInst*> phis;
|
||||
loopmem::SimpleInductionVar iv;
|
||||
bool found_iv = false;
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
phis.push_back(phi);
|
||||
if (!found_iv && loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, iv)) {
|
||||
found_iv = true;
|
||||
}
|
||||
}
|
||||
if (!found_iv || phis.size() != 1 || iv.stride != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* branch = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
|
||||
auto* compare = branch ? dyncast<BinaryInst>(branch->GetCondition()) : nullptr;
|
||||
if (!branch || branch->GetThenBlock() != body || !compare ||
|
||||
compare->GetOpcode() != Opcode::ICmpLT || compare->GetLhs() != iv.phi ||
|
||||
!looputils::IsLoopInvariantValue(loop, compare->GetRhs())) {
|
||||
return false;
|
||||
}
|
||||
auto* preheader_branch = dyncast<UncondBrInst>(loop.preheader->GetInstructions().empty()
|
||||
? nullptr
|
||||
: loop.preheader->GetInstructions().back().get());
|
||||
if (!preheader_branch || preheader_branch->GetDest() != loop.header ||
|
||||
loop.preheader->GetSuccessors().size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* step = dyncast<BinaryInst>(iv.latch_value);
|
||||
if (!step || step->GetParent() != body || !StepOnlyFeedsLoopPhi(step, iv.phi) ||
|
||||
HasExitPhiUse(exit, loop.header)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (auto* begin_const = dyncast<ConstantInt>(iv.start)) {
|
||||
if (auto* end_const = dyncast<ConstantInt>(compare->GetRhs())) {
|
||||
if (end_const->GetValue() - begin_const->GetValue() < kMinConstantTripToVectorize) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StoreInst* store = nullptr;
|
||||
BinaryInst* binary = nullptr;
|
||||
LoadInst* lhs_load = nullptr;
|
||||
LoadInst* rhs_load = nullptr;
|
||||
|
||||
for (const auto& inst_ptr : body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator() || inst == step) {
|
||||
continue;
|
||||
}
|
||||
if (auto* gep = dyncast<GetElementPtrInst>(inst)) {
|
||||
(void)gep;
|
||||
continue;
|
||||
}
|
||||
if (auto* load = dyncast<LoadInst>(inst)) {
|
||||
if (lhs_load == nullptr) {
|
||||
lhs_load = load;
|
||||
} else if (rhs_load == nullptr) {
|
||||
rhs_load = load;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (auto* bin = dyncast<BinaryInst>(inst)) {
|
||||
if (binary != nullptr) {
|
||||
return false;
|
||||
}
|
||||
binary = bin;
|
||||
continue;
|
||||
}
|
||||
if (auto* st = dyncast<StoreInst>(inst)) {
|
||||
if (store != nullptr) {
|
||||
return false;
|
||||
}
|
||||
store = st;
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!store || !binary || !lhs_load || !rhs_load || store->GetValue() != binary) {
|
||||
return false;
|
||||
}
|
||||
lhs_load = dyncast<LoadInst>(binary->GetLhs());
|
||||
rhs_load = dyncast<LoadInst>(binary->GetRhs());
|
||||
if (!lhs_load || !rhs_load) {
|
||||
return false;
|
||||
}
|
||||
|
||||
VectorKind kind = VectorKind::I32;
|
||||
VectorOp op = VectorOp::Add;
|
||||
if (!MatchVectorOpcode(binary->GetOpcode(), &kind, &op) ||
|
||||
!IsExpectedScalarType(binary, kind) ||
|
||||
!IsExpectedScalarType(lhs_load, kind) ||
|
||||
!IsExpectedScalarType(rhs_load, kind) ||
|
||||
!IsExpectedScalarType(store->GetValue(), kind)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int access_size = kind == VectorKind::I32 ? 4 : 4;
|
||||
auto store_ptr = loopmem::AnalyzePointer(store->GetPtr(), iv.phi, loop, access_size);
|
||||
auto lhs_ptr = loopmem::AnalyzePointer(lhs_load->GetPtr(), iv.phi, loop, access_size);
|
||||
auto rhs_ptr = loopmem::AnalyzePointer(rhs_load->GetPtr(), iv.phi, loop, access_size);
|
||||
if (!IsUnitStrideAccess(store_ptr, iv.phi, access_size) ||
|
||||
!IsUnitStrideAccess(lhs_ptr, iv.phi, access_size) ||
|
||||
!IsUnitStrideAccess(rhs_ptr, iv.phi, access_size)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
info.loop = &loop;
|
||||
info.preheader = loop.preheader;
|
||||
info.header = loop.header;
|
||||
info.body = body;
|
||||
info.exit = exit;
|
||||
info.iv = iv;
|
||||
info.begin = iv.start;
|
||||
info.end = compare->GetRhs();
|
||||
info.store = store;
|
||||
info.lhs_load = lhs_load;
|
||||
info.rhs_load = rhs_load;
|
||||
info.binary = binary;
|
||||
info.kind = kind;
|
||||
info.op = op;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MatchFillLoop(Loop& loop, FillLoopInfo& info) {
|
||||
if (!loop.preheader || !loop.header || !loop.IsInnermost()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<PhiInst*> phis;
|
||||
loopmem::SimpleInductionVar iv;
|
||||
bool found_iv = false;
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
phis.push_back(phi);
|
||||
if (!found_iv && loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, iv)) {
|
||||
found_iv = true;
|
||||
}
|
||||
}
|
||||
if (!found_iv || phis.size() != 1 || iv.stride != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* branch = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
|
||||
auto* compare = branch ? dyncast<BinaryInst>(branch->GetCondition()) : nullptr;
|
||||
if (!branch || branch->GetThenBlock() != body || !compare ||
|
||||
compare->GetOpcode() != Opcode::ICmpLT || compare->GetLhs() != iv.phi ||
|
||||
!looputils::IsLoopInvariantValue(loop, compare->GetRhs())) {
|
||||
return false;
|
||||
}
|
||||
auto* preheader_branch = dyncast<UncondBrInst>(
|
||||
loop.preheader->GetInstructions().empty()
|
||||
? nullptr
|
||||
: loop.preheader->GetInstructions().back().get());
|
||||
if (!preheader_branch || preheader_branch->GetDest() != loop.header ||
|
||||
loop.preheader->GetSuccessors().size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* step = dyncast<BinaryInst>(iv.latch_value);
|
||||
if (!step || step->GetParent() != body || !StepOnlyFeedsLoopPhi(step, iv.phi) ||
|
||||
HasExitPhiUse(exit, loop.header)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (auto* begin_const = dyncast<ConstantInt>(iv.start)) {
|
||||
if (auto* end_const = dyncast<ConstantInt>(compare->GetRhs())) {
|
||||
if (end_const->GetValue() - begin_const->GetValue() < kMinConstantTripToVectorize) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StoreInst* store = nullptr;
|
||||
for (const auto& inst_ptr : body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator() || inst == step) {
|
||||
continue;
|
||||
}
|
||||
if (auto* gep = dyncast<GetElementPtrInst>(inst)) {
|
||||
(void)gep;
|
||||
continue;
|
||||
}
|
||||
if (auto* st = dyncast<StoreInst>(inst)) {
|
||||
if (store != nullptr) {
|
||||
return false;
|
||||
}
|
||||
store = st;
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!store || !store->GetValue() ||
|
||||
!looputils::IsLoopInvariantValue(loop, store->GetValue())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
VectorKind kind = VectorKind::I32;
|
||||
if (store->GetValue()->GetType() && store->GetValue()->GetType()->IsFloat()) {
|
||||
kind = VectorKind::F32;
|
||||
} else if (!store->GetValue()->GetType() || !store->GetValue()->GetType()->IsInt32()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int access_size = 4;
|
||||
auto store_ptr = loopmem::AnalyzePointer(store->GetPtr(), iv.phi, loop, access_size);
|
||||
if (!IsUnitStrideAccess(store_ptr, iv.phi, access_size)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
info.loop = &loop;
|
||||
info.preheader = loop.preheader;
|
||||
info.header = loop.header;
|
||||
info.body = body;
|
||||
info.exit = exit;
|
||||
info.iv = iv;
|
||||
info.begin = iv.start;
|
||||
info.end = compare->GetRhs();
|
||||
info.store = store;
|
||||
info.fill_value = store->GetValue();
|
||||
info.kind = kind;
|
||||
return true;
|
||||
}
|
||||
|
||||
const char* HelperName(VectorKind kind, VectorOp op) {
|
||||
if (kind == VectorKind::I32) {
|
||||
switch (op) {
|
||||
case VectorOp::Add:
|
||||
return "__nudtc_neon_i32_add";
|
||||
case VectorOp::Sub:
|
||||
return "__nudtc_neon_i32_sub";
|
||||
case VectorOp::Mul:
|
||||
return "__nudtc_neon_i32_mul";
|
||||
}
|
||||
}
|
||||
switch (op) {
|
||||
case VectorOp::Add:
|
||||
return "__nudtc_neon_f32_add";
|
||||
case VectorOp::Sub:
|
||||
return "__nudtc_neon_f32_sub";
|
||||
case VectorOp::Mul:
|
||||
return "__nudtc_neon_f32_mul";
|
||||
}
|
||||
return "__nudtc_neon_i32_add";
|
||||
}
|
||||
|
||||
const char* FillHelperName(VectorKind kind) {
|
||||
return kind == VectorKind::I32 ? "__nudtc_neon_i32_fill" : "__nudtc_neon_f32_fill";
|
||||
}
|
||||
|
||||
Function* EnsureHelper(Module& module, VectorKind kind, VectorOp op) {
|
||||
auto* helper = module.CreateFunction(
|
||||
HelperName(kind, op), Type::GetVoidType(),
|
||||
{Type::GetPointerType(), Type::GetPointerType(), Type::GetPointerType(),
|
||||
Type::GetInt32Type()},
|
||||
{"%dst", "%lhs", "%rhs", "%n"}, true);
|
||||
helper->SetEffectInfo(false, false, true, true, false, false, false);
|
||||
return helper;
|
||||
}
|
||||
|
||||
Function* EnsureFillHelper(Module& module, VectorKind kind) {
|
||||
auto scalar_type =
|
||||
kind == VectorKind::I32 ? Type::GetInt32Type() : Type::GetFloatType();
|
||||
auto* helper = module.CreateFunction(
|
||||
FillHelperName(kind), Type::GetVoidType(),
|
||||
{Type::GetPointerType(), scalar_type, Type::GetInt32Type()},
|
||||
{"%dst", "%value", "%n"}, true);
|
||||
helper->SetEffectInfo(false, false, true, true, false, false, false);
|
||||
return helper;
|
||||
}
|
||||
|
||||
bool VectorizeLoop(Module& module, Function& function, VectorLoopInfo& info) {
|
||||
auto* dst = MaterializePointerAtBegin(function, info.preheader, info.store->GetPtr(),
|
||||
info.iv.phi, info.begin, *info.loop);
|
||||
auto* lhs = MaterializePointerAtBegin(function, info.preheader, info.lhs_load->GetPtr(),
|
||||
info.iv.phi, info.begin, *info.loop);
|
||||
auto* rhs = MaterializePointerAtBegin(function, info.preheader, info.rhs_load->GetPtr(),
|
||||
info.iv.phi, info.begin, *info.loop);
|
||||
if (!dst || !lhs || !rhs) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* trip_count = info.preheader->Insert<BinaryInst>(
|
||||
looputils::GetTerminatorIndex(info.preheader), Opcode::Sub, Type::GetInt32Type(),
|
||||
info.end, info.begin, nullptr, looputils::NextSyntheticName(function, "vec.n."));
|
||||
|
||||
auto* vector_block =
|
||||
function.CreateBlock(looputils::NextSyntheticBlockName(function, "vec.dispatch"));
|
||||
|
||||
auto& preheader_insts = info.preheader->GetInstructions();
|
||||
if (preheader_insts.empty() || !preheader_insts.back()->IsTerminator()) {
|
||||
return false;
|
||||
}
|
||||
preheader_insts.back()->ClearAllOperands();
|
||||
auto branch = std::make_unique<UncondBrInst>(vector_block, nullptr);
|
||||
branch->SetParent(info.preheader);
|
||||
preheader_insts.back() = std::move(branch);
|
||||
info.preheader->RemoveSuccessor(info.header);
|
||||
info.header->RemovePredecessor(info.preheader);
|
||||
info.preheader->AddSuccessor(vector_block);
|
||||
vector_block->AddPredecessor(info.preheader);
|
||||
for (const auto& inst_ptr : info.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
const int incoming = looputils::GetPhiIncomingIndex(phi, info.preheader);
|
||||
if (incoming >= 0) {
|
||||
phi->RemoveOperand(static_cast<std::size_t>(2 * incoming + 1));
|
||||
phi->RemoveOperand(static_cast<std::size_t>(2 * incoming));
|
||||
}
|
||||
}
|
||||
|
||||
auto* helper = EnsureHelper(module, info.kind, info.op);
|
||||
vector_block->Append<CallInst>(helper, std::vector<Value*>{dst, lhs, rhs, trip_count},
|
||||
nullptr, "");
|
||||
vector_block->Append<UncondBrInst>(info.exit, nullptr);
|
||||
vector_block->AddSuccessor(info.exit);
|
||||
info.exit->AddPredecessor(vector_block);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool VectorizeFillLoop(Module& module, Function& function, FillLoopInfo& info) {
|
||||
auto* dst = MaterializePointerAtBegin(function, info.preheader, info.store->GetPtr(),
|
||||
info.iv.phi, info.begin, *info.loop);
|
||||
if (!dst) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* trip_count = info.preheader->Insert<BinaryInst>(
|
||||
looputils::GetTerminatorIndex(info.preheader), Opcode::Sub, Type::GetInt32Type(),
|
||||
info.end, info.begin, nullptr, looputils::NextSyntheticName(function, "vec.n."));
|
||||
|
||||
auto* vector_block =
|
||||
function.CreateBlock(looputils::NextSyntheticBlockName(function, "vec.fill"));
|
||||
|
||||
auto& preheader_insts = info.preheader->GetInstructions();
|
||||
if (preheader_insts.empty() || !preheader_insts.back()->IsTerminator()) {
|
||||
return false;
|
||||
}
|
||||
preheader_insts.back()->ClearAllOperands();
|
||||
auto branch = std::make_unique<UncondBrInst>(vector_block, nullptr);
|
||||
branch->SetParent(info.preheader);
|
||||
preheader_insts.back() = std::move(branch);
|
||||
info.preheader->RemoveSuccessor(info.header);
|
||||
info.header->RemovePredecessor(info.preheader);
|
||||
info.preheader->AddSuccessor(vector_block);
|
||||
vector_block->AddPredecessor(info.preheader);
|
||||
for (const auto& inst_ptr : info.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
const int incoming = looputils::GetPhiIncomingIndex(phi, info.preheader);
|
||||
if (incoming >= 0) {
|
||||
phi->RemoveOperand(static_cast<std::size_t>(2 * incoming + 1));
|
||||
phi->RemoveOperand(static_cast<std::size_t>(2 * incoming));
|
||||
}
|
||||
}
|
||||
|
||||
auto* helper = EnsureFillHelper(module, info.kind);
|
||||
vector_block->Append<CallInst>(helper, std::vector<Value*>{dst, info.fill_value, trip_count},
|
||||
nullptr, "");
|
||||
vector_block->Append<UncondBrInst>(info.exit, nullptr);
|
||||
vector_block->AddSuccessor(info.exit);
|
||||
info.exit->AddPredecessor(vector_block);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RunLoopVectorizeOnFunction(Module& module, Function& function) {
|
||||
if (function.IsExternal() || !function.GetEntryBlock()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
DominatorTree dom_tree(function);
|
||||
LoopInfo loop_info(function, dom_tree);
|
||||
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
|
||||
VectorLoopInfo info;
|
||||
if (MatchVectorLoop(*loop, info)) {
|
||||
return VectorizeLoop(module, function, info);
|
||||
}
|
||||
FillLoopInfo fill_info;
|
||||
if (MatchFillLoop(*loop, fill_info)) {
|
||||
return VectorizeFillLoop(module, function, fill_info);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoopVectorize(Module& module) {
|
||||
if (utils::IsEnvFlagSet("NUDTC_DISABLE_LOOP_VECTORIZE")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function && !function->IsExternal()) {
|
||||
changed |= RunLoopVectorizeOnFunction(module, *function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -1,140 +0,0 @@
|
||||
#include "mir/MIR.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace mir {
|
||||
namespace {
|
||||
|
||||
bool IsHoistCandidate(const MachineFunction& function, int object_index, int use_count) {
|
||||
const auto& object = function.GetStackObject(object_index);
|
||||
if (object.kind != StackObjectKind::Local) {
|
||||
return false;
|
||||
}
|
||||
if (use_count < 2) {
|
||||
return false;
|
||||
}
|
||||
if (object.size >= 4096) {
|
||||
return true;
|
||||
}
|
||||
return object.size >= 256 && use_count >= 4;
|
||||
}
|
||||
|
||||
bool IsPlainFrameLea(const MachineInstr& inst, int object_index) {
|
||||
if (inst.GetOpcode() != MachineInstr::Opcode::Lea || !inst.HasAddress() ||
|
||||
inst.GetOperands().empty() || inst.GetOperands()[0].GetKind() != OperandKind::VReg) {
|
||||
return false;
|
||||
}
|
||||
const auto& address = inst.GetAddress();
|
||||
return address.base_kind == AddrBaseKind::FrameObject &&
|
||||
address.base_index == object_index && address.const_offset == 0 &&
|
||||
address.scaled_vregs.empty();
|
||||
}
|
||||
|
||||
std::size_t FindEntryInsertPos(const MachineBasicBlock& block) {
|
||||
const auto& instructions = block.GetInstructions();
|
||||
std::size_t pos = 0;
|
||||
while (pos < instructions.size() &&
|
||||
instructions[pos].GetOpcode() == MachineInstr::Opcode::Arg) {
|
||||
++pos;
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void RunAddressHoisting(MachineModule& module) {
|
||||
for (auto& function : module.GetFunctions()) {
|
||||
if (!function || function->GetBlocks().empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unordered_map<int, int> use_counts;
|
||||
for (auto& block : function->GetBlocks()) {
|
||||
for (auto& inst : block->GetInstructions()) {
|
||||
if (!inst.HasAddress()) {
|
||||
continue;
|
||||
}
|
||||
const auto& address = inst.GetAddress();
|
||||
if (address.base_kind == AddrBaseKind::FrameObject && address.base_index >= 0) {
|
||||
++use_counts[address.base_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<int, int> base_vregs;
|
||||
for (const auto& [object_index, count] : use_counts) {
|
||||
if (!IsHoistCandidate(*function, object_index, count)) {
|
||||
continue;
|
||||
}
|
||||
base_vregs.emplace(object_index, -1);
|
||||
}
|
||||
if (base_vregs.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto& block : function->GetBlocks()) {
|
||||
for (auto& inst : block->GetInstructions()) {
|
||||
if (!inst.HasAddress()) {
|
||||
continue;
|
||||
}
|
||||
const auto& address = inst.GetAddress();
|
||||
auto it = base_vregs.find(address.base_index);
|
||||
if (it == base_vregs.end()) {
|
||||
continue;
|
||||
}
|
||||
if (it->second >= 0) {
|
||||
continue;
|
||||
}
|
||||
if (IsPlainFrameLea(inst, address.base_index)) {
|
||||
it->second = inst.GetOperands()[0].GetVReg();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto& entry_block = *function->GetBlocks().front();
|
||||
auto& entry_insts = entry_block.GetInstructions();
|
||||
std::size_t insert_pos = FindEntryInsertPos(entry_block);
|
||||
|
||||
for (auto& [object_index, base_vreg] : base_vregs) {
|
||||
if (base_vreg >= 0) {
|
||||
continue;
|
||||
}
|
||||
base_vreg = function->NewVReg(ValueType::Ptr);
|
||||
MachineInstr lea(MachineInstr::Opcode::Lea, {MachineOperand::VReg(base_vreg)});
|
||||
AddressExpr address;
|
||||
address.base_kind = AddrBaseKind::FrameObject;
|
||||
address.base_index = object_index;
|
||||
lea.SetAddress(std::move(address));
|
||||
entry_insts.insert(entry_insts.begin() + static_cast<std::ptrdiff_t>(insert_pos),
|
||||
std::move(lea));
|
||||
++insert_pos;
|
||||
}
|
||||
|
||||
for (auto& block : function->GetBlocks()) {
|
||||
for (auto& inst : block->GetInstructions()) {
|
||||
if (!inst.HasAddress()) {
|
||||
continue;
|
||||
}
|
||||
auto& address = inst.GetAddress();
|
||||
auto it = base_vregs.find(address.base_index);
|
||||
if (it == base_vregs.end()) {
|
||||
continue;
|
||||
}
|
||||
if (IsPlainFrameLea(inst, address.base_index) &&
|
||||
inst.GetOperands()[0].GetKind() == OperandKind::VReg &&
|
||||
inst.GetOperands()[0].GetVReg() == it->second) {
|
||||
continue;
|
||||
}
|
||||
if (address.base_kind != AddrBaseKind::FrameObject || address.base_index < 0) {
|
||||
continue;
|
||||
}
|
||||
address.base_kind = AddrBaseKind::VReg;
|
||||
address.base_index = it->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mir
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,368 @@
|
||||
#include "AsmPrinterSupport.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <ostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include "utils/Log.h"
|
||||
#include "utils/OptConfig.h"
|
||||
|
||||
namespace mir {
|
||||
|
||||
int AlignTo(int value, int align) {
|
||||
if (align <= 1) {
|
||||
return value;
|
||||
}
|
||||
return ((value + align - 1) / align) * align;
|
||||
}
|
||||
|
||||
bool IsPowerOfTwo(std::int64_t value) {
|
||||
return value > 0 && (value & (value - 1)) == 0;
|
||||
}
|
||||
|
||||
int Log2(std::int64_t value) {
|
||||
int shift = 0;
|
||||
while (value > 1) {
|
||||
value >>= 1;
|
||||
++shift;
|
||||
}
|
||||
return shift;
|
||||
}
|
||||
|
||||
int CountBits64(std::uint64_t value) {
|
||||
int count = 0;
|
||||
while (value != 0) {
|
||||
value &= value - 1;
|
||||
++count;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
std::vector<int> SetBitPositions(std::uint64_t value) {
|
||||
std::vector<int> positions;
|
||||
for (int bit = 0; bit < 63; ++bit) {
|
||||
if ((value & (1ull << bit)) != 0) {
|
||||
positions.push_back(bit);
|
||||
}
|
||||
}
|
||||
return positions;
|
||||
}
|
||||
|
||||
SignedDivMagic ComputeSignedDivMagic(std::int64_t divisor) {
|
||||
const std::uint64_t two31 = 1ull << 31;
|
||||
const std::uint64_t abs_divisor =
|
||||
divisor < 0 ? static_cast<std::uint64_t>(-divisor)
|
||||
: static_cast<std::uint64_t>(divisor);
|
||||
const std::uint64_t divisor_bits = static_cast<std::uint32_t>(divisor);
|
||||
const std::uint64_t t = two31 + (divisor_bits >> 31);
|
||||
const std::uint64_t anc = t - 1 - (t % abs_divisor);
|
||||
|
||||
int p = 31;
|
||||
std::uint64_t q1 = two31 / anc;
|
||||
std::uint64_t r1 = two31 - q1 * anc;
|
||||
std::uint64_t q2 = two31 / abs_divisor;
|
||||
std::uint64_t r2 = two31 - q2 * abs_divisor;
|
||||
|
||||
while (true) {
|
||||
++p;
|
||||
q1 <<= 1;
|
||||
r1 <<= 1;
|
||||
if (r1 >= anc) {
|
||||
++q1;
|
||||
r1 -= anc;
|
||||
}
|
||||
q2 <<= 1;
|
||||
r2 <<= 1;
|
||||
if (r2 >= abs_divisor) {
|
||||
++q2;
|
||||
r2 -= abs_divisor;
|
||||
}
|
||||
const std::uint64_t delta = abs_divisor - r2;
|
||||
if (q1 > delta || (q1 == delta && r1 != 0)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::int64_t multiplier = static_cast<std::int64_t>(q2 + 1);
|
||||
if (divisor < 0) {
|
||||
multiplier = -multiplier;
|
||||
}
|
||||
multiplier = static_cast<std::int32_t>(static_cast<std::uint32_t>(multiplier));
|
||||
return {multiplier, p - 32};
|
||||
}
|
||||
|
||||
std::uint64_t ComputeU64ModuloMagic(std::int64_t divisor) {
|
||||
const auto abs_divisor = static_cast<std::uint64_t>(divisor);
|
||||
return ~std::uint64_t{0} / abs_divisor;
|
||||
}
|
||||
|
||||
const char* GetDRegName(int index) {
|
||||
static const char* kNames[] = {
|
||||
"d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7",
|
||||
"d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15",
|
||||
"d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23",
|
||||
"d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31"};
|
||||
if (index < 0 || index >= 32) {
|
||||
throw std::runtime_error("float register index out of range");
|
||||
}
|
||||
return kNames[index];
|
||||
}
|
||||
|
||||
int ToAsmAlign(int align) {
|
||||
int value = 0;
|
||||
int current = 1;
|
||||
while (current < align) {
|
||||
current <<= 1;
|
||||
++value;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
std::uint32_t FloatBits(float value) {
|
||||
std::uint32_t bits = 0;
|
||||
std::memcpy(&bits, &value, sizeof(bits));
|
||||
return bits;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
bool Is32BitRegName(const char* reg) {
|
||||
return reg != nullptr && reg[0] == 'w';
|
||||
}
|
||||
|
||||
bool IsAddSubImm12(std::int64_t value) {
|
||||
return value >= 0 && value <= 4095;
|
||||
}
|
||||
|
||||
bool IsAddSubImm12Shifted(std::int64_t value) {
|
||||
return value >= 0 && value <= (4095ll << 12) && (value & 0xfffll) == 0;
|
||||
}
|
||||
|
||||
bool IsShiftedContiguousMask32(std::uint32_t value) {
|
||||
if (value == 0 || value == 0xffffffffu) {
|
||||
return false;
|
||||
}
|
||||
for (int start = 0; start < 32; ++start) {
|
||||
std::uint32_t mask = 0;
|
||||
for (int bit = start; bit < 32; ++bit) {
|
||||
mask |= (1u << bit);
|
||||
if (mask == value) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool IsAddSubImm(std::int64_t value) {
|
||||
return IsAddSubImm12(value) || IsAddSubImm12Shifted(value);
|
||||
}
|
||||
|
||||
bool IsLogicalImm32(std::int64_t value) {
|
||||
if (value < 0 || value > 0xffffffffll) {
|
||||
return false;
|
||||
}
|
||||
return IsShiftedContiguousMask32(static_cast<std::uint32_t>(value));
|
||||
}
|
||||
|
||||
bool AsmImmLoweringEnabled() {
|
||||
return utils::IsEnabledUnlessEnvFlag("NUDTC_DISABLE_ASM_IMM_LOWERING");
|
||||
}
|
||||
|
||||
void EmitAddSubImm(std::ostream& os, const char* opcode, const char* dst,
|
||||
const char* src, std::int64_t value) {
|
||||
if (!IsAddSubImm(value)) {
|
||||
throw std::runtime_error(FormatError("mir", "invalid add/sub immediate"));
|
||||
}
|
||||
os << " " << opcode << " " << dst << ", " << src << ", #";
|
||||
if (IsAddSubImm12(value)) {
|
||||
os << value << "\n";
|
||||
return;
|
||||
}
|
||||
os << (value >> 12) << ", lsl #12\n";
|
||||
}
|
||||
|
||||
void EmitAdjustRegByImm(std::ostream& os, const char* dst, const char* src,
|
||||
std::int64_t value) {
|
||||
if (value == 0) {
|
||||
if (std::string(dst) != src) {
|
||||
os << " mov " << dst << ", " << src << "\n";
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const char* opcode = value >= 0 ? "add" : "sub";
|
||||
std::uint64_t remaining = value >= 0 ? static_cast<std::uint64_t>(value)
|
||||
: static_cast<std::uint64_t>(-value);
|
||||
bool first = true;
|
||||
auto emit_chunk = [&](std::uint64_t amount, bool shifted) {
|
||||
const char* current_src = first ? src : dst;
|
||||
os << " " << opcode << " " << dst << ", " << current_src << ", #" << amount;
|
||||
if (shifted) {
|
||||
os << ", lsl #12";
|
||||
}
|
||||
os << "\n";
|
||||
first = false;
|
||||
};
|
||||
|
||||
while (remaining >= 4096) {
|
||||
const std::uint64_t units = std::min<std::uint64_t>(remaining >> 12, 4095);
|
||||
emit_chunk(units, true);
|
||||
remaining -= units << 12;
|
||||
}
|
||||
if (remaining > 0) {
|
||||
emit_chunk(remaining, false);
|
||||
}
|
||||
}
|
||||
|
||||
void EmitMoveImm(std::ostream& os, const char* reg, std::int64_t value) {
|
||||
if (reg == nullptr || reg[0] == '\0') {
|
||||
throw std::runtime_error(FormatError("mir", "invalid register for immediate materialization"));
|
||||
}
|
||||
|
||||
const bool is32 = Is32BitRegName(reg);
|
||||
if (value == 0) {
|
||||
os << " mov " << reg << ", #0\n";
|
||||
return;
|
||||
}
|
||||
|
||||
if (is32) {
|
||||
const std::uint32_t bits = static_cast<std::uint32_t>(value);
|
||||
bool emitted = false;
|
||||
for (int shift = 0; shift <= 16; shift += 16) {
|
||||
const std::uint32_t chunk = (bits >> shift) & 0xffffu;
|
||||
if (chunk == 0 && emitted) {
|
||||
continue;
|
||||
}
|
||||
if (!emitted) {
|
||||
os << " movz " << reg << ", #" << chunk;
|
||||
if (shift != 0) {
|
||||
os << ", lsl #" << shift;
|
||||
}
|
||||
os << "\n";
|
||||
emitted = true;
|
||||
} else if (chunk != 0) {
|
||||
os << " movk " << reg << ", #" << chunk;
|
||||
if (shift != 0) {
|
||||
os << ", lsl #" << shift;
|
||||
}
|
||||
os << "\n";
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const std::uint64_t bits = static_cast<std::uint64_t>(value);
|
||||
bool emitted = false;
|
||||
for (int shift = 0; shift <= 48; shift += 16) {
|
||||
const std::uint64_t chunk = (bits >> shift) & 0xffffull;
|
||||
if (chunk == 0 && emitted) {
|
||||
continue;
|
||||
}
|
||||
if (!emitted) {
|
||||
os << " movz " << reg << ", #" << chunk;
|
||||
if (shift != 0) {
|
||||
os << ", lsl #" << shift;
|
||||
}
|
||||
os << "\n";
|
||||
emitted = true;
|
||||
} else if (chunk != 0) {
|
||||
os << " movk " << reg << ", #" << chunk;
|
||||
if (shift != 0) {
|
||||
os << ", lsl #" << shift;
|
||||
}
|
||||
os << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void EmitCopy(std::ostream& os, const char* dst, const char* src, bool is_float) {
|
||||
if (std::string(dst) == src) {
|
||||
return;
|
||||
}
|
||||
os << " " << (is_float ? "fmov" : "mov") << " " << dst << ", " << src << "\n";
|
||||
}
|
||||
|
||||
int GetAddressShift(ValueType type) {
|
||||
switch (GetValueSize(type)) {
|
||||
case 4:
|
||||
return 2;
|
||||
case 8:
|
||||
return 3;
|
||||
case 16:
|
||||
return 4;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool TryEmitBaseOffsetAccess(ValueType type, const char* value_reg, const char* base_reg,
|
||||
std::int64_t offset, bool is_store, std::ostream& os) {
|
||||
const int size = GetValueSize(type);
|
||||
const char* mnemonic = is_store ? "str" : "ldr";
|
||||
if (offset == 0) {
|
||||
os << " " << mnemonic << " " << value_reg << ", [" << base_reg << "]\n";
|
||||
return true;
|
||||
}
|
||||
if (offset >= 0 && size > 0 && offset % size == 0 && offset / size <= 4095) {
|
||||
os << " " << mnemonic << " " << value_reg << ", [" << base_reg << ", #" << offset
|
||||
<< "]\n";
|
||||
return true;
|
||||
}
|
||||
if (offset >= -256 && offset <= 255) {
|
||||
os << " " << (is_store ? "stur" : "ldur") << " " << value_reg << ", [" << base_reg
|
||||
<< ", #" << offset << "]\n";
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void EmitLoadFromAddr(ValueType type, const char* dst, const char* addr_reg,
|
||||
std::ostream& os) {
|
||||
switch (type) {
|
||||
case ValueType::I1:
|
||||
case ValueType::I32:
|
||||
os << " ldr " << dst << ", [" << addr_reg << "]\n";
|
||||
break;
|
||||
case ValueType::F32:
|
||||
os << " ldr " << dst << ", [" << addr_reg << "]\n";
|
||||
break;
|
||||
case ValueType::Ptr:
|
||||
os << " ldr " << dst << ", [" << addr_reg << "]\n";
|
||||
break;
|
||||
case ValueType::I32x4:
|
||||
case ValueType::F32x4:
|
||||
os << " ldr " << dst << ", [" << addr_reg << "]\n";
|
||||
break;
|
||||
case ValueType::Void:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void EmitStoreToAddr(ValueType type, const char* src, const char* addr_reg,
|
||||
std::ostream& os) {
|
||||
switch (type) {
|
||||
case ValueType::I1:
|
||||
case ValueType::I32:
|
||||
os << " str " << src << ", [" << addr_reg << "]\n";
|
||||
break;
|
||||
case ValueType::F32:
|
||||
os << " str " << src << ", [" << addr_reg << "]\n";
|
||||
break;
|
||||
case ValueType::Ptr:
|
||||
os << " str " << src << ", [" << addr_reg << "]\n";
|
||||
break;
|
||||
case ValueType::I32x4:
|
||||
case ValueType::F32x4:
|
||||
os << " str " << src << ", [" << addr_reg << "]\n";
|
||||
break;
|
||||
case ValueType::Void:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mir
|
||||
@ -0,0 +1,45 @@
|
||||
#pragma once
|
||||
|
||||
#include "mir/MIR.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <iosfwd>
|
||||
#include <vector>
|
||||
|
||||
namespace mir {
|
||||
|
||||
struct SignedDivMagic {
|
||||
std::int64_t multiplier = 0;
|
||||
int shift = 0;
|
||||
};
|
||||
|
||||
int AlignTo(int value, int align);
|
||||
bool IsPowerOfTwo(std::int64_t value);
|
||||
int Log2(std::int64_t value);
|
||||
int CountBits64(std::uint64_t value);
|
||||
std::vector<int> SetBitPositions(std::uint64_t value);
|
||||
SignedDivMagic ComputeSignedDivMagic(std::int64_t divisor);
|
||||
std::uint64_t ComputeU64ModuloMagic(std::int64_t divisor);
|
||||
const char* GetDRegName(int index);
|
||||
int ToAsmAlign(int align);
|
||||
std::uint32_t FloatBits(float value);
|
||||
|
||||
bool IsAddSubImm(std::int64_t value);
|
||||
bool IsLogicalImm32(std::int64_t value);
|
||||
bool AsmImmLoweringEnabled();
|
||||
void EmitAddSubImm(std::ostream& os, const char* opcode, const char* dst,
|
||||
const char* src, std::int64_t value);
|
||||
void EmitAdjustRegByImm(std::ostream& os, const char* dst, const char* src,
|
||||
std::int64_t value);
|
||||
void EmitMoveImm(std::ostream& os, const char* reg, std::int64_t value);
|
||||
void EmitCopy(std::ostream& os, const char* dst, const char* src, bool is_float);
|
||||
int GetAddressShift(ValueType type);
|
||||
bool TryEmitBaseOffsetAccess(ValueType type, const char* value_reg,
|
||||
const char* base_reg, std::int64_t offset,
|
||||
bool is_store, std::ostream& os);
|
||||
void EmitLoadFromAddr(ValueType type, const char* dst, const char* addr_reg,
|
||||
std::ostream& os);
|
||||
void EmitStoreToAddr(ValueType type, const char* src, const char* addr_reg,
|
||||
std::ostream& os);
|
||||
|
||||
} // namespace mir
|
||||
@ -0,0 +1,464 @@
|
||||
#include "mir/Passes.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/OptConfig.h"
|
||||
|
||||
namespace mir {
|
||||
namespace {
|
||||
|
||||
bool IsHoistCandidate(const MachineFunction& function, int object_index, int use_count) {
|
||||
const auto& object = function.GetStackObject(object_index);
|
||||
if (object.kind != StackObjectKind::Local) {
|
||||
return false;
|
||||
}
|
||||
if (use_count < 2) {
|
||||
return false;
|
||||
}
|
||||
if (object.size >= 4096) {
|
||||
return true;
|
||||
}
|
||||
return object.size >= 256 && use_count >= 4;
|
||||
}
|
||||
|
||||
bool IsPlainFrameLea(const MachineInstr& inst, int object_index) {
|
||||
if (inst.GetOpcode() != MachineInstr::Opcode::Lea || !inst.HasAddress() ||
|
||||
inst.GetOperands().empty() || inst.GetOperands()[0].GetKind() != OperandKind::VReg) {
|
||||
return false;
|
||||
}
|
||||
const auto& address = inst.GetAddress();
|
||||
return address.base_kind == AddrBaseKind::FrameObject &&
|
||||
address.base_index == object_index && address.const_offset == 0 &&
|
||||
address.scaled_vregs.empty();
|
||||
}
|
||||
|
||||
bool IsPlainGlobalLea(const MachineInstr& inst, const std::string& symbol) {
|
||||
if (inst.GetOpcode() != MachineInstr::Opcode::Lea || !inst.HasAddress() ||
|
||||
inst.GetOperands().empty() || inst.GetOperands()[0].GetKind() != OperandKind::VReg) {
|
||||
return false;
|
||||
}
|
||||
const auto& address = inst.GetAddress();
|
||||
return address.base_kind == AddrBaseKind::Global && address.symbol == symbol &&
|
||||
address.const_offset == 0 && address.scaled_vregs.empty();
|
||||
}
|
||||
|
||||
bool HasCallClobberingInstruction(const MachineBasicBlock& block) {
|
||||
for (const auto& inst : block.GetInstructions()) {
|
||||
if (inst.GetOpcode() == MachineInstr::Opcode::Call ||
|
||||
inst.GetOpcode() == MachineInstr::Opcode::Memset) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool HasCallClobberingInstruction(const MachineFunction& function) {
|
||||
for (const auto& block : function.GetBlocks()) {
|
||||
if (HasCallClobberingInstruction(*block)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::size_t FindEntryInsertPos(const MachineBasicBlock& block) {
|
||||
const auto& instructions = block.GetInstructions();
|
||||
std::size_t pos = 0;
|
||||
while (pos < instructions.size() &&
|
||||
instructions[pos].GetOpcode() == MachineInstr::Opcode::Arg) {
|
||||
++pos;
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
struct AddressStemKey {
|
||||
AddrBaseKind base_kind = AddrBaseKind::None;
|
||||
int base_index = -1;
|
||||
std::string symbol;
|
||||
std::vector<std::pair<int, std::int64_t>> scaled_vregs;
|
||||
|
||||
bool operator==(const AddressStemKey& rhs) const {
|
||||
return base_kind == rhs.base_kind && base_index == rhs.base_index &&
|
||||
symbol == rhs.symbol && scaled_vregs == rhs.scaled_vregs;
|
||||
}
|
||||
};
|
||||
|
||||
struct AddressStemKeyHash {
|
||||
std::size_t operator()(const AddressStemKey& key) const {
|
||||
std::size_t h = static_cast<std::size_t>(key.base_kind);
|
||||
h ^= std::hash<int>{}(key.base_index) + 0x9e3779b9 + (h << 6) + (h >> 2);
|
||||
h ^= std::hash<std::string>{}(key.symbol) + 0x9e3779b9 + (h << 6) + (h >> 2);
|
||||
for (const auto& term : key.scaled_vregs) {
|
||||
h ^= std::hash<int>{}(term.first) + 0x9e3779b9 + (h << 6) + (h >> 2);
|
||||
h ^= std::hash<std::int64_t>{}(term.second) + 0x9e3779b9 + (h << 6) + (h >> 2);
|
||||
}
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
struct AddressStemInfo {
|
||||
int count = 0;
|
||||
std::size_t first_pos = 0;
|
||||
std::size_t existing_pos = 0;
|
||||
int base_vreg = -1;
|
||||
bool has_existing_base = false;
|
||||
};
|
||||
|
||||
bool ShouldHoistAddressStem(const AddressExpr& address) {
|
||||
if (address.base_kind == AddrBaseKind::None || address.scaled_vregs.empty()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
AddressStemKey MakeAddressStemKey(const AddressExpr& address) {
|
||||
AddressStemKey key;
|
||||
key.base_kind = address.base_kind;
|
||||
key.base_index = address.base_index;
|
||||
key.symbol = address.symbol;
|
||||
key.scaled_vregs = address.scaled_vregs;
|
||||
return key;
|
||||
}
|
||||
|
||||
AddressExpr MakeStemAddress(const AddressStemKey& key) {
|
||||
AddressExpr address;
|
||||
address.base_kind = key.base_kind;
|
||||
address.base_index = key.base_index;
|
||||
address.symbol = key.symbol;
|
||||
address.scaled_vregs = key.scaled_vregs;
|
||||
return address;
|
||||
}
|
||||
|
||||
bool IsExistingStemLea(const MachineInstr& inst, const AddressStemKey& key) {
|
||||
if (inst.GetOpcode() != MachineInstr::Opcode::Lea || !inst.HasAddress() ||
|
||||
inst.GetOperands().empty() || inst.GetOperands()[0].GetKind() != OperandKind::VReg) {
|
||||
return false;
|
||||
}
|
||||
const auto& address = inst.GetAddress();
|
||||
return address.const_offset == 0 && MakeAddressStemKey(address) == key;
|
||||
}
|
||||
|
||||
bool RunScaledAddressStemHoisting(MachineFunction& function, MachineBasicBlock& block) {
|
||||
if (HasCallClobberingInstruction(block)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto& instructions = block.GetInstructions();
|
||||
std::unordered_map<AddressStemKey, AddressStemInfo, AddressStemKeyHash> stems;
|
||||
for (std::size_t i = 0; i < instructions.size(); ++i) {
|
||||
const auto& inst = instructions[i];
|
||||
if (!inst.HasAddress() || !ShouldHoistAddressStem(inst.GetAddress())) {
|
||||
continue;
|
||||
}
|
||||
const auto key = MakeAddressStemKey(inst.GetAddress());
|
||||
auto& info = stems[key];
|
||||
if (info.count == 0) {
|
||||
info.first_pos = i;
|
||||
}
|
||||
++info.count;
|
||||
if (IsExistingStemLea(inst, key)) {
|
||||
info.base_vreg = inst.GetOperands()[0].GetVReg();
|
||||
info.existing_pos = i;
|
||||
info.has_existing_base = true;
|
||||
}
|
||||
}
|
||||
|
||||
struct SelectedStem {
|
||||
AddressStemKey key;
|
||||
std::size_t insert_pos = 0;
|
||||
int base_vreg = -1;
|
||||
bool needs_insert = true;
|
||||
};
|
||||
std::vector<SelectedStem> selected;
|
||||
for (auto& [key, info] : stems) {
|
||||
if (info.count < 3) {
|
||||
continue;
|
||||
}
|
||||
const bool reuse_existing = info.has_existing_base && info.existing_pos == info.first_pos;
|
||||
if (!reuse_existing) {
|
||||
info.base_vreg = function.NewVReg(ValueType::Ptr);
|
||||
}
|
||||
selected.push_back({key, info.first_pos, info.base_vreg, !reuse_existing});
|
||||
}
|
||||
if (selected.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::sort(selected.begin(), selected.end(), [](const SelectedStem& lhs,
|
||||
const SelectedStem& rhs) {
|
||||
return lhs.insert_pos < rhs.insert_pos;
|
||||
});
|
||||
|
||||
auto find_selected = [&](const AddressExpr& address) -> const SelectedStem* {
|
||||
if (!ShouldHoistAddressStem(address)) {
|
||||
return nullptr;
|
||||
}
|
||||
const auto key = MakeAddressStemKey(address);
|
||||
for (const auto& stem : selected) {
|
||||
if (stem.key == key) {
|
||||
return &stem;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
for (auto& inst : instructions) {
|
||||
if (!inst.HasAddress()) {
|
||||
continue;
|
||||
}
|
||||
auto* stem = find_selected(inst.GetAddress());
|
||||
if (stem == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (IsExistingStemLea(inst, stem->key) &&
|
||||
inst.GetOperands()[0].GetKind() == OperandKind::VReg &&
|
||||
inst.GetOperands()[0].GetVReg() == stem->base_vreg) {
|
||||
continue;
|
||||
}
|
||||
const std::int64_t old_offset = inst.GetAddress().const_offset;
|
||||
auto& address = inst.GetAddress();
|
||||
address.base_kind = AddrBaseKind::VReg;
|
||||
address.base_index = stem->base_vreg;
|
||||
address.symbol.clear();
|
||||
address.const_offset = old_offset;
|
||||
address.scaled_vregs.clear();
|
||||
}
|
||||
|
||||
std::vector<MachineInstr> rewritten;
|
||||
rewritten.reserve(instructions.size() + selected.size());
|
||||
std::size_t next_stem = 0;
|
||||
for (std::size_t i = 0; i < instructions.size(); ++i) {
|
||||
while (next_stem < selected.size() && selected[next_stem].insert_pos == i) {
|
||||
const auto& stem = selected[next_stem];
|
||||
if (stem.needs_insert) {
|
||||
MachineInstr lea(MachineInstr::Opcode::Lea,
|
||||
{MachineOperand::VReg(stem.base_vreg)});
|
||||
lea.SetAddress(MakeStemAddress(stem.key));
|
||||
rewritten.push_back(std::move(lea));
|
||||
}
|
||||
++next_stem;
|
||||
}
|
||||
rewritten.push_back(std::move(instructions[i]));
|
||||
}
|
||||
while (next_stem < selected.size()) {
|
||||
const auto& stem = selected[next_stem];
|
||||
if (stem.needs_insert) {
|
||||
MachineInstr lea(MachineInstr::Opcode::Lea,
|
||||
{MachineOperand::VReg(stem.base_vreg)});
|
||||
lea.SetAddress(MakeStemAddress(stem.key));
|
||||
rewritten.push_back(std::move(lea));
|
||||
}
|
||||
++next_stem;
|
||||
}
|
||||
instructions = std::move(rewritten);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void RunAddressHoisting(MachineModule& module) {
|
||||
for (auto& function : module.GetFunctions()) {
|
||||
if (!function || function->GetBlocks().empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const bool run_global_addr_hoist =
|
||||
utils::IsEnabledUnlessEnvFlag("NUDTC_DISABLE_GLOBAL_ADDR_HOIST");
|
||||
const bool has_call_clobber = HasCallClobberingInstruction(*function);
|
||||
std::unordered_map<int, int> use_counts;
|
||||
std::unordered_map<std::string, int> global_use_counts;
|
||||
for (auto& block : function->GetBlocks()) {
|
||||
for (auto& inst : block->GetInstructions()) {
|
||||
if (!inst.HasAddress()) {
|
||||
continue;
|
||||
}
|
||||
const auto& address = inst.GetAddress();
|
||||
if (address.base_kind == AddrBaseKind::FrameObject && address.base_index >= 0) {
|
||||
++use_counts[address.base_index];
|
||||
} else if (run_global_addr_hoist && !has_call_clobber &&
|
||||
address.base_kind == AddrBaseKind::Global && !address.symbol.empty()) {
|
||||
++global_use_counts[address.symbol];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<int, int> base_vregs;
|
||||
for (const auto& [object_index, count] : use_counts) {
|
||||
if (!IsHoistCandidate(*function, object_index, count)) {
|
||||
continue;
|
||||
}
|
||||
base_vregs.emplace(object_index, -1);
|
||||
}
|
||||
std::unordered_map<std::string, int> global_base_vregs;
|
||||
for (const auto& [symbol, count] : global_use_counts) {
|
||||
if (count >= 2) {
|
||||
global_base_vregs.emplace(symbol, -1);
|
||||
}
|
||||
}
|
||||
|
||||
if (base_vregs.empty() && global_base_vregs.empty() && !run_global_addr_hoist) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Existing LEA instructions may be inside non-entry blocks and do not
|
||||
// necessarily dominate all rewritten users. Always create a fresh base in
|
||||
// the entry block for function-wide hoisting.
|
||||
|
||||
auto& entry_block = *function->GetBlocks().front();
|
||||
auto& entry_insts = entry_block.GetInstructions();
|
||||
std::size_t insert_pos = FindEntryInsertPos(entry_block);
|
||||
|
||||
for (auto& [object_index, base_vreg] : base_vregs) {
|
||||
if (base_vreg >= 0) {
|
||||
continue;
|
||||
}
|
||||
base_vreg = function->NewVReg(ValueType::Ptr);
|
||||
MachineInstr lea(MachineInstr::Opcode::Lea, {MachineOperand::VReg(base_vreg)});
|
||||
AddressExpr address;
|
||||
address.base_kind = AddrBaseKind::FrameObject;
|
||||
address.base_index = object_index;
|
||||
lea.SetAddress(std::move(address));
|
||||
entry_insts.insert(entry_insts.begin() + static_cast<std::ptrdiff_t>(insert_pos),
|
||||
std::move(lea));
|
||||
++insert_pos;
|
||||
}
|
||||
for (auto& [symbol, base_vreg] : global_base_vregs) {
|
||||
if (base_vreg >= 0) {
|
||||
continue;
|
||||
}
|
||||
base_vreg = function->NewVReg(ValueType::Ptr);
|
||||
MachineInstr lea(MachineInstr::Opcode::Lea, {MachineOperand::VReg(base_vreg)});
|
||||
AddressExpr address;
|
||||
address.base_kind = AddrBaseKind::Global;
|
||||
address.symbol = symbol;
|
||||
lea.SetAddress(std::move(address));
|
||||
entry_insts.insert(entry_insts.begin() + static_cast<std::ptrdiff_t>(insert_pos),
|
||||
std::move(lea));
|
||||
++insert_pos;
|
||||
}
|
||||
|
||||
for (auto& block : function->GetBlocks()) {
|
||||
for (auto& inst : block->GetInstructions()) {
|
||||
if (!inst.HasAddress()) {
|
||||
continue;
|
||||
}
|
||||
auto& address = inst.GetAddress();
|
||||
auto it = base_vregs.find(address.base_index);
|
||||
if (it == base_vregs.end()) {
|
||||
continue;
|
||||
}
|
||||
if (IsPlainFrameLea(inst, address.base_index) &&
|
||||
inst.GetOperands()[0].GetKind() == OperandKind::VReg &&
|
||||
inst.GetOperands()[0].GetVReg() == it->second) {
|
||||
continue;
|
||||
}
|
||||
if (address.base_kind != AddrBaseKind::FrameObject || address.base_index < 0) {
|
||||
continue;
|
||||
}
|
||||
address.base_kind = AddrBaseKind::VReg;
|
||||
address.base_index = it->second;
|
||||
}
|
||||
}
|
||||
for (auto& block : function->GetBlocks()) {
|
||||
for (auto& inst : block->GetInstructions()) {
|
||||
if (!inst.HasAddress()) {
|
||||
continue;
|
||||
}
|
||||
auto& address = inst.GetAddress();
|
||||
if (address.base_kind != AddrBaseKind::Global || address.symbol.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto it = global_base_vregs.find(address.symbol);
|
||||
if (it == global_base_vregs.end()) {
|
||||
continue;
|
||||
}
|
||||
if (IsPlainGlobalLea(inst, address.symbol) &&
|
||||
inst.GetOperands()[0].GetKind() == OperandKind::VReg &&
|
||||
inst.GetOperands()[0].GetVReg() == it->second) {
|
||||
continue;
|
||||
}
|
||||
address.base_kind = AddrBaseKind::VReg;
|
||||
address.base_index = it->second;
|
||||
}
|
||||
}
|
||||
|
||||
if (!run_global_addr_hoist) {
|
||||
continue;
|
||||
}
|
||||
for (auto& block : function->GetBlocks()) {
|
||||
if (HasCallClobberingInstruction(*block)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, int> block_global_counts;
|
||||
for (const auto& inst : block->GetInstructions()) {
|
||||
if (!inst.HasAddress()) {
|
||||
continue;
|
||||
}
|
||||
const auto& address = inst.GetAddress();
|
||||
if (address.base_kind == AddrBaseKind::Global && !address.symbol.empty()) {
|
||||
++block_global_counts[address.symbol];
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, int> block_base_vregs;
|
||||
for (const auto& [symbol, count] : block_global_counts) {
|
||||
if (count >= 3) {
|
||||
block_base_vregs.emplace(symbol, -1);
|
||||
}
|
||||
}
|
||||
if (block_base_vregs.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& instructions = block->GetInstructions();
|
||||
std::size_t insert_pos = FindEntryInsertPos(*block);
|
||||
for (auto& [symbol, base_vreg] : block_base_vregs) {
|
||||
if (base_vreg >= 0) {
|
||||
continue;
|
||||
}
|
||||
base_vreg = function->NewVReg(ValueType::Ptr);
|
||||
MachineInstr lea(MachineInstr::Opcode::Lea, {MachineOperand::VReg(base_vreg)});
|
||||
AddressExpr address;
|
||||
address.base_kind = AddrBaseKind::Global;
|
||||
address.symbol = symbol;
|
||||
lea.SetAddress(std::move(address));
|
||||
instructions.insert(instructions.begin() + static_cast<std::ptrdiff_t>(insert_pos),
|
||||
std::move(lea));
|
||||
++insert_pos;
|
||||
}
|
||||
|
||||
for (auto& inst : block->GetInstructions()) {
|
||||
if (!inst.HasAddress()) {
|
||||
continue;
|
||||
}
|
||||
auto& address = inst.GetAddress();
|
||||
if (address.base_kind != AddrBaseKind::Global || address.symbol.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto it = block_base_vregs.find(address.symbol);
|
||||
if (it == block_base_vregs.end()) {
|
||||
continue;
|
||||
}
|
||||
if (IsPlainGlobalLea(inst, address.symbol) &&
|
||||
inst.GetOperands()[0].GetKind() == OperandKind::VReg &&
|
||||
inst.GetOperands()[0].GetVReg() == it->second) {
|
||||
continue;
|
||||
}
|
||||
address.base_kind = AddrBaseKind::VReg;
|
||||
address.base_index = it->second;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& block : function->GetBlocks()) {
|
||||
RunScaledAddressStemHoisting(*function, *block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mir
|
||||
@ -0,0 +1,117 @@
|
||||
#include "mir/Passes.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace mir {
|
||||
namespace {
|
||||
|
||||
[[noreturn]] void Fail(const MachineFunction* function,
|
||||
const MachineBasicBlock* block,
|
||||
const std::string& message) {
|
||||
std::string where = "[mir-verify]";
|
||||
if (function != nullptr) {
|
||||
where += " function " + function->GetName();
|
||||
}
|
||||
if (block != nullptr) {
|
||||
where += " block " + block->GetName();
|
||||
}
|
||||
throw std::runtime_error(where + ": " + message);
|
||||
}
|
||||
|
||||
void CheckVReg(const MachineFunction& function, const MachineBasicBlock& block,
|
||||
int vreg) {
|
||||
if (vreg < 0 || vreg >= static_cast<int>(function.GetVRegs().size())) {
|
||||
Fail(&function, &block, "instruction references invalid virtual register");
|
||||
}
|
||||
}
|
||||
|
||||
void CheckAddress(const MachineFunction& function, const MachineBasicBlock& block,
|
||||
const AddressExpr& address) {
|
||||
if (address.base_kind == AddrBaseKind::FrameObject) {
|
||||
if (address.base_index < 0 ||
|
||||
address.base_index >= static_cast<int>(function.GetStackObjects().size())) {
|
||||
Fail(&function, &block, "address references invalid stack object");
|
||||
}
|
||||
} else if (address.base_kind == AddrBaseKind::VReg) {
|
||||
CheckVReg(function, block, address.base_index);
|
||||
}
|
||||
for (const auto& term : address.scaled_vregs) {
|
||||
CheckVReg(function, block, term.first);
|
||||
}
|
||||
}
|
||||
|
||||
void CheckBlockTargets(const MachineFunction& function,
|
||||
const MachineBasicBlock& block,
|
||||
const std::unordered_set<std::string>& block_names,
|
||||
const MachineInstr& instr) {
|
||||
for (const auto& operand : instr.GetOperands()) {
|
||||
if (operand.GetKind() == OperandKind::Block &&
|
||||
block_names.count(operand.GetText()) == 0) {
|
||||
Fail(&function, &block, "branch references unknown block");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckFunction(const MachineFunction& function) {
|
||||
std::unordered_set<std::string> block_names;
|
||||
for (const auto& block : function.GetBlocks()) {
|
||||
if (!block) {
|
||||
Fail(&function, nullptr, "null block");
|
||||
}
|
||||
if (!block_names.insert(block->GetName()).second) {
|
||||
Fail(&function, block.get(), "duplicate block name");
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& object : function.GetStackObjects()) {
|
||||
if (object.index < 0 || object.size < 0 || object.align <= 0) {
|
||||
Fail(&function, nullptr, "invalid stack object");
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
const auto& block = *block_ptr;
|
||||
const auto& instructions = block.GetInstructions();
|
||||
for (std::size_t i = 0; i < instructions.size(); ++i) {
|
||||
const auto& instr = instructions[i];
|
||||
if (instr.IsTerminator() && i + 1 != instructions.size()) {
|
||||
Fail(&function, &block, "terminator is not the last instruction");
|
||||
}
|
||||
for (int def : instr.GetDefs()) {
|
||||
CheckVReg(function, block, def);
|
||||
}
|
||||
for (int use : instr.GetUses()) {
|
||||
CheckVReg(function, block, use);
|
||||
}
|
||||
if (instr.HasAddress()) {
|
||||
CheckAddress(function, block, instr.GetAddress());
|
||||
}
|
||||
CheckBlockTargets(function, block, block_names, instr);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& vreg : function.GetVRegs()) {
|
||||
const auto& allocation = function.GetAllocation(vreg.id);
|
||||
if (allocation.kind == Allocation::Kind::Spill &&
|
||||
(allocation.stack_object < 0 ||
|
||||
allocation.stack_object >=
|
||||
static_cast<int>(function.GetStackObjects().size()))) {
|
||||
Fail(&function, nullptr, "spill allocation references invalid stack object");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void VerifyMIR(const MachineModule& module) {
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
CheckFunction(*function);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mir
|
||||
Loading…
Reference in new issue