Compare commits
24 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
4d9c159dd2 | 1 week ago |
|
|
e55421f447 | 1 week ago |
|
|
69892ef133 | 2 weeks ago |
|
|
407be0fca1 | 2 weeks ago |
|
|
08ce9d96ab | 2 weeks ago |
|
|
bcfbf52488 | 2 weeks ago |
|
|
4cb9354ab4 | 2 weeks ago |
|
|
b33ede5457 | 3 weeks ago |
|
|
252073efe8 | 3 weeks ago |
|
|
c252a676ac | 3 weeks ago |
|
|
abcae58661 | 4 weeks ago |
|
|
1ed7ab0d1b | 4 weeks ago |
|
|
f56f9772a3 | 4 weeks ago |
|
|
29b7bf7357 | 4 weeks ago |
|
|
691f99831c | 1 month ago |
|
|
8157f8d021 | 1 month ago |
|
|
a89c5fb0e4 | 1 month ago |
|
|
ed15fa1c72 | 1 month ago |
|
|
b1c34228b1 | 1 month ago |
|
|
29d1315410 | 1 month ago |
|
|
e4fed12b92 | 2 months ago |
|
|
472f059af7 | 2 months ago |
|
|
96dda8642a | 2 months ago |
|
|
f83b83c664 | 2 months ago |
@ -1,70 +1,72 @@
|
||||
# =========================
|
||||
# Build / CMake
|
||||
# =========================
|
||||
build/
|
||||
cmake-build-*/
|
||||
out/
|
||||
dist/
|
||||
|
||||
CMakeFiles/
|
||||
CMakeCache.txt
|
||||
cmake_install.cmake
|
||||
install_manifest.txt
|
||||
Makefile
|
||||
compile_commands.json
|
||||
.ninja_deps
|
||||
.ninja_log
|
||||
|
||||
# =========================
|
||||
# Generated / intermediate
|
||||
# =========================
|
||||
*.o
|
||||
*.obj
|
||||
*.a
|
||||
*.lib
|
||||
*.so
|
||||
*.dylib
|
||||
*.dll
|
||||
*.exe
|
||||
*.out
|
||||
!test/test_case/**/*.out
|
||||
*.app
|
||||
*.pdb
|
||||
*.ilk
|
||||
*.dSYM/
|
||||
|
||||
*.log
|
||||
*.tmp
|
||||
*.swp
|
||||
*.swo
|
||||
*.bak
|
||||
|
||||
# ANTLR 生成物(通常在 build/,这里额外兜底)
|
||||
**/generated/antlr4/
|
||||
**/antlr4-generated/
|
||||
*.tokens
|
||||
*.interp
|
||||
|
||||
# =========================
|
||||
# IDE / Editor
|
||||
# =========================
|
||||
.vscode/
|
||||
.idea/
|
||||
.fleet/
|
||||
.vs/
|
||||
*.code-workspace
|
||||
|
||||
# CLion
|
||||
cmake-build-debug/
|
||||
cmake-build-release/
|
||||
|
||||
# =========================
|
||||
# OS / misc
|
||||
# =========================
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# =========================
|
||||
# Project outputs
|
||||
# =========================
|
||||
test/test_result/
|
||||
# =========================
|
||||
# Build / CMake
|
||||
# =========================
|
||||
build/
|
||||
build_*/
|
||||
cmake-build-*/
|
||||
out/
|
||||
output/
|
||||
dist/
|
||||
|
||||
CMakeFiles/
|
||||
CMakeCache.txt
|
||||
cmake_install.cmake
|
||||
install_manifest.txt
|
||||
Makefile
|
||||
compile_commands.json
|
||||
.ninja_deps
|
||||
.ninja_log
|
||||
|
||||
# =========================
|
||||
# Generated / intermediate
|
||||
# =========================
|
||||
*.o
|
||||
*.obj
|
||||
*.a
|
||||
*.lib
|
||||
*.so
|
||||
*.dylib
|
||||
*.dll
|
||||
*.exe
|
||||
*.out
|
||||
!test/test_case/**/*.out
|
||||
*.app
|
||||
*.pdb
|
||||
*.ilk
|
||||
*.dSYM/
|
||||
|
||||
*.log
|
||||
*.tmp
|
||||
*.swp
|
||||
*.swo
|
||||
*.bak
|
||||
|
||||
# ANTLR 生成物(通常在 build/,这里额外兜底)
|
||||
**/generated/antlr4/
|
||||
**/antlr4-generated/
|
||||
*.tokens
|
||||
*.interp
|
||||
|
||||
# =========================
|
||||
# IDE / Editor
|
||||
# =========================
|
||||
.vscode/
|
||||
.idea/
|
||||
.fleet/
|
||||
.vs/
|
||||
*.code-workspace
|
||||
|
||||
# CLion
|
||||
cmake-build-debug/
|
||||
cmake-build-release/
|
||||
|
||||
# =========================
|
||||
# OS / misc
|
||||
# =========================
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# =========================
|
||||
# Project outputs
|
||||
# =========================
|
||||
test/test_result/
|
||||
|
||||
@ -0,0 +1,73 @@
|
||||
#pragma once
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
|
||||
class DominatorTree {
|
||||
public:
|
||||
explicit DominatorTree(Function& function);
|
||||
|
||||
void Recalculate();
|
||||
|
||||
Function& GetFunction() const { return *function_; }
|
||||
bool IsReachable(BasicBlock* block) const;
|
||||
bool Dominates(BasicBlock* dom, BasicBlock* node) const;
|
||||
bool Dominates(Instruction* dom, Instruction* user) const;
|
||||
BasicBlock* GetIDom(BasicBlock* block) const;
|
||||
const std::vector<BasicBlock*>& GetChildren(BasicBlock* block) const;
|
||||
const std::vector<BasicBlock*>& GetReversePostOrder() const {
|
||||
return reverse_post_order_;
|
||||
}
|
||||
|
||||
private:
|
||||
Function* function_ = nullptr;
|
||||
std::vector<BasicBlock*> reverse_post_order_;
|
||||
std::unordered_map<BasicBlock*, std::size_t> block_index_;
|
||||
std::vector<std::vector<std::uint8_t>> dominates_;
|
||||
std::unordered_map<BasicBlock*, BasicBlock*> immediate_dominator_;
|
||||
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dom_children_;
|
||||
};
|
||||
|
||||
struct Loop {
|
||||
BasicBlock* header = nullptr;
|
||||
std::unordered_set<BasicBlock*> blocks;
|
||||
std::vector<BasicBlock*> block_list;
|
||||
std::vector<BasicBlock*> latches;
|
||||
std::vector<BasicBlock*> exiting_blocks;
|
||||
std::vector<BasicBlock*> exit_blocks;
|
||||
BasicBlock* preheader = nullptr;
|
||||
Loop* parent = nullptr;
|
||||
std::vector<Loop*> subloops;
|
||||
|
||||
bool Contains(BasicBlock* block) const;
|
||||
bool Contains(const Loop* other) const;
|
||||
bool IsInnermost() const;
|
||||
};
|
||||
|
||||
class LoopInfo {
|
||||
public:
|
||||
LoopInfo(Function& function, const DominatorTree& dom_tree);
|
||||
|
||||
void Recalculate();
|
||||
|
||||
const std::vector<std::unique_ptr<Loop>>& GetLoops() const { return loops_; }
|
||||
std::vector<Loop*> GetTopLevelLoops() const;
|
||||
std::vector<Loop*> GetLoopsInPostOrder() const;
|
||||
Loop* GetLoopFor(BasicBlock* block) const;
|
||||
|
||||
private:
|
||||
Function* function_ = nullptr;
|
||||
const DominatorTree* dom_tree_ = nullptr;
|
||||
std::vector<std::unique_ptr<Loop>> loops_;
|
||||
std::vector<Loop*> top_level_loops_;
|
||||
std::unordered_map<BasicBlock*, Loop*> block_to_loop_;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
namespace ir {
|
||||
|
||||
class Module;
|
||||
|
||||
void RunMem2Reg(Module& module);
|
||||
bool RunConstFold(Module& module);
|
||||
bool RunConstProp(Module& module);
|
||||
bool RunFunctionInlining(Module& module);
|
||||
bool RunTailRecursionElim(Module& module);
|
||||
bool RunArithmeticSimplify(Module& module);
|
||||
bool RunCSE(Module& module);
|
||||
bool RunGVN(Module& module);
|
||||
bool RunLoadStoreElim(Module& module);
|
||||
bool RunDCE(Module& module);
|
||||
bool RunCFGSimplify(Module& module);
|
||||
bool RunLICM(Module& module);
|
||||
bool RunLoopMemoryPromotion(Module& module);
|
||||
bool RunLoopUnswitch(Module& module);
|
||||
bool RunLoopStrengthReduction(Module& module);
|
||||
bool RunLoopUnroll(Module& module);
|
||||
bool RunLoopFission(Module& module);
|
||||
void RunIRPassPipeline(Module& module);
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,41 @@
|
||||
#pragma once
|
||||
|
||||
#include <iterator>
|
||||
|
||||
namespace ir {
|
||||
|
||||
template <typename IterT> struct range {
|
||||
using iterator = IterT;
|
||||
using value_type = typename std::iterator_traits<iterator>::value_type;
|
||||
using reference = typename std::iterator_traits<iterator>::reference;
|
||||
|
||||
private:
|
||||
iterator b;
|
||||
iterator e;
|
||||
|
||||
public:
|
||||
explicit range(iterator b, iterator e) : b(b), e(e) {}
|
||||
iterator begin() { return b; }
|
||||
iterator end() { return e; }
|
||||
iterator begin() const { return b; }
|
||||
iterator end() const { return e; }
|
||||
auto size() const { return std::distance(b, e); }
|
||||
auto empty() const { return b == e; }
|
||||
};
|
||||
|
||||
//! create `range` object from iterator pair [begin, end)
|
||||
template <typename IterT> range<IterT> make_range(IterT b, IterT e) {
|
||||
return range<IterT>(b, e);
|
||||
}
|
||||
//! create `range` object from a container who has `begin()` and `end()` methods
|
||||
template <typename ContainerT>
|
||||
range<typename ContainerT::iterator> make_range(ContainerT &c) {
|
||||
return make_range(c.begin(), c.end());
|
||||
}
|
||||
//! create `range` object from a container who has `begin()` and `end()` methods
|
||||
template <typename ContainerT>
|
||||
range<typename ContainerT::const_iterator> make_range(const ContainerT &c) {
|
||||
return make_range(c.begin(), c.end());
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -1,119 +1,300 @@
|
||||
#pragma once
|
||||
|
||||
#include <initializer_list>
|
||||
#include <iosfwd>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
class Module;
|
||||
}
|
||||
|
||||
namespace mir {
|
||||
|
||||
class MIRContext {
|
||||
public:
|
||||
MIRContext() = default;
|
||||
};
|
||||
|
||||
MIRContext& DefaultContext();
|
||||
|
||||
enum class PhysReg { W0, W8, W9, X29, X30, SP };
|
||||
|
||||
const char* PhysRegName(PhysReg reg);
|
||||
|
||||
enum class Opcode {
|
||||
Prologue,
|
||||
Epilogue,
|
||||
MovImm,
|
||||
LoadStack,
|
||||
StoreStack,
|
||||
AddRR,
|
||||
Ret,
|
||||
};
|
||||
|
||||
class Operand {
|
||||
public:
|
||||
enum class Kind { Reg, Imm, FrameIndex };
|
||||
|
||||
static Operand Reg(PhysReg reg);
|
||||
static Operand Imm(int value);
|
||||
static Operand FrameIndex(int index);
|
||||
|
||||
Kind GetKind() const { return kind_; }
|
||||
PhysReg GetReg() const { return reg_; }
|
||||
int GetImm() const { return imm_; }
|
||||
int GetFrameIndex() const { return imm_; }
|
||||
|
||||
private:
|
||||
Operand(Kind kind, PhysReg reg, int imm);
|
||||
|
||||
Kind kind_;
|
||||
PhysReg reg_;
|
||||
int imm_;
|
||||
};
|
||||
|
||||
class MachineInstr {
|
||||
public:
|
||||
MachineInstr(Opcode opcode, std::vector<Operand> operands = {});
|
||||
|
||||
Opcode GetOpcode() const { return opcode_; }
|
||||
const std::vector<Operand>& GetOperands() const { return operands_; }
|
||||
|
||||
private:
|
||||
Opcode opcode_;
|
||||
std::vector<Operand> operands_;
|
||||
};
|
||||
|
||||
struct FrameSlot {
|
||||
int index = 0;
|
||||
int size = 4;
|
||||
int offset = 0;
|
||||
};
|
||||
|
||||
class MachineBasicBlock {
|
||||
public:
|
||||
explicit MachineBasicBlock(std::string name);
|
||||
|
||||
const std::string& GetName() const { return name_; }
|
||||
std::vector<MachineInstr>& GetInstructions() { return instructions_; }
|
||||
const std::vector<MachineInstr>& GetInstructions() const { return instructions_; }
|
||||
|
||||
MachineInstr& Append(Opcode opcode,
|
||||
std::initializer_list<Operand> operands = {});
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
std::vector<MachineInstr> instructions_;
|
||||
};
|
||||
|
||||
class MachineFunction {
|
||||
public:
|
||||
explicit MachineFunction(std::string name);
|
||||
|
||||
const std::string& GetName() const { return name_; }
|
||||
MachineBasicBlock& GetEntry() { return entry_; }
|
||||
const MachineBasicBlock& GetEntry() const { return entry_; }
|
||||
|
||||
int CreateFrameIndex(int size = 4);
|
||||
FrameSlot& GetFrameSlot(int index);
|
||||
const FrameSlot& GetFrameSlot(int index) const;
|
||||
const std::vector<FrameSlot>& GetFrameSlots() const { return frame_slots_; }
|
||||
|
||||
int GetFrameSize() const { return frame_size_; }
|
||||
void SetFrameSize(int size) { frame_size_ = size; }
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
MachineBasicBlock entry_;
|
||||
std::vector<FrameSlot> frame_slots_;
|
||||
int frame_size_ = 0;
|
||||
};
|
||||
|
||||
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module);
|
||||
void RunRegAlloc(MachineFunction& function);
|
||||
void RunFrameLowering(MachineFunction& function);
|
||||
void PrintAsm(const MachineFunction& function, std::ostream& os);
|
||||
|
||||
} // namespace mir
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <iosfwd>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
class Module;
|
||||
}
|
||||
|
||||
namespace mir {
|
||||
|
||||
class MIRContext {
|
||||
public:
|
||||
MIRContext() = default;
|
||||
};
|
||||
|
||||
MIRContext& DefaultContext();
|
||||
|
||||
enum class ValueType { Void, I1, I32, F32, Ptr };
|
||||
|
||||
enum class RegClass { GPR, FPR };
|
||||
|
||||
enum class CondCode { EQ, NE, LT, GT, LE, GE };
|
||||
|
||||
enum class StackObjectKind { Local, Spill, SavedGPR, SavedFPR };
|
||||
|
||||
enum class AddrBaseKind { None, FrameObject, Global, VReg };
|
||||
|
||||
enum class OperandKind { Invalid, VReg, Imm, Block, Symbol };
|
||||
|
||||
struct PhysReg {
|
||||
RegClass reg_class = RegClass::GPR;
|
||||
int index = -1;
|
||||
|
||||
bool IsValid() const { return index >= 0; }
|
||||
bool operator==(const PhysReg& rhs) const {
|
||||
return reg_class == rhs.reg_class && index == rhs.index;
|
||||
}
|
||||
};
|
||||
|
||||
bool IsGPR(ValueType type);
|
||||
bool IsFPR(ValueType type);
|
||||
int GetValueSize(ValueType type);
|
||||
int GetValueAlign(ValueType type);
|
||||
const char* GetPhysRegName(PhysReg reg, ValueType type);
|
||||
|
||||
class MachineOperand {
|
||||
public:
|
||||
MachineOperand() = default;
|
||||
static MachineOperand VReg(int reg);
|
||||
static MachineOperand Imm(std::int64_t value);
|
||||
static MachineOperand Block(std::string name);
|
||||
static MachineOperand Symbol(std::string name);
|
||||
|
||||
OperandKind GetKind() const { return kind_; }
|
||||
int GetVReg() const { return vreg_; }
|
||||
std::int64_t GetImm() const { return imm_; }
|
||||
const std::string& GetText() const { return text_; }
|
||||
|
||||
private:
|
||||
MachineOperand(OperandKind kind, int vreg, std::int64_t imm, std::string text);
|
||||
|
||||
OperandKind kind_ = OperandKind::Invalid;
|
||||
int vreg_ = -1;
|
||||
std::int64_t imm_ = 0;
|
||||
std::string text_;
|
||||
};
|
||||
|
||||
struct AddressExpr {
|
||||
AddrBaseKind base_kind = AddrBaseKind::None;
|
||||
int base_index = -1;
|
||||
std::string symbol;
|
||||
std::int64_t const_offset = 0;
|
||||
std::vector<std::pair<int, std::int64_t>> scaled_vregs;
|
||||
};
|
||||
|
||||
struct StackObject {
|
||||
int index = -1;
|
||||
StackObjectKind kind = StackObjectKind::Local;
|
||||
int size = 0;
|
||||
int align = 1;
|
||||
int offset = 0;
|
||||
std::string name;
|
||||
};
|
||||
|
||||
struct VRegInfo {
|
||||
int id = -1;
|
||||
ValueType type = ValueType::Void;
|
||||
};
|
||||
|
||||
struct Allocation {
|
||||
enum class Kind { Unassigned, PhysReg, Spill };
|
||||
|
||||
Kind kind = Kind::Unassigned;
|
||||
PhysReg phys;
|
||||
int stack_object = -1;
|
||||
};
|
||||
|
||||
class MachineInstr {
|
||||
public:
|
||||
enum class Opcode {
|
||||
Arg,
|
||||
Copy,
|
||||
Load,
|
||||
Store,
|
||||
Lea,
|
||||
Add,
|
||||
Sub,
|
||||
Mul,
|
||||
Div,
|
||||
Rem,
|
||||
ModMul,
|
||||
ModPow,
|
||||
DigitExtractPow2,
|
||||
And,
|
||||
Or,
|
||||
Xor,
|
||||
Shl,
|
||||
AShr,
|
||||
LShr,
|
||||
FAdd,
|
||||
FSub,
|
||||
FMul,
|
||||
FDiv,
|
||||
FSqrt,
|
||||
FNeg,
|
||||
ICmp,
|
||||
FCmp,
|
||||
ZExt,
|
||||
ItoF,
|
||||
FtoI,
|
||||
Br,
|
||||
CondBr,
|
||||
Call,
|
||||
Ret,
|
||||
Memset,
|
||||
Unreachable,
|
||||
};
|
||||
|
||||
explicit MachineInstr(Opcode opcode,
|
||||
std::vector<MachineOperand> operands = {});
|
||||
|
||||
Opcode GetOpcode() const { return opcode_; }
|
||||
const std::vector<MachineOperand>& GetOperands() const { return operands_; }
|
||||
std::vector<MachineOperand>& GetOperands() { return operands_; }
|
||||
|
||||
void SetCondCode(CondCode code) { cond_code_ = code; }
|
||||
CondCode GetCondCode() const { return cond_code_; }
|
||||
|
||||
void SetAddress(AddressExpr address) {
|
||||
address_ = std::move(address);
|
||||
has_address_ = true;
|
||||
}
|
||||
bool HasAddress() const { return has_address_; }
|
||||
const AddressExpr& GetAddress() const { return address_; }
|
||||
AddressExpr& GetAddress() { return address_; }
|
||||
|
||||
void SetCallInfo(std::string callee, std::vector<ValueType> arg_types,
|
||||
ValueType return_type) {
|
||||
callee_ = std::move(callee);
|
||||
call_arg_types_ = std::move(arg_types);
|
||||
call_return_type_ = return_type;
|
||||
}
|
||||
const std::string& GetCallee() const { return callee_; }
|
||||
const std::vector<ValueType>& GetCallArgTypes() const { return call_arg_types_; }
|
||||
ValueType GetCallReturnType() const { return call_return_type_; }
|
||||
|
||||
void SetValueType(ValueType type) { value_type_ = type; }
|
||||
ValueType GetValueType() const { return value_type_; }
|
||||
|
||||
bool IsTerminator() const;
|
||||
std::vector<int> GetDefs() const;
|
||||
std::vector<int> GetUses() const;
|
||||
|
||||
private:
|
||||
Opcode opcode_;
|
||||
std::vector<MachineOperand> operands_;
|
||||
CondCode cond_code_ = CondCode::EQ;
|
||||
AddressExpr address_;
|
||||
bool has_address_ = false;
|
||||
std::string callee_;
|
||||
std::vector<ValueType> call_arg_types_;
|
||||
ValueType call_return_type_ = ValueType::Void;
|
||||
ValueType value_type_ = ValueType::Void;
|
||||
};
|
||||
|
||||
class MachineBasicBlock {
|
||||
public:
|
||||
explicit MachineBasicBlock(std::string name);
|
||||
|
||||
const std::string& GetName() const { return name_; }
|
||||
std::vector<MachineInstr>& GetInstructions() { return instructions_; }
|
||||
const std::vector<MachineInstr>& GetInstructions() const { return instructions_; }
|
||||
|
||||
MachineInstr& Append(MachineInstr::Opcode opcode,
|
||||
std::vector<MachineOperand> operands = {});
|
||||
MachineInstr& Append(MachineInstr instr);
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
std::vector<MachineInstr> instructions_;
|
||||
};
|
||||
|
||||
class MachineFunction {
|
||||
public:
|
||||
MachineFunction(std::string name, ValueType return_type,
|
||||
std::vector<ValueType> param_types);
|
||||
|
||||
const std::string& GetName() const { return name_; }
|
||||
ValueType GetReturnType() const { return return_type_; }
|
||||
const std::vector<ValueType>& GetParamTypes() const { return param_types_; }
|
||||
|
||||
MachineBasicBlock* CreateBlock(const std::string& name);
|
||||
std::vector<std::unique_ptr<MachineBasicBlock>>& GetBlocks() { return blocks_; }
|
||||
const std::vector<std::unique_ptr<MachineBasicBlock>>& GetBlocks() const {
|
||||
return blocks_;
|
||||
}
|
||||
|
||||
int NewVReg(ValueType type);
|
||||
const VRegInfo& GetVRegInfo(int id) const;
|
||||
VRegInfo& GetVRegInfo(int id);
|
||||
const std::vector<VRegInfo>& GetVRegs() const { return vregs_; }
|
||||
|
||||
int CreateStackObject(int size, int align, StackObjectKind kind,
|
||||
const std::string& name = "");
|
||||
StackObject& GetStackObject(int index);
|
||||
const StackObject& GetStackObject(int index) const;
|
||||
const std::vector<StackObject>& GetStackObjects() const { return stack_objects_; }
|
||||
|
||||
void SetAllocation(int vreg, Allocation allocation);
|
||||
const Allocation& GetAllocation(int vreg) const;
|
||||
Allocation& GetAllocation(int vreg);
|
||||
|
||||
void AddUsedCalleeSavedGPR(int reg_index);
|
||||
void AddUsedCalleeSavedFPR(int reg_index);
|
||||
const std::vector<int>& GetUsedCalleeSavedGPRs() const {
|
||||
return used_callee_saved_gprs_;
|
||||
}
|
||||
const std::vector<int>& GetUsedCalleeSavedFPRs() const {
|
||||
return used_callee_saved_fprs_;
|
||||
}
|
||||
|
||||
void SetFrameSize(int size) { frame_size_ = size; }
|
||||
int GetFrameSize() const { return frame_size_; }
|
||||
|
||||
void SetMaxOutgoingArgBytes(int bytes) { max_outgoing_arg_bytes_ = bytes; }
|
||||
int GetMaxOutgoingArgBytes() const { return max_outgoing_arg_bytes_; }
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
ValueType return_type_ = ValueType::Void;
|
||||
std::vector<ValueType> param_types_;
|
||||
std::vector<std::unique_ptr<MachineBasicBlock>> blocks_;
|
||||
std::vector<VRegInfo> vregs_;
|
||||
std::vector<StackObject> stack_objects_;
|
||||
std::vector<Allocation> allocations_;
|
||||
std::vector<int> used_callee_saved_gprs_;
|
||||
std::vector<int> used_callee_saved_fprs_;
|
||||
int frame_size_ = 0;
|
||||
int max_outgoing_arg_bytes_ = 0;
|
||||
};
|
||||
|
||||
class MachineModule {
|
||||
public:
|
||||
explicit MachineModule(const ir::Module& source) : source_(&source) {}
|
||||
|
||||
const ir::Module& GetSourceModule() const { return *source_; }
|
||||
std::vector<std::unique_ptr<MachineFunction>>& GetFunctions() { return functions_; }
|
||||
const std::vector<std::unique_ptr<MachineFunction>>& GetFunctions() const {
|
||||
return functions_;
|
||||
}
|
||||
|
||||
MachineFunction* AddFunction(std::unique_ptr<MachineFunction> function);
|
||||
|
||||
private:
|
||||
const ir::Module* source_ = nullptr;
|
||||
std::vector<std::unique_ptr<MachineFunction>> functions_;
|
||||
};
|
||||
|
||||
std::unique_ptr<MachineModule> LowerToMIR(const ir::Module& module);
|
||||
bool RunPeephole(MachineModule& module);
|
||||
bool RunSpillReduction(MachineModule& module);
|
||||
bool RunCFGCleanup(MachineModule& module);
|
||||
void RunMIRPreRegAllocPassPipeline(MachineModule& module);
|
||||
void RunMIRPostRegAllocPassPipeline(MachineModule& module);
|
||||
void RunAddressHoisting(MachineModule& module);
|
||||
void RunRegAlloc(MachineModule& module);
|
||||
void RunFrameLowering(MachineModule& module);
|
||||
void PrintAsm(const MachineModule& module, std::ostream& os);
|
||||
|
||||
} // namespace mir
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||
BUILD_DIR="$REPO_ROOT/build_lab1"
|
||||
COMPILER="$BUILD_DIR/bin/compiler"
|
||||
ANTLR_JAR="$REPO_ROOT/third_party/antlr-4.13.2-complete.jar"
|
||||
RUN_ROOT="$REPO_ROOT/output/logs/lab1"
|
||||
RUN_NAME="lab1_$(date +%Y%m%d_%H%M%S)"
|
||||
RUN_DIR="$RUN_ROOT/$RUN_NAME"
|
||||
WHOLE_LOG="$RUN_DIR/whole.log"
|
||||
FAIL_DIR="$RUN_DIR/failures"
|
||||
LEGACY_SAVE_TREE=false
|
||||
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m'
|
||||
|
||||
TEST_DIRS=()
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--save-tree)
|
||||
LEGACY_SAVE_TREE=true
|
||||
;;
|
||||
*)
|
||||
TEST_DIRS+=("$1")
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
mkdir -p "$RUN_DIR"
|
||||
: > "$WHOLE_LOG"
|
||||
|
||||
log_plain() {
|
||||
printf '%s\n' "$*"
|
||||
printf '%s\n' "$*" >> "$WHOLE_LOG"
|
||||
}
|
||||
|
||||
log_color() {
|
||||
local color="$1"
|
||||
shift
|
||||
local message="$*"
|
||||
printf '%b%s%b\n' "$color" "$message" "$NC"
|
||||
printf '%s\n' "$message" >> "$WHOLE_LOG"
|
||||
}
|
||||
|
||||
append_file_to_whole_log() {
|
||||
local title="$1"
|
||||
local file="$2"
|
||||
{
|
||||
printf '\n===== %s =====\n' "$title"
|
||||
cat "$file"
|
||||
printf '\n'
|
||||
} >> "$WHOLE_LOG"
|
||||
}
|
||||
|
||||
cleanup_tmp_dir() {
|
||||
local dir="$1"
|
||||
if [[ -d "$dir" ]]; then
|
||||
rm -rf "$dir"
|
||||
fi
|
||||
}
|
||||
|
||||
discover_default_test_dirs() {
|
||||
local roots=(
|
||||
"$REPO_ROOT/test/test_case"
|
||||
"$REPO_ROOT/test/class_test_case"
|
||||
)
|
||||
local root
|
||||
for root in "${roots[@]}"; do
|
||||
[[ -d "$root" ]] || continue
|
||||
find "$root" -mindepth 1 -maxdepth 1 -type d -print0
|
||||
done | sort -z
|
||||
}
|
||||
|
||||
prune_empty_run_dirs() {
|
||||
if [[ -d "$RUN_DIR/.tmp" ]]; then
|
||||
rmdir "$RUN_DIR/.tmp" 2>/dev/null || true
|
||||
fi
|
||||
if [[ -d "$FAIL_DIR" ]]; then
|
||||
rmdir "$FAIL_DIR" 2>/dev/null || true
|
||||
fi
|
||||
}
|
||||
|
||||
if [[ ${#TEST_DIRS[@]} -eq 0 ]]; then
|
||||
while IFS= read -r -d '' test_dir; do
|
||||
TEST_DIRS+=("$test_dir")
|
||||
done < <(discover_default_test_dirs)
|
||||
fi
|
||||
|
||||
log_plain "Run directory: $RUN_DIR"
|
||||
log_plain "Whole log: $WHOLE_LOG"
|
||||
if [[ "$LEGACY_SAVE_TREE" == true ]]; then
|
||||
log_color "$YELLOW" "Warning: --save-tree is deprecated; successful case artifacts will still be deleted."
|
||||
fi
|
||||
|
||||
log_plain "==> [1/3] Generate ANTLR Lexer/Parser"
|
||||
mkdir -p "$BUILD_DIR/generated/antlr4"
|
||||
if ! java -jar "$ANTLR_JAR" \
|
||||
-Dlanguage=Cpp \
|
||||
-visitor -no-listener \
|
||||
-Xexact-output-dir \
|
||||
-o "$BUILD_DIR/generated/antlr4" \
|
||||
"$REPO_ROOT/src/antlr4/SysY.g4" >> "$WHOLE_LOG" 2>&1; then
|
||||
log_color "$RED" "ANTLR generation failed. See $WHOLE_LOG"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_plain "==> [2/3] Configure and build parse-only compiler"
|
||||
if ! cmake -S "$REPO_ROOT" -B "$BUILD_DIR" -DCMAKE_BUILD_TYPE=Release -DCOMPILER_PARSE_ONLY=ON >> "$WHOLE_LOG" 2>&1; then
|
||||
log_color "$RED" "CMake configure failed. See $WHOLE_LOG"
|
||||
exit 1
|
||||
fi
|
||||
if ! cmake --build "$BUILD_DIR" -j "$(nproc)" >> "$WHOLE_LOG" 2>&1; then
|
||||
log_color "$RED" "Compiler build failed. See $WHOLE_LOG"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_plain "==> [3/3] Run parse validation suite"
|
||||
PASS=0
|
||||
FAIL=0
|
||||
FAIL_LIST=()
|
||||
|
||||
test_one() {
|
||||
local sy_file="$1"
|
||||
local rel="$2"
|
||||
local safe_name="${rel//\//_}"
|
||||
local case_key="${safe_name%.sy}"
|
||||
local tmp_dir="$RUN_DIR/.tmp/$case_key"
|
||||
local fail_case_dir="$FAIL_DIR/$case_key"
|
||||
local tree_file="$tmp_dir/parse.tree"
|
||||
local case_log="$tmp_dir/error.log"
|
||||
|
||||
cleanup_tmp_dir "$tmp_dir"
|
||||
cleanup_tmp_dir "$fail_case_dir"
|
||||
mkdir -p "$tmp_dir"
|
||||
|
||||
if "$COMPILER" --emit-parse-tree "$sy_file" > "$tree_file" 2> "$case_log"; then
|
||||
cleanup_tmp_dir "$tmp_dir"
|
||||
return 0
|
||||
fi
|
||||
|
||||
mkdir -p "$FAIL_DIR"
|
||||
{
|
||||
printf 'Command: %s --emit-parse-tree %s\n' "$COMPILER" "$sy_file"
|
||||
if [[ -s "$case_log" ]]; then
|
||||
printf '\n'
|
||||
cat "$case_log"
|
||||
fi
|
||||
} > "$tmp_dir/error.log.tmp"
|
||||
mv "$tmp_dir/error.log.tmp" "$case_log"
|
||||
mv "$tmp_dir" "$fail_case_dir"
|
||||
append_file_to_whole_log "$rel" "$fail_case_dir/error.log"
|
||||
return 1
|
||||
}
|
||||
|
||||
for test_dir in "${TEST_DIRS[@]}"; do
|
||||
if [[ ! -d "$test_dir" ]]; then
|
||||
log_color "$YELLOW" "skip missing dir: $test_dir"
|
||||
continue
|
||||
fi
|
||||
|
||||
while IFS= read -r -d '' sy_file; do
|
||||
rel="$(realpath --relative-to="$REPO_ROOT" "$sy_file")"
|
||||
if test_one "$sy_file" "$rel"; then
|
||||
log_color "$GREEN" "PASS $rel"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
log_color "$RED" "FAIL $rel"
|
||||
FAIL=$((FAIL + 1))
|
||||
FAIL_LIST+=("$rel")
|
||||
fi
|
||||
done < <(find "$test_dir" -maxdepth 1 -type f -name '*.sy' -print0 | sort -z)
|
||||
done
|
||||
|
||||
prune_empty_run_dirs
|
||||
|
||||
log_plain ""
|
||||
log_plain "summary: ${PASS} PASS / ${FAIL} FAIL / total $((PASS + FAIL))"
|
||||
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
|
||||
log_plain "failed cases:"
|
||||
for f in "${FAIL_LIST[@]}"; do
|
||||
safe_name="${f//\//_}"
|
||||
log_plain "- $f"
|
||||
log_plain " artifacts: $FAIL_DIR/${safe_name%.sy}"
|
||||
done
|
||||
else
|
||||
log_plain "all successful case artifacts have been deleted automatically."
|
||||
fi
|
||||
log_plain "whole log saved to: $WHOLE_LOG"
|
||||
|
||||
[[ $FAIL -eq 0 ]]
|
||||
@ -0,0 +1,288 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||
VERIFY_SCRIPT="$REPO_ROOT/scripts/verify_ir.sh"
|
||||
BUILD_DIR="$REPO_ROOT/build_lab2"
|
||||
RUN_ROOT="$REPO_ROOT/output/logs/lab2"
|
||||
LAST_RUN_FILE="$RUN_ROOT/last_run.txt"
|
||||
LAST_FAILED_FILE="$RUN_ROOT/last_failed.txt"
|
||||
RUN_NAME="lab2_$(date +%Y%m%d_%H%M%S)"
|
||||
RUN_DIR="$RUN_ROOT/$RUN_NAME"
|
||||
WHOLE_LOG="$RUN_DIR/whole.log"
|
||||
FAIL_DIR="$RUN_DIR/failures"
|
||||
LEGACY_SAVE_IR=false
|
||||
FAILED_ONLY=false
|
||||
FALLBACK_TO_FULL=false
|
||||
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m'
|
||||
|
||||
TEST_DIRS=()
|
||||
TEST_FILES=()
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--save-ir)
|
||||
LEGACY_SAVE_IR=true
|
||||
;;
|
||||
--failed-only)
|
||||
FAILED_ONLY=true
|
||||
;;
|
||||
*)
|
||||
if [[ -f "$1" ]]; then
|
||||
TEST_FILES+=("$1")
|
||||
else
|
||||
TEST_DIRS+=("$1")
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
mkdir -p "$RUN_DIR"
|
||||
: > "$WHOLE_LOG"
|
||||
printf '%s\n' "$RUN_DIR" > "$LAST_RUN_FILE"
|
||||
|
||||
log_plain() {
|
||||
printf '%s\n' "$*"
|
||||
printf '%s\n' "$*" >> "$WHOLE_LOG"
|
||||
}
|
||||
|
||||
log_color() {
|
||||
local color="$1"
|
||||
shift
|
||||
local message="$*"
|
||||
printf '%b%s%b\n' "$color" "$message" "$NC"
|
||||
printf '%s\n' "$message" >> "$WHOLE_LOG"
|
||||
}
|
||||
|
||||
append_file_to_whole_log() {
|
||||
local title="$1"
|
||||
local file="$2"
|
||||
{
|
||||
printf '\n===== %s =====\n' "$title"
|
||||
cat "$file"
|
||||
printf '\n'
|
||||
} >> "$WHOLE_LOG"
|
||||
}
|
||||
|
||||
cleanup_tmp_dir() {
|
||||
local dir="$1"
|
||||
if [[ -d "$dir" ]]; then
|
||||
rm -rf "$dir"
|
||||
fi
|
||||
}
|
||||
|
||||
discover_default_test_dirs() {
|
||||
local roots=(
|
||||
"$REPO_ROOT/test/test_case"
|
||||
"$REPO_ROOT/test/class_test_case"
|
||||
)
|
||||
local root
|
||||
for root in "${roots[@]}"; do
|
||||
[[ -d "$root" ]] || continue
|
||||
find "$root" -mindepth 1 -maxdepth 1 -type d -print0
|
||||
done | sort -z
|
||||
}
|
||||
|
||||
prune_empty_run_dirs() {
|
||||
if [[ -d "$RUN_DIR/.tmp" ]]; then
|
||||
rmdir "$RUN_DIR/.tmp" 2>/dev/null || true
|
||||
fi
|
||||
if [[ -d "$FAIL_DIR" ]]; then
|
||||
rmdir "$FAIL_DIR" 2>/dev/null || true
|
||||
fi
|
||||
}
|
||||
|
||||
now_ns() {
|
||||
date +%s%N
|
||||
}
|
||||
|
||||
format_duration_ns() {
|
||||
local ns="$1"
|
||||
local sec=$((ns / 1000000000))
|
||||
local ms=$(((ns % 1000000000) / 1000000))
|
||||
printf '%d.%03ds' "$sec" "$ms"
|
||||
}
|
||||
|
||||
is_transient_io_failure() {
|
||||
local log_file="$1"
|
||||
[[ -f "$log_file" ]] || return 1
|
||||
grep -Eq \
|
||||
'Permission denied|Text file busy|Device or resource busy|Stale file handle|Input/output error|Resource temporarily unavailable|Read-only file system' \
|
||||
"$log_file"
|
||||
}
|
||||
|
||||
test_one() {
|
||||
local sy_file="$1"
|
||||
local rel="$2"
|
||||
local safe_name="${rel//\//_}"
|
||||
local case_key="${safe_name%.sy}"
|
||||
local tmp_dir="$RUN_DIR/.tmp/$case_key"
|
||||
local fail_case_dir="$FAIL_DIR/$case_key"
|
||||
local case_log="$tmp_dir/error.log"
|
||||
local attempt=1
|
||||
|
||||
cleanup_tmp_dir "$fail_case_dir"
|
||||
|
||||
while true; do
|
||||
cleanup_tmp_dir "$tmp_dir"
|
||||
mkdir -p "$tmp_dir"
|
||||
|
||||
if "$VERIFY_SCRIPT" "$sy_file" "$tmp_dir" --run > "$case_log" 2>&1; then
|
||||
cleanup_tmp_dir "$tmp_dir"
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [[ $attempt -eq 1 ]] && is_transient_io_failure "$case_log"; then
|
||||
log_color "$YELLOW" "RETRY $rel (transient I/O failure)"
|
||||
attempt=$((attempt + 1))
|
||||
continue
|
||||
fi
|
||||
|
||||
break
|
||||
done
|
||||
|
||||
mkdir -p "$FAIL_DIR"
|
||||
mv "$tmp_dir" "$fail_case_dir"
|
||||
append_file_to_whole_log "$rel" "$fail_case_dir/error.log"
|
||||
return 1
|
||||
}
|
||||
|
||||
run_case() {
|
||||
local sy_file="$1"
|
||||
local rel
|
||||
local case_start_ns
|
||||
local case_end_ns
|
||||
local case_elapsed_ns
|
||||
rel="$(realpath --relative-to="$REPO_ROOT" "$sy_file")"
|
||||
case_start_ns=$(now_ns)
|
||||
|
||||
if test_one "$sy_file" "$rel"; then
|
||||
case_end_ns=$(now_ns)
|
||||
case_elapsed_ns=$((case_end_ns - case_start_ns))
|
||||
log_color "$GREEN" "PASS $rel [$(format_duration_ns "$case_elapsed_ns")]"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
case_end_ns=$(now_ns)
|
||||
case_elapsed_ns=$((case_end_ns - case_start_ns))
|
||||
log_color "$RED" "FAIL $rel [$(format_duration_ns "$case_elapsed_ns")]"
|
||||
FAIL=$((FAIL + 1))
|
||||
FAIL_LIST+=("$rel")
|
||||
fi
|
||||
}
|
||||
|
||||
TOTAL_START_NS=$(now_ns)
|
||||
|
||||
if [[ "$FAILED_ONLY" == true ]]; then
|
||||
if [[ -f "$LAST_FAILED_FILE" ]]; then
|
||||
while IFS= read -r sy_file; do
|
||||
[[ -n "$sy_file" ]] || continue
|
||||
[[ -f "$sy_file" ]] || continue
|
||||
TEST_FILES+=("$sy_file")
|
||||
done < "$LAST_FAILED_FILE"
|
||||
fi
|
||||
|
||||
if [[ ${#TEST_FILES[@]} -eq 0 ]]; then
|
||||
FALLBACK_TO_FULL=true
|
||||
FAILED_ONLY=false
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ "$FAILED_ONLY" == false && ${#TEST_DIRS[@]} -eq 0 && ${#TEST_FILES[@]} -eq 0 ]]; then
|
||||
while IFS= read -r -d '' test_dir; do
|
||||
TEST_DIRS+=("$test_dir")
|
||||
done < <(discover_default_test_dirs)
|
||||
fi
|
||||
|
||||
log_plain "Run directory: $RUN_DIR"
|
||||
log_plain "Whole log: $WHOLE_LOG"
|
||||
if [[ "$LEGACY_SAVE_IR" == true ]]; then
|
||||
log_color "$YELLOW" "Warning: --save-ir is deprecated; successful case artifacts will still be deleted."
|
||||
fi
|
||||
if [[ "$FAILED_ONLY" == true ]]; then
|
||||
log_plain "Mode: rerun cached failed cases only"
|
||||
fi
|
||||
if [[ "$FALLBACK_TO_FULL" == true ]]; then
|
||||
log_color "$YELLOW" "No cached failed cases found, fallback to full suite."
|
||||
fi
|
||||
|
||||
if [[ ! -f "$VERIFY_SCRIPT" ]]; then
|
||||
log_color "$RED" "missing verify script: $VERIFY_SCRIPT"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_plain "==> [1/2] Configure and build compiler"
|
||||
BUILD_START_NS=$(now_ns)
|
||||
if ! cmake -S "$REPO_ROOT" -B "$BUILD_DIR" >> "$WHOLE_LOG" 2>&1; then
|
||||
log_color "$RED" "CMake configure failed. See $WHOLE_LOG"
|
||||
exit 1
|
||||
fi
|
||||
if ! cmake --build "$BUILD_DIR" -j "$(nproc)" >> "$WHOLE_LOG" 2>&1; then
|
||||
log_color "$RED" "Compiler build failed. See $WHOLE_LOG"
|
||||
exit 1
|
||||
fi
|
||||
BUILD_END_NS=$(now_ns)
|
||||
BUILD_ELAPSED_NS=$((BUILD_END_NS - BUILD_START_NS))
|
||||
|
||||
log_plain "==> [2/2] Run IR validation suite"
|
||||
VALIDATION_START_NS=$(now_ns)
|
||||
PASS=0
|
||||
FAIL=0
|
||||
FAIL_LIST=()
|
||||
|
||||
if [[ "$FAILED_ONLY" == true ]]; then
|
||||
for sy_file in "${TEST_FILES[@]}"; do
|
||||
run_case "$sy_file"
|
||||
done
|
||||
else
|
||||
for sy_file in "${TEST_FILES[@]}"; do
|
||||
run_case "$sy_file"
|
||||
done
|
||||
|
||||
for test_dir in "${TEST_DIRS[@]}"; do
|
||||
if [[ ! -d "$test_dir" ]]; then
|
||||
log_color "$YELLOW" "skip missing dir: $test_dir"
|
||||
continue
|
||||
fi
|
||||
|
||||
while IFS= read -r -d '' sy_file; do
|
||||
run_case "$sy_file"
|
||||
done < <(find "$test_dir" -maxdepth 1 -type f -name '*.sy' -print0 | sort -z)
|
||||
done
|
||||
fi
|
||||
|
||||
rm -f "$LAST_FAILED_FILE"
|
||||
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
|
||||
for f in "${FAIL_LIST[@]}"; do
|
||||
printf '%s/%s\n' "$REPO_ROOT" "$f" >> "$LAST_FAILED_FILE"
|
||||
done
|
||||
fi
|
||||
|
||||
prune_empty_run_dirs
|
||||
VALIDATION_END_NS=$(now_ns)
|
||||
VALIDATION_ELAPSED_NS=$((VALIDATION_END_NS - VALIDATION_START_NS))
|
||||
TOTAL_END_NS=$(now_ns)
|
||||
TOTAL_ELAPSED_NS=$((TOTAL_END_NS - TOTAL_START_NS))
|
||||
|
||||
log_plain ""
|
||||
log_plain "summary: ${PASS} PASS / ${FAIL} FAIL / total $((PASS + FAIL))"
|
||||
log_plain "build elapsed: $(format_duration_ns "$BUILD_ELAPSED_NS")"
|
||||
log_plain "validation elapsed: $(format_duration_ns "$VALIDATION_ELAPSED_NS")"
|
||||
log_plain "total elapsed: $(format_duration_ns "$TOTAL_ELAPSED_NS")"
|
||||
if [[ ${#FAIL_LIST[@]} -gt 0 ]]; then
|
||||
log_plain "failed cases:"
|
||||
for f in "${FAIL_LIST[@]}"; do
|
||||
safe_name="${f//\//_}"
|
||||
log_plain "- $f"
|
||||
log_plain " artifacts: $FAIL_DIR/${safe_name%.sy}"
|
||||
done
|
||||
else
|
||||
log_plain "all successful case artifacts have been deleted automatically."
|
||||
fi
|
||||
log_plain "whole log saved to: $WHOLE_LOG"
|
||||
|
||||
[[ $FAIL -eq 0 ]]
|
||||
@ -0,0 +1,591 @@
|
||||
// Generated from /root/sysy2026/nudt-compiler-cpp/src/antlr4/SysY.g4 by ANTLR 4.13.1
|
||||
|
||||
import org.antlr.v4.runtime.ParserRuleContext;
|
||||
import org.antlr.v4.runtime.tree.ErrorNode;
|
||||
import org.antlr.v4.runtime.tree.TerminalNode;
|
||||
|
||||
/**
|
||||
* This class provides an empty implementation of {@link SysYListener},
|
||||
* which can be extended to create a listener which only needs to handle a subset
|
||||
* of the available methods.
|
||||
*/
|
||||
@SuppressWarnings("CheckReturnValue")
|
||||
public class SysYBaseListener implements SysYListener {
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterCompUnit(SysYParser.CompUnitContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitCompUnit(SysYParser.CompUnitContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterDecl(SysYParser.DeclContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitDecl(SysYParser.DeclContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterConstDecl(SysYParser.ConstDeclContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitConstDecl(SysYParser.ConstDeclContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterBType(SysYParser.BTypeContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitBType(SysYParser.BTypeContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterConstDef(SysYParser.ConstDefContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitConstDef(SysYParser.ConstDefContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterConstInitVal(SysYParser.ConstInitValContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitConstInitVal(SysYParser.ConstInitValContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterVarDecl(SysYParser.VarDeclContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitVarDecl(SysYParser.VarDeclContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterVarDef(SysYParser.VarDefContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitVarDef(SysYParser.VarDefContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterInitVal(SysYParser.InitValContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitInitVal(SysYParser.InitValContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterFuncDef(SysYParser.FuncDefContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitFuncDef(SysYParser.FuncDefContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterFuncType(SysYParser.FuncTypeContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitFuncType(SysYParser.FuncTypeContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterFuncFParams(SysYParser.FuncFParamsContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitFuncFParams(SysYParser.FuncFParamsContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterFuncFParam(SysYParser.FuncFParamContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitFuncFParam(SysYParser.FuncFParamContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterBlock(SysYParser.BlockContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitBlock(SysYParser.BlockContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterBlockItem(SysYParser.BlockItemContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitBlockItem(SysYParser.BlockItemContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterAssignStmt(SysYParser.AssignStmtContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitAssignStmt(SysYParser.AssignStmtContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterExpStmt(SysYParser.ExpStmtContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitExpStmt(SysYParser.ExpStmtContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterBlockStmt(SysYParser.BlockStmtContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitBlockStmt(SysYParser.BlockStmtContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterIfStmt(SysYParser.IfStmtContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitIfStmt(SysYParser.IfStmtContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterWhileStmt(SysYParser.WhileStmtContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitWhileStmt(SysYParser.WhileStmtContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterBreakStmt(SysYParser.BreakStmtContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitBreakStmt(SysYParser.BreakStmtContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterContinueStmt(SysYParser.ContinueStmtContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitContinueStmt(SysYParser.ContinueStmtContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterReturnStmt(SysYParser.ReturnStmtContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitReturnStmt(SysYParser.ReturnStmtContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterExp(SysYParser.ExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitExp(SysYParser.ExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterCond(SysYParser.CondContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitCond(SysYParser.CondContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterLVal(SysYParser.LValContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitLVal(SysYParser.LValContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterPrimaryExp(SysYParser.PrimaryExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitPrimaryExp(SysYParser.PrimaryExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterNumber(SysYParser.NumberContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitNumber(SysYParser.NumberContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterPrimaryUnaryExp(SysYParser.PrimaryUnaryExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitPrimaryUnaryExp(SysYParser.PrimaryUnaryExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterCallUnaryExp(SysYParser.CallUnaryExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitCallUnaryExp(SysYParser.CallUnaryExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterOpUnaryExp(SysYParser.OpUnaryExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitOpUnaryExp(SysYParser.OpUnaryExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterUnaryOp(SysYParser.UnaryOpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitUnaryOp(SysYParser.UnaryOpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterFuncRParams(SysYParser.FuncRParamsContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitFuncRParams(SysYParser.FuncRParamsContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterBinaryMulExp(SysYParser.BinaryMulExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitBinaryMulExp(SysYParser.BinaryMulExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterUnaryMulExp(SysYParser.UnaryMulExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitUnaryMulExp(SysYParser.UnaryMulExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterBinaryAddExp(SysYParser.BinaryAddExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitBinaryAddExp(SysYParser.BinaryAddExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterMulAddExp(SysYParser.MulAddExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitMulAddExp(SysYParser.MulAddExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterAddRelExp(SysYParser.AddRelExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitAddRelExp(SysYParser.AddRelExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterBinaryRelExp(SysYParser.BinaryRelExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitBinaryRelExp(SysYParser.BinaryRelExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterBinaryEqExp(SysYParser.BinaryEqExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitBinaryEqExp(SysYParser.BinaryEqExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterRelEqExp(SysYParser.RelEqExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitRelEqExp(SysYParser.RelEqExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterEqLAndExp(SysYParser.EqLAndExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitEqLAndExp(SysYParser.EqLAndExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterBinaryLAndExp(SysYParser.BinaryLAndExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitBinaryLAndExp(SysYParser.BinaryLAndExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterAndLOrExp(SysYParser.AndLOrExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitAndLOrExp(SysYParser.AndLOrExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterBinaryLOrExp(SysYParser.BinaryLOrExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitBinaryLOrExp(SysYParser.BinaryLOrExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterConstExp(SysYParser.ConstExpContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitConstExp(SysYParser.ConstExpContext ctx) { }
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void enterEveryRule(ParserRuleContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void exitEveryRule(ParserRuleContext ctx) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void visitTerminal(TerminalNode node) { }
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* <p>The default implementation does nothing.</p>
|
||||
*/
|
||||
@Override public void visitErrorNode(ErrorNode node) { }
|
||||
}
|
||||
@ -0,0 +1,358 @@
|
||||
// Generated from /root/sysy2026/nudt-compiler-cpp/src/antlr4/SysY.g4 by ANTLR 4.13.1
|
||||
import org.antlr.v4.runtime.Lexer;
|
||||
import org.antlr.v4.runtime.CharStream;
|
||||
import org.antlr.v4.runtime.Token;
|
||||
import org.antlr.v4.runtime.TokenStream;
|
||||
import org.antlr.v4.runtime.*;
|
||||
import org.antlr.v4.runtime.atn.*;
|
||||
import org.antlr.v4.runtime.dfa.DFA;
|
||||
import org.antlr.v4.runtime.misc.*;
|
||||
|
||||
@SuppressWarnings({"all", "warnings", "unchecked", "unused", "cast", "CheckReturnValue", "this-escape"})
|
||||
public class SysYLexer extends Lexer {
|
||||
static { RuntimeMetaData.checkVersion("4.13.1", RuntimeMetaData.VERSION); }
|
||||
|
||||
protected static final DFA[] _decisionToDFA;
|
||||
protected static final PredictionContextCache _sharedContextCache =
|
||||
new PredictionContextCache();
|
||||
public static final int
|
||||
CONST=1, INT=2, FLOAT=3, VOID=4, IF=5, ELSE=6, WHILE=7, BREAK=8, CONTINUE=9,
|
||||
RETURN=10, ADD=11, SUB=12, MUL=13, DIV=14, MOD=15, ASSIGN=16, EQ=17, NE=18,
|
||||
LT=19, LE=20, GT=21, GE=22, NOT=23, AND=24, OR=25, LPAREN=26, RPAREN=27,
|
||||
LBRACK=28, RBRACK=29, LBRACE=30, RBRACE=31, COMMA=32, SEMI=33, Ident=34,
|
||||
IntConst=35, FloatConst=36, WS=37, LINE_COMMENT=38, BLOCK_COMMENT=39;
|
||||
public static String[] channelNames = {
|
||||
"DEFAULT_TOKEN_CHANNEL", "HIDDEN"
|
||||
};
|
||||
|
||||
public static String[] modeNames = {
|
||||
"DEFAULT_MODE"
|
||||
};
|
||||
|
||||
private static String[] makeRuleNames() {
|
||||
return new String[] {
|
||||
"CONST", "INT", "FLOAT", "VOID", "IF", "ELSE", "WHILE", "BREAK", "CONTINUE",
|
||||
"RETURN", "ADD", "SUB", "MUL", "DIV", "MOD", "ASSIGN", "EQ", "NE", "LT",
|
||||
"LE", "GT", "GE", "NOT", "AND", "OR", "LPAREN", "RPAREN", "LBRACK", "RBRACK",
|
||||
"LBRACE", "RBRACE", "COMMA", "SEMI", "Ident", "Digit", "NonzeroDigit",
|
||||
"OctDigit", "HexDigit", "DecInteger", "OctInteger", "HexInteger", "DecFraction",
|
||||
"DecExponent", "DecFloat", "HexFraction", "BinExponent", "HexFloat",
|
||||
"IntConst", "FloatConst", "WS", "LINE_COMMENT", "BLOCK_COMMENT"
|
||||
};
|
||||
}
|
||||
public static final String[] ruleNames = makeRuleNames();
|
||||
|
||||
private static String[] makeLiteralNames() {
|
||||
return new String[] {
|
||||
null, "'const'", "'int'", "'float'", "'void'", "'if'", "'else'", "'while'",
|
||||
"'break'", "'continue'", "'return'", "'+'", "'-'", "'*'", "'/'", "'%'",
|
||||
"'='", "'=='", "'!='", "'<'", "'<='", "'>'", "'>='", "'!'", "'&&'", "'||'",
|
||||
"'('", "')'", "'['", "']'", "'{'", "'}'", "','", "';'"
|
||||
};
|
||||
}
|
||||
private static final String[] _LITERAL_NAMES = makeLiteralNames();
|
||||
private static String[] makeSymbolicNames() {
|
||||
return new String[] {
|
||||
null, "CONST", "INT", "FLOAT", "VOID", "IF", "ELSE", "WHILE", "BREAK",
|
||||
"CONTINUE", "RETURN", "ADD", "SUB", "MUL", "DIV", "MOD", "ASSIGN", "EQ",
|
||||
"NE", "LT", "LE", "GT", "GE", "NOT", "AND", "OR", "LPAREN", "RPAREN",
|
||||
"LBRACK", "RBRACK", "LBRACE", "RBRACE", "COMMA", "SEMI", "Ident", "IntConst",
|
||||
"FloatConst", "WS", "LINE_COMMENT", "BLOCK_COMMENT"
|
||||
};
|
||||
}
|
||||
private static final String[] _SYMBOLIC_NAMES = makeSymbolicNames();
|
||||
public static final Vocabulary VOCABULARY = new VocabularyImpl(_LITERAL_NAMES, _SYMBOLIC_NAMES);
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #VOCABULARY} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public static final String[] tokenNames;
|
||||
static {
|
||||
tokenNames = new String[_SYMBOLIC_NAMES.length];
|
||||
for (int i = 0; i < tokenNames.length; i++) {
|
||||
tokenNames[i] = VOCABULARY.getLiteralName(i);
|
||||
if (tokenNames[i] == null) {
|
||||
tokenNames[i] = VOCABULARY.getSymbolicName(i);
|
||||
}
|
||||
|
||||
if (tokenNames[i] == null) {
|
||||
tokenNames[i] = "<INVALID>";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@Deprecated
|
||||
public String[] getTokenNames() {
|
||||
return tokenNames;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
public Vocabulary getVocabulary() {
|
||||
return VOCABULARY;
|
||||
}
|
||||
|
||||
|
||||
public SysYLexer(CharStream input) {
|
||||
super(input);
|
||||
_interp = new LexerATNSimulator(this,_ATN,_decisionToDFA,_sharedContextCache);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getGrammarFileName() { return "SysY.g4"; }
|
||||
|
||||
@Override
|
||||
public String[] getRuleNames() { return ruleNames; }
|
||||
|
||||
@Override
|
||||
public String getSerializedATN() { return _serializedATN; }
|
||||
|
||||
@Override
|
||||
public String[] getChannelNames() { return channelNames; }
|
||||
|
||||
@Override
|
||||
public String[] getModeNames() { return modeNames; }
|
||||
|
||||
@Override
|
||||
public ATN getATN() { return _ATN; }
|
||||
|
||||
public static final String _serializedATN =
|
||||
"\u0004\u0000\'\u0171\u0006\uffff\uffff\u0002\u0000\u0007\u0000\u0002\u0001"+
|
||||
"\u0007\u0001\u0002\u0002\u0007\u0002\u0002\u0003\u0007\u0003\u0002\u0004"+
|
||||
"\u0007\u0004\u0002\u0005\u0007\u0005\u0002\u0006\u0007\u0006\u0002\u0007"+
|
||||
"\u0007\u0007\u0002\b\u0007\b\u0002\t\u0007\t\u0002\n\u0007\n\u0002\u000b"+
|
||||
"\u0007\u000b\u0002\f\u0007\f\u0002\r\u0007\r\u0002\u000e\u0007\u000e\u0002"+
|
||||
"\u000f\u0007\u000f\u0002\u0010\u0007\u0010\u0002\u0011\u0007\u0011\u0002"+
|
||||
"\u0012\u0007\u0012\u0002\u0013\u0007\u0013\u0002\u0014\u0007\u0014\u0002"+
|
||||
"\u0015\u0007\u0015\u0002\u0016\u0007\u0016\u0002\u0017\u0007\u0017\u0002"+
|
||||
"\u0018\u0007\u0018\u0002\u0019\u0007\u0019\u0002\u001a\u0007\u001a\u0002"+
|
||||
"\u001b\u0007\u001b\u0002\u001c\u0007\u001c\u0002\u001d\u0007\u001d\u0002"+
|
||||
"\u001e\u0007\u001e\u0002\u001f\u0007\u001f\u0002 \u0007 \u0002!\u0007"+
|
||||
"!\u0002\"\u0007\"\u0002#\u0007#\u0002$\u0007$\u0002%\u0007%\u0002&\u0007"+
|
||||
"&\u0002\'\u0007\'\u0002(\u0007(\u0002)\u0007)\u0002*\u0007*\u0002+\u0007"+
|
||||
"+\u0002,\u0007,\u0002-\u0007-\u0002.\u0007.\u0002/\u0007/\u00020\u0007"+
|
||||
"0\u00021\u00071\u00022\u00072\u00023\u00073\u0001\u0000\u0001\u0000\u0001"+
|
||||
"\u0000\u0001\u0000\u0001\u0000\u0001\u0000\u0001\u0001\u0001\u0001\u0001"+
|
||||
"\u0001\u0001\u0001\u0001\u0002\u0001\u0002\u0001\u0002\u0001\u0002\u0001"+
|
||||
"\u0002\u0001\u0002\u0001\u0003\u0001\u0003\u0001\u0003\u0001\u0003\u0001"+
|
||||
"\u0003\u0001\u0004\u0001\u0004\u0001\u0004\u0001\u0005\u0001\u0005\u0001"+
|
||||
"\u0005\u0001\u0005\u0001\u0005\u0001\u0006\u0001\u0006\u0001\u0006\u0001"+
|
||||
"\u0006\u0001\u0006\u0001\u0006\u0001\u0007\u0001\u0007\u0001\u0007\u0001"+
|
||||
"\u0007\u0001\u0007\u0001\u0007\u0001\b\u0001\b\u0001\b\u0001\b\u0001\b"+
|
||||
"\u0001\b\u0001\b\u0001\b\u0001\b\u0001\t\u0001\t\u0001\t\u0001\t\u0001"+
|
||||
"\t\u0001\t\u0001\t\u0001\n\u0001\n\u0001\u000b\u0001\u000b\u0001\f\u0001"+
|
||||
"\f\u0001\r\u0001\r\u0001\u000e\u0001\u000e\u0001\u000f\u0001\u000f\u0001"+
|
||||
"\u0010\u0001\u0010\u0001\u0010\u0001\u0011\u0001\u0011\u0001\u0011\u0001"+
|
||||
"\u0012\u0001\u0012\u0001\u0013\u0001\u0013\u0001\u0013\u0001\u0014\u0001"+
|
||||
"\u0014\u0001\u0015\u0001\u0015\u0001\u0015\u0001\u0016\u0001\u0016\u0001"+
|
||||
"\u0017\u0001\u0017\u0001\u0017\u0001\u0018\u0001\u0018\u0001\u0018\u0001"+
|
||||
"\u0019\u0001\u0019\u0001\u001a\u0001\u001a\u0001\u001b\u0001\u001b\u0001"+
|
||||
"\u001c\u0001\u001c\u0001\u001d\u0001\u001d\u0001\u001e\u0001\u001e\u0001"+
|
||||
"\u001f\u0001\u001f\u0001 \u0001 \u0001!\u0001!\u0005!\u00d9\b!\n!\f!\u00dc"+
|
||||
"\t!\u0001\"\u0001\"\u0001#\u0001#\u0001$\u0001$\u0001%\u0001%\u0001&\u0001"+
|
||||
"&\u0005&\u00e8\b&\n&\f&\u00eb\t&\u0001\'\u0001\'\u0005\'\u00ef\b\'\n\'"+
|
||||
"\f\'\u00f2\t\'\u0001(\u0001(\u0001(\u0004(\u00f7\b(\u000b(\f(\u00f8\u0001"+
|
||||
")\u0004)\u00fc\b)\u000b)\f)\u00fd\u0001)\u0001)\u0005)\u0102\b)\n)\f)"+
|
||||
"\u0105\t)\u0001)\u0001)\u0004)\u0109\b)\u000b)\f)\u010a\u0003)\u010d\b"+
|
||||
")\u0001*\u0001*\u0003*\u0111\b*\u0001*\u0004*\u0114\b*\u000b*\f*\u0115"+
|
||||
"\u0001+\u0001+\u0003+\u011a\b+\u0001+\u0001+\u0001+\u0003+\u011f\b+\u0001"+
|
||||
",\u0005,\u0122\b,\n,\f,\u0125\t,\u0001,\u0001,\u0004,\u0129\b,\u000b,"+
|
||||
"\f,\u012a\u0001,\u0004,\u012e\b,\u000b,\f,\u012f\u0001,\u0001,\u0003,"+
|
||||
"\u0134\b,\u0001-\u0001-\u0003-\u0138\b-\u0001-\u0004-\u013b\b-\u000b-"+
|
||||
"\f-\u013c\u0001.\u0001.\u0001.\u0001.\u0001.\u0001.\u0001.\u0001.\u0003"+
|
||||
".\u0147\b.\u0001/\u0001/\u0001/\u0003/\u014c\b/\u00010\u00010\u00030\u0150"+
|
||||
"\b0\u00011\u00041\u0153\b1\u000b1\f1\u0154\u00011\u00011\u00012\u0001"+
|
||||
"2\u00012\u00012\u00052\u015d\b2\n2\f2\u0160\t2\u00012\u00012\u00013\u0001"+
|
||||
"3\u00013\u00013\u00053\u0168\b3\n3\f3\u016b\t3\u00013\u00013\u00013\u0001"+
|
||||
"3\u00013\u0001\u0169\u00004\u0001\u0001\u0003\u0002\u0005\u0003\u0007"+
|
||||
"\u0004\t\u0005\u000b\u0006\r\u0007\u000f\b\u0011\t\u0013\n\u0015\u000b"+
|
||||
"\u0017\f\u0019\r\u001b\u000e\u001d\u000f\u001f\u0010!\u0011#\u0012%\u0013"+
|
||||
"\'\u0014)\u0015+\u0016-\u0017/\u00181\u00193\u001a5\u001b7\u001c9\u001d"+
|
||||
";\u001e=\u001f? A!C\"E\u0000G\u0000I\u0000K\u0000M\u0000O\u0000Q\u0000"+
|
||||
"S\u0000U\u0000W\u0000Y\u0000[\u0000]\u0000_#a$c%e&g\'\u0001\u0000\f\u0003"+
|
||||
"\u0000AZ__az\u0004\u000009AZ__az\u0001\u000009\u0001\u000019\u0001\u0000"+
|
||||
"07\u0003\u000009AFaf\u0002\u0000XXxx\u0002\u0000EEee\u0002\u0000++--\u0002"+
|
||||
"\u0000PPpp\u0003\u0000\t\n\r\r \u0002\u0000\n\n\r\r\u017c\u0000\u0001"+
|
||||
"\u0001\u0000\u0000\u0000\u0000\u0003\u0001\u0000\u0000\u0000\u0000\u0005"+
|
||||
"\u0001\u0000\u0000\u0000\u0000\u0007\u0001\u0000\u0000\u0000\u0000\t\u0001"+
|
||||
"\u0000\u0000\u0000\u0000\u000b\u0001\u0000\u0000\u0000\u0000\r\u0001\u0000"+
|
||||
"\u0000\u0000\u0000\u000f\u0001\u0000\u0000\u0000\u0000\u0011\u0001\u0000"+
|
||||
"\u0000\u0000\u0000\u0013\u0001\u0000\u0000\u0000\u0000\u0015\u0001\u0000"+
|
||||
"\u0000\u0000\u0000\u0017\u0001\u0000\u0000\u0000\u0000\u0019\u0001\u0000"+
|
||||
"\u0000\u0000\u0000\u001b\u0001\u0000\u0000\u0000\u0000\u001d\u0001\u0000"+
|
||||
"\u0000\u0000\u0000\u001f\u0001\u0000\u0000\u0000\u0000!\u0001\u0000\u0000"+
|
||||
"\u0000\u0000#\u0001\u0000\u0000\u0000\u0000%\u0001\u0000\u0000\u0000\u0000"+
|
||||
"\'\u0001\u0000\u0000\u0000\u0000)\u0001\u0000\u0000\u0000\u0000+\u0001"+
|
||||
"\u0000\u0000\u0000\u0000-\u0001\u0000\u0000\u0000\u0000/\u0001\u0000\u0000"+
|
||||
"\u0000\u00001\u0001\u0000\u0000\u0000\u00003\u0001\u0000\u0000\u0000\u0000"+
|
||||
"5\u0001\u0000\u0000\u0000\u00007\u0001\u0000\u0000\u0000\u00009\u0001"+
|
||||
"\u0000\u0000\u0000\u0000;\u0001\u0000\u0000\u0000\u0000=\u0001\u0000\u0000"+
|
||||
"\u0000\u0000?\u0001\u0000\u0000\u0000\u0000A\u0001\u0000\u0000\u0000\u0000"+
|
||||
"C\u0001\u0000\u0000\u0000\u0000_\u0001\u0000\u0000\u0000\u0000a\u0001"+
|
||||
"\u0000\u0000\u0000\u0000c\u0001\u0000\u0000\u0000\u0000e\u0001\u0000\u0000"+
|
||||
"\u0000\u0000g\u0001\u0000\u0000\u0000\u0001i\u0001\u0000\u0000\u0000\u0003"+
|
||||
"o\u0001\u0000\u0000\u0000\u0005s\u0001\u0000\u0000\u0000\u0007y\u0001"+
|
||||
"\u0000\u0000\u0000\t~\u0001\u0000\u0000\u0000\u000b\u0081\u0001\u0000"+
|
||||
"\u0000\u0000\r\u0086\u0001\u0000\u0000\u0000\u000f\u008c\u0001\u0000\u0000"+
|
||||
"\u0000\u0011\u0092\u0001\u0000\u0000\u0000\u0013\u009b\u0001\u0000\u0000"+
|
||||
"\u0000\u0015\u00a2\u0001\u0000\u0000\u0000\u0017\u00a4\u0001\u0000\u0000"+
|
||||
"\u0000\u0019\u00a6\u0001\u0000\u0000\u0000\u001b\u00a8\u0001\u0000\u0000"+
|
||||
"\u0000\u001d\u00aa\u0001\u0000\u0000\u0000\u001f\u00ac\u0001\u0000\u0000"+
|
||||
"\u0000!\u00ae\u0001\u0000\u0000\u0000#\u00b1\u0001\u0000\u0000\u0000%"+
|
||||
"\u00b4\u0001\u0000\u0000\u0000\'\u00b6\u0001\u0000\u0000\u0000)\u00b9"+
|
||||
"\u0001\u0000\u0000\u0000+\u00bb\u0001\u0000\u0000\u0000-\u00be\u0001\u0000"+
|
||||
"\u0000\u0000/\u00c0\u0001\u0000\u0000\u00001\u00c3\u0001\u0000\u0000\u0000"+
|
||||
"3\u00c6\u0001\u0000\u0000\u00005\u00c8\u0001\u0000\u0000\u00007\u00ca"+
|
||||
"\u0001\u0000\u0000\u00009\u00cc\u0001\u0000\u0000\u0000;\u00ce\u0001\u0000"+
|
||||
"\u0000\u0000=\u00d0\u0001\u0000\u0000\u0000?\u00d2\u0001\u0000\u0000\u0000"+
|
||||
"A\u00d4\u0001\u0000\u0000\u0000C\u00d6\u0001\u0000\u0000\u0000E\u00dd"+
|
||||
"\u0001\u0000\u0000\u0000G\u00df\u0001\u0000\u0000\u0000I\u00e1\u0001\u0000"+
|
||||
"\u0000\u0000K\u00e3\u0001\u0000\u0000\u0000M\u00e5\u0001\u0000\u0000\u0000"+
|
||||
"O\u00ec\u0001\u0000\u0000\u0000Q\u00f3\u0001\u0000\u0000\u0000S\u010c"+
|
||||
"\u0001\u0000\u0000\u0000U\u010e\u0001\u0000\u0000\u0000W\u011e\u0001\u0000"+
|
||||
"\u0000\u0000Y\u0133\u0001\u0000\u0000\u0000[\u0135\u0001\u0000\u0000\u0000"+
|
||||
"]\u0146\u0001\u0000\u0000\u0000_\u014b\u0001\u0000\u0000\u0000a\u014f"+
|
||||
"\u0001\u0000\u0000\u0000c\u0152\u0001\u0000\u0000\u0000e\u0158\u0001\u0000"+
|
||||
"\u0000\u0000g\u0163\u0001\u0000\u0000\u0000ij\u0005c\u0000\u0000jk\u0005"+
|
||||
"o\u0000\u0000kl\u0005n\u0000\u0000lm\u0005s\u0000\u0000mn\u0005t\u0000"+
|
||||
"\u0000n\u0002\u0001\u0000\u0000\u0000op\u0005i\u0000\u0000pq\u0005n\u0000"+
|
||||
"\u0000qr\u0005t\u0000\u0000r\u0004\u0001\u0000\u0000\u0000st\u0005f\u0000"+
|
||||
"\u0000tu\u0005l\u0000\u0000uv\u0005o\u0000\u0000vw\u0005a\u0000\u0000"+
|
||||
"wx\u0005t\u0000\u0000x\u0006\u0001\u0000\u0000\u0000yz\u0005v\u0000\u0000"+
|
||||
"z{\u0005o\u0000\u0000{|\u0005i\u0000\u0000|}\u0005d\u0000\u0000}\b\u0001"+
|
||||
"\u0000\u0000\u0000~\u007f\u0005i\u0000\u0000\u007f\u0080\u0005f\u0000"+
|
||||
"\u0000\u0080\n\u0001\u0000\u0000\u0000\u0081\u0082\u0005e\u0000\u0000"+
|
||||
"\u0082\u0083\u0005l\u0000\u0000\u0083\u0084\u0005s\u0000\u0000\u0084\u0085"+
|
||||
"\u0005e\u0000\u0000\u0085\f\u0001\u0000\u0000\u0000\u0086\u0087\u0005"+
|
||||
"w\u0000\u0000\u0087\u0088\u0005h\u0000\u0000\u0088\u0089\u0005i\u0000"+
|
||||
"\u0000\u0089\u008a\u0005l\u0000\u0000\u008a\u008b\u0005e\u0000\u0000\u008b"+
|
||||
"\u000e\u0001\u0000\u0000\u0000\u008c\u008d\u0005b\u0000\u0000\u008d\u008e"+
|
||||
"\u0005r\u0000\u0000\u008e\u008f\u0005e\u0000\u0000\u008f\u0090\u0005a"+
|
||||
"\u0000\u0000\u0090\u0091\u0005k\u0000\u0000\u0091\u0010\u0001\u0000\u0000"+
|
||||
"\u0000\u0092\u0093\u0005c\u0000\u0000\u0093\u0094\u0005o\u0000\u0000\u0094"+
|
||||
"\u0095\u0005n\u0000\u0000\u0095\u0096\u0005t\u0000\u0000\u0096\u0097\u0005"+
|
||||
"i\u0000\u0000\u0097\u0098\u0005n\u0000\u0000\u0098\u0099\u0005u\u0000"+
|
||||
"\u0000\u0099\u009a\u0005e\u0000\u0000\u009a\u0012\u0001\u0000\u0000\u0000"+
|
||||
"\u009b\u009c\u0005r\u0000\u0000\u009c\u009d\u0005e\u0000\u0000\u009d\u009e"+
|
||||
"\u0005t\u0000\u0000\u009e\u009f\u0005u\u0000\u0000\u009f\u00a0\u0005r"+
|
||||
"\u0000\u0000\u00a0\u00a1\u0005n\u0000\u0000\u00a1\u0014\u0001\u0000\u0000"+
|
||||
"\u0000\u00a2\u00a3\u0005+\u0000\u0000\u00a3\u0016\u0001\u0000\u0000\u0000"+
|
||||
"\u00a4\u00a5\u0005-\u0000\u0000\u00a5\u0018\u0001\u0000\u0000\u0000\u00a6"+
|
||||
"\u00a7\u0005*\u0000\u0000\u00a7\u001a\u0001\u0000\u0000\u0000\u00a8\u00a9"+
|
||||
"\u0005/\u0000\u0000\u00a9\u001c\u0001\u0000\u0000\u0000\u00aa\u00ab\u0005"+
|
||||
"%\u0000\u0000\u00ab\u001e\u0001\u0000\u0000\u0000\u00ac\u00ad\u0005=\u0000"+
|
||||
"\u0000\u00ad \u0001\u0000\u0000\u0000\u00ae\u00af\u0005=\u0000\u0000\u00af"+
|
||||
"\u00b0\u0005=\u0000\u0000\u00b0\"\u0001\u0000\u0000\u0000\u00b1\u00b2"+
|
||||
"\u0005!\u0000\u0000\u00b2\u00b3\u0005=\u0000\u0000\u00b3$\u0001\u0000"+
|
||||
"\u0000\u0000\u00b4\u00b5\u0005<\u0000\u0000\u00b5&\u0001\u0000\u0000\u0000"+
|
||||
"\u00b6\u00b7\u0005<\u0000\u0000\u00b7\u00b8\u0005=\u0000\u0000\u00b8("+
|
||||
"\u0001\u0000\u0000\u0000\u00b9\u00ba\u0005>\u0000\u0000\u00ba*\u0001\u0000"+
|
||||
"\u0000\u0000\u00bb\u00bc\u0005>\u0000\u0000\u00bc\u00bd\u0005=\u0000\u0000"+
|
||||
"\u00bd,\u0001\u0000\u0000\u0000\u00be\u00bf\u0005!\u0000\u0000\u00bf."+
|
||||
"\u0001\u0000\u0000\u0000\u00c0\u00c1\u0005&\u0000\u0000\u00c1\u00c2\u0005"+
|
||||
"&\u0000\u0000\u00c20\u0001\u0000\u0000\u0000\u00c3\u00c4\u0005|\u0000"+
|
||||
"\u0000\u00c4\u00c5\u0005|\u0000\u0000\u00c52\u0001\u0000\u0000\u0000\u00c6"+
|
||||
"\u00c7\u0005(\u0000\u0000\u00c74\u0001\u0000\u0000\u0000\u00c8\u00c9\u0005"+
|
||||
")\u0000\u0000\u00c96\u0001\u0000\u0000\u0000\u00ca\u00cb\u0005[\u0000"+
|
||||
"\u0000\u00cb8\u0001\u0000\u0000\u0000\u00cc\u00cd\u0005]\u0000\u0000\u00cd"+
|
||||
":\u0001\u0000\u0000\u0000\u00ce\u00cf\u0005{\u0000\u0000\u00cf<\u0001"+
|
||||
"\u0000\u0000\u0000\u00d0\u00d1\u0005}\u0000\u0000\u00d1>\u0001\u0000\u0000"+
|
||||
"\u0000\u00d2\u00d3\u0005,\u0000\u0000\u00d3@\u0001\u0000\u0000\u0000\u00d4"+
|
||||
"\u00d5\u0005;\u0000\u0000\u00d5B\u0001\u0000\u0000\u0000\u00d6\u00da\u0007"+
|
||||
"\u0000\u0000\u0000\u00d7\u00d9\u0007\u0001\u0000\u0000\u00d8\u00d7\u0001"+
|
||||
"\u0000\u0000\u0000\u00d9\u00dc\u0001\u0000\u0000\u0000\u00da\u00d8\u0001"+
|
||||
"\u0000\u0000\u0000\u00da\u00db\u0001\u0000\u0000\u0000\u00dbD\u0001\u0000"+
|
||||
"\u0000\u0000\u00dc\u00da\u0001\u0000\u0000\u0000\u00dd\u00de\u0007\u0002"+
|
||||
"\u0000\u0000\u00deF\u0001\u0000\u0000\u0000\u00df\u00e0\u0007\u0003\u0000"+
|
||||
"\u0000\u00e0H\u0001\u0000\u0000\u0000\u00e1\u00e2\u0007\u0004\u0000\u0000"+
|
||||
"\u00e2J\u0001\u0000\u0000\u0000\u00e3\u00e4\u0007\u0005\u0000\u0000\u00e4"+
|
||||
"L\u0001\u0000\u0000\u0000\u00e5\u00e9\u0003G#\u0000\u00e6\u00e8\u0003"+
|
||||
"E\"\u0000\u00e7\u00e6\u0001\u0000\u0000\u0000\u00e8\u00eb\u0001\u0000"+
|
||||
"\u0000\u0000\u00e9\u00e7\u0001\u0000\u0000\u0000\u00e9\u00ea\u0001\u0000"+
|
||||
"\u0000\u0000\u00eaN\u0001\u0000\u0000\u0000\u00eb\u00e9\u0001\u0000\u0000"+
|
||||
"\u0000\u00ec\u00f0\u00050\u0000\u0000\u00ed\u00ef\u0003I$\u0000\u00ee"+
|
||||
"\u00ed\u0001\u0000\u0000\u0000\u00ef\u00f2\u0001\u0000\u0000\u0000\u00f0"+
|
||||
"\u00ee\u0001\u0000\u0000\u0000\u00f0\u00f1\u0001\u0000\u0000\u0000\u00f1"+
|
||||
"P\u0001\u0000\u0000\u0000\u00f2\u00f0\u0001\u0000\u0000\u0000\u00f3\u00f4"+
|
||||
"\u00050\u0000\u0000\u00f4\u00f6\u0007\u0006\u0000\u0000\u00f5\u00f7\u0003"+
|
||||
"K%\u0000\u00f6\u00f5\u0001\u0000\u0000\u0000\u00f7\u00f8\u0001\u0000\u0000"+
|
||||
"\u0000\u00f8\u00f6\u0001\u0000\u0000\u0000\u00f8\u00f9\u0001\u0000\u0000"+
|
||||
"\u0000\u00f9R\u0001\u0000\u0000\u0000\u00fa\u00fc\u0003E\"\u0000\u00fb"+
|
||||
"\u00fa\u0001\u0000\u0000\u0000\u00fc\u00fd\u0001\u0000\u0000\u0000\u00fd"+
|
||||
"\u00fb\u0001\u0000\u0000\u0000\u00fd\u00fe\u0001\u0000\u0000\u0000\u00fe"+
|
||||
"\u00ff\u0001\u0000\u0000\u0000\u00ff\u0103\u0005.\u0000\u0000\u0100\u0102"+
|
||||
"\u0003E\"\u0000\u0101\u0100\u0001\u0000\u0000\u0000\u0102\u0105\u0001"+
|
||||
"\u0000\u0000\u0000\u0103\u0101\u0001\u0000\u0000\u0000\u0103\u0104\u0001"+
|
||||
"\u0000\u0000\u0000\u0104\u010d\u0001\u0000\u0000\u0000\u0105\u0103\u0001"+
|
||||
"\u0000\u0000\u0000\u0106\u0108\u0005.\u0000\u0000\u0107\u0109\u0003E\""+
|
||||
"\u0000\u0108\u0107\u0001\u0000\u0000\u0000\u0109\u010a\u0001\u0000\u0000"+
|
||||
"\u0000\u010a\u0108\u0001\u0000\u0000\u0000\u010a\u010b\u0001\u0000\u0000"+
|
||||
"\u0000\u010b\u010d\u0001\u0000\u0000\u0000\u010c\u00fb\u0001\u0000\u0000"+
|
||||
"\u0000\u010c\u0106\u0001\u0000\u0000\u0000\u010dT\u0001\u0000\u0000\u0000"+
|
||||
"\u010e\u0110\u0007\u0007\u0000\u0000\u010f\u0111\u0007\b\u0000\u0000\u0110"+
|
||||
"\u010f\u0001\u0000\u0000\u0000\u0110\u0111\u0001\u0000\u0000\u0000\u0111"+
|
||||
"\u0113\u0001\u0000\u0000\u0000\u0112\u0114\u0003E\"\u0000\u0113\u0112"+
|
||||
"\u0001\u0000\u0000\u0000\u0114\u0115\u0001\u0000\u0000\u0000\u0115\u0113"+
|
||||
"\u0001\u0000\u0000\u0000\u0115\u0116\u0001\u0000\u0000\u0000\u0116V\u0001"+
|
||||
"\u0000\u0000\u0000\u0117\u0119\u0003S)\u0000\u0118\u011a\u0003U*\u0000"+
|
||||
"\u0119\u0118\u0001\u0000\u0000\u0000\u0119\u011a\u0001\u0000\u0000\u0000"+
|
||||
"\u011a\u011f\u0001\u0000\u0000\u0000\u011b\u011c\u0003M&\u0000\u011c\u011d"+
|
||||
"\u0003U*\u0000\u011d\u011f\u0001\u0000\u0000\u0000\u011e\u0117\u0001\u0000"+
|
||||
"\u0000\u0000\u011e\u011b\u0001\u0000\u0000\u0000\u011fX\u0001\u0000\u0000"+
|
||||
"\u0000\u0120\u0122\u0003K%\u0000\u0121\u0120\u0001\u0000\u0000\u0000\u0122"+
|
||||
"\u0125\u0001\u0000\u0000\u0000\u0123\u0121\u0001\u0000\u0000\u0000\u0123"+
|
||||
"\u0124\u0001\u0000\u0000\u0000\u0124\u0126\u0001\u0000\u0000\u0000\u0125"+
|
||||
"\u0123\u0001\u0000\u0000\u0000\u0126\u0128\u0005.\u0000\u0000\u0127\u0129"+
|
||||
"\u0003K%\u0000\u0128\u0127\u0001\u0000\u0000\u0000\u0129\u012a\u0001\u0000"+
|
||||
"\u0000\u0000\u012a\u0128\u0001\u0000\u0000\u0000\u012a\u012b\u0001\u0000"+
|
||||
"\u0000\u0000\u012b\u0134\u0001\u0000\u0000\u0000\u012c\u012e\u0003K%\u0000"+
|
||||
"\u012d\u012c\u0001\u0000\u0000\u0000\u012e\u012f\u0001\u0000\u0000\u0000"+
|
||||
"\u012f\u012d\u0001\u0000\u0000\u0000\u012f\u0130\u0001\u0000\u0000\u0000"+
|
||||
"\u0130\u0131\u0001\u0000\u0000\u0000\u0131\u0132\u0005.\u0000\u0000\u0132"+
|
||||
"\u0134\u0001\u0000\u0000\u0000\u0133\u0123\u0001\u0000\u0000\u0000\u0133"+
|
||||
"\u012d\u0001\u0000\u0000\u0000\u0134Z\u0001\u0000\u0000\u0000\u0135\u0137"+
|
||||
"\u0007\t\u0000\u0000\u0136\u0138\u0007\b\u0000\u0000\u0137\u0136\u0001"+
|
||||
"\u0000\u0000\u0000\u0137\u0138\u0001\u0000\u0000\u0000\u0138\u013a\u0001"+
|
||||
"\u0000\u0000\u0000\u0139\u013b\u0003E\"\u0000\u013a\u0139\u0001\u0000"+
|
||||
"\u0000\u0000\u013b\u013c\u0001\u0000\u0000\u0000\u013c\u013a\u0001\u0000"+
|
||||
"\u0000\u0000\u013c\u013d\u0001\u0000\u0000\u0000\u013d\\\u0001\u0000\u0000"+
|
||||
"\u0000\u013e\u013f\u00050\u0000\u0000\u013f\u0140\u0007\u0006\u0000\u0000"+
|
||||
"\u0140\u0141\u0003Y,\u0000\u0141\u0142\u0003[-\u0000\u0142\u0147\u0001"+
|
||||
"\u0000\u0000\u0000\u0143\u0144\u0003Q(\u0000\u0144\u0145\u0003[-\u0000"+
|
||||
"\u0145\u0147\u0001\u0000\u0000\u0000\u0146\u013e\u0001\u0000\u0000\u0000"+
|
||||
"\u0146\u0143\u0001\u0000\u0000\u0000\u0147^\u0001\u0000\u0000\u0000\u0148"+
|
||||
"\u014c\u0003M&\u0000\u0149\u014c\u0003O\'\u0000\u014a\u014c\u0003Q(\u0000"+
|
||||
"\u014b\u0148\u0001\u0000\u0000\u0000\u014b\u0149\u0001\u0000\u0000\u0000"+
|
||||
"\u014b\u014a\u0001\u0000\u0000\u0000\u014c`\u0001\u0000\u0000\u0000\u014d"+
|
||||
"\u0150\u0003W+\u0000\u014e\u0150\u0003].\u0000\u014f\u014d\u0001\u0000"+
|
||||
"\u0000\u0000\u014f\u014e\u0001\u0000\u0000\u0000\u0150b\u0001\u0000\u0000"+
|
||||
"\u0000\u0151\u0153\u0007\n\u0000\u0000\u0152\u0151\u0001\u0000\u0000\u0000"+
|
||||
"\u0153\u0154\u0001\u0000\u0000\u0000\u0154\u0152\u0001\u0000\u0000\u0000"+
|
||||
"\u0154\u0155\u0001\u0000\u0000\u0000\u0155\u0156\u0001\u0000\u0000\u0000"+
|
||||
"\u0156\u0157\u00061\u0000\u0000\u0157d\u0001\u0000\u0000\u0000\u0158\u0159"+
|
||||
"\u0005/\u0000\u0000\u0159\u015a\u0005/\u0000\u0000\u015a\u015e\u0001\u0000"+
|
||||
"\u0000\u0000\u015b\u015d\b\u000b\u0000\u0000\u015c\u015b\u0001\u0000\u0000"+
|
||||
"\u0000\u015d\u0160\u0001\u0000\u0000\u0000\u015e\u015c\u0001\u0000\u0000"+
|
||||
"\u0000\u015e\u015f\u0001\u0000\u0000\u0000\u015f\u0161\u0001\u0000\u0000"+
|
||||
"\u0000\u0160\u015e\u0001\u0000\u0000\u0000\u0161\u0162\u00062\u0000\u0000"+
|
||||
"\u0162f\u0001\u0000\u0000\u0000\u0163\u0164\u0005/\u0000\u0000\u0164\u0165"+
|
||||
"\u0005*\u0000\u0000\u0165\u0169\u0001\u0000\u0000\u0000\u0166\u0168\t"+
|
||||
"\u0000\u0000\u0000\u0167\u0166\u0001\u0000\u0000\u0000\u0168\u016b\u0001"+
|
||||
"\u0000\u0000\u0000\u0169\u016a\u0001\u0000\u0000\u0000\u0169\u0167\u0001"+
|
||||
"\u0000\u0000\u0000\u016a\u016c\u0001\u0000\u0000\u0000\u016b\u0169\u0001"+
|
||||
"\u0000\u0000\u0000\u016c\u016d\u0005*\u0000\u0000\u016d\u016e\u0005/\u0000"+
|
||||
"\u0000\u016e\u016f\u0001\u0000\u0000\u0000\u016f\u0170\u00063\u0000\u0000"+
|
||||
"\u0170h\u0001\u0000\u0000\u0000\u0019\u0000\u00da\u00e9\u00f0\u00f8\u00fd"+
|
||||
"\u0103\u010a\u010c\u0110\u0115\u0119\u011e\u0123\u012a\u012f\u0133\u0137"+
|
||||
"\u013c\u0146\u014b\u014f\u0154\u015e\u0169\u0001\u0006\u0000\u0000";
|
||||
public static final ATN _ATN =
|
||||
new ATNDeserializer().deserialize(_serializedATN.toCharArray());
|
||||
static {
|
||||
_decisionToDFA = new DFA[_ATN.getNumberOfDecisions()];
|
||||
for (int i = 0; i < _ATN.getNumberOfDecisions(); i++) {
|
||||
_decisionToDFA[i] = new DFA(_ATN.getDecisionState(i), i);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,515 @@
|
||||
// Generated from /root/sysy2026/nudt-compiler-cpp/src/antlr4/SysY.g4 by ANTLR 4.13.1
|
||||
import org.antlr.v4.runtime.tree.ParseTreeListener;
|
||||
|
||||
/**
|
||||
* This interface defines a complete listener for a parse tree produced by
|
||||
* {@link SysYParser}.
|
||||
*/
|
||||
public interface SysYListener extends ParseTreeListener {
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#compUnit}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterCompUnit(SysYParser.CompUnitContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#compUnit}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitCompUnit(SysYParser.CompUnitContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#decl}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterDecl(SysYParser.DeclContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#decl}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitDecl(SysYParser.DeclContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#constDecl}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterConstDecl(SysYParser.ConstDeclContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#constDecl}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitConstDecl(SysYParser.ConstDeclContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#bType}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterBType(SysYParser.BTypeContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#bType}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitBType(SysYParser.BTypeContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#constDef}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterConstDef(SysYParser.ConstDefContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#constDef}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitConstDef(SysYParser.ConstDefContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#constInitVal}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterConstInitVal(SysYParser.ConstInitValContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#constInitVal}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitConstInitVal(SysYParser.ConstInitValContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#varDecl}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterVarDecl(SysYParser.VarDeclContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#varDecl}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitVarDecl(SysYParser.VarDeclContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#varDef}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterVarDef(SysYParser.VarDefContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#varDef}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitVarDef(SysYParser.VarDefContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#initVal}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterInitVal(SysYParser.InitValContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#initVal}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitInitVal(SysYParser.InitValContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#funcDef}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterFuncDef(SysYParser.FuncDefContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#funcDef}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitFuncDef(SysYParser.FuncDefContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#funcType}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterFuncType(SysYParser.FuncTypeContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#funcType}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitFuncType(SysYParser.FuncTypeContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#funcFParams}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterFuncFParams(SysYParser.FuncFParamsContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#funcFParams}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitFuncFParams(SysYParser.FuncFParamsContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#funcFParam}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterFuncFParam(SysYParser.FuncFParamContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#funcFParam}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitFuncFParam(SysYParser.FuncFParamContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#block}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterBlock(SysYParser.BlockContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#block}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitBlock(SysYParser.BlockContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#blockItem}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterBlockItem(SysYParser.BlockItemContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#blockItem}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitBlockItem(SysYParser.BlockItemContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code assignStmt}
|
||||
* labeled alternative in {@link SysYParser#stmt}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterAssignStmt(SysYParser.AssignStmtContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code assignStmt}
|
||||
* labeled alternative in {@link SysYParser#stmt}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitAssignStmt(SysYParser.AssignStmtContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code expStmt}
|
||||
* labeled alternative in {@link SysYParser#stmt}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterExpStmt(SysYParser.ExpStmtContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code expStmt}
|
||||
* labeled alternative in {@link SysYParser#stmt}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitExpStmt(SysYParser.ExpStmtContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code blockStmt}
|
||||
* labeled alternative in {@link SysYParser#stmt}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterBlockStmt(SysYParser.BlockStmtContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code blockStmt}
|
||||
* labeled alternative in {@link SysYParser#stmt}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitBlockStmt(SysYParser.BlockStmtContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code ifStmt}
|
||||
* labeled alternative in {@link SysYParser#stmt}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterIfStmt(SysYParser.IfStmtContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code ifStmt}
|
||||
* labeled alternative in {@link SysYParser#stmt}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitIfStmt(SysYParser.IfStmtContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code whileStmt}
|
||||
* labeled alternative in {@link SysYParser#stmt}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterWhileStmt(SysYParser.WhileStmtContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code whileStmt}
|
||||
* labeled alternative in {@link SysYParser#stmt}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitWhileStmt(SysYParser.WhileStmtContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code breakStmt}
|
||||
* labeled alternative in {@link SysYParser#stmt}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterBreakStmt(SysYParser.BreakStmtContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code breakStmt}
|
||||
* labeled alternative in {@link SysYParser#stmt}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitBreakStmt(SysYParser.BreakStmtContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code continueStmt}
|
||||
* labeled alternative in {@link SysYParser#stmt}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterContinueStmt(SysYParser.ContinueStmtContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code continueStmt}
|
||||
* labeled alternative in {@link SysYParser#stmt}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitContinueStmt(SysYParser.ContinueStmtContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code returnStmt}
|
||||
* labeled alternative in {@link SysYParser#stmt}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterReturnStmt(SysYParser.ReturnStmtContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code returnStmt}
|
||||
* labeled alternative in {@link SysYParser#stmt}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitReturnStmt(SysYParser.ReturnStmtContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#exp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterExp(SysYParser.ExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#exp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitExp(SysYParser.ExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#cond}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterCond(SysYParser.CondContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#cond}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitCond(SysYParser.CondContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#lVal}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterLVal(SysYParser.LValContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#lVal}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitLVal(SysYParser.LValContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#primaryExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterPrimaryExp(SysYParser.PrimaryExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#primaryExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitPrimaryExp(SysYParser.PrimaryExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#number}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterNumber(SysYParser.NumberContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#number}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitNumber(SysYParser.NumberContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code primaryUnaryExp}
|
||||
* labeled alternative in {@link SysYParser#unaryExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterPrimaryUnaryExp(SysYParser.PrimaryUnaryExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code primaryUnaryExp}
|
||||
* labeled alternative in {@link SysYParser#unaryExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitPrimaryUnaryExp(SysYParser.PrimaryUnaryExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code callUnaryExp}
|
||||
* labeled alternative in {@link SysYParser#unaryExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterCallUnaryExp(SysYParser.CallUnaryExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code callUnaryExp}
|
||||
* labeled alternative in {@link SysYParser#unaryExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitCallUnaryExp(SysYParser.CallUnaryExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code opUnaryExp}
|
||||
* labeled alternative in {@link SysYParser#unaryExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterOpUnaryExp(SysYParser.OpUnaryExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code opUnaryExp}
|
||||
* labeled alternative in {@link SysYParser#unaryExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitOpUnaryExp(SysYParser.OpUnaryExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#unaryOp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterUnaryOp(SysYParser.UnaryOpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#unaryOp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitUnaryOp(SysYParser.UnaryOpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#funcRParams}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterFuncRParams(SysYParser.FuncRParamsContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#funcRParams}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitFuncRParams(SysYParser.FuncRParamsContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code binaryMulExp}
|
||||
* labeled alternative in {@link SysYParser#mulExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterBinaryMulExp(SysYParser.BinaryMulExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code binaryMulExp}
|
||||
* labeled alternative in {@link SysYParser#mulExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitBinaryMulExp(SysYParser.BinaryMulExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code unaryMulExp}
|
||||
* labeled alternative in {@link SysYParser#mulExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterUnaryMulExp(SysYParser.UnaryMulExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code unaryMulExp}
|
||||
* labeled alternative in {@link SysYParser#mulExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitUnaryMulExp(SysYParser.UnaryMulExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code binaryAddExp}
|
||||
* labeled alternative in {@link SysYParser#addExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterBinaryAddExp(SysYParser.BinaryAddExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code binaryAddExp}
|
||||
* labeled alternative in {@link SysYParser#addExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitBinaryAddExp(SysYParser.BinaryAddExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code mulAddExp}
|
||||
* labeled alternative in {@link SysYParser#addExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterMulAddExp(SysYParser.MulAddExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code mulAddExp}
|
||||
* labeled alternative in {@link SysYParser#addExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitMulAddExp(SysYParser.MulAddExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code addRelExp}
|
||||
* labeled alternative in {@link SysYParser#relExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterAddRelExp(SysYParser.AddRelExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code addRelExp}
|
||||
* labeled alternative in {@link SysYParser#relExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitAddRelExp(SysYParser.AddRelExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code binaryRelExp}
|
||||
* labeled alternative in {@link SysYParser#relExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterBinaryRelExp(SysYParser.BinaryRelExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code binaryRelExp}
|
||||
* labeled alternative in {@link SysYParser#relExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitBinaryRelExp(SysYParser.BinaryRelExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code binaryEqExp}
|
||||
* labeled alternative in {@link SysYParser#eqExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterBinaryEqExp(SysYParser.BinaryEqExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code binaryEqExp}
|
||||
* labeled alternative in {@link SysYParser#eqExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitBinaryEqExp(SysYParser.BinaryEqExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code relEqExp}
|
||||
* labeled alternative in {@link SysYParser#eqExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterRelEqExp(SysYParser.RelEqExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code relEqExp}
|
||||
* labeled alternative in {@link SysYParser#eqExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitRelEqExp(SysYParser.RelEqExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code eqLAndExp}
|
||||
* labeled alternative in {@link SysYParser#lAndExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterEqLAndExp(SysYParser.EqLAndExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code eqLAndExp}
|
||||
* labeled alternative in {@link SysYParser#lAndExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitEqLAndExp(SysYParser.EqLAndExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code binaryLAndExp}
|
||||
* labeled alternative in {@link SysYParser#lAndExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterBinaryLAndExp(SysYParser.BinaryLAndExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code binaryLAndExp}
|
||||
* labeled alternative in {@link SysYParser#lAndExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitBinaryLAndExp(SysYParser.BinaryLAndExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code andLOrExp}
|
||||
* labeled alternative in {@link SysYParser#lOrExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterAndLOrExp(SysYParser.AndLOrExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code andLOrExp}
|
||||
* labeled alternative in {@link SysYParser#lOrExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitAndLOrExp(SysYParser.AndLOrExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by the {@code binaryLOrExp}
|
||||
* labeled alternative in {@link SysYParser#lOrExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterBinaryLOrExp(SysYParser.BinaryLOrExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by the {@code binaryLOrExp}
|
||||
* labeled alternative in {@link SysYParser#lOrExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitBinaryLOrExp(SysYParser.BinaryLOrExpContext ctx);
|
||||
/**
|
||||
* Enter a parse tree produced by {@link SysYParser#constExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void enterConstExp(SysYParser.ConstExpContext ctx);
|
||||
/**
|
||||
* Exit a parse tree produced by {@link SysYParser#constExp}.
|
||||
* @param ctx the parse tree
|
||||
*/
|
||||
void exitConstExp(SysYParser.ConstExpContext ctx);
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,98 +1,178 @@
|
||||
// SysY 子集语法:支持形如
|
||||
// int main() { int a = 1; int b = 2; return a + b; }
|
||||
// 的最小返回表达式编译。
|
||||
|
||||
// 后续需要自行添加
|
||||
|
||||
grammar SysY;
|
||||
|
||||
/*===-------------------------------------------===*/
|
||||
/* Lexer rules */
|
||||
/*===-------------------------------------------===*/
|
||||
|
||||
INT: 'int';
|
||||
RETURN: 'return';
|
||||
|
||||
ASSIGN: '=';
|
||||
ADD: '+';
|
||||
|
||||
LPAREN: '(';
|
||||
RPAREN: ')';
|
||||
LBRACE: '{';
|
||||
RBRACE: '}';
|
||||
SEMICOLON: ';';
|
||||
|
||||
ID: [a-zA-Z_][a-zA-Z_0-9]*;
|
||||
ILITERAL: [0-9]+;
|
||||
|
||||
WS: [ \t\r\n] -> skip;
|
||||
LINECOMMENT: '//' ~[\r\n]* -> skip;
|
||||
BLOCKCOMMENT: '/*' .*? '*/' -> skip;
|
||||
|
||||
/*===-------------------------------------------===*/
|
||||
/* Syntax rules */
|
||||
/*===-------------------------------------------===*/
|
||||
|
||||
compUnit
|
||||
: funcDef EOF
|
||||
;
|
||||
|
||||
decl
|
||||
: btype varDef SEMICOLON
|
||||
;
|
||||
|
||||
btype
|
||||
: INT
|
||||
;
|
||||
|
||||
varDef
|
||||
: lValue (ASSIGN initValue)?
|
||||
;
|
||||
|
||||
initValue
|
||||
: exp
|
||||
;
|
||||
|
||||
funcDef
|
||||
: funcType ID LPAREN RPAREN blockStmt
|
||||
;
|
||||
|
||||
funcType
|
||||
: INT
|
||||
;
|
||||
|
||||
blockStmt
|
||||
: LBRACE blockItem* RBRACE
|
||||
;
|
||||
|
||||
blockItem
|
||||
: decl
|
||||
| stmt
|
||||
;
|
||||
|
||||
stmt
|
||||
: returnStmt
|
||||
;
|
||||
|
||||
returnStmt
|
||||
: RETURN exp SEMICOLON
|
||||
;
|
||||
|
||||
exp
|
||||
: LPAREN exp RPAREN # parenExp
|
||||
| var # varExp
|
||||
| number # numberExp
|
||||
| exp ADD exp # additiveExp
|
||||
;
|
||||
|
||||
var
|
||||
: ID
|
||||
;
|
||||
|
||||
lValue
|
||||
: ID
|
||||
;
|
||||
|
||||
number
|
||||
: ILITERAL
|
||||
;
|
||||
grammar SysY;
|
||||
|
||||
////Grammer
|
||||
|
||||
module: compUnit EOF;
|
||||
|
||||
compUnit: (decl | funcDef)+;
|
||||
|
||||
decl: constDecl | varDecl;
|
||||
|
||||
constDecl: CONST bType constDef (COMMA constDef)* SEMI;
|
||||
|
||||
bType: INT | FLOAT;
|
||||
|
||||
constDef: Ident (LBRACK constExp RBRACK)* ASSIGN constInitVal;
|
||||
|
||||
constInitVal:
|
||||
constExp
|
||||
| LBRACE (constInitVal (COMMA constInitVal)*)? RBRACE;
|
||||
|
||||
varDecl: bType varDef (COMMA varDef)* SEMI;
|
||||
|
||||
varDef:
|
||||
Ident (LBRACK constExp RBRACK)*
|
||||
| Ident (LBRACK constExp RBRACK)* ASSIGN initVal;
|
||||
|
||||
initVal: exp | LBRACE (initVal (COMMA initVal)*)? RBRACE;
|
||||
|
||||
funcDef: funcType Ident LPAREN funcFParams? RPAREN block;
|
||||
|
||||
funcType: VOID | INT | FLOAT;
|
||||
|
||||
funcFParams: funcFParam (COMMA funcFParam)*;
|
||||
|
||||
funcFParam:
|
||||
bType Ident
|
||||
| bType Ident LBRACK RBRACK (LBRACK exp RBRACK)*;
|
||||
|
||||
block: LBRACE blockItem* RBRACE;
|
||||
|
||||
blockItem: decl | stmt;
|
||||
|
||||
stmt:
|
||||
lVal ASSIGN exp SEMI
|
||||
| exp? SEMI
|
||||
| block
|
||||
| IF LPAREN cond RPAREN stmt (ELSE stmt)?
|
||||
| WHILE LPAREN cond RPAREN stmt
|
||||
| BREAK SEMI
|
||||
| CONTINUE SEMI
|
||||
| RETURN exp? SEMI;
|
||||
|
||||
exp: addExp;
|
||||
|
||||
cond: lOrExp;
|
||||
|
||||
lVal: Ident (LBRACK exp RBRACK)*;
|
||||
|
||||
primaryExp: LPAREN exp RPAREN | lVal | number;
|
||||
|
||||
number: IntConst | FloatConst;
|
||||
|
||||
unaryExp:
|
||||
primaryExp
|
||||
| Ident LPAREN funcRParams? RPAREN
|
||||
| unaryOp unaryExp;
|
||||
|
||||
unaryOp: ADD | SUB | NOT;
|
||||
|
||||
funcRParams: exp (COMMA exp)*;
|
||||
|
||||
mulExp:
|
||||
unaryExp
|
||||
| mulExp op = (MUL | DIV | MOD) unaryExp;
|
||||
|
||||
addExp:
|
||||
mulExp
|
||||
| addExp op = (ADD | SUB) mulExp;
|
||||
|
||||
relExp:
|
||||
addExp
|
||||
| relExp op = (LT | GT | LE | GE) addExp;
|
||||
|
||||
eqExp:
|
||||
relExp
|
||||
| eqExp op = (EQ | NE) relExp;
|
||||
|
||||
lAndExp: eqExp | lAndExp AND eqExp ;
|
||||
|
||||
lOrExp: lAndExp | lOrExp OR lAndExp ;
|
||||
|
||||
constExp: addExp;
|
||||
|
||||
////Lexer
|
||||
//keywords
|
||||
CONST: 'const';
|
||||
INT: 'int';
|
||||
FLOAT: 'float';
|
||||
VOID: 'void';
|
||||
IF: 'if';
|
||||
ELSE: 'else';
|
||||
WHILE: 'while';
|
||||
BREAK: 'break';
|
||||
CONTINUE: 'continue';
|
||||
RETURN: 'return';
|
||||
|
||||
//operators
|
||||
ADD: '+';
|
||||
SUB: '-';
|
||||
MUL: '*';
|
||||
DIV: '/';
|
||||
MOD: '%';
|
||||
ASSIGN: '=';
|
||||
EQ: '==';
|
||||
NE: '!=';
|
||||
LT: '<';
|
||||
LE: '<=';
|
||||
GT: '>';
|
||||
GE: '>=';
|
||||
NOT: '!';
|
||||
AND: '&&';
|
||||
OR: '||';
|
||||
|
||||
//括号
|
||||
LPAREN: '(';
|
||||
RPAREN: ')';
|
||||
LBRACK: '[';
|
||||
RBRACK: ']';
|
||||
LBRACE: '{';
|
||||
RBRACE: '}';
|
||||
COMMA: ',';
|
||||
SEMI: ';';
|
||||
|
||||
//标识符
|
||||
Ident: [a-zA-Z_] [a-zA-Z_0-9]*;
|
||||
|
||||
//数字常量片段
|
||||
// 十进制数字
|
||||
fragment Digit: [0-9];
|
||||
// 非零十进制数字
|
||||
fragment NonzeroDigit: [1-9];
|
||||
// 八进制数字
|
||||
fragment OctDigit: [0-7];
|
||||
// 十六进制数字
|
||||
fragment HexDigit: [0-9a-fA-F];
|
||||
// 十进制整数:非零开头,后接若干十进制数字
|
||||
fragment DecInteger: NonzeroDigit Digit*;
|
||||
// 八进制整数:以 0 开头
|
||||
fragment OctInteger: '0' OctDigit*;
|
||||
// 十六进制整数:以 0x 或 0X 开头
|
||||
fragment HexInteger: '0' [xX] HexDigit+;
|
||||
// 十进制小数部分
|
||||
fragment DecFraction: Digit+ '.' Digit* | '.' Digit+;
|
||||
// 十进制指数部分
|
||||
fragment DecExponent: [eE] [+\-]? Digit+;
|
||||
// 十进制浮点数
|
||||
fragment DecFloat:
|
||||
DecFraction DecExponent?
|
||||
| DecInteger DecExponent;
|
||||
// 十六进制小数部分
|
||||
fragment HexFraction: HexDigit* '.' HexDigit+ | HexDigit+ '.';
|
||||
// 十六进制浮点数的二进制指数部分
|
||||
fragment BinExponent: [pP] [+\-]? Digit+;
|
||||
// 十六进制浮点数
|
||||
fragment HexFloat:
|
||||
'0' [xX] HexFraction BinExponent
|
||||
| HexInteger BinExponent;
|
||||
//整型常量
|
||||
IntConst: DecInteger | OctInteger | HexInteger;
|
||||
//浮点常量
|
||||
FloatConst: DecFloat | HexFloat;
|
||||
|
||||
//空白符规则
|
||||
WS: [ \t\r\n]+ -> skip;
|
||||
// 单行注释
|
||||
LINE_COMMENT: '//' ~[\r\n]* -> skip;
|
||||
// 跨行注释
|
||||
BLOCK_COMMENT: '/*' .*? '*/' -> skip;
|
||||
@ -1,11 +1,12 @@
|
||||
// GlobalValue 占位实现:
|
||||
// - 具体的全局初始化器、打印和链接语义需要自行补全
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
namespace ir {
|
||||
|
||||
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
|
||||
: User(std::move(ty), std::move(name)) {}
|
||||
GlobalValue::GlobalValue(std::shared_ptr<Type> object_type,
|
||||
const std::string& name, bool is_const, Value* init)
|
||||
: User(Type::GetPointerType(object_type), name),
|
||||
object_type_(std::move(object_type)),
|
||||
is_const_(is_const),
|
||||
init_(init) {}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace ir
|
||||
|
||||
@ -1,89 +1,213 @@
|
||||
// IR 构建工具:
|
||||
// - 管理插入点(当前基本块/位置)
|
||||
// - 提供创建各类指令的便捷接口,降低 IRGen 复杂度
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
#include "utils/Log.h"
|
||||
|
||||
namespace ir {
|
||||
IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb)
|
||||
: ctx_(ctx), insert_block_(bb) {}
|
||||
namespace {
|
||||
|
||||
BasicBlock* RequireInsertBlock(BasicBlock* block) {
|
||||
if (!block) {
|
||||
throw std::runtime_error("IRBuilder has no insert block");
|
||||
}
|
||||
return block;
|
||||
}
|
||||
|
||||
bool IsFloatBinaryOp(Opcode op) {
|
||||
return op == Opcode::FAdd || op == Opcode::FSub || op == Opcode::FMul ||
|
||||
op == Opcode::FDiv || op == Opcode::FRem || op == Opcode::FCmpEQ ||
|
||||
op == Opcode::FCmpNE || op == Opcode::FCmpLT || op == Opcode::FCmpGT ||
|
||||
op == Opcode::FCmpLE || op == Opcode::FCmpGE;
|
||||
}
|
||||
|
||||
bool IsCompareOp(Opcode op) {
|
||||
return (op >= Opcode::ICmpEQ && op <= Opcode::ICmpGE) ||
|
||||
(op >= Opcode::FCmpEQ && op <= Opcode::FCmpGE);
|
||||
}
|
||||
|
||||
std::shared_ptr<Type> ResultTypeForBinary(Opcode op, Value* lhs) {
|
||||
if (IsCompareOp(op)) {
|
||||
return Type::GetInt1Type();
|
||||
}
|
||||
if (IsFloatBinaryOp(op)) {
|
||||
return Type::GetFloatType();
|
||||
}
|
||||
return lhs->GetType();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
IRBuilder::IRBuilder(Context& ctx, BasicBlock* bb) : ctx_(ctx), insert_block_(bb) {}
|
||||
|
||||
void IRBuilder::SetInsertPoint(BasicBlock* bb) { insert_block_ = bb; }
|
||||
|
||||
BasicBlock* IRBuilder::GetInsertBlock() const { return insert_block_; }
|
||||
ConstantInt* IRBuilder::CreateConstInt(int v) { return ctx_.GetConstInt(v); }
|
||||
|
||||
ConstantFloat* IRBuilder::CreateConstFloat(float v) {
|
||||
return new ConstantFloat(Type::GetFloatType(), v);
|
||||
}
|
||||
|
||||
ConstantI1* IRBuilder::CreateConstBool(bool v) { return ctx_.GetConstBool(v); }
|
||||
|
||||
ConstantInt* IRBuilder::CreateConstInt(int v) {
|
||||
// 常量不需要挂在基本块里,由 Context 负责去重与生命周期。
|
||||
return ctx_.GetConstInt(v);
|
||||
ConstantArrayValue* IRBuilder::CreateConstArray(std::shared_ptr<Type> array_type,
|
||||
const std::vector<Value*>& elements,
|
||||
const std::vector<size_t>& dims,
|
||||
const std::string& name) {
|
||||
return new ConstantArrayValue(std::move(array_type), elements, dims, name);
|
||||
}
|
||||
|
||||
BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
if (!lhs) {
|
||||
throw std::runtime_error(
|
||||
FormatError("ir", "IRBuilder::CreateBinary 缺少 lhs"));
|
||||
}
|
||||
if (!rhs) {
|
||||
throw std::runtime_error(
|
||||
FormatError("ir", "IRBuilder::CreateBinary 缺少 rhs"));
|
||||
}
|
||||
return insert_block_->Append<BinaryInst>(op, lhs->GetType(), lhs, rhs, name);
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
return block->Append<BinaryInst>(op, ResultTypeForBinary(op, lhs), lhs, rhs,
|
||||
nullptr, name);
|
||||
}
|
||||
|
||||
BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs, const std::string& name) {
|
||||
return CreateBinary(Opcode::Add, lhs, rhs, name);
|
||||
}
|
||||
BinaryInst* IRBuilder::CreateSub(Value* lhs, Value* rhs, const std::string& name) {
|
||||
return CreateBinary(Opcode::Sub, lhs, rhs, name);
|
||||
}
|
||||
BinaryInst* IRBuilder::CreateMul(Value* lhs, Value* rhs, const std::string& name) {
|
||||
return CreateBinary(Opcode::Mul, lhs, rhs, name);
|
||||
}
|
||||
BinaryInst* IRBuilder::CreateDiv(Value* lhs, Value* rhs, const std::string& name) {
|
||||
return CreateBinary(Opcode::Div, lhs, rhs, name);
|
||||
}
|
||||
BinaryInst* IRBuilder::CreateRem(Value* lhs, Value* rhs, const std::string& name) {
|
||||
return CreateBinary(Opcode::Rem, lhs, rhs, name);
|
||||
}
|
||||
BinaryInst* IRBuilder::CreateAnd(Value* lhs, Value* rhs, const std::string& name) {
|
||||
return CreateBinary(Opcode::And, lhs, rhs, name);
|
||||
}
|
||||
BinaryInst* IRBuilder::CreateOr(Value* lhs, Value* rhs, const std::string& name) {
|
||||
return CreateBinary(Opcode::Or, lhs, rhs, name);
|
||||
}
|
||||
BinaryInst* IRBuilder::CreateXor(Value* lhs, Value* rhs, const std::string& name) {
|
||||
return CreateBinary(Opcode::Xor, lhs, rhs, name);
|
||||
}
|
||||
BinaryInst* IRBuilder::CreateShl(Value* lhs, Value* rhs, const std::string& name) {
|
||||
return CreateBinary(Opcode::Shl, lhs, rhs, name);
|
||||
}
|
||||
BinaryInst* IRBuilder::CreateAShr(Value* lhs, Value* rhs, const std::string& name) {
|
||||
return CreateBinary(Opcode::AShr, lhs, rhs, name);
|
||||
}
|
||||
BinaryInst* IRBuilder::CreateLShr(Value* lhs, Value* rhs, const std::string& name) {
|
||||
return CreateBinary(Opcode::LShr, lhs, rhs, name);
|
||||
}
|
||||
BinaryInst* IRBuilder::CreateICmp(Opcode op, Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
return CreateBinary(op, lhs, rhs, name);
|
||||
}
|
||||
BinaryInst* IRBuilder::CreateFCmp(Opcode op, Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
return CreateBinary(op, lhs, rhs, name);
|
||||
}
|
||||
|
||||
AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
return insert_block_->Append<AllocaInst>(Type::GetPtrInt32Type(), name);
|
||||
UnaryInst* IRBuilder::CreateNeg(Value* operand, const std::string& name) {
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
return block->Append<UnaryInst>(Opcode::Neg, operand->GetType(), operand, nullptr,
|
||||
name);
|
||||
}
|
||||
UnaryInst* IRBuilder::CreateNot(Value* operand, const std::string& name) {
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
return block->Append<UnaryInst>(Opcode::Not, operand->GetType(), operand, nullptr,
|
||||
name);
|
||||
}
|
||||
UnaryInst* IRBuilder::CreateFNeg(Value* operand, const std::string& name) {
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
return block->Append<UnaryInst>(Opcode::FNeg, operand->GetType(), operand,
|
||||
nullptr, name);
|
||||
}
|
||||
UnaryInst* IRBuilder::CreateFtoI(Value* operand, const std::string& name) {
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
return block->Append<UnaryInst>(Opcode::FtoI, Type::GetInt32Type(), operand,
|
||||
nullptr, name);
|
||||
}
|
||||
UnaryInst* IRBuilder::CreateIToF(Value* operand, const std::string& name) {
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
return block->Append<UnaryInst>(Opcode::IToF, Type::GetFloatType(), operand,
|
||||
nullptr, name);
|
||||
}
|
||||
|
||||
LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
if (!ptr) {
|
||||
throw std::runtime_error(
|
||||
FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr"));
|
||||
}
|
||||
return insert_block_->Append<LoadInst>(Type::GetInt32Type(), ptr, name);
|
||||
AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr<Type> allocated_type,
|
||||
const std::string& name) {
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
return block->Append<AllocaInst>(std::move(allocated_type), nullptr, name);
|
||||
}
|
||||
|
||||
LoadInst* IRBuilder::CreateLoad(Value* ptr, std::shared_ptr<Type> value_type,
|
||||
const std::string& name) {
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
return block->Append<LoadInst>(std::move(value_type), ptr, nullptr, name);
|
||||
}
|
||||
|
||||
StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
if (!val) {
|
||||
throw std::runtime_error(
|
||||
FormatError("ir", "IRBuilder::CreateStore 缺少 val"));
|
||||
}
|
||||
if (!ptr) {
|
||||
throw std::runtime_error(
|
||||
FormatError("ir", "IRBuilder::CreateStore 缺少 ptr"));
|
||||
}
|
||||
return insert_block_->Append<StoreInst>(Type::GetVoidType(), val, ptr);
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
return block->Append<StoreInst>(val, ptr, nullptr);
|
||||
}
|
||||
|
||||
ReturnInst* IRBuilder::CreateRet(Value* v) {
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
if (!v) {
|
||||
throw std::runtime_error(
|
||||
FormatError("ir", "IRBuilder::CreateRet 缺少返回值"));
|
||||
}
|
||||
return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v);
|
||||
UncondBrInst* IRBuilder::CreateBr(BasicBlock* dest) {
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
auto* inst = block->Append<UncondBrInst>(dest, nullptr);
|
||||
block->AddSuccessor(dest);
|
||||
dest->AddPredecessor(block);
|
||||
return inst;
|
||||
}
|
||||
|
||||
CondBrInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* then_bb,
|
||||
BasicBlock* else_bb) {
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
auto* inst = block->Append<CondBrInst>(cond, then_bb, else_bb, nullptr);
|
||||
block->AddSuccessor(then_bb);
|
||||
block->AddSuccessor(else_bb);
|
||||
then_bb->AddPredecessor(block);
|
||||
else_bb->AddPredecessor(block);
|
||||
return inst;
|
||||
}
|
||||
|
||||
ReturnInst* IRBuilder::CreateRet(Value* val) {
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
return block->Append<ReturnInst>(val, nullptr);
|
||||
}
|
||||
|
||||
UnreachableInst* IRBuilder::CreateUnreachable() {
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
return block->Append<UnreachableInst>(nullptr);
|
||||
}
|
||||
|
||||
CallInst* IRBuilder::CreateCall(Function* callee, const std::vector<Value*>& args,
|
||||
const std::string& name) {
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
std::string real_name = callee->GetReturnType()->IsVoid() ? std::string() : name;
|
||||
return block->Append<CallInst>(callee, args, nullptr, real_name);
|
||||
}
|
||||
|
||||
GetElementPtrInst* IRBuilder::CreateGEP(Value* ptr,
|
||||
std::shared_ptr<Type> source_type,
|
||||
const std::vector<Value*>& indices,
|
||||
const std::string& name) {
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
return block->Append<GetElementPtrInst>(std::move(source_type), ptr, indices,
|
||||
nullptr, name);
|
||||
}
|
||||
|
||||
PhiInst* IRBuilder::CreatePhi(std::shared_ptr<Type> type,
|
||||
const std::string& name) {
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
return block->Append<PhiInst>(std::move(type), nullptr, name);
|
||||
}
|
||||
|
||||
ZextInst* IRBuilder::CreateZext(Value* val, std::shared_ptr<Type> target_type,
|
||||
const std::string& name) {
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
return block->Append<ZextInst>(val, std::move(target_type), nullptr, name);
|
||||
}
|
||||
|
||||
MemsetInst* IRBuilder::CreateMemset(Value* dst, Value* val, Value* len,
|
||||
Value* is_volatile) {
|
||||
auto* block = RequireInsertBlock(insert_block_);
|
||||
return block->Append<MemsetInst>(dst, val, len, is_volatile, nullptr);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace ir
|
||||
|
||||
@ -1,21 +1,45 @@
|
||||
// 保存函数列表并提供模块级上下文访问。
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
namespace ir {
|
||||
|
||||
Context& Module::GetContext() { return context_; }
|
||||
Function* Module::CreateFunction(
|
||||
const std::string& name, std::shared_ptr<Type> ret_type,
|
||||
const std::vector<std::shared_ptr<Type>>& param_types,
|
||||
const std::vector<std::string>& param_names, bool is_external) {
|
||||
if (auto* existing = GetFunction(name)) {
|
||||
existing->SetExternal(existing->IsExternal() && is_external);
|
||||
return existing;
|
||||
}
|
||||
auto func = std::make_unique<Function>(name, std::move(ret_type), param_types,
|
||||
param_names, is_external);
|
||||
auto* ptr = func.get();
|
||||
functions_.push_back(std::move(func));
|
||||
function_map_[name] = ptr;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
const Context& Module::GetContext() const { return context_; }
|
||||
Function* Module::GetFunction(const std::string& name) const {
|
||||
auto it = function_map_.find(name);
|
||||
return it == function_map_.end() ? nullptr : it->second;
|
||||
}
|
||||
|
||||
Function* Module::CreateFunction(const std::string& name,
|
||||
std::shared_ptr<Type> ret_type) {
|
||||
functions_.push_back(std::make_unique<Function>(name, std::move(ret_type)));
|
||||
return functions_.back().get();
|
||||
GlobalValue* Module::CreateGlobalValue(const std::string& name,
|
||||
std::shared_ptr<Type> object_type,
|
||||
bool is_const, Value* init) {
|
||||
if (auto* existing = GetGlobalValue(name)) {
|
||||
return existing;
|
||||
}
|
||||
auto global =
|
||||
std::make_unique<GlobalValue>(std::move(object_type), name, is_const, init);
|
||||
auto* ptr = global.get();
|
||||
globals_.push_back(std::move(global));
|
||||
global_map_[name] = ptr;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const {
|
||||
return functions_;
|
||||
GlobalValue* Module::GetGlobalValue(const std::string& name) const {
|
||||
auto it = global_map_.find(name);
|
||||
return it == global_map_.end() ? nullptr : it->second;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace ir
|
||||
|
||||
@ -1,31 +1,111 @@
|
||||
// 当前仅支持 void、i32 和 i32*。
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <ostream>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace ir {
|
||||
|
||||
Type::Type(Kind k) : kind_(k) {}
|
||||
Type::Type(Kind kind) : kind_(kind) {}
|
||||
|
||||
Type::Type(Kind kind, std::shared_ptr<Type> element_type, size_t num_elements)
|
||||
: kind_(kind),
|
||||
element_type_(std::move(element_type)),
|
||||
num_elements_(num_elements) {}
|
||||
|
||||
const std::shared_ptr<Type>& Type::GetVoidType() {
|
||||
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Void);
|
||||
static const auto type = std::make_shared<Type>(Kind::Void);
|
||||
return type;
|
||||
}
|
||||
|
||||
const std::shared_ptr<Type>& Type::GetInt1Type() {
|
||||
static const auto type = std::make_shared<Type>(Kind::Int1);
|
||||
return type;
|
||||
}
|
||||
|
||||
const std::shared_ptr<Type>& Type::GetInt32Type() {
|
||||
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int32);
|
||||
static const auto type = std::make_shared<Type>(Kind::Int32);
|
||||
return type;
|
||||
}
|
||||
|
||||
const std::shared_ptr<Type>& Type::GetPtrInt32Type() {
|
||||
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::PtrInt32);
|
||||
const std::shared_ptr<Type>& Type::GetFloatType() {
|
||||
static const auto type = std::make_shared<Type>(Kind::Float);
|
||||
return type;
|
||||
}
|
||||
|
||||
const std::shared_ptr<Type>& Type::GetLabelType() {
|
||||
static const auto type = std::make_shared<Type>(Kind::Label);
|
||||
return type;
|
||||
}
|
||||
|
||||
Type::Kind Type::GetKind() const { return kind_; }
|
||||
const std::shared_ptr<Type>& Type::GetBoolType() { return GetInt1Type(); }
|
||||
|
||||
bool Type::IsVoid() const { return kind_ == Kind::Void; }
|
||||
std::shared_ptr<Type> Type::GetPointerType(std::shared_ptr<Type> pointee) {
|
||||
return std::make_shared<Type>(Kind::Pointer, std::move(pointee));
|
||||
}
|
||||
|
||||
bool Type::IsInt32() const { return kind_ == Kind::Int32; }
|
||||
const std::shared_ptr<Type>& Type::GetPtrInt32Type() {
|
||||
static const auto type = std::make_shared<Type>(Kind::Pointer);
|
||||
return type;
|
||||
}
|
||||
|
||||
bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; }
|
||||
std::shared_ptr<Type> Type::GetArrayType(std::shared_ptr<Type> element_type,
|
||||
size_t num_elements) {
|
||||
return std::make_shared<Type>(Kind::Array, std::move(element_type), num_elements);
|
||||
}
|
||||
|
||||
int Type::GetSize() const {
|
||||
switch (kind_) {
|
||||
case Kind::Void:
|
||||
case Kind::Label:
|
||||
case Kind::Function:
|
||||
return 0;
|
||||
case Kind::Int1:
|
||||
return 1;
|
||||
case Kind::Int32:
|
||||
case Kind::Float:
|
||||
return 4;
|
||||
case Kind::Pointer:
|
||||
return 8;
|
||||
case Kind::Array:
|
||||
return static_cast<int>(num_elements_) *
|
||||
(element_type_ ? element_type_->GetSize() : 0);
|
||||
}
|
||||
throw std::runtime_error("unknown IR type kind");
|
||||
}
|
||||
|
||||
void Type::Print(std::ostream& os) const {
|
||||
switch (kind_) {
|
||||
case Kind::Void:
|
||||
os << "void";
|
||||
return;
|
||||
case Kind::Int1:
|
||||
os << "i1";
|
||||
return;
|
||||
case Kind::Int32:
|
||||
os << "i32";
|
||||
return;
|
||||
case Kind::Float:
|
||||
os << "float";
|
||||
return;
|
||||
case Kind::Label:
|
||||
os << "label";
|
||||
return;
|
||||
case Kind::Function:
|
||||
os << "fn";
|
||||
return;
|
||||
case Kind::Pointer:
|
||||
os << "ptr";
|
||||
return;
|
||||
case Kind::Array:
|
||||
os << "[" << num_elements_ << " x ";
|
||||
if (element_type_) {
|
||||
element_type_->Print(os);
|
||||
} else {
|
||||
os << "void";
|
||||
}
|
||||
os << "]";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace ir
|
||||
|
||||
@ -1,4 +1,167 @@
|
||||
// 支配树分析:
|
||||
// - 构建/查询 Dominator Tree 及相关关系
|
||||
// - 为 mem2reg、CFG 优化与循环分析提供基础能力
|
||||
#include "ir/Analysis.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
std::vector<BasicBlock*> BuildReversePostOrder(Function& function) {
|
||||
std::vector<BasicBlock*> post_order;
|
||||
auto* entry = function.GetEntryBlock();
|
||||
if (!entry) {
|
||||
return post_order;
|
||||
}
|
||||
|
||||
std::unordered_set<BasicBlock*> visited;
|
||||
std::function<void(BasicBlock*)> dfs = [&](BasicBlock* block) {
|
||||
if (!block || !visited.insert(block).second) {
|
||||
return;
|
||||
}
|
||||
for (auto* succ : block->GetSuccessors()) {
|
||||
dfs(succ);
|
||||
}
|
||||
post_order.push_back(block);
|
||||
};
|
||||
dfs(entry);
|
||||
std::reverse(post_order.begin(), post_order.end());
|
||||
return post_order;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
DominatorTree::DominatorTree(Function& function) : function_(&function) {
|
||||
Recalculate();
|
||||
}
|
||||
|
||||
void DominatorTree::Recalculate() {
|
||||
reverse_post_order_ = BuildReversePostOrder(*function_);
|
||||
block_index_.clear();
|
||||
dominates_.clear();
|
||||
immediate_dominator_.clear();
|
||||
dom_children_.clear();
|
||||
|
||||
const auto num_blocks = reverse_post_order_.size();
|
||||
for (std::size_t i = 0; i < num_blocks; ++i) {
|
||||
block_index_.emplace(reverse_post_order_[i], i);
|
||||
}
|
||||
if (num_blocks == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dominates_.assign(num_blocks, std::vector<std::uint8_t>(num_blocks, 1));
|
||||
dominates_[0].assign(num_blocks, 0);
|
||||
dominates_[0][0] = 1;
|
||||
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
for (std::size_t i = 1; i < num_blocks; ++i) {
|
||||
auto* block = reverse_post_order_[i];
|
||||
std::vector<std::uint8_t> next(num_blocks, 1);
|
||||
bool has_reachable_pred = false;
|
||||
for (auto* pred : block->GetPredecessors()) {
|
||||
auto pred_it = block_index_.find(pred);
|
||||
if (pred_it == block_index_.end()) {
|
||||
continue;
|
||||
}
|
||||
has_reachable_pred = true;
|
||||
const auto& pred_dom = dominates_[pred_it->second];
|
||||
for (std::size_t bit = 0; bit < num_blocks; ++bit) {
|
||||
next[bit] &= pred_dom[bit];
|
||||
}
|
||||
}
|
||||
if (!has_reachable_pred) {
|
||||
next.assign(num_blocks, 0);
|
||||
}
|
||||
next[i] = 1;
|
||||
if (next != dominates_[i]) {
|
||||
dominates_[i] = std::move(next);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::size_t> dom_depth(num_blocks, 0);
|
||||
for (std::size_t i = 0; i < num_blocks; ++i) {
|
||||
for (std::size_t candidate = 0; candidate < num_blocks; ++candidate) {
|
||||
if (dominates_[i][candidate]) {
|
||||
++dom_depth[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (std::size_t i = 1; i < num_blocks; ++i) {
|
||||
auto* block = reverse_post_order_[i];
|
||||
BasicBlock* idom = nullptr;
|
||||
std::size_t best_depth = 0;
|
||||
for (std::size_t candidate = 0; candidate < num_blocks; ++candidate) {
|
||||
if (candidate == i || !dominates_[i][candidate]) {
|
||||
continue;
|
||||
}
|
||||
auto* candidate_block = reverse_post_order_[candidate];
|
||||
if (idom == nullptr || dom_depth[candidate] > best_depth) {
|
||||
idom = candidate_block;
|
||||
best_depth = dom_depth[candidate];
|
||||
}
|
||||
}
|
||||
immediate_dominator_.emplace(block, idom);
|
||||
if (idom) {
|
||||
dom_children_[idom].push_back(block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool DominatorTree::IsReachable(BasicBlock* block) const {
|
||||
return block != nullptr && block_index_.find(block) != block_index_.end();
|
||||
}
|
||||
|
||||
bool DominatorTree::Dominates(BasicBlock* dom, BasicBlock* node) const {
|
||||
if (!dom || !node) {
|
||||
return false;
|
||||
}
|
||||
const auto dom_it = block_index_.find(dom);
|
||||
const auto node_it = block_index_.find(node);
|
||||
if (dom_it == block_index_.end() || node_it == block_index_.end()) {
|
||||
return false;
|
||||
}
|
||||
return dominates_[node_it->second][dom_it->second] != 0;
|
||||
}
|
||||
|
||||
bool DominatorTree::Dominates(Instruction* dom, Instruction* user) const {
|
||||
if (!dom || !user) {
|
||||
return false;
|
||||
}
|
||||
if (dom == user) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto* dom_block = dom->GetParent();
|
||||
auto* user_block = user->GetParent();
|
||||
if (dom_block != user_block) {
|
||||
return Dominates(dom_block, user_block);
|
||||
}
|
||||
|
||||
for (const auto& inst_ptr : dom_block->GetInstructions()) {
|
||||
if (inst_ptr.get() == dom) {
|
||||
return true;
|
||||
}
|
||||
if (inst_ptr.get() == user) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
BasicBlock* DominatorTree::GetIDom(BasicBlock* block) const {
|
||||
auto it = immediate_dominator_.find(block);
|
||||
return it == immediate_dominator_.end() ? nullptr : it->second;
|
||||
}
|
||||
|
||||
const std::vector<BasicBlock*>& DominatorTree::GetChildren(BasicBlock* block) const {
|
||||
static const std::vector<BasicBlock*> kEmpty;
|
||||
auto it = dom_children_.find(block);
|
||||
return it == dom_children_.end() ? kEmpty : it->second;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
@ -1,4 +1,214 @@
|
||||
// 循环分析:
|
||||
// - 识别循环结构与层级关系
|
||||
// - 为后续优化(可选)提供循环信息
|
||||
#include "ir/Analysis.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
std::vector<BasicBlock*> CollectNaturalLoopBlocks(BasicBlock* header,
|
||||
BasicBlock* latch) {
|
||||
std::vector<BasicBlock*> stack{latch};
|
||||
std::unordered_set<BasicBlock*> loop_blocks{header, latch};
|
||||
while (!stack.empty()) {
|
||||
auto* block = stack.back();
|
||||
stack.pop_back();
|
||||
for (auto* pred : block->GetPredecessors()) {
|
||||
if (!pred || !loop_blocks.insert(pred).second) {
|
||||
continue;
|
||||
}
|
||||
stack.push_back(pred);
|
||||
}
|
||||
}
|
||||
return {loop_blocks.begin(), loop_blocks.end()};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool Loop::Contains(BasicBlock* block) const {
|
||||
return block != nullptr && blocks.find(block) != blocks.end();
|
||||
}
|
||||
|
||||
bool Loop::Contains(const Loop* other) const {
|
||||
if (!other) {
|
||||
return false;
|
||||
}
|
||||
for (auto* block : other->blocks) {
|
||||
if (!Contains(block)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Loop::IsInnermost() const { return subloops.empty(); }
|
||||
|
||||
LoopInfo::LoopInfo(Function& function, const DominatorTree& dom_tree)
|
||||
: function_(&function), dom_tree_(&dom_tree) {
|
||||
Recalculate();
|
||||
}
|
||||
|
||||
void LoopInfo::Recalculate() {
|
||||
loops_.clear();
|
||||
top_level_loops_.clear();
|
||||
block_to_loop_.clear();
|
||||
|
||||
std::unordered_map<BasicBlock*, Loop*> loops_by_header;
|
||||
for (auto* block : dom_tree_->GetReversePostOrder()) {
|
||||
for (auto* succ : block->GetSuccessors()) {
|
||||
if (!dom_tree_->Dominates(succ, block)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Loop* loop = nullptr;
|
||||
auto it = loops_by_header.find(succ);
|
||||
if (it == loops_by_header.end()) {
|
||||
auto new_loop = std::make_unique<Loop>();
|
||||
new_loop->header = succ;
|
||||
loop = new_loop.get();
|
||||
loops_.push_back(std::move(new_loop));
|
||||
loops_by_header.emplace(succ, loop);
|
||||
} else {
|
||||
loop = it->second;
|
||||
}
|
||||
|
||||
if (std::find(loop->latches.begin(), loop->latches.end(), block) ==
|
||||
loop->latches.end()) {
|
||||
loop->latches.push_back(block);
|
||||
}
|
||||
for (auto* natural_block : CollectNaturalLoopBlocks(succ, block)) {
|
||||
loop->blocks.insert(natural_block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<BasicBlock*, std::size_t> function_order;
|
||||
for (std::size_t i = 0; i < function_->GetBlocks().size(); ++i) {
|
||||
function_order.emplace(function_->GetBlocks()[i].get(), i);
|
||||
}
|
||||
|
||||
for (const auto& loop_ptr : loops_) {
|
||||
auto& loop = *loop_ptr;
|
||||
loop.block_list.clear();
|
||||
loop.exiting_blocks.clear();
|
||||
loop.exit_blocks.clear();
|
||||
loop.subloops.clear();
|
||||
loop.parent = nullptr;
|
||||
|
||||
for (const auto& block_ptr : function_->GetBlocks()) {
|
||||
if (loop.Contains(block_ptr.get())) {
|
||||
loop.block_list.push_back(block_ptr.get());
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(loop.latches.begin(), loop.latches.end(),
|
||||
[&](BasicBlock* lhs, BasicBlock* rhs) {
|
||||
return function_order[lhs] < function_order[rhs];
|
||||
});
|
||||
|
||||
std::vector<BasicBlock*> outside_preds;
|
||||
for (auto* pred : loop.header->GetPredecessors()) {
|
||||
if (!loop.Contains(pred)) {
|
||||
outside_preds.push_back(pred);
|
||||
}
|
||||
}
|
||||
if (outside_preds.size() == 1 &&
|
||||
outside_preds.front()->GetSuccessors().size() == 1) {
|
||||
loop.preheader = outside_preds.front();
|
||||
} else {
|
||||
loop.preheader = nullptr;
|
||||
}
|
||||
|
||||
std::unordered_set<BasicBlock*> exiting_seen;
|
||||
std::unordered_set<BasicBlock*> exit_seen;
|
||||
for (auto* block : loop.block_list) {
|
||||
for (auto* succ : block->GetSuccessors()) {
|
||||
if (loop.Contains(succ)) {
|
||||
continue;
|
||||
}
|
||||
if (exiting_seen.insert(block).second) {
|
||||
loop.exiting_blocks.push_back(block);
|
||||
}
|
||||
if (exit_seen.insert(succ).second) {
|
||||
loop.exit_blocks.push_back(succ);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(loop.exiting_blocks.begin(), loop.exiting_blocks.end(),
|
||||
[&](BasicBlock* lhs, BasicBlock* rhs) {
|
||||
return function_order[lhs] < function_order[rhs];
|
||||
});
|
||||
std::sort(loop.exit_blocks.begin(), loop.exit_blocks.end(),
|
||||
[&](BasicBlock* lhs, BasicBlock* rhs) {
|
||||
return function_order[lhs] < function_order[rhs];
|
||||
});
|
||||
}
|
||||
|
||||
for (const auto& loop_ptr : loops_) {
|
||||
auto* loop = loop_ptr.get();
|
||||
Loop* parent = nullptr;
|
||||
for (const auto& candidate_ptr : loops_) {
|
||||
auto* candidate = candidate_ptr.get();
|
||||
if (candidate == loop || !candidate->Contains(loop)) {
|
||||
continue;
|
||||
}
|
||||
if (!parent || candidate->blocks.size() < parent->blocks.size()) {
|
||||
parent = candidate;
|
||||
}
|
||||
}
|
||||
loop->parent = parent;
|
||||
if (parent) {
|
||||
parent->subloops.push_back(loop);
|
||||
} else {
|
||||
top_level_loops_.push_back(loop);
|
||||
}
|
||||
}
|
||||
|
||||
auto loop_order = [&](Loop* lhs, Loop* rhs) {
|
||||
return function_order[lhs->header] < function_order[rhs->header];
|
||||
};
|
||||
std::sort(top_level_loops_.begin(), top_level_loops_.end(), loop_order);
|
||||
for (const auto& loop_ptr : loops_) {
|
||||
std::sort(loop_ptr->subloops.begin(), loop_ptr->subloops.end(), loop_order);
|
||||
}
|
||||
|
||||
for (const auto& block_ptr : function_->GetBlocks()) {
|
||||
Loop* innermost = nullptr;
|
||||
for (const auto& loop_ptr : loops_) {
|
||||
auto* loop = loop_ptr.get();
|
||||
if (!loop->Contains(block_ptr.get())) {
|
||||
continue;
|
||||
}
|
||||
if (!innermost || loop->blocks.size() < innermost->blocks.size()) {
|
||||
innermost = loop;
|
||||
}
|
||||
}
|
||||
if (innermost) {
|
||||
block_to_loop_.emplace(block_ptr.get(), innermost);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Loop*> LoopInfo::GetTopLevelLoops() const { return top_level_loops_; }
|
||||
|
||||
std::vector<Loop*> LoopInfo::GetLoopsInPostOrder() const {
|
||||
std::vector<Loop*> ordered;
|
||||
std::function<void(Loop*)> dfs = [&](Loop* loop) {
|
||||
for (auto* subloop : loop->subloops) {
|
||||
dfs(subloop);
|
||||
}
|
||||
ordered.push_back(loop);
|
||||
};
|
||||
for (auto* loop : top_level_loops_) {
|
||||
dfs(loop);
|
||||
}
|
||||
return ordered;
|
||||
}
|
||||
|
||||
Loop* LoopInfo::GetLoopFor(BasicBlock* block) const {
|
||||
auto it = block_to_loop_.find(block);
|
||||
return it == block_to_loop_.end() ? nullptr : it->second;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
@ -0,0 +1,137 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
bool IsPowerOfTwoPositive(int value) {
|
||||
return value > 0 && (value & (value - 1)) == 0;
|
||||
}
|
||||
|
||||
std::size_t FindInstructionIndex(BasicBlock* block, Instruction* inst) {
|
||||
if (!block || !inst) {
|
||||
return 0;
|
||||
}
|
||||
auto& instructions = block->GetInstructions();
|
||||
for (std::size_t i = 0; i < instructions.size(); ++i) {
|
||||
if (instructions[i].get() == inst) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return instructions.size();
|
||||
}
|
||||
|
||||
bool IsZero(Value* value) {
|
||||
if (auto* ci = dyncast<ConstantInt>(value)) {
|
||||
return ci->GetValue() == 0;
|
||||
}
|
||||
if (auto* cb = dyncast<ConstantI1>(value)) {
|
||||
return !cb->GetValue();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Value* OtherCompareOperand(BinaryInst* cmp, Value* value) {
|
||||
if (!cmp || cmp->GetNumOperands() != 2) {
|
||||
return nullptr;
|
||||
}
|
||||
if (cmp->GetLhs() == value) {
|
||||
return cmp->GetRhs();
|
||||
}
|
||||
if (cmp->GetRhs() == value) {
|
||||
return cmp->GetLhs();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool SimplifyPowerOfTwoRemTests(Function& function) {
|
||||
bool changed = false;
|
||||
std::vector<Instruction*> dead_rems;
|
||||
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
auto* block = block_ptr.get();
|
||||
if (!block) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* rem = dyncast<BinaryInst>(inst_ptr.get());
|
||||
if (!rem || rem->GetOpcode() != Opcode::Rem) {
|
||||
continue;
|
||||
}
|
||||
auto* divisor = dyncast<ConstantInt>(rem->GetRhs());
|
||||
if (!divisor || !IsPowerOfTwoPositive(divisor->GetValue())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int mask_value = divisor->GetValue() - 1;
|
||||
if (mask_value == 0) {
|
||||
rem->ReplaceAllUsesWith(looputils::ConstInt(0));
|
||||
dead_rems.push_back(rem);
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<BinaryInst*> compare_uses;
|
||||
bool all_uses_are_zero_tests = !rem->GetUses().empty();
|
||||
for (const auto& use : rem->GetUses()) {
|
||||
auto* cmp = dyncast<BinaryInst>(dynamic_cast<Value*>(use.GetUser()));
|
||||
if (!cmp || (cmp->GetOpcode() != Opcode::ICmpEQ &&
|
||||
cmp->GetOpcode() != Opcode::ICmpNE) ||
|
||||
!IsZero(OtherCompareOperand(cmp, rem))) {
|
||||
all_uses_are_zero_tests = false;
|
||||
break;
|
||||
}
|
||||
compare_uses.push_back(cmp);
|
||||
}
|
||||
if (!all_uses_are_zero_tests || compare_uses.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto insert_index = FindInstructionIndex(block, rem) + 1;
|
||||
auto* masked = block->Insert<BinaryInst>(
|
||||
insert_index, Opcode::And, Type::GetInt32Type(), rem->GetLhs(),
|
||||
looputils::ConstInt(mask_value), nullptr,
|
||||
looputils::NextSyntheticName(function, "pow2.mask."));
|
||||
|
||||
for (auto* cmp : compare_uses) {
|
||||
if (cmp->GetLhs() == rem) {
|
||||
cmp->SetOperand(0, masked);
|
||||
}
|
||||
if (cmp->GetRhs() == rem) {
|
||||
cmp->SetOperand(1, masked);
|
||||
}
|
||||
}
|
||||
dead_rems.push_back(rem);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* rem : dead_rems) {
|
||||
if (rem->GetUses().empty() && rem->GetParent()) {
|
||||
rem->GetParent()->EraseInstruction(rem);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunArithmeticSimplify(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (!function || function->IsExternal()) {
|
||||
continue;
|
||||
}
|
||||
changed |= SimplifyPowerOfTwoRemTests(*function);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -1,4 +1,107 @@
|
||||
// CFG 简化:
|
||||
// - 删除不可达块、合并空块、简化分支等
|
||||
// - 改善 IR 结构,便于后续优化与后端生成
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "PassUtils.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
bool TryGetConstBranchTarget(CondBrInst* br, BasicBlock*& target, BasicBlock*& removed) {
|
||||
if (!br) {
|
||||
return false;
|
||||
}
|
||||
auto* then_block = br->GetThenBlock();
|
||||
auto* else_block = br->GetElseBlock();
|
||||
if (then_block == else_block) {
|
||||
target = then_block;
|
||||
removed = nullptr;
|
||||
return true;
|
||||
}
|
||||
if (auto* cond = dyncast<ConstantI1>(br->GetCondition())) {
|
||||
target = cond->GetValue() ? then_block : else_block;
|
||||
removed = cond->GetValue() ? else_block : then_block;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool SimplifyBlockTerminator(BasicBlock* block) {
|
||||
if (!block || block->GetInstructions().empty()) {
|
||||
return false;
|
||||
}
|
||||
auto* term = block->GetInstructions().back().get();
|
||||
auto* condbr = dyncast<CondBrInst>(term);
|
||||
if (!condbr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
BasicBlock* target = nullptr;
|
||||
BasicBlock* removed = nullptr;
|
||||
if (!TryGetConstBranchTarget(condbr, target, removed)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (removed) {
|
||||
passutils::RemoveIncomingFromSuccessor(removed, block);
|
||||
removed->RemovePredecessor(block);
|
||||
block->RemoveSuccessor(removed);
|
||||
}
|
||||
passutils::ReplaceTerminatorWithBr(block, target);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SimplifyPhiNodes(Function& function) {
|
||||
bool changed = false;
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
bool local_changed = true;
|
||||
while (local_changed) {
|
||||
local_changed = false;
|
||||
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
if (!passutils::SimplifyPhiInst(phi)) {
|
||||
continue;
|
||||
}
|
||||
local_changed = true;
|
||||
changed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RunCFGSimplifyOnFunction(Function& function) {
|
||||
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
changed |= SimplifyBlockTerminator(block_ptr.get());
|
||||
}
|
||||
|
||||
changed |= passutils::RemoveUnreachableBlocks(function);
|
||||
changed |= SimplifyPhiNodes(function);
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunCFGSimplify(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunCFGSimplifyOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
@ -1,4 +1,469 @@
|
||||
// IR 常量折叠:
|
||||
// - 折叠可判定的常量表达式
|
||||
// - 简化常量控制流分支(按实现范围裁剪)
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "PassUtils.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
Value* GetInt32Const(Context& ctx, std::int32_t value) {
|
||||
return ctx.GetConstInt(static_cast<int>(value));
|
||||
}
|
||||
|
||||
Value* GetBoolConst(Context& ctx, bool value) { return ctx.GetConstBool(value); }
|
||||
|
||||
Value* GetFloatConst(float value) {
|
||||
return new ConstantFloat(Type::GetFloatType(), value);
|
||||
}
|
||||
|
||||
bool TryGetInt32(Value* value, std::int32_t& out) {
|
||||
if (auto* ci = dyncast<ConstantInt>(value)) {
|
||||
out = static_cast<std::int32_t>(ci->GetValue());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool TryGetBool(Value* value, bool& out) {
|
||||
if (auto* cb = dyncast<ConstantI1>(value)) {
|
||||
out = cb->GetValue();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool TryGetFloat(Value* value, float& out) {
|
||||
if (auto* cf = dyncast<ConstantFloat>(value)) {
|
||||
out = cf->GetValue();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsZeroValue(Value* value) {
|
||||
std::int32_t i32 = 0;
|
||||
bool i1 = false;
|
||||
float f32 = 0.0f;
|
||||
return (TryGetInt32(value, i32) && i32 == 0) || (TryGetBool(value, i1) && !i1) ||
|
||||
(TryGetFloat(value, f32) && passutils::FloatBits(f32) == 0);
|
||||
}
|
||||
|
||||
bool IsOneValue(Value* value) {
|
||||
std::int32_t i32 = 0;
|
||||
bool i1 = false;
|
||||
float f32 = 0.0f;
|
||||
return (TryGetInt32(value, i32) && i32 == 1) || (TryGetBool(value, i1) && i1) ||
|
||||
(TryGetFloat(value, f32) &&
|
||||
passutils::FloatBits(f32) == passutils::FloatBits(1.0f));
|
||||
}
|
||||
|
||||
bool IsAllOnesInt(Value* value) {
|
||||
std::int32_t i32 = 0;
|
||||
return TryGetInt32(value, i32) && i32 == -1;
|
||||
}
|
||||
|
||||
std::int32_t WrapInt32(std::uint32_t value) {
|
||||
return static_cast<std::int32_t>(value);
|
||||
}
|
||||
|
||||
Value* FoldBinary(Context& ctx, BinaryInst* inst) {
|
||||
const auto opcode = inst->GetOpcode();
|
||||
auto* lhs = inst->GetLhs();
|
||||
auto* rhs = inst->GetRhs();
|
||||
|
||||
std::int32_t lhs_i32 = 0;
|
||||
std::int32_t rhs_i32 = 0;
|
||||
bool lhs_i1 = false;
|
||||
bool rhs_i1 = false;
|
||||
float lhs_f32 = 0.0f;
|
||||
float rhs_f32 = 0.0f;
|
||||
|
||||
const bool has_lhs_i32 = TryGetInt32(lhs, lhs_i32);
|
||||
const bool has_rhs_i32 = TryGetInt32(rhs, rhs_i32);
|
||||
const bool has_lhs_i1 = TryGetBool(lhs, lhs_i1);
|
||||
const bool has_rhs_i1 = TryGetBool(rhs, rhs_i1);
|
||||
const bool has_lhs_f32 = TryGetFloat(lhs, lhs_f32);
|
||||
const bool has_rhs_f32 = TryGetFloat(rhs, rhs_f32);
|
||||
|
||||
if (has_lhs_i32 && has_rhs_i32) {
|
||||
switch (opcode) {
|
||||
case Opcode::Add:
|
||||
return GetInt32Const(
|
||||
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) +
|
||||
static_cast<std::uint32_t>(rhs_i32)));
|
||||
case Opcode::Sub:
|
||||
return GetInt32Const(
|
||||
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) -
|
||||
static_cast<std::uint32_t>(rhs_i32)));
|
||||
case Opcode::Mul:
|
||||
return GetInt32Const(
|
||||
ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32) *
|
||||
static_cast<std::uint32_t>(rhs_i32)));
|
||||
case Opcode::Div:
|
||||
if (rhs_i32 == 0 ||
|
||||
(lhs_i32 == std::numeric_limits<std::int32_t>::min() && rhs_i32 == -1)) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetInt32Const(ctx, lhs_i32 / rhs_i32);
|
||||
case Opcode::Rem:
|
||||
if (rhs_i32 == 0 ||
|
||||
(lhs_i32 == std::numeric_limits<std::int32_t>::min() && rhs_i32 == -1)) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetInt32Const(ctx, lhs_i32 % rhs_i32);
|
||||
case Opcode::And:
|
||||
return GetInt32Const(ctx, lhs_i32 & rhs_i32);
|
||||
case Opcode::Or:
|
||||
return GetInt32Const(ctx, lhs_i32 | rhs_i32);
|
||||
case Opcode::Xor:
|
||||
return GetInt32Const(ctx, lhs_i32 ^ rhs_i32);
|
||||
case Opcode::Shl:
|
||||
if (rhs_i32 < 0 || rhs_i32 >= 32) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetInt32Const(ctx, WrapInt32(static_cast<std::uint32_t>(lhs_i32)
|
||||
<< rhs_i32));
|
||||
case Opcode::AShr:
|
||||
if (rhs_i32 < 0 || rhs_i32 >= 32) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetInt32Const(ctx, lhs_i32 >> rhs_i32);
|
||||
case Opcode::LShr:
|
||||
if (rhs_i32 < 0 || rhs_i32 >= 32) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetInt32Const(
|
||||
ctx,
|
||||
WrapInt32(static_cast<std::uint32_t>(lhs_i32) >> rhs_i32));
|
||||
case Opcode::ICmpEQ:
|
||||
return GetBoolConst(ctx, lhs_i32 == rhs_i32);
|
||||
case Opcode::ICmpNE:
|
||||
return GetBoolConst(ctx, lhs_i32 != rhs_i32);
|
||||
case Opcode::ICmpLT:
|
||||
return GetBoolConst(ctx, lhs_i32 < rhs_i32);
|
||||
case Opcode::ICmpGT:
|
||||
return GetBoolConst(ctx, lhs_i32 > rhs_i32);
|
||||
case Opcode::ICmpLE:
|
||||
return GetBoolConst(ctx, lhs_i32 <= rhs_i32);
|
||||
case Opcode::ICmpGE:
|
||||
return GetBoolConst(ctx, lhs_i32 >= rhs_i32);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_lhs_i1 && has_rhs_i1) {
|
||||
switch (opcode) {
|
||||
case Opcode::And:
|
||||
return GetBoolConst(ctx, lhs_i1 && rhs_i1);
|
||||
case Opcode::Or:
|
||||
return GetBoolConst(ctx, lhs_i1 || rhs_i1);
|
||||
case Opcode::Xor:
|
||||
return GetBoolConst(ctx, lhs_i1 != rhs_i1);
|
||||
case Opcode::ICmpEQ:
|
||||
return GetBoolConst(ctx, lhs_i1 == rhs_i1);
|
||||
case Opcode::ICmpNE:
|
||||
return GetBoolConst(ctx, lhs_i1 != rhs_i1);
|
||||
case Opcode::ICmpLT:
|
||||
return GetBoolConst(ctx, static_cast<int>(lhs_i1) < static_cast<int>(rhs_i1));
|
||||
case Opcode::ICmpGT:
|
||||
return GetBoolConst(ctx, static_cast<int>(lhs_i1) > static_cast<int>(rhs_i1));
|
||||
case Opcode::ICmpLE:
|
||||
return GetBoolConst(ctx, static_cast<int>(lhs_i1) <= static_cast<int>(rhs_i1));
|
||||
case Opcode::ICmpGE:
|
||||
return GetBoolConst(ctx, static_cast<int>(lhs_i1) >= static_cast<int>(rhs_i1));
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_lhs_f32 && has_rhs_f32) {
|
||||
switch (opcode) {
|
||||
case Opcode::FAdd:
|
||||
return GetFloatConst(lhs_f32 + rhs_f32);
|
||||
case Opcode::FSub:
|
||||
return GetFloatConst(lhs_f32 - rhs_f32);
|
||||
case Opcode::FMul:
|
||||
return GetFloatConst(lhs_f32 * rhs_f32);
|
||||
case Opcode::FDiv:
|
||||
return GetFloatConst(lhs_f32 / rhs_f32);
|
||||
case Opcode::FRem:
|
||||
return GetFloatConst(std::fmod(lhs_f32, rhs_f32));
|
||||
case Opcode::FCmpEQ:
|
||||
return GetBoolConst(
|
||||
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 == rhs_f32);
|
||||
case Opcode::FCmpNE:
|
||||
return GetBoolConst(
|
||||
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 != rhs_f32);
|
||||
case Opcode::FCmpLT:
|
||||
return GetBoolConst(
|
||||
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 < rhs_f32);
|
||||
case Opcode::FCmpGT:
|
||||
return GetBoolConst(
|
||||
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 > rhs_f32);
|
||||
case Opcode::FCmpLE:
|
||||
return GetBoolConst(
|
||||
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 <= rhs_f32);
|
||||
case Opcode::FCmpGE:
|
||||
return GetBoolConst(
|
||||
ctx, !std::isnan(lhs_f32) && !std::isnan(rhs_f32) && lhs_f32 >= rhs_f32);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
switch (opcode) {
|
||||
case Opcode::Add:
|
||||
if (IsZeroValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
if (IsZeroValue(lhs)) {
|
||||
return rhs;
|
||||
}
|
||||
break;
|
||||
case Opcode::Sub:
|
||||
if (IsZeroValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
break;
|
||||
case Opcode::Mul:
|
||||
if (IsOneValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
if (IsOneValue(lhs)) {
|
||||
return rhs;
|
||||
}
|
||||
if (IsZeroValue(lhs) || IsZeroValue(rhs)) {
|
||||
return GetInt32Const(ctx, 0);
|
||||
}
|
||||
break;
|
||||
case Opcode::Div:
|
||||
if (IsOneValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
if (IsZeroValue(lhs) && !IsZeroValue(rhs)) {
|
||||
return GetInt32Const(ctx, 0);
|
||||
}
|
||||
break;
|
||||
case Opcode::Rem:
|
||||
if ((has_rhs_i32 && (rhs_i32 == 1 || rhs_i32 == -1)) ||
|
||||
(has_rhs_i1 && rhs_i1)) {
|
||||
return GetInt32Const(ctx, 0);
|
||||
}
|
||||
if (IsZeroValue(lhs) && !IsZeroValue(rhs)) {
|
||||
return GetInt32Const(ctx, 0);
|
||||
}
|
||||
break;
|
||||
case Opcode::And:
|
||||
if (IsZeroValue(lhs) || IsZeroValue(rhs)) {
|
||||
return inst->GetType()->IsInt1() ? GetBoolConst(ctx, false)
|
||||
: GetInt32Const(ctx, 0);
|
||||
}
|
||||
if (has_lhs_i1 && lhs_i1) {
|
||||
return rhs;
|
||||
}
|
||||
if (has_rhs_i1 && rhs_i1) {
|
||||
return lhs;
|
||||
}
|
||||
if (IsAllOnesInt(lhs)) {
|
||||
return rhs;
|
||||
}
|
||||
if (IsAllOnesInt(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
if (passutils::AreEquivalentValues(lhs, rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
break;
|
||||
case Opcode::Or:
|
||||
if (IsZeroValue(lhs)) {
|
||||
return rhs;
|
||||
}
|
||||
if (IsZeroValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
if (has_lhs_i1 && lhs_i1) {
|
||||
return GetBoolConst(ctx, true);
|
||||
}
|
||||
if (has_rhs_i1 && rhs_i1) {
|
||||
return GetBoolConst(ctx, true);
|
||||
}
|
||||
if (passutils::AreEquivalentValues(lhs, rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
break;
|
||||
case Opcode::Xor:
|
||||
if (IsZeroValue(lhs)) {
|
||||
return rhs;
|
||||
}
|
||||
if (IsZeroValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
if (passutils::AreEquivalentValues(lhs, rhs)) {
|
||||
return inst->GetType()->IsInt1() ? GetBoolConst(ctx, false)
|
||||
: GetInt32Const(ctx, 0);
|
||||
}
|
||||
break;
|
||||
case Opcode::Shl:
|
||||
case Opcode::AShr:
|
||||
case Opcode::LShr:
|
||||
if (IsZeroValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
if (IsZeroValue(lhs)) {
|
||||
return GetInt32Const(ctx, 0);
|
||||
}
|
||||
break;
|
||||
case Opcode::FAdd:
|
||||
if (IsZeroValue(lhs)) {
|
||||
return rhs;
|
||||
}
|
||||
if (IsZeroValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
break;
|
||||
case Opcode::FSub:
|
||||
if (IsZeroValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
break;
|
||||
case Opcode::FMul:
|
||||
if (IsOneValue(lhs)) {
|
||||
return rhs;
|
||||
}
|
||||
if (IsOneValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
break;
|
||||
case Opcode::FDiv:
|
||||
if (IsOneValue(rhs)) {
|
||||
return lhs;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Value* FoldUnary(Context& ctx, UnaryInst* inst) {
|
||||
auto* operand = inst->GetOprd();
|
||||
std::int32_t i32 = 0;
|
||||
bool i1 = false;
|
||||
float f32 = 0.0f;
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Neg:
|
||||
if (TryGetInt32(operand, i32)) {
|
||||
return GetInt32Const(ctx, WrapInt32(0u - static_cast<std::uint32_t>(i32)));
|
||||
}
|
||||
break;
|
||||
case Opcode::Not:
|
||||
if (TryGetBool(operand, i1)) {
|
||||
return GetBoolConst(ctx, !i1);
|
||||
}
|
||||
if (TryGetInt32(operand, i32)) {
|
||||
return GetInt32Const(ctx, i32 ^ 1);
|
||||
}
|
||||
break;
|
||||
case Opcode::FNeg:
|
||||
if (TryGetFloat(operand, f32)) {
|
||||
return GetFloatConst(-f32);
|
||||
}
|
||||
break;
|
||||
case Opcode::FtoI:
|
||||
if (TryGetFloat(operand, f32)) {
|
||||
return GetInt32Const(ctx, static_cast<std::int32_t>(f32));
|
||||
}
|
||||
break;
|
||||
case Opcode::IToF:
|
||||
if (TryGetInt32(operand, i32)) {
|
||||
return GetFloatConst(static_cast<float>(i32));
|
||||
}
|
||||
if (TryGetBool(operand, i1)) {
|
||||
return GetFloatConst(i1 ? 1.0f : 0.0f);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Value* FoldZext(Context& ctx, ZextInst* inst) {
|
||||
auto* value = inst->GetValue();
|
||||
bool i1 = false;
|
||||
std::int32_t i32 = 0;
|
||||
if (inst->GetType()->IsInt1()) {
|
||||
if (TryGetBool(value, i1)) {
|
||||
return GetBoolConst(ctx, i1);
|
||||
}
|
||||
if (TryGetInt32(value, i32)) {
|
||||
return GetBoolConst(ctx, i32 != 0);
|
||||
}
|
||||
}
|
||||
if (inst->GetType()->IsInt32()) {
|
||||
if (TryGetBool(value, i1)) {
|
||||
return GetInt32Const(ctx, i1 ? 1 : 0);
|
||||
}
|
||||
if (TryGetInt32(value, i32)) {
|
||||
return GetInt32Const(ctx, i32);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool FoldFunction(Function& function, Context& ctx) {
|
||||
if (function.IsExternal()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
std::vector<Instruction*> to_remove;
|
||||
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
Value* replacement = nullptr;
|
||||
if (auto* binary = dyncast<BinaryInst>(inst)) {
|
||||
replacement = FoldBinary(ctx, binary);
|
||||
} else if (auto* unary = dyncast<UnaryInst>(inst)) {
|
||||
replacement = FoldUnary(ctx, unary);
|
||||
} else if (auto* zext = dyncast<ZextInst>(inst)) {
|
||||
replacement = FoldZext(ctx, zext);
|
||||
}
|
||||
|
||||
if (!replacement || replacement == inst) {
|
||||
continue;
|
||||
}
|
||||
inst->ReplaceAllUsesWith(replacement);
|
||||
to_remove.push_back(inst);
|
||||
changed = true;
|
||||
}
|
||||
|
||||
for (auto* inst : to_remove) {
|
||||
block_ptr->EraseInstruction(inst);
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunConstFold(Module& module) {
|
||||
bool changed = false;
|
||||
auto& ctx = module.GetContext();
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= FoldFunction(*function, ctx);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
@ -0,0 +1,219 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "MemoryUtils.h"
|
||||
#include "PassUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct ExprKey {
|
||||
Opcode opcode = Opcode::Add;
|
||||
std::uintptr_t result_type = 0;
|
||||
std::uintptr_t aux_type = 0;
|
||||
struct OperandKey {
|
||||
int kind = 0;
|
||||
std::intptr_t value = 0;
|
||||
|
||||
bool operator==(const OperandKey& rhs) const {
|
||||
return kind == rhs.kind && value == rhs.value;
|
||||
}
|
||||
};
|
||||
std::vector<OperandKey> operands;
|
||||
|
||||
bool operator==(const ExprKey& rhs) const {
|
||||
return opcode == rhs.opcode && result_type == rhs.result_type &&
|
||||
aux_type == rhs.aux_type && operands == rhs.operands;
|
||||
}
|
||||
};
|
||||
|
||||
struct ExprKeyHash {
|
||||
std::size_t operator()(const ExprKey& key) const {
|
||||
std::size_t h = static_cast<std::size_t>(key.opcode);
|
||||
h ^= std::hash<std::uintptr_t>{}(key.result_type) + 0x9e3779b9 + (h << 6) +
|
||||
(h >> 2);
|
||||
h ^= std::hash<std::uintptr_t>{}(key.aux_type) + 0x9e3779b9 + (h << 6) +
|
||||
(h >> 2);
|
||||
for (auto operand : key.operands) {
|
||||
h ^= std::hash<int>{}(operand.kind) + 0x9e3779b9 + (h << 6) + (h >> 2);
|
||||
h ^= std::hash<std::intptr_t>{}(operand.value) + 0x9e3779b9 + (h << 6) + (h >> 2);
|
||||
}
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
ExprKey::OperandKey BuildOperandKey(Value* value) {
|
||||
if (auto* ci = dyncast<ConstantInt>(value)) {
|
||||
return {1, ci->GetValue()};
|
||||
}
|
||||
if (auto* cb = dyncast<ConstantI1>(value)) {
|
||||
return {2, cb->GetValue() ? 1 : 0};
|
||||
}
|
||||
if (auto* cf = dyncast<ConstantFloat>(value)) {
|
||||
return {3, static_cast<std::intptr_t>(passutils::FloatBits(cf->GetValue()))};
|
||||
}
|
||||
return {0, reinterpret_cast<std::intptr_t>(value)};
|
||||
}
|
||||
|
||||
struct ScopedExpr {
|
||||
ExprKey key;
|
||||
Value* previous = nullptr;
|
||||
bool had_previous = false;
|
||||
};
|
||||
|
||||
bool IsSupportedGVNInstruction(Instruction* inst) {
|
||||
if (!inst || inst->IsVoid()) {
|
||||
return false;
|
||||
}
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::Div:
|
||||
case Opcode::Rem:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FSub:
|
||||
case Opcode::FMul:
|
||||
case Opcode::FDiv:
|
||||
case Opcode::FRem:
|
||||
case Opcode::And:
|
||||
case Opcode::Or:
|
||||
case Opcode::Xor:
|
||||
case Opcode::Shl:
|
||||
case Opcode::AShr:
|
||||
case Opcode::LShr:
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGE:
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
case Opcode::FCmpLT:
|
||||
case Opcode::FCmpGT:
|
||||
case Opcode::FCmpLE:
|
||||
case Opcode::FCmpGE:
|
||||
case Opcode::Neg:
|
||||
case Opcode::Not:
|
||||
case Opcode::FNeg:
|
||||
case Opcode::FtoI:
|
||||
case Opcode::IToF:
|
||||
case Opcode::GetElementPtr:
|
||||
case Opcode::Zext:
|
||||
return true;
|
||||
case Opcode::Call:
|
||||
return memutils::IsPureCall(dyncast<CallInst>(inst));
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
ExprKey BuildExprKey(Instruction* inst) {
|
||||
ExprKey key;
|
||||
key.opcode = inst->GetOpcode();
|
||||
key.result_type =
|
||||
reinterpret_cast<std::uintptr_t>(inst->GetType().get());
|
||||
if (auto* gep = dyncast<GetElementPtrInst>(inst)) {
|
||||
key.aux_type = reinterpret_cast<std::uintptr_t>(gep->GetSourceType().get());
|
||||
}
|
||||
key.operands.reserve(inst->GetNumOperands());
|
||||
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
|
||||
key.operands.push_back(BuildOperandKey(inst->GetOperand(i)));
|
||||
}
|
||||
if (inst->GetNumOperands() == 2 &&
|
||||
passutils::IsCommutativeOpcode(inst->GetOpcode()) &&
|
||||
(key.operands[1].kind < key.operands[0].kind ||
|
||||
(key.operands[1].kind == key.operands[0].kind &&
|
||||
key.operands[1].value < key.operands[0].value))) {
|
||||
std::swap(key.operands[0], key.operands[1]);
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
bool RunGVNInDomSubtree(
|
||||
BasicBlock* block, const DominatorTree& dom_tree,
|
||||
std::unordered_map<ExprKey, Value*, ExprKeyHash>& available) {
|
||||
if (!block) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
std::vector<ScopedExpr> scope;
|
||||
std::vector<Instruction*> to_remove;
|
||||
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (!IsSupportedGVNInstruction(inst)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto key = BuildExprKey(inst);
|
||||
auto it = available.find(key);
|
||||
if (it != available.end()) {
|
||||
inst->ReplaceAllUsesWith(it->second);
|
||||
to_remove.push_back(inst);
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
ScopedExpr scoped{key, nullptr, false};
|
||||
auto existing = available.find(key);
|
||||
if (existing != available.end()) {
|
||||
scoped.previous = existing->second;
|
||||
scoped.had_previous = true;
|
||||
existing->second = inst;
|
||||
} else {
|
||||
available.emplace(key, inst);
|
||||
}
|
||||
scope.push_back(std::move(scoped));
|
||||
}
|
||||
|
||||
for (auto* inst : to_remove) {
|
||||
block->EraseInstruction(inst);
|
||||
}
|
||||
|
||||
for (auto* child : dom_tree.GetChildren(block)) {
|
||||
changed |= RunGVNInDomSubtree(child, dom_tree, available);
|
||||
}
|
||||
|
||||
for (auto it = scope.rbegin(); it != scope.rend(); ++it) {
|
||||
if (it->had_previous) {
|
||||
available[it->key] = it->previous;
|
||||
} else {
|
||||
available.erase(it->key);
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RunGVNOnFunction(Function& function) {
|
||||
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
DominatorTree dom_tree(function);
|
||||
std::unordered_map<ExprKey, Value*, ExprKeyHash> available;
|
||||
return RunGVNInDomSubtree(function.GetEntryBlock(), dom_tree, available);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunGVN(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunGVNOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,756 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "LoopPassUtils.h"
|
||||
#include "MathIdiomUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct InlineCandidateInfo {
|
||||
bool valid = false;
|
||||
int cost = 0;
|
||||
bool has_nested_call = false;
|
||||
bool has_control_flow = false;
|
||||
};
|
||||
|
||||
bool IsInlineableInstruction(const Instruction* inst) {
|
||||
if (!inst) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::Div:
|
||||
case Opcode::Rem:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FSub:
|
||||
case Opcode::FMul:
|
||||
case Opcode::FDiv:
|
||||
case Opcode::FRem:
|
||||
case Opcode::And:
|
||||
case Opcode::Or:
|
||||
case Opcode::Xor:
|
||||
case Opcode::Shl:
|
||||
case Opcode::AShr:
|
||||
case Opcode::LShr:
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGE:
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
case Opcode::FCmpLT:
|
||||
case Opcode::FCmpGT:
|
||||
case Opcode::FCmpLE:
|
||||
case Opcode::FCmpGE:
|
||||
case Opcode::Neg:
|
||||
case Opcode::Not:
|
||||
case Opcode::FNeg:
|
||||
case Opcode::FtoI:
|
||||
case Opcode::IToF:
|
||||
case Opcode::Load:
|
||||
case Opcode::Store:
|
||||
case Opcode::GetElementPtr:
|
||||
case Opcode::Phi:
|
||||
case Opcode::Zext:
|
||||
case Opcode::Memset:
|
||||
case Opcode::Call:
|
||||
case Opcode::Return:
|
||||
case Opcode::Br:
|
||||
case Opcode::CondBr:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
int EstimateInstructionCost(const Instruction* inst) {
|
||||
if (!inst) {
|
||||
return 0;
|
||||
}
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Phi:
|
||||
case Opcode::Return:
|
||||
return 0;
|
||||
case Opcode::Load:
|
||||
case Opcode::Store:
|
||||
case Opcode::Memset:
|
||||
return 3;
|
||||
case Opcode::Call:
|
||||
return 8;
|
||||
case Opcode::GetElementPtr:
|
||||
return 2;
|
||||
default:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
InlineCandidateInfo AnalyzeInlineCandidate(const Function& function) {
|
||||
InlineCandidateInfo info;
|
||||
if (function.IsExternal() || function.IsRecursive()) {
|
||||
return info;
|
||||
}
|
||||
if (function.GetBlocks().empty() || function.GetBlocks().size() > 16) {
|
||||
return info;
|
||||
}
|
||||
|
||||
bool saw_return = false;
|
||||
for (const auto& block : function.GetBlocks()) {
|
||||
if (!block || block->GetInstructions().empty()) {
|
||||
return info;
|
||||
}
|
||||
|
||||
for (std::size_t i = 0; i < block->GetInstructions().size(); ++i) {
|
||||
auto* inst = block->GetInstructions()[i].get();
|
||||
if (!IsInlineableInstruction(inst) || dyncast<AllocaInst>(inst) ||
|
||||
dyncast<UnreachableInst>(inst)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
if (dyncast<ReturnInst>(inst)) {
|
||||
if (i + 1 != block->GetInstructions().size()) {
|
||||
return {};
|
||||
}
|
||||
saw_return = true;
|
||||
continue;
|
||||
}
|
||||
if ((dyncast<UncondBrInst>(inst) || dyncast<CondBrInst>(inst)) &&
|
||||
i + 1 != block->GetInstructions().size()) {
|
||||
return {};
|
||||
}
|
||||
if (dyncast<CondBrInst>(inst) || dyncast<UncondBrInst>(inst)) {
|
||||
info.has_control_flow = true;
|
||||
}
|
||||
if (dyncast<CallInst>(inst)) {
|
||||
info.has_nested_call = true;
|
||||
}
|
||||
info.cost += EstimateInstructionCost(inst);
|
||||
}
|
||||
}
|
||||
|
||||
if (!saw_return) {
|
||||
return {};
|
||||
}
|
||||
|
||||
info.valid = true;
|
||||
return info;
|
||||
}
|
||||
|
||||
std::unordered_map<Function*, int> CountDirectCalls(Module& module) {
|
||||
std::unordered_map<Function*, int> counts;
|
||||
for (const auto& function_ptr : module.GetFunctions()) {
|
||||
if (!function_ptr) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& block_ptr : function_ptr->GetBlocks()) {
|
||||
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
|
||||
if (auto* call = dyncast<CallInst>(inst_ptr.get())) {
|
||||
if (auto* callee = call->GetCallee()) {
|
||||
++counts[callee];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return counts;
|
||||
}
|
||||
|
||||
bool ShouldInlineCallSite(const Function& caller, const CallInst& call,
|
||||
const InlineCandidateInfo& callee_info, int call_count) {
|
||||
auto* callee = call.GetCallee();
|
||||
if (!callee || callee == &caller || !callee_info.valid) {
|
||||
return false;
|
||||
}
|
||||
if (mathidiom::IsToleranceNewtonSqrtShape(*callee)) {
|
||||
return false;
|
||||
}
|
||||
if (mathidiom::IsPow2DigitExtractShape(*callee)) {
|
||||
return false;
|
||||
}
|
||||
if (callee_info.has_control_flow && callee_info.has_nested_call) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int budget = callee->CanDiscardUnusedCall() ? 96 : 72;
|
||||
if (call_count <= 1) {
|
||||
budget += 48;
|
||||
}
|
||||
if (callee_info.has_nested_call) {
|
||||
budget -= 8;
|
||||
}
|
||||
if (callee_info.has_control_flow) {
|
||||
budget -= 12;
|
||||
}
|
||||
if (callee->MayWriteMemory()) {
|
||||
budget -= 4;
|
||||
}
|
||||
return callee_info.cost <= budget;
|
||||
}
|
||||
|
||||
Instruction* CloneInstructionAt(Function& function, Instruction* inst, BasicBlock* dest,
|
||||
std::size_t insert_index,
|
||||
std::unordered_map<Value*, Value*>& remap) {
|
||||
if (!inst || !dest) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto name = inst->IsVoid() ? std::string()
|
||||
: looputils::NextSyntheticName(function, "inline.");
|
||||
auto remap_operand = [&](Value* value) { return looputils::RemapValue(remap, value); };
|
||||
auto remember = [&](Instruction* clone) {
|
||||
if (clone && !inst->IsVoid()) {
|
||||
remap[inst] = clone;
|
||||
}
|
||||
return clone;
|
||||
};
|
||||
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::Div:
|
||||
case Opcode::Rem:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FSub:
|
||||
case Opcode::FMul:
|
||||
case Opcode::FDiv:
|
||||
case Opcode::FRem:
|
||||
case Opcode::And:
|
||||
case Opcode::Or:
|
||||
case Opcode::Xor:
|
||||
case Opcode::Shl:
|
||||
case Opcode::AShr:
|
||||
case Opcode::LShr:
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGE:
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
case Opcode::FCmpLT:
|
||||
case Opcode::FCmpGT:
|
||||
case Opcode::FCmpLE:
|
||||
case Opcode::FCmpGE: {
|
||||
auto* bin = static_cast<BinaryInst*>(inst);
|
||||
return remember(dest->Insert<BinaryInst>(insert_index, inst->GetOpcode(), inst->GetType(),
|
||||
remap_operand(bin->GetLhs()),
|
||||
remap_operand(bin->GetRhs()), nullptr, name));
|
||||
}
|
||||
case Opcode::Neg:
|
||||
case Opcode::Not:
|
||||
case Opcode::FNeg:
|
||||
case Opcode::FtoI:
|
||||
case Opcode::IToF: {
|
||||
auto* un = static_cast<UnaryInst*>(inst);
|
||||
return remember(dest->Insert<UnaryInst>(insert_index, inst->GetOpcode(), inst->GetType(),
|
||||
remap_operand(un->GetOprd()), nullptr, name));
|
||||
}
|
||||
case Opcode::Load: {
|
||||
auto* load = static_cast<LoadInst*>(inst);
|
||||
return remember(dest->Insert<LoadInst>(insert_index, inst->GetType(),
|
||||
remap_operand(load->GetPtr()), nullptr, name));
|
||||
}
|
||||
case Opcode::Store: {
|
||||
auto* store = static_cast<StoreInst*>(inst);
|
||||
return dest->Insert<StoreInst>(insert_index, remap_operand(store->GetValue()),
|
||||
remap_operand(store->GetPtr()), nullptr);
|
||||
}
|
||||
case Opcode::Memset: {
|
||||
auto* memset = static_cast<MemsetInst*>(inst);
|
||||
return dest->Insert<MemsetInst>(insert_index, remap_operand(memset->GetDest()),
|
||||
remap_operand(memset->GetValue()),
|
||||
remap_operand(memset->GetLength()),
|
||||
remap_operand(memset->GetIsVolatile()), nullptr);
|
||||
}
|
||||
case Opcode::GetElementPtr: {
|
||||
auto* gep = static_cast<GetElementPtrInst*>(inst);
|
||||
std::vector<Value*> indices;
|
||||
indices.reserve(gep->GetNumIndices());
|
||||
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
|
||||
indices.push_back(remap_operand(gep->GetIndex(i)));
|
||||
}
|
||||
return remember(dest->Insert<GetElementPtrInst>(
|
||||
insert_index, gep->GetSourceType(), remap_operand(gep->GetPointer()), indices, nullptr,
|
||||
name));
|
||||
}
|
||||
case Opcode::Zext: {
|
||||
auto* zext = static_cast<ZextInst*>(inst);
|
||||
return remember(dest->Insert<ZextInst>(insert_index, remap_operand(zext->GetValue()),
|
||||
inst->GetType(), nullptr, name));
|
||||
}
|
||||
case Opcode::Call: {
|
||||
auto* call = static_cast<CallInst*>(inst);
|
||||
std::vector<Value*> args;
|
||||
const auto original_args = call->GetArguments();
|
||||
args.reserve(original_args.size());
|
||||
for (auto* arg : original_args) {
|
||||
args.push_back(remap_operand(arg));
|
||||
}
|
||||
return remember(
|
||||
dest->Insert<CallInst>(insert_index, call->GetCallee(), args, nullptr, name));
|
||||
}
|
||||
case Opcode::Return:
|
||||
case Opcode::Alloca:
|
||||
case Opcode::Phi:
|
||||
case Opcode::Br:
|
||||
case Opcode::CondBr:
|
||||
case Opcode::Unreachable:
|
||||
break;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool InlineCallSite(Function& caller, CallInst* call) {
|
||||
if (!call) {
|
||||
return false;
|
||||
}
|
||||
auto* callee = call->GetCallee();
|
||||
if (!callee || callee->GetBlocks().size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& callee_args = callee->GetArguments();
|
||||
const auto call_args = call->GetArguments();
|
||||
if (callee_args.size() != call_args.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* block = call->GetParent();
|
||||
if (!block) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto& instructions = block->GetInstructions();
|
||||
auto call_it = std::find_if(instructions.begin(), instructions.end(),
|
||||
[&](const std::unique_ptr<Instruction>& current) {
|
||||
return current.get() == call;
|
||||
});
|
||||
if (call_it == instructions.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::size_t insert_index = static_cast<std::size_t>(call_it - instructions.begin());
|
||||
std::unordered_map<Value*, Value*> remap;
|
||||
for (std::size_t i = 0; i < call_args.size(); ++i) {
|
||||
remap[callee_args[i].get()] = call_args[i];
|
||||
}
|
||||
|
||||
Value* return_value = nullptr;
|
||||
for (const auto& inst_ptr : callee->GetBlocks().front()->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (auto* ret = dyncast<ReturnInst>(inst)) {
|
||||
if (ret->HasReturnValue()) {
|
||||
return_value = looputils::RemapValue(remap, ret->GetReturnValue());
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (!CloneInstructionAt(caller, inst, block, insert_index, remap)) {
|
||||
return false;
|
||||
}
|
||||
++insert_index;
|
||||
}
|
||||
|
||||
if (!call->GetType()->IsVoid()) {
|
||||
if (!return_value) {
|
||||
return false;
|
||||
}
|
||||
call->ReplaceAllUsesWith(return_value);
|
||||
}
|
||||
|
||||
block->EraseInstruction(call);
|
||||
return true;
|
||||
}
|
||||
|
||||
void ReplaceIncomingBlock(BasicBlock* block, BasicBlock* old_pred, BasicBlock* new_pred) {
|
||||
if (!block || !old_pred || !new_pred) {
|
||||
return;
|
||||
}
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
const int index = looputils::GetPhiIncomingIndex(phi, old_pred);
|
||||
if (index >= 0) {
|
||||
phi->SetOperand(static_cast<std::size_t>(2 * index + 1), new_pred);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<BasicBlock*> CollectReachableBlocks(Function& function) {
|
||||
std::vector<BasicBlock*> order;
|
||||
auto* entry = function.GetEntryBlock();
|
||||
if (!entry) {
|
||||
return order;
|
||||
}
|
||||
|
||||
std::vector<BasicBlock*> stack{entry};
|
||||
std::unordered_set<BasicBlock*> visited;
|
||||
while (!stack.empty()) {
|
||||
auto* block = stack.back();
|
||||
stack.pop_back();
|
||||
if (!block || !visited.insert(block).second) {
|
||||
continue;
|
||||
}
|
||||
order.push_back(block);
|
||||
const auto& succs = block->GetSuccessors();
|
||||
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
|
||||
if (*it) {
|
||||
stack.push_back(*it);
|
||||
}
|
||||
}
|
||||
}
|
||||
return order;
|
||||
}
|
||||
|
||||
BasicBlock* SplitBlockAfterCall(Function& caller, BasicBlock* block, CallInst* call) {
|
||||
if (!block || !call) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto& instructions = block->GetInstructions();
|
||||
auto call_it = std::find_if(instructions.begin(), instructions.end(),
|
||||
[&](const std::unique_ptr<Instruction>& current) {
|
||||
return current.get() == call;
|
||||
});
|
||||
if (call_it == instructions.end() || std::next(call_it) == instructions.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* continuation =
|
||||
caller.CreateBlock(looputils::NextSyntheticBlockName(caller, "inline.cont"));
|
||||
auto& continuation_insts = continuation->GetInstructions();
|
||||
for (auto it = std::next(call_it); it != instructions.end(); ++it) {
|
||||
(*it)->SetParent(continuation);
|
||||
continuation_insts.push_back(std::move(*it));
|
||||
}
|
||||
instructions.erase(std::next(call_it), instructions.end());
|
||||
|
||||
auto old_succs = block->GetSuccessors();
|
||||
for (auto* succ : old_succs) {
|
||||
block->RemoveSuccessor(succ);
|
||||
succ->RemovePredecessor(block);
|
||||
succ->AddPredecessor(continuation);
|
||||
continuation->AddSuccessor(succ);
|
||||
ReplaceIncomingBlock(succ, block, continuation);
|
||||
}
|
||||
return continuation;
|
||||
}
|
||||
|
||||
bool CanInlineCFGCallSite(Function& caller, CallInst* call,
|
||||
std::vector<BasicBlock*>& callee_blocks) {
|
||||
auto* callee = call ? call->GetCallee() : nullptr;
|
||||
if (!call || !callee || callee->GetBlocks().size() <= 1 ||
|
||||
callee == &caller) {
|
||||
return false;
|
||||
}
|
||||
if (mathidiom::IsToleranceNewtonSqrtShape(*callee)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& callee_args = callee->GetArguments();
|
||||
const auto call_args = call->GetArguments();
|
||||
if (callee_args.size() != call_args.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
callee_blocks = CollectReachableBlocks(*callee);
|
||||
if (callee_blocks.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_set<BasicBlock*> reachable(callee_blocks.begin(), callee_blocks.end());
|
||||
for (auto* block : callee_blocks) {
|
||||
if (!block || block->GetInstructions().empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool seen_non_phi = false;
|
||||
for (std::size_t i = 0; i < block->GetInstructions().size(); ++i) {
|
||||
auto* inst = block->GetInstructions()[i].get();
|
||||
if (dyncast<AllocaInst>(inst) || dyncast<UnreachableInst>(inst) ||
|
||||
!IsInlineableInstruction(inst)) {
|
||||
return false;
|
||||
}
|
||||
if (dyncast<PhiInst>(inst)) {
|
||||
if (seen_non_phi) {
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
seen_non_phi = true;
|
||||
|
||||
if (auto* br = dyncast<UncondBrInst>(inst)) {
|
||||
if (i + 1 != block->GetInstructions().size() ||
|
||||
reachable.count(br->GetDest()) == 0) {
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto* condbr = dyncast<CondBrInst>(inst)) {
|
||||
if (i + 1 != block->GetInstructions().size() ||
|
||||
reachable.count(condbr->GetThenBlock()) == 0 ||
|
||||
reachable.count(condbr->GetElseBlock()) == 0) {
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (dyncast<ReturnInst>(inst)) {
|
||||
if (i + 1 != block->GetInstructions().size()) {
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inst->IsTerminator() || !looputils::IsCloneableInstruction(inst)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InlineCFGCallSite(Function& caller, CallInst* call) {
|
||||
auto* callee = call ? call->GetCallee() : nullptr;
|
||||
if (!call || !callee || callee->GetBlocks().size() <= 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& callee_args = callee->GetArguments();
|
||||
const auto call_args = call->GetArguments();
|
||||
if (callee_args.size() != call_args.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<BasicBlock*> callee_blocks;
|
||||
if (!CanInlineCFGCallSite(caller, call, callee_blocks)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* call_block = call->GetParent();
|
||||
auto* continuation = SplitBlockAfterCall(caller, call_block, call);
|
||||
if (!call_block || !continuation) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_map<Value*, Value*> remap;
|
||||
for (std::size_t i = 0; i < call_args.size(); ++i) {
|
||||
remap[callee_args[i].get()] = call_args[i];
|
||||
}
|
||||
|
||||
std::unordered_map<BasicBlock*, BasicBlock*> block_map;
|
||||
for (auto* block : callee_blocks) {
|
||||
block_map[block] =
|
||||
caller.CreateBlock(looputils::NextSyntheticBlockName(caller, "inline.bb"));
|
||||
}
|
||||
|
||||
for (auto* block : callee_blocks) {
|
||||
auto* clone = block_map.at(block);
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
auto* cloned_phi = clone->Append<PhiInst>(
|
||||
phi->GetType(), nullptr,
|
||||
looputils::NextSyntheticName(caller, "inline.phi."));
|
||||
remap[phi] = cloned_phi;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* block : callee_blocks) {
|
||||
auto* clone = block_map.at(block);
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (dyncast<PhiInst>(inst)) {
|
||||
continue;
|
||||
}
|
||||
if (inst->IsTerminator()) {
|
||||
continue;
|
||||
}
|
||||
if (!CloneInstructionAt(caller, inst, clone,
|
||||
looputils::GetTerminatorIndex(clone), remap)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* block : callee_blocks) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
auto* cloned_phi = static_cast<PhiInst*>(remap.at(phi));
|
||||
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
|
||||
auto* incoming_block = phi->GetIncomingBlock(i);
|
||||
auto block_it = block_map.find(incoming_block);
|
||||
if (block_it == block_map.end()) {
|
||||
return false;
|
||||
}
|
||||
cloned_phi->AddIncoming(looputils::RemapValue(remap, phi->GetIncomingValue(i)),
|
||||
block_it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<BasicBlock*, Value*>> return_edges;
|
||||
for (auto* block : callee_blocks) {
|
||||
auto* clone = block_map.at(block);
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (dyncast<PhiInst>(inst) || !inst->IsTerminator()) {
|
||||
continue;
|
||||
}
|
||||
if (auto* ret = dyncast<ReturnInst>(inst)) {
|
||||
clone->Append<UncondBrInst>(continuation, nullptr);
|
||||
clone->AddSuccessor(continuation);
|
||||
continuation->AddPredecessor(clone);
|
||||
return_edges.emplace_back(
|
||||
clone, ret->HasReturnValue() ? looputils::RemapValue(remap, ret->GetReturnValue())
|
||||
: nullptr);
|
||||
continue;
|
||||
}
|
||||
if (auto* br = dyncast<UncondBrInst>(inst)) {
|
||||
auto* target = block_map.at(br->GetDest());
|
||||
clone->Append<UncondBrInst>(target, nullptr);
|
||||
clone->AddSuccessor(target);
|
||||
target->AddPredecessor(clone);
|
||||
continue;
|
||||
}
|
||||
if (auto* condbr = dyncast<CondBrInst>(inst)) {
|
||||
auto* then_block = block_map.at(condbr->GetThenBlock());
|
||||
auto* else_block = block_map.at(condbr->GetElseBlock());
|
||||
clone->Append<CondBrInst>(looputils::RemapValue(remap, condbr->GetCondition()),
|
||||
then_block, else_block, nullptr);
|
||||
clone->AddSuccessor(then_block);
|
||||
clone->AddSuccessor(else_block);
|
||||
then_block->AddPredecessor(clone);
|
||||
else_block->AddPredecessor(clone);
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
call_block->Append<UncondBrInst>(block_map.at(callee->GetEntryBlock()), nullptr);
|
||||
call_block->AddSuccessor(block_map.at(callee->GetEntryBlock()));
|
||||
block_map.at(callee->GetEntryBlock())->AddPredecessor(call_block);
|
||||
|
||||
if (!call->GetType()->IsVoid()) {
|
||||
Value* return_value = nullptr;
|
||||
if (return_edges.size() == 1) {
|
||||
return_value = return_edges.front().second;
|
||||
} else {
|
||||
auto* phi = continuation->Insert<PhiInst>(
|
||||
looputils::GetFirstNonPhiIndex(continuation), call->GetType(), nullptr,
|
||||
looputils::NextSyntheticName(caller, "inline.ret."));
|
||||
for (const auto& [pred, value] : return_edges) {
|
||||
if (!value) {
|
||||
return false;
|
||||
}
|
||||
phi->AddIncoming(value, pred);
|
||||
}
|
||||
return_value = phi;
|
||||
}
|
||||
if (!return_value) {
|
||||
return false;
|
||||
}
|
||||
call->ReplaceAllUsesWith(return_value);
|
||||
}
|
||||
|
||||
call_block->EraseInstruction(call);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RunFunctionInliningOnFunction(
|
||||
Function& function,
|
||||
const std::unordered_map<Function*, InlineCandidateInfo>& callee_info,
|
||||
const std::unordered_map<Function*, int>& call_counts) {
|
||||
if (function.IsExternal()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
std::vector<BasicBlock*> block_snapshot;
|
||||
block_snapshot.reserve(function.GetBlocks().size());
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
if (block_ptr) {
|
||||
block_snapshot.push_back(block_ptr.get());
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* block : block_snapshot) {
|
||||
if (!block) {
|
||||
continue;
|
||||
}
|
||||
std::vector<CallInst*> calls;
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
if (auto* call = dyncast<CallInst>(inst_ptr.get())) {
|
||||
calls.push_back(call);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* call : calls) {
|
||||
auto* callee = call->GetCallee();
|
||||
if (!callee) {
|
||||
continue;
|
||||
}
|
||||
auto info_it = callee_info.find(callee);
|
||||
if (info_it == callee_info.end()) {
|
||||
continue;
|
||||
}
|
||||
const int call_count =
|
||||
call_counts.count(callee) != 0 ? call_counts.at(callee) : 0;
|
||||
if (!ShouldInlineCallSite(function, *call, info_it->second, call_count)) {
|
||||
continue;
|
||||
}
|
||||
if (callee->GetBlocks().size() == 1) {
|
||||
changed |= InlineCallSite(function, call);
|
||||
} else {
|
||||
changed |= InlineCFGCallSite(function, call);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunFunctionInlining(Module& module) {
|
||||
std::unordered_map<Function*, InlineCandidateInfo> callee_info;
|
||||
for (const auto& function_ptr : module.GetFunctions()) {
|
||||
if (function_ptr) {
|
||||
callee_info.emplace(function_ptr.get(), AnalyzeInlineCandidate(*function_ptr));
|
||||
}
|
||||
}
|
||||
|
||||
const auto call_counts = CountDirectCalls(module);
|
||||
bool changed = false;
|
||||
for (const auto& function_ptr : module.GetFunctions()) {
|
||||
if (function_ptr) {
|
||||
changed |= RunFunctionInliningOnFunction(*function_ptr, callee_info, call_counts);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,236 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "LoopMemoryUtils.h"
|
||||
#include "LoopPassUtils.h"
|
||||
#include "MemoryUtils.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct HoistedLoadKey {
|
||||
memutils::AddressKey address;
|
||||
std::uintptr_t type_id = 0;
|
||||
|
||||
bool operator==(const HoistedLoadKey& rhs) const {
|
||||
return type_id == rhs.type_id && address == rhs.address;
|
||||
}
|
||||
};
|
||||
|
||||
struct HoistedLoadKeyHash {
|
||||
std::size_t operator()(const HoistedLoadKey& key) const {
|
||||
std::size_t h = memutils::AddressKeyHash{}(key.address);
|
||||
h ^= std::hash<std::uintptr_t>{}(key.type_id) + 0x9e3779b9 + (h << 6) + (h >> 2);
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
bool IsHoistableInstruction(const Instruction* inst) {
|
||||
if (!inst || inst->IsTerminator() || inst->IsVoid()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FSub:
|
||||
case Opcode::FMul:
|
||||
case Opcode::FDiv:
|
||||
case Opcode::And:
|
||||
case Opcode::Or:
|
||||
case Opcode::Xor:
|
||||
case Opcode::Shl:
|
||||
case Opcode::AShr:
|
||||
case Opcode::LShr:
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGE:
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
case Opcode::FCmpLT:
|
||||
case Opcode::FCmpGT:
|
||||
case Opcode::FCmpLE:
|
||||
case Opcode::FCmpGE:
|
||||
case Opcode::Neg:
|
||||
case Opcode::Not:
|
||||
case Opcode::FNeg:
|
||||
case Opcode::FtoI:
|
||||
case Opcode::IToF:
|
||||
case Opcode::GetElementPtr:
|
||||
case Opcode::Zext:
|
||||
case Opcode::Load:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsLoopInvariantInstruction(
|
||||
const Loop& loop, Instruction* inst,
|
||||
const std::unordered_set<Instruction*>& invariant_insts,
|
||||
PhiInst* iv, int iv_stride,
|
||||
const std::vector<loopmem::MemoryAccessInfo>& accesses,
|
||||
const memutils::EscapeSummary& escapes) {
|
||||
if (!IsHoistableInstruction(inst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
|
||||
auto* operand = inst->GetOperand(i);
|
||||
auto* operand_inst = dyncast<Instruction>(operand);
|
||||
if (!operand_inst) {
|
||||
continue;
|
||||
}
|
||||
if (!loop.Contains(operand_inst->GetParent())) {
|
||||
continue;
|
||||
}
|
||||
if (invariant_insts.find(operand_inst) == invariant_insts.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto* load = dyncast<LoadInst>(inst)) {
|
||||
return loopmem::IsSafeInvariantLoadToHoist(loop, load, iv, iv_stride, accesses, &escapes);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HoistLoopInvariants(Function& function, const Loop& loop,
|
||||
BasicBlock* preheader) {
|
||||
if (!preheader) {
|
||||
return false;
|
||||
}
|
||||
|
||||
loopmem::SimpleInductionVar induction_var;
|
||||
PhiInst* iv = nullptr;
|
||||
int iv_stride = 1;
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
if (loopmem::MatchSimpleInductionVariable(loop, preheader, phi, induction_var)) {
|
||||
iv = induction_var.phi;
|
||||
iv_stride = induction_var.stride;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const auto escapes = memutils::AnalyzeEscapes(function);
|
||||
const auto accesses = loopmem::CollectMemoryAccesses(loop, iv, &escapes);
|
||||
|
||||
std::unordered_set<Instruction*> invariant_insts;
|
||||
std::vector<Instruction*> hoist_list;
|
||||
|
||||
bool progress = true;
|
||||
while (progress) {
|
||||
progress = false;
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
auto* block = block_ptr.get();
|
||||
if (!loop.Contains(block) || block == preheader) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (invariant_insts.find(inst) != invariant_insts.end()) {
|
||||
continue;
|
||||
}
|
||||
if (!IsLoopInvariantInstruction(loop, inst, invariant_insts, iv, iv_stride,
|
||||
accesses, escapes)) {
|
||||
continue;
|
||||
}
|
||||
invariant_insts.insert(inst);
|
||||
hoist_list.push_back(inst);
|
||||
progress = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
std::unordered_map<HoistedLoadKey, LoadInst*, HoistedLoadKeyHash> hoisted_loads;
|
||||
for (auto* inst : hoist_list) {
|
||||
if (auto* load = dyncast<LoadInst>(inst)) {
|
||||
auto ptr = loopmem::AnalyzePointer(load->GetPtr(), iv, loop,
|
||||
load->GetType()->GetSize(), &escapes);
|
||||
if (ptr.exact_key_valid) {
|
||||
HoistedLoadKey key{ptr.exact_key,
|
||||
reinterpret_cast<std::uintptr_t>(load->GetType().get())};
|
||||
auto it = hoisted_loads.find(key);
|
||||
if (it != hoisted_loads.end()) {
|
||||
load->ReplaceAllUsesWith(it->second);
|
||||
load->GetParent()->EraseInstruction(load);
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
auto* moved = dyncast<LoadInst>(
|
||||
looputils::MoveInstructionBeforeTerminator(load, preheader));
|
||||
if (moved) {
|
||||
hoisted_loads.emplace(std::move(key), moved);
|
||||
changed = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (looputils::MoveInstructionBeforeTerminator(inst, preheader)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RunLICMOnFunction(Function& function) {
|
||||
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
while (true) {
|
||||
DominatorTree dom_tree(function);
|
||||
LoopInfo loop_info(function, dom_tree);
|
||||
bool local_changed = false;
|
||||
|
||||
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
|
||||
auto* old_preheader = loop->preheader;
|
||||
auto* preheader = looputils::EnsurePreheader(function, *loop);
|
||||
bool loop_changed = preheader != old_preheader;
|
||||
loop_changed |= HoistLoopInvariants(function, *loop, preheader);
|
||||
if (!loop_changed) {
|
||||
continue;
|
||||
}
|
||||
changed = true;
|
||||
local_changed = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!local_changed) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLICM(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunLICMOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,323 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "MemoryUtils.h"
|
||||
#include "PassUtils.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct AvailableValue {
|
||||
Value* value = nullptr;
|
||||
|
||||
bool operator==(const AvailableValue& rhs) const {
|
||||
return passutils::AreEquivalentValues(value, rhs.value) || value == rhs.value;
|
||||
}
|
||||
};
|
||||
|
||||
using MemoryState =
|
||||
std::unordered_map<memutils::AddressKey, AvailableValue,
|
||||
memutils::AddressKeyHash>;
|
||||
|
||||
bool SameAvailableValue(const AvailableValue& lhs, const AvailableValue& rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
|
||||
bool SameMemoryState(const MemoryState& lhs, const MemoryState& rhs) {
|
||||
if (lhs.size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
for (const auto& [key, value] : lhs) {
|
||||
auto it = rhs.find(key);
|
||||
if (it == rhs.end() || !SameAvailableValue(value, it->second)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MemoryState MeetMemoryStates(const std::vector<MemoryState*>& predecessors) {
|
||||
if (predecessors.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
MemoryState in = *predecessors.front();
|
||||
for (auto it = in.begin(); it != in.end();) {
|
||||
bool keep = true;
|
||||
for (std::size_t i = 1; i < predecessors.size(); ++i) {
|
||||
auto pred_it = predecessors[i]->find(it->first);
|
||||
if (pred_it == predecessors[i]->end() ||
|
||||
!SameAvailableValue(it->second, pred_it->second)) {
|
||||
keep = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!keep) {
|
||||
it = in.erase(it);
|
||||
continue;
|
||||
}
|
||||
++it;
|
||||
}
|
||||
return in;
|
||||
}
|
||||
|
||||
void InvalidateAliasStates(MemoryState& state,
|
||||
const memutils::AddressKey& key) {
|
||||
for (auto it = state.begin(); it != state.end();) {
|
||||
if (memutils::MayAliasConservatively(it->first, key)) {
|
||||
it = state.erase(it);
|
||||
continue;
|
||||
}
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
void InvalidateStatesForCall(MemoryState& state, Function* callee) {
|
||||
for (auto it = state.begin(); it != state.end();) {
|
||||
if (memutils::CallMayWriteRoot(callee, it->first.kind)) {
|
||||
it = state.erase(it);
|
||||
continue;
|
||||
}
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
void SimulateInstruction(const memutils::EscapeSummary& escapes, Instruction* inst,
|
||||
MemoryState& state) {
|
||||
if (!inst) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto* load = dyncast<LoadInst>(inst)) {
|
||||
memutils::AddressKey key;
|
||||
if (!memutils::BuildExactAddressKey(load->GetPtr(), &escapes, key)) {
|
||||
state.clear();
|
||||
return;
|
||||
}
|
||||
if (state.find(key) == state.end()) {
|
||||
state[key] = {load};
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto* store = dyncast<StoreInst>(inst)) {
|
||||
memutils::AddressKey key;
|
||||
if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) {
|
||||
state.clear();
|
||||
return;
|
||||
}
|
||||
InvalidateAliasStates(state, key);
|
||||
state[key] = {store->GetValue()};
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto* call = dyncast<CallInst>(inst)) {
|
||||
InvalidateStatesForCall(state, call->GetCallee());
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto* memset = dyncast<MemsetInst>(inst)) {
|
||||
memutils::AddressKey key;
|
||||
if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key)) {
|
||||
state.clear();
|
||||
return;
|
||||
}
|
||||
InvalidateAliasStates(state, key);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
MemoryState SimulateBlock(const memutils::EscapeSummary& escapes, BasicBlock* block,
|
||||
const MemoryState& in_state) {
|
||||
MemoryState state = in_state;
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
SimulateInstruction(escapes, inst_ptr.get(), state);
|
||||
}
|
||||
return state;
|
||||
}
|
||||
|
||||
bool MarkLoadObserved(
|
||||
const memutils::AddressKey& key,
|
||||
std::unordered_map<memutils::AddressKey, StoreInst*, memutils::AddressKeyHash>&
|
||||
pending_stores) {
|
||||
bool changed = false;
|
||||
for (auto it = pending_stores.begin(); it != pending_stores.end();) {
|
||||
if (memutils::MayAliasConservatively(it->first, key)) {
|
||||
it = pending_stores.erase(it);
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
++it;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
void InvalidatePendingForCall(
|
||||
std::unordered_map<memutils::AddressKey, StoreInst*, memutils::AddressKeyHash>&
|
||||
pending_stores,
|
||||
Function* callee) {
|
||||
for (auto it = pending_stores.begin(); it != pending_stores.end();) {
|
||||
if (memutils::CallMayReadRoot(callee, it->first.kind) ||
|
||||
memutils::CallMayWriteRoot(callee, it->first.kind)) {
|
||||
it = pending_stores.erase(it);
|
||||
continue;
|
||||
}
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
bool OptimizeBlock(
|
||||
const memutils::EscapeSummary& escapes, BasicBlock* block,
|
||||
const MemoryState& in_state) {
|
||||
bool changed = false;
|
||||
MemoryState state = in_state;
|
||||
std::unordered_map<memutils::AddressKey, StoreInst*, memutils::AddressKeyHash>
|
||||
pending_stores;
|
||||
std::vector<Instruction*> to_remove;
|
||||
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
|
||||
if (auto* load = dyncast<LoadInst>(inst)) {
|
||||
memutils::AddressKey key;
|
||||
if (!memutils::BuildExactAddressKey(load->GetPtr(), &escapes, key)) {
|
||||
state.clear();
|
||||
pending_stores.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
MarkLoadObserved(key, pending_stores);
|
||||
auto it = state.find(key);
|
||||
if (it != state.end() && it->second.value != load) {
|
||||
load->ReplaceAllUsesWith(it->second.value);
|
||||
to_remove.push_back(load);
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
if (state.find(key) == state.end()) {
|
||||
state[key] = {load};
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto* store = dyncast<StoreInst>(inst)) {
|
||||
memutils::AddressKey key;
|
||||
if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) {
|
||||
state.clear();
|
||||
pending_stores.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto it = pending_stores.begin(); it != pending_stores.end();) {
|
||||
if (!memutils::MayAliasConservatively(it->first, key)) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
if (it->first == key) {
|
||||
to_remove.push_back(it->second);
|
||||
changed = true;
|
||||
}
|
||||
it = pending_stores.erase(it);
|
||||
}
|
||||
|
||||
pending_stores.emplace(key, store);
|
||||
InvalidateAliasStates(state, key);
|
||||
state[key] = {store->GetValue()};
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto* call = dyncast<CallInst>(inst)) {
|
||||
InvalidateStatesForCall(state, call->GetCallee());
|
||||
InvalidatePendingForCall(pending_stores, call->GetCallee());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto* memset = dyncast<MemsetInst>(inst)) {
|
||||
memutils::AddressKey key;
|
||||
if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key)) {
|
||||
state.clear();
|
||||
pending_stores.clear();
|
||||
continue;
|
||||
}
|
||||
InvalidateAliasStates(state, key);
|
||||
MarkLoadObserved(key, pending_stores);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* inst : to_remove) {
|
||||
if (inst->GetParent() == block) {
|
||||
block->EraseInstruction(inst);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RunLoadStoreElimOnFunction(Function& function) {
|
||||
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto escapes = memutils::AnalyzeEscapes(function);
|
||||
const auto reachable_blocks = passutils::CollectReachableBlocks(function);
|
||||
if (reachable_blocks.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_map<BasicBlock*, MemoryState> in_states;
|
||||
std::unordered_map<BasicBlock*, MemoryState> out_states;
|
||||
|
||||
bool dataflow_changed = true;
|
||||
while (dataflow_changed) {
|
||||
dataflow_changed = false;
|
||||
for (auto* block : reachable_blocks) {
|
||||
MemoryState in_state;
|
||||
if (block != function.GetEntryBlock()) {
|
||||
std::vector<MemoryState*> predecessors;
|
||||
for (auto* pred : block->GetPredecessors()) {
|
||||
auto it = out_states.find(pred);
|
||||
if (it != out_states.end()) {
|
||||
predecessors.push_back(&it->second);
|
||||
}
|
||||
}
|
||||
in_state = MeetMemoryStates(predecessors);
|
||||
}
|
||||
|
||||
auto out_state = SimulateBlock(escapes, block, in_state);
|
||||
auto in_it = in_states.find(block);
|
||||
if (in_it == in_states.end() || !SameMemoryState(in_it->second, in_state)) {
|
||||
in_states[block] = in_state;
|
||||
dataflow_changed = true;
|
||||
}
|
||||
auto out_it = out_states.find(block);
|
||||
if (out_it == out_states.end() || !SameMemoryState(out_it->second, out_state)) {
|
||||
out_states[block] = std::move(out_state);
|
||||
dataflow_changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
for (auto* block : reachable_blocks) {
|
||||
changed |= OptimizeBlock(escapes, block, in_states[block]);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoadStoreElim(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunLoadStoreElimOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,326 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "LoopMemoryUtils.h"
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct FissionLoopInfo {
|
||||
Loop* loop = nullptr;
|
||||
BasicBlock* preheader = nullptr;
|
||||
BasicBlock* header = nullptr;
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
CondBrInst* branch = nullptr;
|
||||
BinaryInst* compare = nullptr;
|
||||
Opcode compare_opcode = Opcode::ICmpLT;
|
||||
Value* bound = nullptr;
|
||||
loopmem::SimpleInductionVar induction_var;
|
||||
PhiInst* iv = nullptr;
|
||||
BinaryInst* step_inst = nullptr;
|
||||
};
|
||||
|
||||
bool HasSyntheticLoopTag(const std::string& name) {
|
||||
return name.find("unroll.") != std::string::npos ||
|
||||
name.find("fission.") != std::string::npos;
|
||||
}
|
||||
|
||||
bool IsAlreadyTransformedLoop(const Loop& loop, BasicBlock* body) {
|
||||
if (!loop.preheader || !loop.header || !body) {
|
||||
return true;
|
||||
}
|
||||
return HasSyntheticLoopTag(loop.preheader->GetName()) ||
|
||||
HasSyntheticLoopTag(loop.header->GetName()) ||
|
||||
HasSyntheticLoopTag(body->GetName());
|
||||
}
|
||||
|
||||
Opcode SwapCompareOpcode(Opcode opcode) {
|
||||
switch (opcode) {
|
||||
case Opcode::ICmpLT:
|
||||
return Opcode::ICmpGT;
|
||||
case Opcode::ICmpLE:
|
||||
return Opcode::ICmpGE;
|
||||
case Opcode::ICmpGT:
|
||||
return Opcode::ICmpLT;
|
||||
case Opcode::ICmpGE:
|
||||
return Opcode::ICmpLE;
|
||||
default:
|
||||
return opcode;
|
||||
}
|
||||
}
|
||||
|
||||
bool MatchFissionLoop(Loop& loop, FissionLoopInfo& info) {
|
||||
if (!loop.preheader || !loop.header || !loop.IsInnermost()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) {
|
||||
return false;
|
||||
}
|
||||
if (IsAlreadyTransformedLoop(loop, body)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<PhiInst*> phis;
|
||||
loopmem::SimpleInductionVar induction_var;
|
||||
bool found_iv = false;
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
phis.push_back(phi);
|
||||
if (!found_iv &&
|
||||
loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) {
|
||||
found_iv = true;
|
||||
}
|
||||
}
|
||||
if (!found_iv || phis.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* branch = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
|
||||
auto* compare = branch ? dyncast<BinaryInst>(branch->GetCondition()) : nullptr;
|
||||
if (!branch || branch->GetThenBlock() != body || !compare) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Opcode compare_opcode = compare->GetOpcode();
|
||||
Value* bound = nullptr;
|
||||
if (compare->GetLhs() == induction_var.phi &&
|
||||
looputils::IsLoopInvariantValue(loop, compare->GetRhs())) {
|
||||
bound = compare->GetRhs();
|
||||
} else if (compare->GetRhs() == induction_var.phi &&
|
||||
looputils::IsLoopInvariantValue(loop, compare->GetLhs())) {
|
||||
bound = compare->GetLhs();
|
||||
compare_opcode = SwapCompareOpcode(compare_opcode);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* step_inst = dyncast<BinaryInst>(induction_var.latch_value);
|
||||
if (!step_inst || step_inst->GetParent() != body) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto& inst_ptr : body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator() || inst == step_inst) {
|
||||
continue;
|
||||
}
|
||||
if (!looputils::IsCloneableInstruction(inst) || dyncast<CallInst>(inst) ||
|
||||
dyncast<MemsetInst>(inst) || dyncast<AllocaInst>(inst)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
info.loop = &loop;
|
||||
info.preheader = loop.preheader;
|
||||
info.header = loop.header;
|
||||
info.body = body;
|
||||
info.exit = exit;
|
||||
info.branch = branch;
|
||||
info.compare = compare;
|
||||
info.compare_opcode = compare_opcode;
|
||||
info.bound = bound;
|
||||
info.induction_var = induction_var;
|
||||
info.iv = induction_var.phi;
|
||||
info.step_inst = step_inst;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ContainsInterestingPayload(const std::vector<Instruction*>& group) {
|
||||
bool has_memory = false;
|
||||
for (auto* inst : group) {
|
||||
if (dyncast<LoadInst>(inst) || dyncast<StoreInst>(inst)) {
|
||||
has_memory = true;
|
||||
}
|
||||
}
|
||||
return has_memory;
|
||||
}
|
||||
|
||||
Value* RemapExitValue(Value* value, PhiInst* old_iv, PhiInst* new_iv) {
|
||||
if (value == old_iv) {
|
||||
return new_iv;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
bool BuildSecondLoop(Function& function, const FissionLoopInfo& info,
|
||||
const std::vector<Instruction*>& second_group) {
|
||||
auto* second_header =
|
||||
function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.header"));
|
||||
auto* second_body =
|
||||
function.CreateBlock(looputils::NextSyntheticBlockName(function, "fission.body"));
|
||||
|
||||
const int preheader_index = looputils::GetPhiIncomingIndex(info.iv, info.preheader);
|
||||
if (preheader_index < 0) {
|
||||
return false;
|
||||
}
|
||||
auto* second_iv = second_header->Append<PhiInst>(
|
||||
info.iv->GetType(), nullptr,
|
||||
looputils::NextSyntheticName(function, "fission.iv."));
|
||||
second_iv->AddIncoming(info.iv->GetIncomingValue(preheader_index), info.header);
|
||||
|
||||
auto* second_cmp = second_header->Append<BinaryInst>(
|
||||
info.compare_opcode, Type::GetBoolType(), second_iv, info.bound, nullptr,
|
||||
looputils::NextSyntheticName(function, "fission.cmp."));
|
||||
second_header->Append<CondBrInst>(second_cmp, second_body, info.exit, nullptr);
|
||||
second_header->AddPredecessor(info.header);
|
||||
second_header->AddSuccessor(second_body);
|
||||
second_header->AddSuccessor(info.exit);
|
||||
|
||||
std::unordered_map<Value*, Value*> remap;
|
||||
remap[info.iv] = second_iv;
|
||||
std::unordered_set<Instruction*> selected(second_group.begin(), second_group.end());
|
||||
selected.insert(info.step_inst);
|
||||
for (const auto& inst_ptr : info.body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator() || selected.find(inst) == selected.end()) {
|
||||
continue;
|
||||
}
|
||||
looputils::CloneInstruction(function, inst, second_body, remap, "fission.");
|
||||
}
|
||||
auto* cloned_step_value = looputils::RemapValue(remap, info.step_inst);
|
||||
if (!cloned_step_value) {
|
||||
return false;
|
||||
}
|
||||
second_iv->AddIncoming(cloned_step_value, second_body);
|
||||
second_body->Append<UncondBrInst>(second_header, nullptr);
|
||||
second_body->AddPredecessor(second_header);
|
||||
second_body->AddSuccessor(second_header);
|
||||
second_header->AddPredecessor(second_body);
|
||||
|
||||
if (!looputils::RedirectSuccessorEdge(info.header, info.exit, second_header)) {
|
||||
return false;
|
||||
}
|
||||
info.exit->RemovePredecessor(info.header);
|
||||
info.exit->AddPredecessor(second_header);
|
||||
|
||||
for (const auto& inst_ptr : info.exit->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
const int incoming = looputils::GetPhiIncomingIndex(phi, info.header);
|
||||
if (incoming < 0) {
|
||||
continue;
|
||||
}
|
||||
phi->SetOperand(static_cast<std::size_t>(2 * incoming),
|
||||
RemapExitValue(phi->GetIncomingValue(incoming), info.iv, second_iv));
|
||||
phi->SetOperand(static_cast<std::size_t>(2 * incoming + 1), second_header);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RunLoopFissionOnFunction(Function& function) {
|
||||
if (function.IsExternal() || !function.GetEntryBlock()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
while (true) {
|
||||
DominatorTree dom_tree(function);
|
||||
LoopInfo loop_info(function, dom_tree);
|
||||
bool local_changed = false;
|
||||
|
||||
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
|
||||
FissionLoopInfo info;
|
||||
if (!MatchFissionLoop(*loop, info)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto accesses = loopmem::CollectMemoryAccesses(*loop, info.iv);
|
||||
std::vector<Instruction*> payload;
|
||||
for (const auto& inst_ptr : info.body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator() || inst == info.step_inst) {
|
||||
continue;
|
||||
}
|
||||
payload.push_back(inst);
|
||||
}
|
||||
if (payload.size() < 2) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int chosen_cut = -1;
|
||||
std::vector<Instruction*> first_group;
|
||||
std::vector<Instruction*> second_group;
|
||||
for (std::size_t cut = 1; cut < payload.size(); ++cut) {
|
||||
std::vector<Instruction*> first(payload.begin(), payload.begin() + static_cast<long long>(cut));
|
||||
std::vector<Instruction*> second(payload.begin() + static_cast<long long>(cut),
|
||||
payload.end());
|
||||
if (!ContainsInterestingPayload(first) || !ContainsInterestingPayload(second)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unordered_set<Instruction*> first_set(first.begin(), first.end());
|
||||
std::unordered_set<Instruction*> second_set(second.begin(), second.end());
|
||||
if (loopmem::HasScalarDependenceAcrossCut(first, second_set) ||
|
||||
loopmem::HasMemoryDependenceAcrossCut(accesses, first_set, second_set,
|
||||
info.induction_var.stride)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
chosen_cut = static_cast<int>(cut);
|
||||
first_group = std::move(first);
|
||||
second_group = std::move(second);
|
||||
break;
|
||||
}
|
||||
|
||||
if (chosen_cut < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unordered_set<Instruction*> keep(first_group.begin(), first_group.end());
|
||||
keep.insert(info.step_inst);
|
||||
std::vector<Instruction*> to_remove;
|
||||
for (const auto& inst_ptr : info.body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator() || keep.find(inst) != keep.end()) {
|
||||
continue;
|
||||
}
|
||||
to_remove.push_back(inst);
|
||||
}
|
||||
if (!BuildSecondLoop(function, info, second_group)) {
|
||||
continue;
|
||||
}
|
||||
for (auto* inst : to_remove) {
|
||||
info.body->EraseInstruction(inst);
|
||||
}
|
||||
|
||||
changed = true;
|
||||
local_changed = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!local_changed) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoopFission(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunLoopFissionOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,855 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "LoopMemoryUtils.h"
|
||||
#include "LoopPassUtils.h"
|
||||
#include "MemoryUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct DominatorInfo {
|
||||
std::vector<BasicBlock*> blocks;
|
||||
std::unordered_map<BasicBlock*, size_t> index;
|
||||
std::vector<std::vector<bool>> dominates;
|
||||
std::vector<BasicBlock*> idom;
|
||||
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dom_tree_children;
|
||||
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dominance_frontier;
|
||||
};
|
||||
|
||||
enum class SeedStateKind { Unavailable, Available, Conflict };
|
||||
|
||||
struct SeedState {
|
||||
SeedStateKind kind = SeedStateKind::Unavailable;
|
||||
StoreInst* store = nullptr;
|
||||
|
||||
bool operator==(const SeedState& rhs) const {
|
||||
return kind == rhs.kind && store == rhs.store;
|
||||
}
|
||||
bool operator!=(const SeedState& rhs) const { return !(*this == rhs); }
|
||||
};
|
||||
|
||||
struct CandidateKey {
|
||||
memutils::AddressKey address;
|
||||
std::uintptr_t type_id = 0;
|
||||
|
||||
bool operator==(const CandidateKey& rhs) const {
|
||||
return type_id == rhs.type_id && address == rhs.address;
|
||||
}
|
||||
};
|
||||
|
||||
struct CandidateKeyHash {
|
||||
std::size_t operator()(const CandidateKey& key) const {
|
||||
std::size_t h = memutils::AddressKeyHash{}(key.address);
|
||||
h ^= std::hash<std::uintptr_t>{}(key.type_id) + 0x9e3779b9 + (h << 6) + (h >> 2);
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
struct PromotionCandidate {
|
||||
CandidateKey key;
|
||||
std::shared_ptr<Type> value_type;
|
||||
loopmem::PointerInfo pointer_info;
|
||||
std::vector<LoadInst*> loads;
|
||||
std::vector<StoreInst*> stores;
|
||||
StoreInst* seed_store = nullptr;
|
||||
Value* canonical_ptr = nullptr;
|
||||
Value* initial_value = nullptr;
|
||||
std::unordered_set<BasicBlock*> def_blocks;
|
||||
std::unordered_map<BasicBlock*, PhiInst*> phis;
|
||||
|
||||
int EstimatedBenefit() const {
|
||||
return static_cast<int>(loads.size()) + 2 * static_cast<int>(stores.size()) - 1;
|
||||
}
|
||||
};
|
||||
|
||||
bool IsScalarPromotableType(const std::shared_ptr<Type>& type) {
|
||||
return type && (type->IsInt1() || type->IsInt32() || type->IsFloat());
|
||||
}
|
||||
|
||||
int CountFunctionInstructions(const Function& function) {
|
||||
int count = 0;
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
if (!block_ptr) {
|
||||
continue;
|
||||
}
|
||||
count += static_cast<int>(block_ptr->GetInstructions().size());
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
int CountLoopInstructions(const Loop& loop) {
|
||||
int count = 0;
|
||||
for (auto* block : loop.block_list) {
|
||||
if (!block) {
|
||||
continue;
|
||||
}
|
||||
count += static_cast<int>(block->GetInstructions().size());
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
bool ShouldAnalyzeFunction(const Function& function) {
|
||||
constexpr int kMaxFunctionInstructions = 2000;
|
||||
return CountFunctionInstructions(function) <= kMaxFunctionInstructions;
|
||||
}
|
||||
|
||||
bool ShouldAnalyzeLoop(const Loop& loop) {
|
||||
constexpr int kMaxLoopBlocks = 8;
|
||||
constexpr int kMaxLoopInstructions = 96;
|
||||
return static_cast<int>(loop.block_list.size()) <= kMaxLoopBlocks &&
|
||||
CountLoopInstructions(loop) <= kMaxLoopInstructions;
|
||||
}
|
||||
|
||||
bool DominatesBlock(const DominatorInfo& info, BasicBlock* dom, BasicBlock* block) {
|
||||
if (!dom || !block) {
|
||||
return false;
|
||||
}
|
||||
auto dom_it = info.index.find(dom);
|
||||
auto block_it = info.index.find(block);
|
||||
if (dom_it == info.index.end() || block_it == info.index.end()) {
|
||||
return false;
|
||||
}
|
||||
return info.dominates[block_it->second][dom_it->second];
|
||||
}
|
||||
|
||||
std::vector<BasicBlock*> CollectReachableBlocks(Function& function) {
|
||||
std::vector<BasicBlock*> order;
|
||||
auto* entry = function.GetEntryBlock();
|
||||
if (!entry) {
|
||||
return order;
|
||||
}
|
||||
|
||||
std::unordered_set<BasicBlock*> visited;
|
||||
std::vector<BasicBlock*> stack{entry};
|
||||
while (!stack.empty()) {
|
||||
auto* block = stack.back();
|
||||
stack.pop_back();
|
||||
if (!block || !visited.insert(block).second) {
|
||||
continue;
|
||||
}
|
||||
order.push_back(block);
|
||||
const auto& succs = block->GetSuccessors();
|
||||
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
|
||||
if (*it) {
|
||||
stack.push_back(*it);
|
||||
}
|
||||
}
|
||||
}
|
||||
return order;
|
||||
}
|
||||
|
||||
std::vector<bool> IntersectDominators(const std::vector<std::vector<bool>>& doms,
|
||||
const std::vector<size_t>& pred_indices,
|
||||
size_t self_index) {
|
||||
std::vector<bool> result(doms.size(), true);
|
||||
if (pred_indices.empty()) {
|
||||
std::fill(result.begin(), result.end(), false);
|
||||
result[self_index] = true;
|
||||
return result;
|
||||
}
|
||||
|
||||
result = doms[pred_indices.front()];
|
||||
for (size_t i = 1; i < pred_indices.size(); ++i) {
|
||||
const auto& pred_dom = doms[pred_indices[i]];
|
||||
for (size_t j = 0; j < result.size(); ++j) {
|
||||
result[j] = result[j] && pred_dom[j];
|
||||
}
|
||||
}
|
||||
result[self_index] = true;
|
||||
return result;
|
||||
}
|
||||
|
||||
DominatorInfo BuildDominatorInfo(Function& function) {
|
||||
DominatorInfo info;
|
||||
info.blocks = CollectReachableBlocks(function);
|
||||
info.idom.resize(info.blocks.size(), nullptr);
|
||||
info.dominates.assign(info.blocks.size(),
|
||||
std::vector<bool>(info.blocks.size(), true));
|
||||
|
||||
if (info.blocks.empty()) {
|
||||
return info;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < info.blocks.size(); ++i) {
|
||||
info.index[info.blocks[i]] = i;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < info.blocks.size(); ++i) {
|
||||
std::fill(info.dominates[i].begin(), info.dominates[i].end(), i != 0);
|
||||
info.dominates[i][i] = true;
|
||||
}
|
||||
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
for (size_t i = 1; i < info.blocks.size(); ++i) {
|
||||
std::vector<size_t> pred_indices;
|
||||
for (auto* pred : info.blocks[i]->GetPredecessors()) {
|
||||
auto it = info.index.find(pred);
|
||||
if (it != info.index.end()) {
|
||||
pred_indices.push_back(it->second);
|
||||
}
|
||||
}
|
||||
auto new_dom = IntersectDominators(info.dominates, pred_indices, i);
|
||||
if (new_dom != info.dominates[i]) {
|
||||
info.dominates[i] = std::move(new_dom);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < info.blocks.size(); ++i) {
|
||||
BasicBlock* candidate_idom = nullptr;
|
||||
for (size_t j = 0; j < info.blocks.size(); ++j) {
|
||||
if (i == j || !info.dominates[i][j]) {
|
||||
continue;
|
||||
}
|
||||
bool is_immediate = true;
|
||||
for (size_t k = 0; k < info.blocks.size(); ++k) {
|
||||
if (k == i || k == j || !info.dominates[i][k]) {
|
||||
continue;
|
||||
}
|
||||
if (info.dominates[k][j]) {
|
||||
is_immediate = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (is_immediate) {
|
||||
candidate_idom = info.blocks[j];
|
||||
break;
|
||||
}
|
||||
}
|
||||
info.idom[i] = candidate_idom;
|
||||
if (candidate_idom) {
|
||||
info.dom_tree_children[candidate_idom].push_back(info.blocks[i]);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* block : info.blocks) {
|
||||
info.dominance_frontier[block] = {};
|
||||
}
|
||||
|
||||
for (auto* block : info.blocks) {
|
||||
std::vector<BasicBlock*> reachable_preds;
|
||||
for (auto* pred : block->GetPredecessors()) {
|
||||
if (info.index.find(pred) != info.index.end()) {
|
||||
reachable_preds.push_back(pred);
|
||||
}
|
||||
}
|
||||
if (reachable_preds.size() < 2) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* idom_block = info.idom[info.index[block]];
|
||||
for (auto* pred : reachable_preds) {
|
||||
auto* runner = pred;
|
||||
while (runner && runner != idom_block) {
|
||||
auto& frontier = info.dominance_frontier[runner];
|
||||
if (std::find(frontier.begin(), frontier.end(), block) == frontier.end()) {
|
||||
frontier.push_back(block);
|
||||
}
|
||||
auto idom_it = info.index.find(runner);
|
||||
if (idom_it == info.index.end()) {
|
||||
break;
|
||||
}
|
||||
runner = info.idom[idom_it->second];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
SeedState MergeSeedState(const SeedState& lhs, const SeedState& rhs) {
|
||||
if (lhs == rhs) {
|
||||
return lhs;
|
||||
}
|
||||
return {SeedStateKind::Conflict, nullptr};
|
||||
}
|
||||
|
||||
SeedState TransferSeedState(const SeedState& in, BasicBlock* block,
|
||||
const PromotionCandidate& candidate,
|
||||
const memutils::EscapeSummary& escapes) {
|
||||
SeedState state = in;
|
||||
if (!block) {
|
||||
return state;
|
||||
}
|
||||
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (auto* call = dyncast<CallInst>(inst)) {
|
||||
if (memutils::CallMayWriteRoot(call->GetCallee(), candidate.pointer_info.root_kind)) {
|
||||
state = {SeedStateKind::Unavailable, nullptr};
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (auto* memset = dyncast<MemsetInst>(inst)) {
|
||||
memutils::AddressKey key;
|
||||
if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key) ||
|
||||
memutils::MayAliasConservatively(key, candidate.key.address)) {
|
||||
state = {SeedStateKind::Unavailable, nullptr};
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* store = dyncast<StoreInst>(inst);
|
||||
if (!store) {
|
||||
continue;
|
||||
}
|
||||
|
||||
memutils::AddressKey key;
|
||||
if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) {
|
||||
state = {SeedStateKind::Unavailable, nullptr};
|
||||
continue;
|
||||
}
|
||||
if (!memutils::MayAliasConservatively(key, candidate.key.address)) {
|
||||
continue;
|
||||
}
|
||||
if (key == candidate.key.address && store->GetValue()->GetType() == candidate.value_type) {
|
||||
state = {SeedStateKind::Available, store};
|
||||
} else {
|
||||
state = {SeedStateKind::Unavailable, nullptr};
|
||||
}
|
||||
}
|
||||
return state;
|
||||
}
|
||||
|
||||
StoreInst* FindSeedStoreInPreheader(const Loop& loop,
|
||||
const PromotionCandidate& candidate,
|
||||
const memutils::EscapeSummary& escapes) {
|
||||
auto* preheader = loop.preheader;
|
||||
if (!preheader) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
StoreInst* seed = nullptr;
|
||||
for (const auto& inst_ptr : preheader->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (auto* call = dyncast<CallInst>(inst)) {
|
||||
if (memutils::CallMayWriteRoot(call->GetCallee(), candidate.pointer_info.root_kind)) {
|
||||
seed = nullptr;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (auto* memset = dyncast<MemsetInst>(inst)) {
|
||||
memutils::AddressKey key;
|
||||
if (!memutils::BuildExactAddressKey(memset->GetDest(), &escapes, key) ||
|
||||
memutils::MayAliasConservatively(key, candidate.key.address)) {
|
||||
seed = nullptr;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* store = dyncast<StoreInst>(inst);
|
||||
if (!store) {
|
||||
continue;
|
||||
}
|
||||
|
||||
memutils::AddressKey key;
|
||||
if (!memutils::BuildExactAddressKey(store->GetPtr(), &escapes, key)) {
|
||||
seed = nullptr;
|
||||
continue;
|
||||
}
|
||||
if (!memutils::MayAliasConservatively(key, candidate.key.address)) {
|
||||
continue;
|
||||
}
|
||||
if (key == candidate.key.address && store->GetValue()->GetType() == candidate.value_type) {
|
||||
seed = store;
|
||||
} else {
|
||||
seed = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
return seed;
|
||||
}
|
||||
|
||||
StoreInst* FindReachingSeedStoreAtLoopEntry(Function& function, const Loop& loop,
|
||||
const PromotionCandidate& candidate,
|
||||
const memutils::EscapeSummary& escapes) {
|
||||
auto* preheader = loop.preheader;
|
||||
if (!preheader) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto blocks = CollectReachableBlocks(function);
|
||||
std::unordered_map<BasicBlock*, SeedState> in_state;
|
||||
std::unordered_map<BasicBlock*, SeedState> out_state;
|
||||
for (auto* block : blocks) {
|
||||
in_state[block] = {SeedStateKind::Unavailable, nullptr};
|
||||
out_state[block] = {SeedStateKind::Unavailable, nullptr};
|
||||
}
|
||||
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
for (auto* block : blocks) {
|
||||
SeedState merged{SeedStateKind::Unavailable, nullptr};
|
||||
bool first_pred = true;
|
||||
for (auto* pred : block->GetPredecessors()) {
|
||||
auto it = out_state.find(pred);
|
||||
if (it == out_state.end()) {
|
||||
continue;
|
||||
}
|
||||
if (first_pred) {
|
||||
merged = it->second;
|
||||
first_pred = false;
|
||||
} else {
|
||||
merged = MergeSeedState(merged, it->second);
|
||||
}
|
||||
}
|
||||
|
||||
if (block == function.GetEntryBlock() && first_pred) {
|
||||
merged = {SeedStateKind::Unavailable, nullptr};
|
||||
}
|
||||
|
||||
SeedState next_out = TransferSeedState(merged, block, candidate, escapes);
|
||||
if (merged != in_state[block] || next_out != out_state[block]) {
|
||||
in_state[block] = merged;
|
||||
out_state[block] = next_out;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto it = out_state.find(preheader);
|
||||
if (it == out_state.end() || it->second.kind != SeedStateKind::Available) {
|
||||
return nullptr;
|
||||
}
|
||||
return it->second.store;
|
||||
}
|
||||
|
||||
bool ExitBlocksArePromotable(const Loop& loop) {
|
||||
for (auto* exit : loop.exit_blocks) {
|
||||
if (!exit) {
|
||||
return false;
|
||||
}
|
||||
for (auto* pred : exit->GetPredecessors()) {
|
||||
if (!loop.Contains(pred)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return !loop.exit_blocks.empty();
|
||||
}
|
||||
|
||||
bool IsSafeToPromoteCandidate(const Loop& loop, const PromotionCandidate& candidate,
|
||||
const std::vector<loopmem::MemoryAccessInfo>& accesses,
|
||||
int iv_stride, const DominatorInfo& dom_info) {
|
||||
if (!candidate.seed_store || !candidate.canonical_ptr || !candidate.initial_value) {
|
||||
return false;
|
||||
}
|
||||
if (loop.parent != nullptr && candidate.seed_store->GetParent() != loop.preheader) {
|
||||
return false;
|
||||
}
|
||||
if (!DominatesBlock(dom_info, candidate.seed_store->GetParent(), loop.preheader)) {
|
||||
return false;
|
||||
}
|
||||
auto* ptr_inst = dyncast<Instruction>(candidate.canonical_ptr);
|
||||
if (ptr_inst &&
|
||||
(loop.Contains(ptr_inst->GetParent()) ||
|
||||
!DominatesBlock(dom_info, ptr_inst->GetParent(), loop.preheader))) {
|
||||
return false;
|
||||
}
|
||||
if (!ExitBlocksArePromotable(loop)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto* block : loop.block_list) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* call = dyncast<CallInst>(inst_ptr.get());
|
||||
if (!call) {
|
||||
continue;
|
||||
}
|
||||
if (memutils::CallMayReadRoot(call->GetCallee(), candidate.pointer_info.root_kind) ||
|
||||
memutils::CallMayWriteRoot(call->GetCallee(), candidate.pointer_info.root_kind)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& access : accesses) {
|
||||
if (!loopmem::MayAliasSameIteration(candidate.pointer_info, access.ptr) &&
|
||||
!loopmem::HasCrossIterationDependence(candidate.pointer_info, access.ptr, iv_stride)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!access.ptr.exact_key_valid || !(access.ptr.exact_key == candidate.key.address)) {
|
||||
return false;
|
||||
}
|
||||
if (isa<MemsetInst>(access.inst)) {
|
||||
return false;
|
||||
}
|
||||
if (auto* load = dyncast<LoadInst>(access.inst)) {
|
||||
if (load->GetType() != candidate.value_type) {
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (auto* store = dyncast<StoreInst>(access.inst)) {
|
||||
if (store->GetValue()->GetType() != candidate.value_type) {
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void InsertPhiNodes(const Loop& loop, PromotionCandidate& candidate,
|
||||
const DominatorInfo& dom_info, Function& function) {
|
||||
std::queue<BasicBlock*> worklist;
|
||||
std::unordered_set<BasicBlock*> queued;
|
||||
|
||||
for (auto* block : candidate.def_blocks) {
|
||||
worklist.push(block);
|
||||
queued.insert(block);
|
||||
}
|
||||
|
||||
while (!worklist.empty()) {
|
||||
auto* block = worklist.front();
|
||||
worklist.pop();
|
||||
|
||||
auto frontier_it = dom_info.dominance_frontier.find(block);
|
||||
if (frontier_it == dom_info.dominance_frontier.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto* frontier_block : frontier_it->second) {
|
||||
if (!loop.Contains(frontier_block)) {
|
||||
continue;
|
||||
}
|
||||
if (candidate.phis.find(frontier_block) != candidate.phis.end()) {
|
||||
continue;
|
||||
}
|
||||
auto* phi = frontier_block->Insert<PhiInst>(
|
||||
looputils::GetFirstNonPhiIndex(frontier_block), candidate.value_type, nullptr,
|
||||
looputils::NextSyntheticName(function, "lmp.phi."));
|
||||
candidate.phis[frontier_block] = phi;
|
||||
if (candidate.def_blocks.insert(frontier_block).second && queued.insert(frontier_block).second) {
|
||||
worklist.push(frontier_block);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RenameCandidateInLoop(
|
||||
BasicBlock* block, const Loop& loop, PromotionCandidate& candidate,
|
||||
const DominatorInfo& dom_info, std::vector<Value*>& stack,
|
||||
std::unordered_map<BasicBlock*, Value*>& block_out) {
|
||||
if (!block || !loop.Contains(block)) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t pushed = 0;
|
||||
PhiInst* block_phi = nullptr;
|
||||
auto phi_it = candidate.phis.find(block);
|
||||
if (phi_it != candidate.phis.end()) {
|
||||
block_phi = phi_it->second;
|
||||
stack.push_back(block_phi);
|
||||
++pushed;
|
||||
}
|
||||
|
||||
std::vector<Instruction*> to_remove;
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst == block_phi) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto* load = dyncast<LoadInst>(inst)) {
|
||||
auto it = std::find(candidate.loads.begin(), candidate.loads.end(), load);
|
||||
if (it == candidate.loads.end()) {
|
||||
continue;
|
||||
}
|
||||
load->ReplaceAllUsesWith(stack.back());
|
||||
to_remove.push_back(load);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* store = dyncast<StoreInst>(inst);
|
||||
if (!store) {
|
||||
continue;
|
||||
}
|
||||
auto it = std::find(candidate.stores.begin(), candidate.stores.end(), store);
|
||||
if (it == candidate.stores.end()) {
|
||||
continue;
|
||||
}
|
||||
stack.push_back(store->GetValue());
|
||||
++pushed;
|
||||
to_remove.push_back(store);
|
||||
}
|
||||
|
||||
block_out[block] = stack.back();
|
||||
for (auto* succ : block->GetSuccessors()) {
|
||||
if (!loop.Contains(succ)) {
|
||||
continue;
|
||||
}
|
||||
auto succ_phi_it = candidate.phis.find(succ);
|
||||
if (succ_phi_it == candidate.phis.end()) {
|
||||
continue;
|
||||
}
|
||||
succ_phi_it->second->AddIncoming(stack.back(), block);
|
||||
}
|
||||
|
||||
auto child_it = dom_info.dom_tree_children.find(block);
|
||||
if (child_it != dom_info.dom_tree_children.end()) {
|
||||
for (auto* child : child_it->second) {
|
||||
RenameCandidateInLoop(child, loop, candidate, dom_info, stack, block_out);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* inst : to_remove) {
|
||||
if (inst->GetParent() == block) {
|
||||
block->EraseInstruction(inst);
|
||||
}
|
||||
}
|
||||
|
||||
while (pushed > 0) {
|
||||
stack.pop_back();
|
||||
--pushed;
|
||||
}
|
||||
}
|
||||
|
||||
void InsertExitStores(Function& function, const Loop& loop, PromotionCandidate& candidate,
|
||||
const std::unordered_map<BasicBlock*, Value*>& block_out) {
|
||||
std::unordered_set<BasicBlock*> seen;
|
||||
for (auto* exit : loop.exit_blocks) {
|
||||
if (!exit || !seen.insert(exit).second) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<BasicBlock*> preds;
|
||||
preds.reserve(exit->GetPredecessors().size());
|
||||
for (auto* pred : exit->GetPredecessors()) {
|
||||
if (loop.Contains(pred)) {
|
||||
preds.push_back(pred);
|
||||
}
|
||||
}
|
||||
if (preds.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Value* final_value = nullptr;
|
||||
auto insert_index = looputils::GetFirstNonPhiIndex(exit);
|
||||
if (preds.size() == 1) {
|
||||
auto it = block_out.find(preds.front());
|
||||
if (it == block_out.end()) {
|
||||
continue;
|
||||
}
|
||||
final_value = it->second;
|
||||
} else {
|
||||
auto* phi = exit->Insert<PhiInst>(insert_index, candidate.value_type, nullptr,
|
||||
looputils::NextSyntheticName(function, "lmp.exit."));
|
||||
++insert_index;
|
||||
for (auto* pred : preds) {
|
||||
auto it = block_out.find(pred);
|
||||
if (it != block_out.end()) {
|
||||
phi->AddIncoming(it->second, pred);
|
||||
}
|
||||
}
|
||||
final_value = phi;
|
||||
}
|
||||
exit->Insert<StoreInst>(insert_index, final_value, candidate.canonical_ptr, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
bool PromoteCandidate(Function& function, const Loop& loop, PromotionCandidate& candidate,
|
||||
const DominatorInfo& dom_info) {
|
||||
if (!candidate.seed_store || !candidate.initial_value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
InsertPhiNodes(loop, candidate, dom_info, function);
|
||||
|
||||
auto header_phi_it = candidate.phis.find(loop.header);
|
||||
if (header_phi_it != candidate.phis.end()) {
|
||||
header_phi_it->second->AddIncoming(candidate.initial_value, loop.preheader);
|
||||
}
|
||||
|
||||
std::vector<Value*> stack{candidate.initial_value};
|
||||
std::unordered_map<BasicBlock*, Value*> block_out;
|
||||
RenameCandidateInLoop(loop.header, loop, candidate, dom_info, stack, block_out);
|
||||
InsertExitStores(function, loop, candidate, block_out);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<PromotionCandidate> CollectCandidates(
|
||||
const Loop& loop, const std::vector<loopmem::MemoryAccessInfo>& accesses,
|
||||
const memutils::EscapeSummary& escapes, int iv_stride, Function& function,
|
||||
const DominatorInfo& dom_info) {
|
||||
constexpr std::size_t kMaxLoopAccesses = 64;
|
||||
if (accesses.size() > kMaxLoopAccesses) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::unordered_map<CandidateKey, PromotionCandidate, CandidateKeyHash> groups;
|
||||
|
||||
for (const auto& access : accesses) {
|
||||
if (!access.ptr.exact_key_valid || !access.ptr.invariant_address) {
|
||||
continue;
|
||||
}
|
||||
if (!access.is_read && !access.is_write) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::shared_ptr<Type> value_type;
|
||||
if (auto* load = dyncast<LoadInst>(access.inst)) {
|
||||
value_type = load->GetType();
|
||||
} else if (auto* store = dyncast<StoreInst>(access.inst)) {
|
||||
value_type = store->GetValue()->GetType();
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!IsScalarPromotableType(value_type)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
CandidateKey key{access.ptr.exact_key,
|
||||
reinterpret_cast<std::uintptr_t>(value_type.get())};
|
||||
auto& candidate = groups[key];
|
||||
candidate.key = key;
|
||||
candidate.value_type = value_type;
|
||||
candidate.pointer_info = access.ptr;
|
||||
if (auto* load = dyncast<LoadInst>(access.inst)) {
|
||||
candidate.loads.push_back(load);
|
||||
} else if (auto* store = dyncast<StoreInst>(access.inst)) {
|
||||
candidate.stores.push_back(store);
|
||||
candidate.def_blocks.insert(store->GetParent());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<PromotionCandidate> candidates;
|
||||
candidates.reserve(groups.size());
|
||||
for (auto& [key, candidate] : groups) {
|
||||
if (candidate.stores.empty()) {
|
||||
continue;
|
||||
}
|
||||
candidate.seed_store = FindSeedStoreInPreheader(loop, candidate, escapes);
|
||||
if (!candidate.seed_store && loop.parent == nullptr) {
|
||||
candidate.seed_store =
|
||||
FindReachingSeedStoreAtLoopEntry(function, loop, candidate, escapes);
|
||||
}
|
||||
if (!candidate.seed_store) {
|
||||
continue;
|
||||
}
|
||||
candidate.initial_value = candidate.seed_store->GetValue();
|
||||
candidate.canonical_ptr = candidate.seed_store->GetPtr();
|
||||
if (!IsSafeToPromoteCandidate(loop, candidate, accesses, iv_stride, dom_info)) {
|
||||
continue;
|
||||
}
|
||||
candidates.push_back(std::move(candidate));
|
||||
}
|
||||
|
||||
std::sort(candidates.begin(), candidates.end(),
|
||||
[](const PromotionCandidate& lhs, const PromotionCandidate& rhs) {
|
||||
return lhs.EstimatedBenefit() > rhs.EstimatedBenefit();
|
||||
});
|
||||
return candidates;
|
||||
}
|
||||
|
||||
bool PromoteLoopMemory(Function& function, const Loop& loop,
|
||||
const DominatorInfo& dom_info) {
|
||||
if (!loop.preheader || !loop.header || !loop.IsInnermost() ||
|
||||
!ShouldAnalyzeLoop(loop)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
loopmem::SimpleInductionVar induction_var;
|
||||
PhiInst* iv = nullptr;
|
||||
int iv_stride = 1;
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
if (loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) {
|
||||
iv = induction_var.phi;
|
||||
iv_stride = induction_var.stride;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
while (true) {
|
||||
const auto escapes = memutils::AnalyzeEscapes(function);
|
||||
const auto accesses = loopmem::CollectMemoryAccesses(loop, iv, &escapes);
|
||||
auto candidates = CollectCandidates(loop, accesses, escapes, iv_stride, function, dom_info);
|
||||
if (candidates.empty()) {
|
||||
break;
|
||||
}
|
||||
if (!PromoteCandidate(function, loop, candidates.front(), dom_info)) {
|
||||
break;
|
||||
}
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RunLoopMemoryPromotionOnFunction(Function& function) {
|
||||
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (!ShouldAnalyzeFunction(function)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
while (true) {
|
||||
DominatorTree dom_tree(function);
|
||||
LoopInfo loop_info(function, dom_tree);
|
||||
|
||||
bool cfg_changed = false;
|
||||
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
|
||||
auto* old_preheader = loop->preheader;
|
||||
auto* preheader = looputils::EnsurePreheader(function, *loop);
|
||||
if (preheader != old_preheader) {
|
||||
changed = true;
|
||||
cfg_changed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (cfg_changed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto dom_info = BuildDominatorInfo(function);
|
||||
bool local_changed = false;
|
||||
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
|
||||
local_changed |= PromoteLoopMemory(function, *loop, dom_info);
|
||||
}
|
||||
changed |= local_changed;
|
||||
if (!local_changed) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoopMemoryPromotion(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunLoopMemoryPromotionOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,506 @@
|
||||
#pragma once
|
||||
|
||||
#include "LoopPassUtils.h"
|
||||
#include "MemoryUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace ir::loopmem {
|
||||
|
||||
struct SimpleInductionVar {
|
||||
PhiInst* phi = nullptr;
|
||||
Value* start = nullptr;
|
||||
Value* latch_value = nullptr;
|
||||
BasicBlock* latch = nullptr;
|
||||
int stride = 0;
|
||||
};
|
||||
|
||||
inline bool MatchSimpleInductionVariable(const Loop& loop, BasicBlock* preheader,
|
||||
PhiInst* phi, SimpleInductionVar& info) {
|
||||
if (!phi || !preheader || phi->GetParent() != loop.header ||
|
||||
!phi->GetType()->IsInt32() || phi->GetNumIncomings() != 2 ||
|
||||
loop.latches.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* latch = loop.latches.front();
|
||||
const int preheader_index = looputils::GetPhiIncomingIndex(phi, preheader);
|
||||
const int latch_index = looputils::GetPhiIncomingIndex(phi, latch);
|
||||
if (preheader_index < 0 || latch_index < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* step_inst = dyncast<BinaryInst>(phi->GetIncomingValue(latch_index));
|
||||
if (!step_inst || step_inst->GetParent() != latch) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int stride = 0;
|
||||
if (step_inst->GetOpcode() == Opcode::Add) {
|
||||
if (step_inst->GetLhs() == phi) {
|
||||
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
|
||||
if (!delta) {
|
||||
return false;
|
||||
}
|
||||
stride = delta->GetValue();
|
||||
} else if (step_inst->GetRhs() == phi) {
|
||||
auto* delta = dyncast<ConstantInt>(step_inst->GetLhs());
|
||||
if (!delta) {
|
||||
return false;
|
||||
}
|
||||
stride = delta->GetValue();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else if (step_inst->GetOpcode() == Opcode::Sub) {
|
||||
if (step_inst->GetLhs() != phi) {
|
||||
return false;
|
||||
}
|
||||
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
|
||||
if (!delta) {
|
||||
return false;
|
||||
}
|
||||
stride = -delta->GetValue();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (stride == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
info.phi = phi;
|
||||
info.start = phi->GetIncomingValue(preheader_index);
|
||||
info.latch_value = phi->GetIncomingValue(latch_index);
|
||||
info.latch = latch;
|
||||
info.stride = stride;
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool GetCanonicalLoopBlocks(const Loop& loop, BasicBlock*& body,
|
||||
BasicBlock*& exit) {
|
||||
body = nullptr;
|
||||
exit = nullptr;
|
||||
if (!loop.header || loop.latches.size() != 1 || loop.block_list.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* condbr = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
|
||||
if (!condbr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* then_block = condbr->GetThenBlock();
|
||||
auto* else_block = condbr->GetElseBlock();
|
||||
const bool then_in_loop = loop.Contains(then_block);
|
||||
const bool else_in_loop = loop.Contains(else_block);
|
||||
if (then_in_loop == else_in_loop) {
|
||||
return false;
|
||||
}
|
||||
|
||||
body = then_in_loop ? then_block : else_block;
|
||||
exit = then_in_loop ? else_block : then_block;
|
||||
if (!body || !exit || body != loop.latches.front() ||
|
||||
body->GetSuccessors().size() != 1 || body->GetSuccessors().front() != loop.header) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
struct AffineExpr {
|
||||
bool valid = false;
|
||||
Value* var = nullptr;
|
||||
std::int64_t coeff = 0;
|
||||
std::int64_t constant = 0;
|
||||
};
|
||||
|
||||
inline AffineExpr MakeConst(std::int64_t value) {
|
||||
return {true, nullptr, 0, value};
|
||||
}
|
||||
|
||||
inline AffineExpr Scale(const AffineExpr& expr, std::int64_t factor) {
|
||||
if (!expr.valid) {
|
||||
return {};
|
||||
}
|
||||
return {true, expr.var, expr.coeff * factor, expr.constant * factor};
|
||||
}
|
||||
|
||||
inline AffineExpr Combine(const AffineExpr& lhs, const AffineExpr& rhs, int sign) {
|
||||
if (!lhs.valid || !rhs.valid) {
|
||||
return {};
|
||||
}
|
||||
if (lhs.var != nullptr && rhs.var != nullptr && lhs.var != rhs.var) {
|
||||
return {};
|
||||
}
|
||||
AffineExpr out;
|
||||
out.valid = true;
|
||||
out.var = lhs.var ? lhs.var : rhs.var;
|
||||
out.coeff = lhs.coeff + sign * rhs.coeff;
|
||||
out.constant = lhs.constant + sign * rhs.constant;
|
||||
return out;
|
||||
}
|
||||
|
||||
inline AffineExpr AnalyzeAffine(Value* value, PhiInst* iv, const Loop& loop) {
|
||||
if (!value) {
|
||||
return {};
|
||||
}
|
||||
if (auto* ci = dyncast<ConstantInt>(value)) {
|
||||
return MakeConst(ci->GetValue());
|
||||
}
|
||||
if (value == iv) {
|
||||
return {true, iv, 1, 0};
|
||||
}
|
||||
if (looputils::IsLoopInvariantValue(loop, value)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
if (auto* zext = dyncast<ZextInst>(value)) {
|
||||
return AnalyzeAffine(zext->GetValue(), iv, loop);
|
||||
}
|
||||
auto* inst = dyncast<Instruction>(value);
|
||||
if (!inst) {
|
||||
return {};
|
||||
}
|
||||
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
return Combine(AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetLhs(), iv, loop),
|
||||
AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetRhs(), iv, loop), +1);
|
||||
case Opcode::Sub:
|
||||
return Combine(AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetLhs(), iv, loop),
|
||||
AnalyzeAffine(static_cast<BinaryInst*>(inst)->GetRhs(), iv, loop), -1);
|
||||
case Opcode::Mul: {
|
||||
auto* bin = static_cast<BinaryInst*>(inst);
|
||||
auto lhs = AnalyzeAffine(bin->GetLhs(), iv, loop);
|
||||
auto rhs = AnalyzeAffine(bin->GetRhs(), iv, loop);
|
||||
if (lhs.valid && lhs.var == nullptr && rhs.valid) {
|
||||
return Scale(rhs, lhs.constant);
|
||||
}
|
||||
if (rhs.valid && rhs.var == nullptr && lhs.valid) {
|
||||
return Scale(lhs, rhs.constant);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
case Opcode::Neg:
|
||||
return Scale(AnalyzeAffine(static_cast<UnaryInst*>(inst)->GetOprd(), iv, loop), -1);
|
||||
default:
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
struct PointerInfo {
|
||||
Value* base = nullptr;
|
||||
AffineExpr byte_offset;
|
||||
bool invariant_address = false;
|
||||
bool distinct_root = false;
|
||||
bool argument_root = false;
|
||||
bool readonly_root = false;
|
||||
bool exact_key_valid = false;
|
||||
memutils::PointerRootKind root_kind = memutils::PointerRootKind::Unknown;
|
||||
memutils::AddressKey exact_key;
|
||||
int access_size = 0;
|
||||
};
|
||||
|
||||
inline Value* StripPointerBase(Value* pointer) {
|
||||
auto* value = pointer;
|
||||
while (auto* gep = dyncast<GetElementPtrInst>(value)) {
|
||||
value = gep->GetPointer();
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
inline std::shared_ptr<Type> AdvanceGEPType(std::shared_ptr<Type> current) {
|
||||
if (current && current->IsArray()) {
|
||||
return current->GetElementType();
|
||||
}
|
||||
return current;
|
||||
}
|
||||
|
||||
inline PointerInfo AnalyzePointer(Value* pointer, PhiInst* iv, const Loop& loop,
|
||||
int access_size,
|
||||
const memutils::EscapeSummary* escapes = nullptr) {
|
||||
PointerInfo info;
|
||||
info.access_size = access_size;
|
||||
info.base = StripPointerBase(pointer);
|
||||
info.root_kind = memutils::ClassifyRoot(info.base, escapes);
|
||||
info.argument_root = info.root_kind == memutils::PointerRootKind::Param;
|
||||
info.readonly_root = info.root_kind == memutils::PointerRootKind::ReadonlyGlobal;
|
||||
info.distinct_root = info.root_kind == memutils::PointerRootKind::Local ||
|
||||
info.root_kind == memutils::PointerRootKind::Global ||
|
||||
info.root_kind == memutils::PointerRootKind::ReadonlyGlobal;
|
||||
info.exact_key_valid =
|
||||
escapes != nullptr && memutils::BuildExactAddressKey(pointer, escapes, info.exact_key);
|
||||
|
||||
info.invariant_address = looputils::IsLoopInvariantValue(loop, pointer);
|
||||
if (!dyncast<GetElementPtrInst>(pointer)) {
|
||||
info.byte_offset = MakeConst(0);
|
||||
return info;
|
||||
}
|
||||
|
||||
auto* gep = static_cast<GetElementPtrInst*>(pointer);
|
||||
std::shared_ptr<Type> current = gep->GetSourceType();
|
||||
AffineExpr total = MakeConst(0);
|
||||
bool all_indices_loop_invariant = looputils::IsLoopInvariantValue(loop, gep->GetPointer());
|
||||
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
|
||||
auto* index = gep->GetIndex(i);
|
||||
all_indices_loop_invariant &= looputils::IsLoopInvariantValue(loop, index);
|
||||
const std::int64_t stride = current ? current->GetSize() : 0;
|
||||
auto term = AnalyzeAffine(index, iv, loop);
|
||||
if (!term.valid) {
|
||||
total = {};
|
||||
} else if (total.valid) {
|
||||
total = Combine(total, Scale(term, stride), +1);
|
||||
}
|
||||
current = AdvanceGEPType(current);
|
||||
}
|
||||
info.invariant_address = all_indices_loop_invariant;
|
||||
info.byte_offset = total;
|
||||
return info;
|
||||
}
|
||||
|
||||
struct MemoryAccessInfo {
|
||||
Instruction* inst = nullptr;
|
||||
Value* pointer = nullptr;
|
||||
PointerInfo ptr;
|
||||
bool is_read = false;
|
||||
bool is_write = false;
|
||||
};
|
||||
|
||||
inline std::vector<MemoryAccessInfo> CollectMemoryAccesses(const Loop& loop,
|
||||
PhiInst* iv,
|
||||
const memutils::EscapeSummary* escapes =
|
||||
nullptr) {
|
||||
std::vector<MemoryAccessInfo> accesses;
|
||||
for (auto* block : loop.block_list) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (auto* load = dyncast<LoadInst>(inst)) {
|
||||
accesses.push_back(
|
||||
{inst, load->GetPtr(),
|
||||
AnalyzePointer(load->GetPtr(), iv, loop, load->GetType()->GetSize(), escapes),
|
||||
true,
|
||||
false});
|
||||
} else if (auto* store = dyncast<StoreInst>(inst)) {
|
||||
accesses.push_back({inst, store->GetPtr(),
|
||||
AnalyzePointer(store->GetPtr(), iv, loop,
|
||||
store->GetValue()->GetType()->GetSize(), escapes),
|
||||
false, true});
|
||||
} else if (auto* memset = dyncast<MemsetInst>(inst)) {
|
||||
accesses.push_back(
|
||||
{inst, memset->GetDest(),
|
||||
AnalyzePointer(memset->GetDest(), iv, loop, 1, escapes), false, true});
|
||||
}
|
||||
}
|
||||
}
|
||||
return accesses;
|
||||
}
|
||||
|
||||
inline bool SameAffineAddress(const PointerInfo& lhs, const PointerInfo& rhs) {
|
||||
return lhs.base == rhs.base && lhs.byte_offset.valid && rhs.byte_offset.valid &&
|
||||
lhs.byte_offset.var == rhs.byte_offset.var &&
|
||||
lhs.byte_offset.coeff == rhs.byte_offset.coeff &&
|
||||
lhs.byte_offset.constant == rhs.byte_offset.constant;
|
||||
}
|
||||
|
||||
inline bool MayAliasSameIteration(const PointerInfo& lhs, const PointerInfo& rhs) {
|
||||
if (lhs.exact_key_valid && rhs.exact_key_valid) {
|
||||
return memutils::MayAliasConservatively(lhs.exact_key, rhs.exact_key);
|
||||
}
|
||||
if (!lhs.base || !rhs.base) {
|
||||
return true;
|
||||
}
|
||||
if (lhs.base != rhs.base) {
|
||||
if (lhs.distinct_root && rhs.distinct_root && !lhs.argument_root && !rhs.argument_root) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (!lhs.byte_offset.valid || !rhs.byte_offset.valid) {
|
||||
return true;
|
||||
}
|
||||
if (lhs.byte_offset.var != rhs.byte_offset.var) {
|
||||
return true;
|
||||
}
|
||||
if (lhs.byte_offset.coeff != rhs.byte_offset.coeff) {
|
||||
return true;
|
||||
}
|
||||
const auto diff = std::llabs(lhs.byte_offset.constant - rhs.byte_offset.constant);
|
||||
const auto overlap = std::min(lhs.access_size, rhs.access_size);
|
||||
return diff < overlap;
|
||||
}
|
||||
|
||||
inline bool HasCrossIterationDependence(const PointerInfo& lhs, const PointerInfo& rhs,
|
||||
int iv_stride) {
|
||||
if (lhs.exact_key_valid && rhs.exact_key_valid &&
|
||||
!memutils::MayAliasConservatively(lhs.exact_key, rhs.exact_key)) {
|
||||
return false;
|
||||
}
|
||||
if (!lhs.base || !rhs.base) {
|
||||
return true;
|
||||
}
|
||||
if (lhs.base != rhs.base) {
|
||||
if (lhs.distinct_root && rhs.distinct_root && !lhs.argument_root && !rhs.argument_root) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (!lhs.byte_offset.valid || !rhs.byte_offset.valid) {
|
||||
return true;
|
||||
}
|
||||
if (lhs.byte_offset.var != rhs.byte_offset.var) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto lhs_step = lhs.byte_offset.coeff * iv_stride;
|
||||
const auto rhs_step = rhs.byte_offset.coeff * iv_stride;
|
||||
if (lhs_step == 0 && rhs_step == 0) {
|
||||
return MayAliasSameIteration(lhs, rhs);
|
||||
}
|
||||
if (lhs_step == rhs_step && lhs_step != 0) {
|
||||
const auto diff = rhs.byte_offset.constant - lhs.byte_offset.constant;
|
||||
return diff != 0 && diff % std::llabs(lhs_step) == 0;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool CallMayWritePointer(Function* callee, const PointerInfo& ptr) {
|
||||
if (ptr.readonly_root) {
|
||||
return false;
|
||||
}
|
||||
return memutils::CallMayWriteRoot(callee, ptr.root_kind);
|
||||
}
|
||||
|
||||
inline bool IsSafeInvariantLoadToHoist(const Loop& loop, LoadInst* load, PhiInst* iv,
|
||||
int iv_stride,
|
||||
const std::vector<MemoryAccessInfo>& accesses,
|
||||
const memutils::EscapeSummary* escapes = nullptr) {
|
||||
if (!load) {
|
||||
return false;
|
||||
}
|
||||
auto ptr = AnalyzePointer(load->GetPtr(), iv, loop, load->GetType()->GetSize(), escapes);
|
||||
if (!ptr.invariant_address) {
|
||||
return false;
|
||||
}
|
||||
if (ptr.readonly_root) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (auto* block : loop.block_list) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst == load) {
|
||||
continue;
|
||||
}
|
||||
if (auto* call = dyncast<CallInst>(inst)) {
|
||||
if (CallMayWritePointer(call->GetCallee(), ptr)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& access : accesses) {
|
||||
if (access.inst == load || !access.is_write) {
|
||||
continue;
|
||||
}
|
||||
if (MayAliasSameIteration(ptr, access.ptr)) {
|
||||
return false;
|
||||
}
|
||||
if (HasCrossIterationDependence(ptr, access.ptr, iv_stride)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool HasScalarDependenceAcrossCut(const std::vector<Instruction*>& first_group,
|
||||
const std::unordered_set<Instruction*>& second_set) {
|
||||
for (auto* inst : first_group) {
|
||||
if (!inst || inst->IsVoid()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& use : inst->GetUses()) {
|
||||
auto* user = dyncast<Instruction>(use.GetUser());
|
||||
if (user && second_set.find(user) != second_set.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool HasMemoryDependenceAcrossCut(const std::vector<MemoryAccessInfo>& accesses,
|
||||
const std::unordered_set<Instruction*>& first_set,
|
||||
const std::unordered_set<Instruction*>& second_set,
|
||||
int iv_stride) {
|
||||
for (const auto& lhs : accesses) {
|
||||
if (first_set.find(lhs.inst) == first_set.end()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto& rhs : accesses) {
|
||||
if (second_set.find(rhs.inst) == second_set.end()) {
|
||||
continue;
|
||||
}
|
||||
if (!lhs.is_write && !rhs.is_write) {
|
||||
continue;
|
||||
}
|
||||
if (MayAliasSameIteration(lhs.ptr, rhs.ptr) ||
|
||||
HasCrossIterationDependence(lhs.ptr, rhs.ptr, iv_stride)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool IsLoopParallelizable(const Loop& loop, PhiInst* iv, int iv_stride,
|
||||
const std::vector<MemoryAccessInfo>& accesses) {
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
if (phi != iv) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* block : loop.block_list) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (auto* call = dyncast<CallInst>(inst)) {
|
||||
auto* callee = call->GetCallee();
|
||||
if (callee == nullptr || callee->HasObservableSideEffects() || callee->IsRecursive()) {
|
||||
return false;
|
||||
}
|
||||
for (const auto& access : accesses) {
|
||||
if (CallMayWritePointer(callee, access.ptr)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (dyncast<MemsetInst>(inst)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (std::size_t i = 0; i < accesses.size(); ++i) {
|
||||
for (std::size_t j = i + 1; j < accesses.size(); ++j) {
|
||||
if (!accesses[i].is_write && !accesses[j].is_write) {
|
||||
continue;
|
||||
}
|
||||
if (HasCrossIterationDependence(accesses[i].ptr, accesses[j].ptr, iv_stride)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace ir::loopmem
|
||||
@ -0,0 +1,440 @@
|
||||
#pragma once
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir::looputils {
|
||||
|
||||
inline Instruction* GetTerminator(BasicBlock* block) {
|
||||
if (!block || block->GetInstructions().empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto* inst = block->GetInstructions().back().get();
|
||||
return inst && inst->IsTerminator() ? inst : nullptr;
|
||||
}
|
||||
|
||||
inline std::size_t GetTerminatorIndex(BasicBlock* block) {
|
||||
if (!block) {
|
||||
return 0;
|
||||
}
|
||||
const auto size = block->GetInstructions().size();
|
||||
if (!block->HasTerminator()) {
|
||||
return size;
|
||||
}
|
||||
return size == 0 ? 0 : size - 1;
|
||||
}
|
||||
|
||||
inline std::size_t GetFirstNonPhiIndex(BasicBlock* block) {
|
||||
if (!block) {
|
||||
return 0;
|
||||
}
|
||||
std::size_t index = 0;
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
if (!dyncast<PhiInst>(inst_ptr.get())) {
|
||||
break;
|
||||
}
|
||||
++index;
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
inline std::string NextSyntheticName(Function& function, const std::string& prefix) {
|
||||
static std::unordered_map<Function*, int> counters;
|
||||
const int id = ++counters[&function];
|
||||
return "%" + prefix + std::to_string(id);
|
||||
}
|
||||
|
||||
inline std::string NextSyntheticBlockName(Function& function,
|
||||
const std::string& prefix) {
|
||||
static std::unordered_map<Function*, int> counters;
|
||||
const int id = ++counters[&function];
|
||||
return prefix + "." + std::to_string(id);
|
||||
}
|
||||
|
||||
inline ConstantInt* ConstInt(int value) {
|
||||
return new ConstantInt(Type::GetInt32Type(), value);
|
||||
}
|
||||
|
||||
inline int GetPhiIncomingIndex(PhiInst* phi, BasicBlock* block) {
|
||||
if (!phi || !block) {
|
||||
return -1;
|
||||
}
|
||||
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
|
||||
if (phi->GetIncomingBlock(i) == block) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
inline bool ReplacePhiIncoming(PhiInst* phi, BasicBlock* old_block,
|
||||
Value* new_value, BasicBlock* new_block) {
|
||||
if (!phi || !old_block || !new_value || !new_block) {
|
||||
return false;
|
||||
}
|
||||
const int index = GetPhiIncomingIndex(phi, old_block);
|
||||
if (index < 0) {
|
||||
return false;
|
||||
}
|
||||
phi->SetOperand(static_cast<std::size_t>(2 * index), new_value);
|
||||
phi->SetOperand(static_cast<std::size_t>(2 * index + 1), new_block);
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool RedirectSuccessorEdge(BasicBlock* pred, BasicBlock* old_succ,
|
||||
BasicBlock* new_succ) {
|
||||
auto* terminator = GetTerminator(pred);
|
||||
if (!terminator || !old_succ || !new_succ) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (auto* br = dyncast<UncondBrInst>(terminator)) {
|
||||
if (br->GetDest() != old_succ) {
|
||||
return false;
|
||||
}
|
||||
br->SetOperand(0, new_succ);
|
||||
} else if (auto* condbr = dyncast<CondBrInst>(terminator)) {
|
||||
bool changed = false;
|
||||
if (condbr->GetThenBlock() == old_succ) {
|
||||
condbr->SetOperand(1, new_succ);
|
||||
changed = true;
|
||||
}
|
||||
if (condbr->GetElseBlock() == old_succ) {
|
||||
condbr->SetOperand(2, new_succ);
|
||||
changed = true;
|
||||
}
|
||||
if (!changed) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
pred->RemoveSuccessor(old_succ);
|
||||
pred->AddSuccessor(new_succ);
|
||||
return true;
|
||||
}
|
||||
|
||||
inline Instruction* MoveInstructionBeforeTerminator(Instruction* inst,
|
||||
BasicBlock* dest) {
|
||||
if (!inst || !dest) {
|
||||
return nullptr;
|
||||
}
|
||||
auto* src = inst->GetParent();
|
||||
if (!src || src == dest) {
|
||||
return inst;
|
||||
}
|
||||
|
||||
auto& src_insts = src->GetInstructions();
|
||||
auto src_it = std::find_if(src_insts.begin(), src_insts.end(),
|
||||
[&](const std::unique_ptr<Instruction>& current) {
|
||||
return current.get() == inst;
|
||||
});
|
||||
if (src_it == src_insts.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto moved = std::move(*src_it);
|
||||
src_insts.erase(src_it);
|
||||
moved->SetParent(dest);
|
||||
|
||||
auto& dest_insts = dest->GetInstructions();
|
||||
auto insert_it = dest_insts.begin() +
|
||||
static_cast<long long>(GetTerminatorIndex(dest));
|
||||
auto* ptr = moved.get();
|
||||
dest_insts.insert(insert_it, std::move(moved));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
inline bool IsLoopInvariantValue(const Loop& loop, Value* value) {
|
||||
auto* inst = dyncast<Instruction>(value);
|
||||
return inst == nullptr || !loop.Contains(inst->GetParent());
|
||||
}
|
||||
|
||||
inline Value* RemapValue(const std::unordered_map<Value*, Value*>& remap,
|
||||
Value* value) {
|
||||
auto it = remap.find(value);
|
||||
return it == remap.end() ? value : it->second;
|
||||
}
|
||||
|
||||
inline bool IsCloneableInstruction(const Instruction* inst) {
|
||||
if (!inst || inst->IsTerminator() || inst->GetOpcode() == Opcode::Phi) {
|
||||
return false;
|
||||
}
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::Div:
|
||||
case Opcode::Rem:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FSub:
|
||||
case Opcode::FMul:
|
||||
case Opcode::FDiv:
|
||||
case Opcode::FRem:
|
||||
case Opcode::And:
|
||||
case Opcode::Or:
|
||||
case Opcode::Xor:
|
||||
case Opcode::Shl:
|
||||
case Opcode::AShr:
|
||||
case Opcode::LShr:
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGE:
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
case Opcode::FCmpLT:
|
||||
case Opcode::FCmpGT:
|
||||
case Opcode::FCmpLE:
|
||||
case Opcode::FCmpGE:
|
||||
case Opcode::Neg:
|
||||
case Opcode::Not:
|
||||
case Opcode::FNeg:
|
||||
case Opcode::FtoI:
|
||||
case Opcode::IToF:
|
||||
case Opcode::Alloca:
|
||||
case Opcode::Load:
|
||||
case Opcode::Store:
|
||||
case Opcode::Memset:
|
||||
case Opcode::GetElementPtr:
|
||||
case Opcode::Zext:
|
||||
case Opcode::Call:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline Instruction* CloneInstruction(Function& function, Instruction* inst,
|
||||
BasicBlock* dest,
|
||||
std::unordered_map<Value*, Value*>& remap,
|
||||
const std::string& prefix) {
|
||||
if (!inst || !dest || !IsCloneableInstruction(inst)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto insert_index = GetTerminatorIndex(dest);
|
||||
const auto name = inst->IsVoid() ? std::string()
|
||||
: NextSyntheticName(function, prefix);
|
||||
|
||||
auto remap_operand = [&](Value* value) { return RemapValue(remap, value); };
|
||||
auto remember = [&](Instruction* clone) {
|
||||
if (clone && !inst->IsVoid()) {
|
||||
remap[inst] = clone;
|
||||
}
|
||||
return clone;
|
||||
};
|
||||
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::Div:
|
||||
case Opcode::Rem:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FSub:
|
||||
case Opcode::FMul:
|
||||
case Opcode::FDiv:
|
||||
case Opcode::FRem:
|
||||
case Opcode::And:
|
||||
case Opcode::Or:
|
||||
case Opcode::Xor:
|
||||
case Opcode::Shl:
|
||||
case Opcode::AShr:
|
||||
case Opcode::LShr:
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGE:
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
case Opcode::FCmpLT:
|
||||
case Opcode::FCmpGT:
|
||||
case Opcode::FCmpLE:
|
||||
case Opcode::FCmpGE: {
|
||||
auto* bin = static_cast<BinaryInst*>(inst);
|
||||
return remember(dest->Insert<BinaryInst>(
|
||||
insert_index, inst->GetOpcode(), inst->GetType(),
|
||||
remap_operand(bin->GetLhs()), remap_operand(bin->GetRhs()), nullptr,
|
||||
name));
|
||||
}
|
||||
case Opcode::Neg:
|
||||
case Opcode::Not:
|
||||
case Opcode::FNeg:
|
||||
case Opcode::FtoI:
|
||||
case Opcode::IToF: {
|
||||
auto* un = static_cast<UnaryInst*>(inst);
|
||||
return remember(dest->Insert<UnaryInst>(insert_index, inst->GetOpcode(),
|
||||
inst->GetType(),
|
||||
remap_operand(un->GetOprd()),
|
||||
nullptr, name));
|
||||
}
|
||||
case Opcode::Alloca: {
|
||||
auto* alloca = static_cast<AllocaInst*>(inst);
|
||||
return remember(dest->Insert<AllocaInst>(insert_index,
|
||||
alloca->GetAllocatedType(),
|
||||
nullptr, name));
|
||||
}
|
||||
case Opcode::Load: {
|
||||
auto* load = static_cast<LoadInst*>(inst);
|
||||
return remember(dest->Insert<LoadInst>(insert_index, inst->GetType(),
|
||||
remap_operand(load->GetPtr()),
|
||||
nullptr, name));
|
||||
}
|
||||
case Opcode::Store: {
|
||||
auto* store = static_cast<StoreInst*>(inst);
|
||||
return dest->Insert<StoreInst>(insert_index,
|
||||
remap_operand(store->GetValue()),
|
||||
remap_operand(store->GetPtr()), nullptr);
|
||||
}
|
||||
case Opcode::Memset: {
|
||||
auto* memset = static_cast<MemsetInst*>(inst);
|
||||
return dest->Insert<MemsetInst>(insert_index,
|
||||
remap_operand(memset->GetDest()),
|
||||
remap_operand(memset->GetValue()),
|
||||
remap_operand(memset->GetLength()),
|
||||
remap_operand(memset->GetIsVolatile()),
|
||||
nullptr);
|
||||
}
|
||||
case Opcode::GetElementPtr: {
|
||||
auto* gep = static_cast<GetElementPtrInst*>(inst);
|
||||
std::vector<Value*> indices;
|
||||
indices.reserve(gep->GetNumIndices());
|
||||
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
|
||||
indices.push_back(remap_operand(gep->GetIndex(i)));
|
||||
}
|
||||
return remember(dest->Insert<GetElementPtrInst>(
|
||||
insert_index, gep->GetSourceType(), remap_operand(gep->GetPointer()),
|
||||
indices, nullptr, name));
|
||||
}
|
||||
case Opcode::Zext: {
|
||||
auto* zext = static_cast<ZextInst*>(inst);
|
||||
return remember(dest->Insert<ZextInst>(insert_index,
|
||||
remap_operand(zext->GetValue()),
|
||||
inst->GetType(), nullptr, name));
|
||||
}
|
||||
case Opcode::Call: {
|
||||
auto* call = static_cast<CallInst*>(inst);
|
||||
std::vector<Value*> args;
|
||||
const auto original_args = call->GetArguments();
|
||||
args.reserve(original_args.size());
|
||||
for (auto* arg : original_args) {
|
||||
args.push_back(remap_operand(arg));
|
||||
}
|
||||
return remember(dest->Insert<CallInst>(insert_index, call->GetCallee(),
|
||||
args, nullptr, name));
|
||||
}
|
||||
case Opcode::Phi:
|
||||
case Opcode::Br:
|
||||
case Opcode::CondBr:
|
||||
case Opcode::Return:
|
||||
case Opcode::Unreachable:
|
||||
break;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline BasicBlock* EnsurePreheader(Function& function, Loop& loop) {
|
||||
if (loop.preheader) {
|
||||
return loop.preheader;
|
||||
}
|
||||
|
||||
auto* header = loop.header;
|
||||
if (!header) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<BasicBlock*> outside_preds;
|
||||
for (auto* pred : header->GetPredecessors()) {
|
||||
if (!loop.Contains(pred)) {
|
||||
outside_preds.push_back(pred);
|
||||
}
|
||||
}
|
||||
if (outside_preds.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (outside_preds.size() == 1 &&
|
||||
outside_preds.front()->GetSuccessors().size() == 1) {
|
||||
loop.preheader = outside_preds.front();
|
||||
return loop.preheader;
|
||||
}
|
||||
|
||||
auto* preheader = function.CreateBlock(
|
||||
NextSyntheticBlockName(function, header->GetName() + ".preheader"));
|
||||
|
||||
std::size_t phi_insert_index = 0;
|
||||
for (const auto& inst_ptr : header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
|
||||
std::vector<int> outside_incomings;
|
||||
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
|
||||
if (!loop.Contains(phi->GetIncomingBlock(i))) {
|
||||
outside_incomings.push_back(i);
|
||||
}
|
||||
}
|
||||
if (outside_incomings.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Value* merged_value = nullptr;
|
||||
if (outside_incomings.size() == 1) {
|
||||
merged_value = phi->GetIncomingValue(outside_incomings.front());
|
||||
} else {
|
||||
auto new_phi = std::make_unique<PhiInst>(
|
||||
phi->GetType(), nullptr,
|
||||
NextSyntheticName(function, "preheader.phi."));
|
||||
auto* new_phi_ptr = new_phi.get();
|
||||
new_phi_ptr->SetParent(preheader);
|
||||
auto& preheader_insts = preheader->GetInstructions();
|
||||
preheader_insts.insert(preheader_insts.begin() +
|
||||
static_cast<long long>(phi_insert_index),
|
||||
std::move(new_phi));
|
||||
++phi_insert_index;
|
||||
|
||||
for (int incoming_index : outside_incomings) {
|
||||
new_phi_ptr->AddIncoming(phi->GetIncomingValue(incoming_index),
|
||||
phi->GetIncomingBlock(incoming_index));
|
||||
}
|
||||
merged_value = new_phi_ptr;
|
||||
}
|
||||
|
||||
for (auto it = outside_incomings.rbegin(); it != outside_incomings.rend();
|
||||
++it) {
|
||||
phi->RemoveOperand(static_cast<std::size_t>(2 * *it + 1));
|
||||
phi->RemoveOperand(static_cast<std::size_t>(2 * *it));
|
||||
}
|
||||
phi->AddIncoming(merged_value, preheader);
|
||||
}
|
||||
|
||||
preheader->Append<UncondBrInst>(header, nullptr);
|
||||
preheader->AddSuccessor(header);
|
||||
header->AddPredecessor(preheader);
|
||||
|
||||
for (auto* pred : outside_preds) {
|
||||
if (RedirectSuccessorEdge(pred, header, preheader)) {
|
||||
preheader->AddPredecessor(pred);
|
||||
header->RemovePredecessor(pred);
|
||||
}
|
||||
}
|
||||
|
||||
loop.preheader = preheader;
|
||||
return preheader;
|
||||
}
|
||||
|
||||
} // namespace ir::looputils
|
||||
@ -0,0 +1,295 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct InductionVarInfo {
|
||||
PhiInst* phi = nullptr;
|
||||
Value* start = nullptr;
|
||||
BasicBlock* latch = nullptr;
|
||||
int stride = 0;
|
||||
};
|
||||
|
||||
Value* BuildMulValue(Function& function, BasicBlock* block, Value* lhs, Value* rhs,
|
||||
const std::string& prefix) {
|
||||
if (auto* lhs_const = dyncast<ConstantInt>(lhs)) {
|
||||
if (lhs_const->GetValue() == 0) {
|
||||
return looputils::ConstInt(0);
|
||||
}
|
||||
if (lhs_const->GetValue() == 1) {
|
||||
return rhs;
|
||||
}
|
||||
}
|
||||
if (auto* rhs_const = dyncast<ConstantInt>(rhs)) {
|
||||
if (rhs_const->GetValue() == 0) {
|
||||
return looputils::ConstInt(0);
|
||||
}
|
||||
if (rhs_const->GetValue() == 1) {
|
||||
return lhs;
|
||||
}
|
||||
}
|
||||
if (auto* lhs_const = dyncast<ConstantInt>(lhs)) {
|
||||
if (auto* rhs_const = dyncast<ConstantInt>(rhs)) {
|
||||
return looputils::ConstInt(lhs_const->GetValue() * rhs_const->GetValue());
|
||||
}
|
||||
}
|
||||
return block->Insert<BinaryInst>(looputils::GetTerminatorIndex(block), Opcode::Mul,
|
||||
Type::GetInt32Type(), lhs, rhs, nullptr,
|
||||
looputils::NextSyntheticName(function, prefix));
|
||||
}
|
||||
|
||||
Value* BuildScaledValue(Function& function, BasicBlock* block, Value* base,
|
||||
int factor, const std::string& prefix) {
|
||||
if (factor == 0) {
|
||||
return looputils::ConstInt(0);
|
||||
}
|
||||
if (factor == 1) {
|
||||
return base;
|
||||
}
|
||||
if (auto* base_const = dyncast<ConstantInt>(base)) {
|
||||
return looputils::ConstInt(base_const->GetValue() * factor);
|
||||
}
|
||||
if (factor == -1) {
|
||||
return block->Insert<UnaryInst>(looputils::GetTerminatorIndex(block), Opcode::Neg,
|
||||
base->GetType(), base, nullptr,
|
||||
looputils::NextSyntheticName(function, prefix));
|
||||
}
|
||||
return BuildMulValue(function, block, base, looputils::ConstInt(factor), prefix);
|
||||
}
|
||||
|
||||
bool MatchSimpleInductionVariable(const Loop& loop, BasicBlock* preheader,
|
||||
PhiInst* phi, InductionVarInfo& info) {
|
||||
if (!phi || !phi->GetType() || !phi->GetType()->IsInt32() ||
|
||||
phi->GetParent() != loop.header || phi->GetNumIncomings() != 2 ||
|
||||
loop.latches.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* latch = loop.latches.front();
|
||||
int preheader_index = -1;
|
||||
int latch_index = -1;
|
||||
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
|
||||
if (phi->GetIncomingBlock(i) == preheader) {
|
||||
preheader_index = i;
|
||||
} else if (phi->GetIncomingBlock(i) == latch) {
|
||||
latch_index = i;
|
||||
}
|
||||
}
|
||||
if (preheader_index < 0 || latch_index < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* step_inst = dyncast<BinaryInst>(phi->GetIncomingValue(latch_index));
|
||||
if (!step_inst || step_inst->GetParent() != latch) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int stride = 0;
|
||||
if (step_inst->GetOpcode() == Opcode::Add) {
|
||||
if (step_inst->GetLhs() == phi) {
|
||||
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
|
||||
if (!delta) {
|
||||
return false;
|
||||
}
|
||||
stride = delta->GetValue();
|
||||
} else if (step_inst->GetRhs() == phi) {
|
||||
auto* delta = dyncast<ConstantInt>(step_inst->GetLhs());
|
||||
if (!delta) {
|
||||
return false;
|
||||
}
|
||||
stride = delta->GetValue();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else if (step_inst->GetOpcode() == Opcode::Sub) {
|
||||
if (step_inst->GetLhs() != phi) {
|
||||
return false;
|
||||
}
|
||||
auto* delta = dyncast<ConstantInt>(step_inst->GetRhs());
|
||||
if (!delta) {
|
||||
return false;
|
||||
}
|
||||
stride = -delta->GetValue();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (stride == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
info.phi = phi;
|
||||
info.start = phi->GetIncomingValue(preheader_index);
|
||||
info.latch = latch;
|
||||
info.stride = stride;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsMulCandidate(const Loop& loop, Instruction* inst, PhiInst* phi, Value*& factor) {
|
||||
auto* mul = dyncast<BinaryInst>(inst);
|
||||
if (!mul || mul->GetOpcode() != Opcode::Mul || !mul->GetType()->IsInt32()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (mul->GetLhs() == phi && looputils::IsLoopInvariantValue(loop, mul->GetRhs())) {
|
||||
factor = mul->GetRhs();
|
||||
return true;
|
||||
}
|
||||
if (mul->GetRhs() == phi && looputils::IsLoopInvariantValue(loop, mul->GetLhs())) {
|
||||
factor = mul->GetLhs();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Value* CreateReducedPhi(Function& function, BasicBlock* header, BasicBlock* preheader,
|
||||
const InductionVarInfo& iv, Value* factor) {
|
||||
auto* reduced_phi = header->Insert<PhiInst>(
|
||||
looputils::GetFirstNonPhiIndex(header), Type::GetInt32Type(), nullptr,
|
||||
looputils::NextSyntheticName(function, "lsr.phi."));
|
||||
|
||||
Value* init = BuildMulValue(function, preheader, iv.start, factor, "lsr.init.");
|
||||
reduced_phi->AddIncoming(init, preheader);
|
||||
|
||||
Value* step = BuildScaledValue(function, preheader, factor, std::abs(iv.stride),
|
||||
"lsr.step.");
|
||||
Instruction* next = nullptr;
|
||||
if (iv.stride > 0) {
|
||||
next = iv.latch->Insert<BinaryInst>(
|
||||
looputils::GetTerminatorIndex(iv.latch), Opcode::Add, Type::GetInt32Type(),
|
||||
reduced_phi, step, nullptr,
|
||||
looputils::NextSyntheticName(function, "lsr.next."));
|
||||
} else {
|
||||
next = iv.latch->Insert<BinaryInst>(
|
||||
looputils::GetTerminatorIndex(iv.latch), Opcode::Sub, Type::GetInt32Type(),
|
||||
reduced_phi, step, nullptr,
|
||||
looputils::NextSyntheticName(function, "lsr.next."));
|
||||
}
|
||||
reduced_phi->AddIncoming(next, iv.latch);
|
||||
return reduced_phi;
|
||||
}
|
||||
|
||||
bool ReduceLoopMultiplications(Function& function, const Loop& loop,
|
||||
BasicBlock* preheader) {
|
||||
if (!preheader || loop.latches.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<InductionVarInfo> induction_vars;
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
InductionVarInfo info;
|
||||
if (MatchSimpleInductionVariable(loop, preheader, phi, info)) {
|
||||
induction_vars.push_back(info);
|
||||
}
|
||||
}
|
||||
if (induction_vars.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
std::vector<Instruction*> to_remove;
|
||||
for (const auto& iv : induction_vars) {
|
||||
std::vector<std::pair<Instruction*, Value*>> candidates;
|
||||
for (auto* block : loop.block_list) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst == iv.phi) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Value* factor = nullptr;
|
||||
if (IsMulCandidate(loop, inst, iv.phi, factor)) {
|
||||
candidates.push_back({inst, factor});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (candidates.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unordered_map<Value*, Value*> reduced_cache;
|
||||
for (const auto& candidate : candidates) {
|
||||
auto* inst = candidate.first;
|
||||
auto* factor = candidate.second;
|
||||
|
||||
auto cache_it = reduced_cache.find(factor);
|
||||
Value* replacement = nullptr;
|
||||
if (cache_it != reduced_cache.end()) {
|
||||
replacement = cache_it->second;
|
||||
} else {
|
||||
replacement = CreateReducedPhi(function, loop.header, preheader, iv, factor);
|
||||
reduced_cache.emplace(factor, replacement);
|
||||
}
|
||||
|
||||
inst->ReplaceAllUsesWith(replacement);
|
||||
to_remove.push_back(inst);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* inst : to_remove) {
|
||||
if (inst && inst->GetParent()) {
|
||||
inst->GetParent()->EraseInstruction(inst);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RunLoopStrengthReductionOnFunction(Function& function) {
|
||||
if (function.IsExternal() || function.GetEntryBlock() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
while (true) {
|
||||
DominatorTree dom_tree(function);
|
||||
LoopInfo loop_info(function, dom_tree);
|
||||
bool local_changed = false;
|
||||
|
||||
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
|
||||
auto* old_preheader = loop->preheader;
|
||||
auto* preheader = looputils::EnsurePreheader(function, *loop);
|
||||
bool loop_changed = preheader != old_preheader;
|
||||
loop_changed |= ReduceLoopMultiplications(function, *loop, preheader);
|
||||
if (!loop_changed) {
|
||||
continue;
|
||||
}
|
||||
changed = true;
|
||||
local_changed = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!local_changed) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoopStrengthReduction(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunLoopStrengthReductionOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,400 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "LoopMemoryUtils.h"
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct CountedLoopInfo {
|
||||
Loop* loop = nullptr;
|
||||
BasicBlock* preheader = nullptr;
|
||||
BasicBlock* header = nullptr;
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
CondBrInst* branch = nullptr;
|
||||
BinaryInst* compare = nullptr;
|
||||
Opcode compare_opcode = Opcode::ICmpLT;
|
||||
Value* bound = nullptr;
|
||||
loopmem::SimpleInductionVar induction_var;
|
||||
std::vector<PhiInst*> phis;
|
||||
};
|
||||
|
||||
bool HasSyntheticLoopTag(const std::string& name) {
|
||||
return name.find("unroll.") != std::string::npos;
|
||||
}
|
||||
|
||||
bool IsAlreadyTransformedLoop(const Loop& loop, BasicBlock* body) {
|
||||
if (!loop.preheader || !loop.header || !body) {
|
||||
return true;
|
||||
}
|
||||
if (HasSyntheticLoopTag(loop.preheader->GetName()) ||
|
||||
HasSyntheticLoopTag(loop.header->GetName()) ||
|
||||
HasSyntheticLoopTag(body->GetName())) {
|
||||
return true;
|
||||
}
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
const int incoming = looputils::GetPhiIncomingIndex(phi, loop.preheader);
|
||||
if (incoming < 0) {
|
||||
continue;
|
||||
}
|
||||
auto* incoming_phi = dyncast<PhiInst>(phi->GetIncomingValue(incoming));
|
||||
if (incoming_phi && incoming_phi->GetParent() &&
|
||||
HasSyntheticLoopTag(incoming_phi->GetParent()->GetName())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsSupportedCompareOpcode(Opcode opcode) {
|
||||
switch (opcode) {
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpGE:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Opcode SwapCompareOpcode(Opcode opcode) {
|
||||
switch (opcode) {
|
||||
case Opcode::ICmpLT:
|
||||
return Opcode::ICmpGT;
|
||||
case Opcode::ICmpLE:
|
||||
return Opcode::ICmpGE;
|
||||
case Opcode::ICmpGT:
|
||||
return Opcode::ICmpLT;
|
||||
case Opcode::ICmpGE:
|
||||
return Opcode::ICmpLE;
|
||||
default:
|
||||
return opcode;
|
||||
}
|
||||
}
|
||||
|
||||
int CountPayloadInstructions(BasicBlock* block) {
|
||||
int count = 0;
|
||||
if (!block) {
|
||||
return 0;
|
||||
}
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator()) {
|
||||
break;
|
||||
}
|
||||
++count;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
int ChooseUnrollFactor(BasicBlock* body) {
|
||||
const int inst_count = CountPayloadInstructions(body);
|
||||
int mem_ops = 0;
|
||||
for (const auto& inst_ptr : body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator()) {
|
||||
break;
|
||||
}
|
||||
if (dyncast<LoadInst>(inst) || dyncast<StoreInst>(inst)) {
|
||||
++mem_ops;
|
||||
}
|
||||
}
|
||||
if (inst_count >= 2 && inst_count <= 6 && mem_ops <= 2) {
|
||||
return 4;
|
||||
}
|
||||
if (inst_count >= 2 && inst_count <= 18) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
bool HasUnsafeLoopCarriedMemoryDependence(
|
||||
const std::vector<loopmem::MemoryAccessInfo>& accesses, int iv_stride) {
|
||||
for (std::size_t i = 0; i < accesses.size(); ++i) {
|
||||
if (accesses[i].is_write &&
|
||||
loopmem::HasCrossIterationDependence(accesses[i].ptr, accesses[i].ptr,
|
||||
iv_stride)) {
|
||||
return true;
|
||||
}
|
||||
for (std::size_t j = i + 1; j < accesses.size(); ++j) {
|
||||
if (!accesses[i].is_write && !accesses[j].is_write) {
|
||||
continue;
|
||||
}
|
||||
if (loopmem::HasCrossIterationDependence(accesses[i].ptr, accesses[j].ptr,
|
||||
iv_stride)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool MatchCountedLoop(Loop& loop, CountedLoopInfo& info) {
|
||||
if (!loop.preheader || !loop.header || !loop.IsInnermost()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
BasicBlock* body = nullptr;
|
||||
BasicBlock* exit = nullptr;
|
||||
if (!loopmem::GetCanonicalLoopBlocks(loop, body, exit)) {
|
||||
return false;
|
||||
}
|
||||
if (IsAlreadyTransformedLoop(loop, body)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* branch = dyncast<CondBrInst>(looputils::GetTerminator(loop.header));
|
||||
if (!branch || branch->GetThenBlock() != body) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* compare = dyncast<BinaryInst>(branch->GetCondition());
|
||||
if (!compare || !compare->GetType()->IsBool() ||
|
||||
!IsSupportedCompareOpcode(compare->GetOpcode())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool found_iv = false;
|
||||
loopmem::SimpleInductionVar induction_var;
|
||||
std::vector<PhiInst*> phis;
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
phis.push_back(phi);
|
||||
if (!found_iv &&
|
||||
loopmem::MatchSimpleInductionVariable(loop, loop.preheader, phi, induction_var)) {
|
||||
found_iv = true;
|
||||
}
|
||||
}
|
||||
if (!found_iv) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Opcode compare_opcode = compare->GetOpcode();
|
||||
Value* bound = nullptr;
|
||||
if (compare->GetLhs() == induction_var.phi &&
|
||||
looputils::IsLoopInvariantValue(loop, compare->GetRhs())) {
|
||||
bound = compare->GetRhs();
|
||||
} else if (compare->GetRhs() == induction_var.phi &&
|
||||
looputils::IsLoopInvariantValue(loop, compare->GetLhs())) {
|
||||
bound = compare->GetLhs();
|
||||
compare_opcode = SwapCompareOpcode(compare_opcode);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!bound) {
|
||||
return false;
|
||||
}
|
||||
if ((induction_var.stride > 0 &&
|
||||
!(compare_opcode == Opcode::ICmpLT || compare_opcode == Opcode::ICmpLE)) ||
|
||||
(induction_var.stride < 0 &&
|
||||
!(compare_opcode == Opcode::ICmpGT || compare_opcode == Opcode::ICmpGE))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto accesses = loopmem::CollectMemoryAccesses(loop, induction_var.phi);
|
||||
if (HasUnsafeLoopCarriedMemoryDependence(accesses, induction_var.stride)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto& inst_ptr : loop.header->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (dyncast<PhiInst>(inst) || inst == compare || inst->IsTerminator()) {
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto& inst_ptr : body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator()) {
|
||||
continue;
|
||||
}
|
||||
if (!looputils::IsCloneableInstruction(inst) || dyncast<CallInst>(inst) ||
|
||||
dyncast<MemsetInst>(inst) || dyncast<AllocaInst>(inst)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
info.loop = &loop;
|
||||
info.preheader = loop.preheader;
|
||||
info.header = loop.header;
|
||||
info.body = body;
|
||||
info.exit = exit;
|
||||
info.branch = branch;
|
||||
info.compare = compare;
|
||||
info.compare_opcode = compare_opcode;
|
||||
info.bound = bound;
|
||||
info.induction_var = induction_var;
|
||||
info.phis = std::move(phis);
|
||||
return true;
|
||||
}
|
||||
|
||||
Value* BuildAdjustedBound(Function& function, BasicBlock* preheader, Value* bound,
|
||||
int stride, int factor) {
|
||||
const int delta = std::abs(stride) * (factor - 1);
|
||||
if (delta == 0) {
|
||||
return bound;
|
||||
}
|
||||
if (auto* ci = dyncast<ConstantInt>(bound)) {
|
||||
return looputils::ConstInt(stride > 0 ? ci->GetValue() - delta : ci->GetValue() + delta);
|
||||
}
|
||||
return preheader->Insert<BinaryInst>(
|
||||
looputils::GetTerminatorIndex(preheader),
|
||||
stride > 0 ? Opcode::Sub : Opcode::Add, Type::GetInt32Type(), bound,
|
||||
looputils::ConstInt(delta), nullptr,
|
||||
looputils::NextSyntheticName(function, "unroll.bound."));
|
||||
}
|
||||
|
||||
bool RunLoopUnrollOnFunction(Function& function) {
|
||||
if (function.IsExternal() || !function.GetEntryBlock()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
while (true) {
|
||||
DominatorTree dom_tree(function);
|
||||
LoopInfo loop_info(function, dom_tree);
|
||||
bool local_changed = false;
|
||||
|
||||
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
|
||||
CountedLoopInfo info;
|
||||
if (!MatchCountedLoop(*loop, info)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int factor = ChooseUnrollFactor(info.body);
|
||||
if (factor <= 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* unrolled_header =
|
||||
function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.header"));
|
||||
auto* unrolled_body =
|
||||
function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.body"));
|
||||
auto* unrolled_exit =
|
||||
function.CreateBlock(looputils::NextSyntheticBlockName(function, "unroll.exit"));
|
||||
|
||||
std::unordered_map<Value*, Value*> remap;
|
||||
std::unordered_map<PhiInst*, PhiInst*> unrolled_phis;
|
||||
std::unordered_map<PhiInst*, PhiInst*> exit_phis;
|
||||
std::unordered_map<PhiInst*, Value*> current_phi_values;
|
||||
std::unordered_map<PhiInst*, Value*> latch_values;
|
||||
|
||||
for (auto* phi : info.phis) {
|
||||
auto* cloned_phi = unrolled_header->Append<PhiInst>(
|
||||
phi->GetType(), nullptr,
|
||||
looputils::NextSyntheticName(function, "unroll.phi."));
|
||||
const int preheader_index = looputils::GetPhiIncomingIndex(phi, info.preheader);
|
||||
const int latch_index = looputils::GetPhiIncomingIndex(phi, info.body);
|
||||
if (preheader_index < 0 || latch_index < 0) {
|
||||
continue;
|
||||
}
|
||||
cloned_phi->AddIncoming(phi->GetIncomingValue(preheader_index), info.preheader);
|
||||
remap[phi] = cloned_phi;
|
||||
unrolled_phis.emplace(phi, cloned_phi);
|
||||
current_phi_values.emplace(phi, cloned_phi);
|
||||
latch_values.emplace(phi, phi->GetIncomingValue(latch_index));
|
||||
}
|
||||
|
||||
auto* adjusted_bound = BuildAdjustedBound(function, info.preheader, info.bound,
|
||||
info.induction_var.stride, factor);
|
||||
auto* unrolled_cond = unrolled_header->Append<BinaryInst>(
|
||||
info.compare_opcode, Type::GetBoolType(), unrolled_phis[info.induction_var.phi],
|
||||
adjusted_bound, nullptr,
|
||||
looputils::NextSyntheticName(function, "unroll.cmp."));
|
||||
unrolled_header->Append<CondBrInst>(unrolled_cond, unrolled_body, unrolled_exit, nullptr);
|
||||
unrolled_header->AddPredecessor(info.preheader);
|
||||
unrolled_header->AddSuccessor(unrolled_body);
|
||||
unrolled_header->AddSuccessor(unrolled_exit);
|
||||
|
||||
for (int lane = 0; lane < factor; ++lane) {
|
||||
for (const auto& inst_ptr : info.body->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (inst->IsTerminator()) {
|
||||
continue;
|
||||
}
|
||||
looputils::CloneInstruction(function, inst, unrolled_body, remap,
|
||||
"unroll." + std::to_string(lane) + ".");
|
||||
}
|
||||
|
||||
std::unordered_map<PhiInst*, Value*> next_phi_values;
|
||||
for (const auto& entry : latch_values) {
|
||||
next_phi_values.emplace(entry.first,
|
||||
looputils::RemapValue(remap, entry.second));
|
||||
}
|
||||
for (const auto& entry : next_phi_values) {
|
||||
remap[entry.first] = entry.second;
|
||||
current_phi_values[entry.first] = entry.second;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& entry : unrolled_phis) {
|
||||
entry.second->AddIncoming(current_phi_values[entry.first], unrolled_body);
|
||||
}
|
||||
unrolled_body->Append<UncondBrInst>(unrolled_header, nullptr);
|
||||
unrolled_body->AddPredecessor(unrolled_header);
|
||||
unrolled_body->AddSuccessor(unrolled_header);
|
||||
unrolled_header->AddPredecessor(unrolled_body);
|
||||
|
||||
for (const auto& entry : unrolled_phis) {
|
||||
auto* exit_phi = unrolled_exit->Append<PhiInst>(
|
||||
entry.first->GetType(), nullptr,
|
||||
looputils::NextSyntheticName(function, "unroll.exit."));
|
||||
exit_phi->AddIncoming(entry.second, unrolled_header);
|
||||
exit_phis.emplace(entry.first, exit_phi);
|
||||
}
|
||||
unrolled_exit->Append<UncondBrInst>(info.header, nullptr);
|
||||
unrolled_exit->AddPredecessor(unrolled_header);
|
||||
unrolled_exit->AddSuccessor(info.header);
|
||||
|
||||
if (!looputils::RedirectSuccessorEdge(info.preheader, info.header, unrolled_header)) {
|
||||
continue;
|
||||
}
|
||||
info.header->RemovePredecessor(info.preheader);
|
||||
info.header->AddPredecessor(unrolled_exit);
|
||||
|
||||
for (auto* phi : info.phis) {
|
||||
looputils::ReplacePhiIncoming(phi, info.preheader, exit_phis[phi], unrolled_exit);
|
||||
}
|
||||
|
||||
changed = true;
|
||||
local_changed = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!local_changed) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoopUnroll(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunLoopUnrollOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,313 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/Analysis.h"
|
||||
#include "ir/IR.h"
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct UnswitchInfo {
|
||||
Loop* loop = nullptr;
|
||||
BasicBlock* preheader = nullptr;
|
||||
BasicBlock* guard_block = nullptr;
|
||||
CondBrInst* guard = nullptr;
|
||||
Value* condition = nullptr;
|
||||
std::vector<BasicBlock*> order;
|
||||
};
|
||||
|
||||
bool HasSyntheticUnswitchTag(const std::string& name) {
|
||||
return name.find("unswitch.") != std::string::npos;
|
||||
}
|
||||
|
||||
bool IsSafeLoopBlockForUnswitch(BasicBlock* block) {
|
||||
if (!block) {
|
||||
return false;
|
||||
}
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (dyncast<CallInst>(inst) || dyncast<MemsetInst>(inst) ||
|
||||
dyncast<AllocaInst>(inst) || dyncast<UnreachableInst>(inst)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void CollectLoopDFS(BasicBlock* block, const Loop& loop,
|
||||
std::unordered_set<BasicBlock*>& visited,
|
||||
std::vector<BasicBlock*>& postorder) {
|
||||
if (!block || !loop.Contains(block) || !visited.insert(block).second) {
|
||||
return;
|
||||
}
|
||||
for (auto* succ : block->GetSuccessors()) {
|
||||
if (loop.Contains(succ)) {
|
||||
CollectLoopDFS(succ, loop, visited, postorder);
|
||||
}
|
||||
}
|
||||
postorder.push_back(block);
|
||||
}
|
||||
|
||||
std::vector<BasicBlock*> CollectLoopRPO(const Loop& loop) {
|
||||
std::vector<BasicBlock*> postorder;
|
||||
std::unordered_set<BasicBlock*> visited;
|
||||
CollectLoopDFS(loop.header, loop, visited, postorder);
|
||||
return std::vector<BasicBlock*>(postorder.rbegin(), postorder.rend());
|
||||
}
|
||||
|
||||
void RemovePhiIncomingFromPred(BasicBlock* block, BasicBlock* pred) {
|
||||
if (!block || !pred) {
|
||||
return;
|
||||
}
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
const int index = looputils::GetPhiIncomingIndex(phi, pred);
|
||||
if (index < 0) {
|
||||
continue;
|
||||
}
|
||||
phi->RemoveOperand(static_cast<std::size_t>(2 * index + 1));
|
||||
phi->RemoveOperand(static_cast<std::size_t>(2 * index));
|
||||
}
|
||||
}
|
||||
|
||||
void ReplaceTerminatorWithBr(BasicBlock* block, BasicBlock* dest) {
|
||||
auto& instructions = block->GetInstructions();
|
||||
if (instructions.empty() || !instructions.back()->IsTerminator()) {
|
||||
return;
|
||||
}
|
||||
instructions.back()->ClearAllOperands();
|
||||
instructions.pop_back();
|
||||
block->Append<UncondBrInst>(dest, nullptr);
|
||||
}
|
||||
|
||||
void ReplaceTerminatorWithCondBr(BasicBlock* block, Value* cond,
|
||||
BasicBlock* then_block, BasicBlock* else_block) {
|
||||
auto& instructions = block->GetInstructions();
|
||||
if (instructions.empty() || !instructions.back()->IsTerminator()) {
|
||||
return;
|
||||
}
|
||||
instructions.back()->ClearAllOperands();
|
||||
instructions.pop_back();
|
||||
block->Append<CondBrInst>(cond, then_block, else_block, nullptr);
|
||||
}
|
||||
|
||||
bool MatchLoopUnswitch(Loop& loop, UnswitchInfo& info) {
|
||||
if (!loop.preheader || !loop.IsInnermost() || loop.block_list.size() > 6) {
|
||||
return false;
|
||||
}
|
||||
if (HasSyntheticUnswitchTag(loop.preheader->GetName()) ||
|
||||
HasSyntheticUnswitchTag(loop.header->GetName())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int instruction_count = 0;
|
||||
for (auto* block : loop.block_list) {
|
||||
if (!IsSafeLoopBlockForUnswitch(block) || HasSyntheticUnswitchTag(block->GetName())) {
|
||||
return false;
|
||||
}
|
||||
instruction_count += static_cast<int>(block->GetInstructions().size());
|
||||
}
|
||||
if (instruction_count > 48) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto* block : loop.block_list) {
|
||||
auto* condbr = dyncast<CondBrInst>(looputils::GetTerminator(block));
|
||||
if (!condbr) {
|
||||
continue;
|
||||
}
|
||||
auto* cond_inst = dyncast<Instruction>(condbr->GetCondition());
|
||||
if (cond_inst && loop.Contains(cond_inst->GetParent())) {
|
||||
continue;
|
||||
}
|
||||
info.loop = &loop;
|
||||
info.preheader = loop.preheader;
|
||||
info.guard_block = block;
|
||||
info.guard = condbr;
|
||||
info.condition = condbr->GetCondition();
|
||||
info.order = CollectLoopRPO(loop);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CloneLoopForUnswitch(Function& function, const UnswitchInfo& info,
|
||||
std::unordered_map<BasicBlock*, BasicBlock*>& block_map,
|
||||
std::unordered_map<Value*, Value*>& value_map) {
|
||||
for (auto* block : info.order) {
|
||||
block_map[block] = function.CreateBlock(
|
||||
looputils::NextSyntheticBlockName(function, "unswitch.loop"));
|
||||
}
|
||||
|
||||
for (auto* block : info.order) {
|
||||
auto* clone = block_map.at(block);
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
auto* cloned_phi = clone->Append<PhiInst>(
|
||||
phi->GetType(), nullptr, looputils::NextSyntheticName(function, "unswitch.phi."));
|
||||
value_map[phi] = cloned_phi;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* block : info.order) {
|
||||
auto* clone = block_map.at(block);
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (dyncast<PhiInst>(inst) || inst->IsTerminator()) {
|
||||
continue;
|
||||
}
|
||||
if (!looputils::CloneInstruction(function, inst, clone, value_map, "unswitch.")) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* block : info.order) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
auto* cloned_phi = static_cast<PhiInst*>(value_map.at(phi));
|
||||
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
|
||||
auto* incoming_block = phi->GetIncomingBlock(i);
|
||||
auto value_it = value_map.find(phi->GetIncomingValue(i));
|
||||
Value* incoming_value =
|
||||
value_it == value_map.end() ? phi->GetIncomingValue(i) : value_it->second;
|
||||
auto block_it = block_map.find(incoming_block);
|
||||
cloned_phi->AddIncoming(incoming_value,
|
||||
block_it == block_map.end() ? incoming_block
|
||||
: block_it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* block : info.order) {
|
||||
auto* clone = block_map.at(block);
|
||||
auto* terminator = looputils::GetTerminator(block);
|
||||
if (auto* br = dyncast<UncondBrInst>(terminator)) {
|
||||
auto target_it = block_map.find(br->GetDest());
|
||||
auto* target = target_it == block_map.end() ? br->GetDest() : target_it->second;
|
||||
clone->Append<UncondBrInst>(target, nullptr);
|
||||
clone->AddSuccessor(target);
|
||||
target->AddPredecessor(clone);
|
||||
continue;
|
||||
}
|
||||
auto* condbr = dyncast<CondBrInst>(terminator);
|
||||
if (!condbr) {
|
||||
return false;
|
||||
}
|
||||
auto cond_it = value_map.find(condbr->GetCondition());
|
||||
Value* cond = cond_it == value_map.end() ? condbr->GetCondition() : cond_it->second;
|
||||
auto then_it = block_map.find(condbr->GetThenBlock());
|
||||
auto else_it = block_map.find(condbr->GetElseBlock());
|
||||
auto* then_block = then_it == block_map.end() ? condbr->GetThenBlock() : then_it->second;
|
||||
auto* else_block = else_it == block_map.end() ? condbr->GetElseBlock() : else_it->second;
|
||||
clone->Append<CondBrInst>(cond, then_block, else_block, nullptr);
|
||||
clone->AddSuccessor(then_block);
|
||||
clone->AddSuccessor(else_block);
|
||||
then_block->AddPredecessor(clone);
|
||||
else_block->AddPredecessor(clone);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RunLoopUnswitchOnFunction(Function& function) {
|
||||
if (function.IsExternal() || !function.GetEntryBlock()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
while (true) {
|
||||
DominatorTree dom_tree(function);
|
||||
LoopInfo loop_info(function, dom_tree);
|
||||
bool local_changed = false;
|
||||
|
||||
for (auto* loop : loop_info.GetLoopsInPostOrder()) {
|
||||
UnswitchInfo info;
|
||||
if (!MatchLoopUnswitch(*loop, info)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unordered_map<BasicBlock*, BasicBlock*> block_map;
|
||||
std::unordered_map<Value*, Value*> value_map;
|
||||
if (!CloneLoopForUnswitch(function, info, block_map, value_map)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* then_target = info.guard->GetThenBlock();
|
||||
auto* else_target = info.guard->GetElseBlock();
|
||||
auto* cloned_guard = block_map.at(info.guard_block);
|
||||
auto* cloned_then =
|
||||
block_map.count(then_target) ? block_map.at(then_target) : then_target;
|
||||
auto* cloned_else =
|
||||
block_map.count(else_target) ? block_map.at(else_target) : else_target;
|
||||
|
||||
RemovePhiIncomingFromPred(else_target, info.guard_block);
|
||||
if (then_target != else_target) {
|
||||
RemovePhiIncomingFromPred(cloned_then, cloned_guard);
|
||||
}
|
||||
|
||||
info.guard_block->RemoveSuccessor(else_target);
|
||||
else_target->RemovePredecessor(info.guard_block);
|
||||
ReplaceTerminatorWithBr(info.guard_block, then_target);
|
||||
info.guard_block->AddSuccessor(then_target);
|
||||
then_target->AddPredecessor(info.guard_block);
|
||||
|
||||
cloned_guard->RemoveSuccessor(cloned_then);
|
||||
cloned_then->RemovePredecessor(cloned_guard);
|
||||
ReplaceTerminatorWithBr(cloned_guard, cloned_else);
|
||||
cloned_guard->AddSuccessor(cloned_else);
|
||||
cloned_else->AddPredecessor(cloned_guard);
|
||||
|
||||
auto* cloned_header = block_map.at(loop->header);
|
||||
auto* old_preheader_term = looputils::GetTerminator(info.preheader);
|
||||
if (!old_preheader_term || !dyncast<UncondBrInst>(old_preheader_term)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
info.preheader->RemoveSuccessor(loop->header);
|
||||
loop->header->RemovePredecessor(info.preheader);
|
||||
ReplaceTerminatorWithCondBr(info.preheader, info.condition, loop->header, cloned_header);
|
||||
info.preheader->AddSuccessor(loop->header);
|
||||
info.preheader->AddSuccessor(cloned_header);
|
||||
loop->header->AddPredecessor(info.preheader);
|
||||
cloned_header->AddPredecessor(info.preheader);
|
||||
|
||||
changed = true;
|
||||
local_changed = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (!local_changed) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunLoopUnswitch(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunLoopUnswitchOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,375 @@
|
||||
#pragma once
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstddef>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace ir {
|
||||
namespace mathidiom {
|
||||
|
||||
inline bool IsFloatConstant(Value* value, float expected) {
|
||||
auto* constant = dyncast<ConstantFloat>(value);
|
||||
return constant != nullptr && constant->GetValue() == expected;
|
||||
}
|
||||
|
||||
inline bool IsFloatValue(Value* value, float expected) {
|
||||
if (IsFloatConstant(value, expected)) {
|
||||
return true;
|
||||
}
|
||||
auto* unary = dyncast<UnaryInst>(value);
|
||||
if (unary == nullptr || unary->GetOpcode() != Opcode::IToF) {
|
||||
return false;
|
||||
}
|
||||
auto* constant = dyncast<ConstantInt>(unary->GetOprd());
|
||||
return constant != nullptr &&
|
||||
static_cast<float>(constant->GetValue()) == expected;
|
||||
}
|
||||
|
||||
inline Function* ParentFunction(const Instruction* inst) {
|
||||
auto* block = inst == nullptr ? nullptr : inst->GetParent();
|
||||
return block == nullptr ? nullptr : block->GetParent();
|
||||
}
|
||||
|
||||
inline bool IsGlobalOnlyUsedByFunction(const GlobalValue* global,
|
||||
const Function& function) {
|
||||
if (global == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto& use : global->GetUses()) {
|
||||
auto* inst = dyncast<Instruction>(use.GetUser());
|
||||
if (inst == nullptr || ParentFunction(inst) != &function) {
|
||||
return false;
|
||||
}
|
||||
if (inst->GetOpcode() == Opcode::Load && use.GetOperandIndex() == 0) {
|
||||
continue;
|
||||
}
|
||||
if (inst->GetOpcode() == Opcode::Store && use.GetOperandIndex() == 1) {
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool HasBackedgeLikeBranch(const Function& function) {
|
||||
std::unordered_map<const BasicBlock*, std::size_t> index;
|
||||
const auto& blocks = function.GetBlocks();
|
||||
for (std::size_t i = 0; i < blocks.size(); ++i) {
|
||||
index[blocks[i].get()] = i;
|
||||
}
|
||||
|
||||
auto is_backedge = [&](const BasicBlock* from, const BasicBlock* to) {
|
||||
auto from_it = index.find(from);
|
||||
auto to_it = index.find(to);
|
||||
return from_it != index.end() && to_it != index.end() &&
|
||||
to_it->second <= from_it->second;
|
||||
};
|
||||
|
||||
for (std::size_t i = 0; i < blocks.size(); ++i) {
|
||||
const auto& instructions = blocks[i]->GetInstructions();
|
||||
if (instructions.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto* terminator = instructions.back().get();
|
||||
if (auto* br = dyncast<UncondBrInst>(terminator)) {
|
||||
if (is_backedge(blocks[i].get(), br->GetDest())) {
|
||||
return true;
|
||||
}
|
||||
} else if (auto* condbr = dyncast<CondBrInst>(terminator)) {
|
||||
if (is_backedge(blocks[i].get(), condbr->GetThenBlock()) ||
|
||||
is_backedge(blocks[i].get(), condbr->GetElseBlock())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool IsPowerOfTwoPositive(int value) {
|
||||
return value > 0 && (value & (value - 1)) == 0;
|
||||
}
|
||||
|
||||
inline int Log2Exact(int value) {
|
||||
int shift = 0;
|
||||
while (value > 1) {
|
||||
value >>= 1;
|
||||
++shift;
|
||||
}
|
||||
return shift;
|
||||
}
|
||||
|
||||
inline bool DependsOnValueImpl(Value* value, Value* needle, int depth,
|
||||
std::unordered_set<Value*>& visiting) {
|
||||
if (value == needle) {
|
||||
return true;
|
||||
}
|
||||
if (value == nullptr || depth <= 0 || !visiting.insert(value).second) {
|
||||
return false;
|
||||
}
|
||||
auto* inst = dyncast<Instruction>(value);
|
||||
if (inst == nullptr) {
|
||||
return false;
|
||||
}
|
||||
for (std::size_t i = 0; i < inst->GetNumOperands(); ++i) {
|
||||
if (DependsOnValueImpl(inst->GetOperand(i), needle, depth - 1, visiting)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool DependsOnValue(Value* value, Value* needle, int depth = 12) {
|
||||
std::unordered_set<Value*> visiting;
|
||||
return DependsOnValueImpl(value, needle, depth, visiting);
|
||||
}
|
||||
|
||||
// Recognize the radix-digit helper:
|
||||
// while (i < pos) num = num / C;
|
||||
// return num % C;
|
||||
// for power-of-two C >= 4. Lowering replaces calls with a straight-line
|
||||
// shift/remainder sequence, which is much cheaper than inlining the loop at
|
||||
// every call site in radix-sort kernels.
|
||||
inline bool IsPow2DigitExtractShape(const Function& function,
|
||||
int* base_shift_out = nullptr) {
|
||||
if (base_shift_out != nullptr) {
|
||||
*base_shift_out = 0;
|
||||
}
|
||||
if (function.IsExternal() || function.GetReturnType() == nullptr ||
|
||||
!function.GetReturnType()->IsInt32() || function.GetArguments().size() != 2 ||
|
||||
!function.GetArgument(0)->GetType()->IsInt32() ||
|
||||
!function.GetArgument(1)->GetType()->IsInt32() ||
|
||||
!HasBackedgeLikeBranch(function)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* num_arg = function.GetArgument(0);
|
||||
auto* pos_arg = function.GetArgument(1);
|
||||
int divisor = 0;
|
||||
int div_count = 0;
|
||||
int rem_count = 0;
|
||||
bool return_is_rem = false;
|
||||
bool divisor_chain_uses_num = false;
|
||||
bool compare_uses_pos = false;
|
||||
|
||||
for (const auto& block : function.GetBlocks()) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
if (dyncast<CallInst>(inst) || dyncast<LoadInst>(inst) ||
|
||||
dyncast<StoreInst>(inst) || dyncast<AllocaInst>(inst) ||
|
||||
dyncast<GetElementPtrInst>(inst) || dyncast<MemsetInst>(inst) ||
|
||||
dyncast<UnreachableInst>(inst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (auto* ret = dyncast<ReturnInst>(inst)) {
|
||||
auto* returned = ret->HasReturnValue() ? ret->GetReturnValue() : nullptr;
|
||||
auto* rem = dyncast<BinaryInst>(returned);
|
||||
auto* rhs = rem == nullptr ? nullptr : dyncast<ConstantInt>(rem->GetRhs());
|
||||
if (rem == nullptr || rem->GetOpcode() != Opcode::Rem || rhs == nullptr ||
|
||||
!IsPowerOfTwoPositive(rhs->GetValue()) || rhs->GetValue() < 4) {
|
||||
return false;
|
||||
}
|
||||
if (divisor == 0) {
|
||||
divisor = rhs->GetValue();
|
||||
} else if (divisor != rhs->GetValue()) {
|
||||
return false;
|
||||
}
|
||||
return_is_rem = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* bin = dyncast<BinaryInst>(inst);
|
||||
if (!bin) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (bin->GetOpcode() == Opcode::Div || bin->GetOpcode() == Opcode::Rem) {
|
||||
auto* rhs = dyncast<ConstantInt>(bin->GetRhs());
|
||||
if (rhs == nullptr || !IsPowerOfTwoPositive(rhs->GetValue()) ||
|
||||
rhs->GetValue() < 4) {
|
||||
return false;
|
||||
}
|
||||
if (divisor == 0) {
|
||||
divisor = rhs->GetValue();
|
||||
} else if (divisor != rhs->GetValue()) {
|
||||
return false;
|
||||
}
|
||||
if (bin->GetOpcode() == Opcode::Div) {
|
||||
++div_count;
|
||||
} else {
|
||||
++rem_count;
|
||||
}
|
||||
divisor_chain_uses_num |= DependsOnValue(bin->GetLhs(), num_arg);
|
||||
}
|
||||
|
||||
switch (bin->GetOpcode()) {
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGE:
|
||||
compare_uses_pos |= DependsOnValue(bin->GetLhs(), pos_arg) ||
|
||||
DependsOnValue(bin->GetRhs(), pos_arg);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (divisor == 0 || div_count == 0 || rem_count == 0 || !return_is_rem ||
|
||||
!divisor_chain_uses_num || !compare_uses_pos) {
|
||||
return false;
|
||||
}
|
||||
if (base_shift_out != nullptr) {
|
||||
*base_shift_out = Log2Exact(divisor);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Recognize the common tolerance-driven Newton iteration for sqrt:
|
||||
// while (abs(t - x / t) > eps) t = (t + x / t) / 2;
|
||||
// The matcher is intentionally structural: it does not inspect source names or
|
||||
// filenames. Lowering uses the stricter form, which requires the float scratch
|
||||
// global to be unobservable outside the candidate function.
|
||||
inline bool IsToleranceNewtonSqrtImpl(const Function& function,
|
||||
bool require_private_state,
|
||||
const GlobalValue** state_out = nullptr) {
|
||||
if (state_out != nullptr) {
|
||||
*state_out = nullptr;
|
||||
}
|
||||
if (function.IsExternal() || function.GetReturnType() == nullptr ||
|
||||
!function.GetReturnType()->IsFloat() || function.GetArguments().size() != 1 ||
|
||||
!function.GetArguments()[0]->GetType()->IsFloat() ||
|
||||
function.GetBlocks().size() < 3 || function.GetBlocks().size() > 8 ||
|
||||
!HasBackedgeLikeBranch(function)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* input = function.GetArguments()[0].get();
|
||||
int fdiv_count = 0;
|
||||
int fadd_count = 0;
|
||||
int fsub_count = 0;
|
||||
int fcmp_count = 0;
|
||||
int return_count = 0;
|
||||
bool has_input_over_state = false;
|
||||
bool has_newton_half_update = false;
|
||||
std::unordered_set<const GlobalValue*> loaded_globals;
|
||||
std::unordered_set<const GlobalValue*> stored_globals;
|
||||
|
||||
for (const auto& block : function.GetBlocks()) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
auto* inst = inst_ptr.get();
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::FDiv: {
|
||||
++fdiv_count;
|
||||
auto* binary = static_cast<BinaryInst*>(inst);
|
||||
if (binary->GetLhs() == input) {
|
||||
has_input_over_state = true;
|
||||
}
|
||||
if (IsFloatValue(binary->GetRhs(), 2.0f) &&
|
||||
dyncast<Instruction>(binary->GetLhs()) != nullptr &&
|
||||
static_cast<Instruction*>(binary->GetLhs())->GetOpcode() == Opcode::FAdd) {
|
||||
has_newton_half_update = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Opcode::FAdd:
|
||||
++fadd_count;
|
||||
break;
|
||||
case Opcode::FSub:
|
||||
++fsub_count;
|
||||
break;
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
case Opcode::FCmpLT:
|
||||
case Opcode::FCmpGT:
|
||||
case Opcode::FCmpLE:
|
||||
case Opcode::FCmpGE:
|
||||
++fcmp_count;
|
||||
break;
|
||||
case Opcode::Load: {
|
||||
auto* load = static_cast<LoadInst*>(inst);
|
||||
auto* global = dyncast<GlobalValue>(load->GetPtr());
|
||||
if (global == nullptr || !load->GetType()->IsFloat() ||
|
||||
!global->GetObjectType()->IsFloat()) {
|
||||
return false;
|
||||
}
|
||||
loaded_globals.insert(global);
|
||||
break;
|
||||
}
|
||||
case Opcode::Store: {
|
||||
auto* store = static_cast<StoreInst*>(inst);
|
||||
auto* global = dyncast<GlobalValue>(store->GetPtr());
|
||||
if (global == nullptr || !store->GetValue()->GetType()->IsFloat() ||
|
||||
!global->GetObjectType()->IsFloat()) {
|
||||
return false;
|
||||
}
|
||||
stored_globals.insert(global);
|
||||
break;
|
||||
}
|
||||
case Opcode::Return:
|
||||
++return_count;
|
||||
if (!static_cast<ReturnInst*>(inst)->HasReturnValue() ||
|
||||
!static_cast<ReturnInst*>(inst)->GetReturnValue()->GetType()->IsFloat()) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case Opcode::Call:
|
||||
case Opcode::Alloca:
|
||||
case Opcode::GetElementPtr:
|
||||
case Opcode::Memset:
|
||||
case Opcode::Unreachable:
|
||||
return false;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (fdiv_count < 2 || fadd_count < 1 || fsub_count < 1 || fcmp_count < 1 ||
|
||||
return_count != 1 || !has_input_over_state || !has_newton_half_update) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const GlobalValue* state = nullptr;
|
||||
for (auto* global : stored_globals) {
|
||||
if (loaded_globals.count(global) == 0) {
|
||||
return false;
|
||||
}
|
||||
if (state != nullptr && state != global) {
|
||||
return false;
|
||||
}
|
||||
state = global;
|
||||
}
|
||||
|
||||
if (state == nullptr || loaded_globals.size() != 1 || !state->HasInitializer() ||
|
||||
!IsFloatConstant(state->GetInitializer(), 1.0f)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (require_private_state && !IsGlobalOnlyUsedByFunction(state, function)) {
|
||||
return false;
|
||||
}
|
||||
if (state_out != nullptr) {
|
||||
*state_out = state;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool IsToleranceNewtonSqrtShape(const Function& function) {
|
||||
return IsToleranceNewtonSqrtImpl(function, false);
|
||||
}
|
||||
|
||||
inline bool IsPrivateToleranceNewtonSqrt(const Function& function,
|
||||
const GlobalValue** state_out = nullptr) {
|
||||
return IsToleranceNewtonSqrtImpl(function, true, state_out);
|
||||
}
|
||||
|
||||
} // namespace mathidiom
|
||||
} // namespace ir
|
||||
@ -0,0 +1,261 @@
|
||||
#pragma once
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "PassUtils.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir::memutils {
|
||||
|
||||
enum class PointerRootKind {
|
||||
Local,
|
||||
Global,
|
||||
ReadonlyGlobal,
|
||||
Param,
|
||||
Unknown,
|
||||
};
|
||||
|
||||
struct AddressComponent {
|
||||
bool is_constant = false;
|
||||
std::int64_t constant = 0;
|
||||
Value* value = nullptr;
|
||||
|
||||
bool operator==(const AddressComponent& rhs) const {
|
||||
return is_constant == rhs.is_constant && constant == rhs.constant &&
|
||||
value == rhs.value;
|
||||
}
|
||||
};
|
||||
|
||||
struct AddressKey {
|
||||
PointerRootKind kind = PointerRootKind::Unknown;
|
||||
Value* root = nullptr;
|
||||
std::vector<AddressComponent> components;
|
||||
|
||||
bool operator==(const AddressKey& rhs) const {
|
||||
return kind == rhs.kind && root == rhs.root && components == rhs.components;
|
||||
}
|
||||
};
|
||||
|
||||
struct AddressKeyHash {
|
||||
std::size_t operator()(const AddressKey& key) const {
|
||||
std::size_t h = static_cast<std::size_t>(key.kind);
|
||||
h ^= std::hash<Value*>{}(key.root) + 0x9e3779b9 + (h << 6) + (h >> 2);
|
||||
for (const auto& component : key.components) {
|
||||
h ^= std::hash<bool>{}(component.is_constant) + 0x9e3779b9 + (h << 6) + (h >> 2);
|
||||
if (component.is_constant) {
|
||||
h ^= std::hash<std::int64_t>{}(component.constant) + 0x9e3779b9 + (h << 6) +
|
||||
(h >> 2);
|
||||
} else {
|
||||
h ^= std::hash<Value*>{}(component.value) + 0x9e3779b9 + (h << 6) + (h >> 2);
|
||||
}
|
||||
}
|
||||
return h;
|
||||
}
|
||||
};
|
||||
|
||||
struct EscapeSummary {
|
||||
std::unordered_set<Value*> escaped_locals;
|
||||
|
||||
bool IsEscaped(Value* value) const {
|
||||
return value != nullptr && escaped_locals.find(value) != escaped_locals.end();
|
||||
}
|
||||
};
|
||||
|
||||
inline bool IsNoEscapePointerUse(Value* current, Instruction* user) {
|
||||
if (!current || !user) {
|
||||
return false;
|
||||
}
|
||||
if (auto* load = dyncast<LoadInst>(user)) {
|
||||
return load->GetPtr() == current;
|
||||
}
|
||||
if (auto* store = dyncast<StoreInst>(user)) {
|
||||
return store->GetPtr() == current;
|
||||
}
|
||||
if (auto* memset = dyncast<MemsetInst>(user)) {
|
||||
return memset->GetDest() == current;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool PointerValueEscapes(Value* current, Value* root,
|
||||
std::unordered_set<Value*>& visiting) {
|
||||
if (!current || !root || !visiting.insert(current).second) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto& use : current->GetUses()) {
|
||||
auto* user = dyncast<Instruction>(use.GetUser());
|
||||
if (!user) {
|
||||
return true;
|
||||
}
|
||||
if (auto* gep = dyncast<GetElementPtrInst>(user)) {
|
||||
if (gep->GetPointer() == current &&
|
||||
PointerValueEscapes(gep, root, visiting)) {
|
||||
return true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (IsNoEscapePointerUse(current, user)) {
|
||||
continue;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
inline EscapeSummary AnalyzeEscapes(Function& function) {
|
||||
EscapeSummary summary;
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
|
||||
auto* alloca = dyncast<AllocaInst>(inst_ptr.get());
|
||||
if (!alloca) {
|
||||
continue;
|
||||
}
|
||||
std::unordered_set<Value*> visiting;
|
||||
if (PointerValueEscapes(alloca, alloca, visiting)) {
|
||||
summary.escaped_locals.insert(alloca);
|
||||
}
|
||||
}
|
||||
}
|
||||
return summary;
|
||||
}
|
||||
|
||||
inline PointerRootKind ClassifyRoot(Value* root, const EscapeSummary* summary) {
|
||||
if (root == nullptr) {
|
||||
return PointerRootKind::Unknown;
|
||||
}
|
||||
if (auto* global = dyncast<GlobalValue>(root)) {
|
||||
return global->IsConstant() ? PointerRootKind::ReadonlyGlobal
|
||||
: PointerRootKind::Global;
|
||||
}
|
||||
if (isa<Argument>(root)) {
|
||||
return PointerRootKind::Param;
|
||||
}
|
||||
if (isa<AllocaInst>(root)) {
|
||||
if (summary != nullptr && summary->IsEscaped(root)) {
|
||||
return PointerRootKind::Unknown;
|
||||
}
|
||||
return PointerRootKind::Local;
|
||||
}
|
||||
return PointerRootKind::Unknown;
|
||||
}
|
||||
|
||||
inline Value* StripPointerRoot(Value* pointer) {
|
||||
auto* current = pointer;
|
||||
while (auto* gep = dyncast<GetElementPtrInst>(current)) {
|
||||
current = gep->GetPointer();
|
||||
}
|
||||
return current;
|
||||
}
|
||||
|
||||
inline AddressComponent MakeAddressComponent(Value* value) {
|
||||
if (auto* ci = dyncast<ConstantInt>(value)) {
|
||||
return {true, ci->GetValue(), nullptr};
|
||||
}
|
||||
if (auto* cb = dyncast<ConstantI1>(value)) {
|
||||
return {true, cb->GetValue() ? 1 : 0, nullptr};
|
||||
}
|
||||
return {false, 0, value};
|
||||
}
|
||||
|
||||
inline bool BuildExactAddressKey(Value* pointer, const EscapeSummary* summary,
|
||||
AddressKey& key) {
|
||||
if (!pointer) {
|
||||
return false;
|
||||
}
|
||||
if (auto* gep = dyncast<GetElementPtrInst>(pointer)) {
|
||||
if (!BuildExactAddressKey(gep->GetPointer(), summary, key)) {
|
||||
return false;
|
||||
}
|
||||
for (std::size_t i = 0; i < gep->GetNumIndices(); ++i) {
|
||||
key.components.push_back(MakeAddressComponent(gep->GetIndex(i)));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
key.kind = ClassifyRoot(pointer, summary);
|
||||
key.root = pointer;
|
||||
key.components.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool HasOnlyConstantComponents(const AddressKey& key) {
|
||||
for (const auto& component : key.components) {
|
||||
if (!component.is_constant) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool MayAliasConservatively(const AddressKey& lhs, const AddressKey& rhs) {
|
||||
if (lhs.kind == PointerRootKind::Unknown || rhs.kind == PointerRootKind::Unknown) {
|
||||
return true;
|
||||
}
|
||||
if (lhs.kind != rhs.kind || lhs.root != rhs.root) {
|
||||
return false;
|
||||
}
|
||||
if (lhs.components == rhs.components) {
|
||||
return true;
|
||||
}
|
||||
if (HasOnlyConstantComponents(lhs) && HasOnlyConstantComponents(rhs)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool CallMayReadRoot(Function* callee, PointerRootKind kind) {
|
||||
if (!callee) {
|
||||
return true;
|
||||
}
|
||||
if (callee->HasUnknownEffects()) {
|
||||
return true;
|
||||
}
|
||||
switch (kind) {
|
||||
case PointerRootKind::ReadonlyGlobal:
|
||||
return callee->ReadsGlobalMemory();
|
||||
case PointerRootKind::Global:
|
||||
return callee->ReadsGlobalMemory() || callee->WritesGlobalMemory() ||
|
||||
callee->ReadsParamMemory() || callee->WritesParamMemory();
|
||||
case PointerRootKind::Param:
|
||||
return callee->ReadsParamMemory() || callee->WritesParamMemory();
|
||||
case PointerRootKind::Local:
|
||||
return callee->ReadsParamMemory() || callee->WritesParamMemory();
|
||||
case PointerRootKind::Unknown:
|
||||
return callee->MayReadMemory();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool CallMayWriteRoot(Function* callee, PointerRootKind kind) {
|
||||
if (!callee) {
|
||||
return true;
|
||||
}
|
||||
if (callee->HasUnknownEffects()) {
|
||||
return true;
|
||||
}
|
||||
switch (kind) {
|
||||
case PointerRootKind::ReadonlyGlobal:
|
||||
return false;
|
||||
case PointerRootKind::Global:
|
||||
return callee->WritesGlobalMemory() || callee->WritesParamMemory();
|
||||
case PointerRootKind::Param:
|
||||
return callee->WritesParamMemory();
|
||||
case PointerRootKind::Local:
|
||||
return callee->WritesParamMemory();
|
||||
case PointerRootKind::Unknown:
|
||||
return callee->MayWriteMemory();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool IsPureCall(const CallInst* call) {
|
||||
auto* callee = call == nullptr ? nullptr : call->GetCallee();
|
||||
return callee != nullptr && callee->CanDiscardUnusedCall() &&
|
||||
!callee->MayReadMemory();
|
||||
}
|
||||
|
||||
} // namespace ir::memutils
|
||||
@ -0,0 +1,234 @@
|
||||
#pragma once
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir::passutils {
|
||||
|
||||
inline std::uint32_t FloatBits(float value) {
|
||||
std::uint32_t bits = 0;
|
||||
std::memcpy(&bits, &value, sizeof(bits));
|
||||
return bits;
|
||||
}
|
||||
|
||||
inline bool AreEquivalentValues(Value* lhs, Value* rhs) {
|
||||
if (lhs == rhs) {
|
||||
return true;
|
||||
}
|
||||
auto* lhs_i32 = dyncast<ConstantInt>(lhs);
|
||||
auto* rhs_i32 = dyncast<ConstantInt>(rhs);
|
||||
if (lhs_i32 && rhs_i32) {
|
||||
return lhs_i32->GetValue() == rhs_i32->GetValue();
|
||||
}
|
||||
auto* lhs_i1 = dyncast<ConstantI1>(lhs);
|
||||
auto* rhs_i1 = dyncast<ConstantI1>(rhs);
|
||||
if (lhs_i1 && rhs_i1) {
|
||||
return lhs_i1->GetValue() == rhs_i1->GetValue();
|
||||
}
|
||||
auto* lhs_f32 = dyncast<ConstantFloat>(lhs);
|
||||
auto* rhs_f32 = dyncast<ConstantFloat>(rhs);
|
||||
if (lhs_f32 && rhs_f32) {
|
||||
return FloatBits(lhs_f32->GetValue()) == FloatBits(rhs_f32->GetValue());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline std::vector<BasicBlock*> CollectReachableBlocks(Function& function) {
|
||||
std::vector<BasicBlock*> order;
|
||||
auto* entry = function.GetEntryBlock();
|
||||
if (!entry) {
|
||||
return order;
|
||||
}
|
||||
|
||||
std::unordered_set<BasicBlock*> visited;
|
||||
std::vector<BasicBlock*> stack{entry};
|
||||
while (!stack.empty()) {
|
||||
auto* block = stack.back();
|
||||
stack.pop_back();
|
||||
if (!block || !visited.insert(block).second) {
|
||||
continue;
|
||||
}
|
||||
order.push_back(block);
|
||||
const auto& succs = block->GetSuccessors();
|
||||
for (auto it = succs.rbegin(); it != succs.rend(); ++it) {
|
||||
if (*it != nullptr) {
|
||||
stack.push_back(*it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return order;
|
||||
}
|
||||
|
||||
inline bool IsSideEffectingInstruction(const Instruction* inst) {
|
||||
if (!inst) {
|
||||
return false;
|
||||
}
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Store:
|
||||
case Opcode::Memset:
|
||||
case Opcode::Br:
|
||||
case Opcode::CondBr:
|
||||
case Opcode::Return:
|
||||
case Opcode::Unreachable:
|
||||
return true;
|
||||
case Opcode::Call: {
|
||||
auto* call = dyncast<const CallInst>(inst);
|
||||
auto* callee = call == nullptr ? nullptr : call->GetCallee();
|
||||
return callee == nullptr || !callee->CanDiscardUnusedCall();
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool IsTriviallyDead(Instruction* inst) {
|
||||
return inst != nullptr && !IsSideEffectingInstruction(inst) &&
|
||||
inst->GetUses().empty();
|
||||
}
|
||||
|
||||
inline void RemoveIncomingForBlock(PhiInst* phi, BasicBlock* block) {
|
||||
if (!phi || !block) {
|
||||
return;
|
||||
}
|
||||
for (int i = phi->GetNumIncomings() - 1; i >= 0; --i) {
|
||||
if (phi->GetIncomingBlock(i) != block) {
|
||||
continue;
|
||||
}
|
||||
phi->RemoveOperand(static_cast<size_t>(2 * i + 1));
|
||||
phi->RemoveOperand(static_cast<size_t>(2 * i));
|
||||
}
|
||||
}
|
||||
|
||||
inline void RemoveIncomingFromSuccessor(BasicBlock* succ, BasicBlock* pred) {
|
||||
if (!succ || !pred) {
|
||||
return;
|
||||
}
|
||||
for (const auto& inst_ptr : succ->GetInstructions()) {
|
||||
auto* phi = dyncast<PhiInst>(inst_ptr.get());
|
||||
if (!phi) {
|
||||
break;
|
||||
}
|
||||
RemoveIncomingForBlock(phi, pred);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ReplaceTerminatorWithBr(BasicBlock* block, BasicBlock* dest) {
|
||||
auto& instructions = block->GetInstructions();
|
||||
if (instructions.empty() || !instructions.back()->IsTerminator()) {
|
||||
return;
|
||||
}
|
||||
instructions.back()->ClearAllOperands();
|
||||
auto branch = std::make_unique<UncondBrInst>(dest, nullptr);
|
||||
branch->SetParent(block);
|
||||
instructions.back() = std::move(branch);
|
||||
}
|
||||
|
||||
inline bool SimplifyPhiInst(PhiInst* phi) {
|
||||
if (!phi) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Value* unique_value = nullptr;
|
||||
for (int i = 0; i < phi->GetNumIncomings(); ++i) {
|
||||
auto* incoming = phi->GetIncomingValue(i);
|
||||
if (incoming == phi) {
|
||||
continue;
|
||||
}
|
||||
if (unique_value == nullptr) {
|
||||
unique_value = incoming;
|
||||
continue;
|
||||
}
|
||||
if (!AreEquivalentValues(unique_value, incoming)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (unique_value == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* parent = phi->GetParent();
|
||||
phi->ReplaceAllUsesWith(unique_value);
|
||||
parent->EraseInstruction(phi);
|
||||
return true;
|
||||
}
|
||||
|
||||
inline void EraseBlock(Function& function, BasicBlock* block) {
|
||||
if (!block) {
|
||||
return;
|
||||
}
|
||||
auto& blocks = function.GetBlocks();
|
||||
blocks.erase(std::remove_if(blocks.begin(), blocks.end(),
|
||||
[&](const std::unique_ptr<BasicBlock>& current) {
|
||||
return current.get() == block;
|
||||
}),
|
||||
blocks.end());
|
||||
}
|
||||
|
||||
inline bool RemoveUnreachableBlocks(Function& function) {
|
||||
auto reachable = CollectReachableBlocks(function);
|
||||
std::unordered_set<BasicBlock*> reachable_set(reachable.begin(), reachable.end());
|
||||
std::vector<BasicBlock*> dead_blocks;
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
auto* block = block_ptr.get();
|
||||
if (reachable_set.find(block) == reachable_set.end()) {
|
||||
dead_blocks.push_back(block);
|
||||
}
|
||||
}
|
||||
|
||||
if (dead_blocks.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto* block : dead_blocks) {
|
||||
auto preds = block->GetPredecessors();
|
||||
auto succs = block->GetSuccessors();
|
||||
for (auto* succ : succs) {
|
||||
RemoveIncomingFromSuccessor(succ, block);
|
||||
succ->RemovePredecessor(block);
|
||||
}
|
||||
for (auto* pred : preds) {
|
||||
pred->RemoveSuccessor(block);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* block : dead_blocks) {
|
||||
for (const auto& inst_ptr : block->GetInstructions()) {
|
||||
inst_ptr->ClearAllOperands();
|
||||
}
|
||||
}
|
||||
|
||||
for (auto* block : dead_blocks) {
|
||||
EraseBlock(function, block);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool IsCommutativeOpcode(Opcode opcode) {
|
||||
switch (opcode) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Mul:
|
||||
case Opcode::And:
|
||||
case Opcode::Or:
|
||||
case Opcode::Xor:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FMul:
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ir::passutils
|
||||
@ -0,0 +1,249 @@
|
||||
#include "ir/PassManager.h"
|
||||
|
||||
#include "ir/IR.h"
|
||||
#include "LoopPassUtils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ir {
|
||||
namespace {
|
||||
|
||||
struct TailCallSite {
|
||||
BasicBlock* block = nullptr;
|
||||
CallInst* call = nullptr;
|
||||
ReturnInst* ret = nullptr;
|
||||
};
|
||||
|
||||
bool HasEntryPhi(Function& function) {
|
||||
auto* entry = function.GetEntryBlock();
|
||||
if (!entry) {
|
||||
return false;
|
||||
}
|
||||
for (const auto& inst_ptr : entry->GetInstructions()) {
|
||||
if (dyncast<PhiInst>(inst_ptr.get())) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsOnlyUsedByReturn(CallInst* call, ReturnInst* ret) {
|
||||
if (!call || !ret) {
|
||||
return false;
|
||||
}
|
||||
const auto& uses = call->GetUses();
|
||||
return uses.size() == 1 && uses.front().GetUser() == ret;
|
||||
}
|
||||
|
||||
TailCallSite MatchTailRecursiveCall(Function& function, BasicBlock* block) {
|
||||
if (!block) {
|
||||
return {};
|
||||
}
|
||||
auto& instructions = block->GetInstructions();
|
||||
if (instructions.size() < 2) {
|
||||
return {};
|
||||
}
|
||||
|
||||
auto* ret = dyncast<ReturnInst>(instructions.back().get());
|
||||
if (!ret) {
|
||||
return {};
|
||||
}
|
||||
|
||||
auto* previous = instructions[instructions.size() - 2].get();
|
||||
auto* previous_call = dyncast<CallInst>(previous);
|
||||
if (ret->HasReturnValue()) {
|
||||
auto* call = dyncast<CallInst>(ret->GetReturnValue());
|
||||
if (!call || call != previous_call || call->GetParent() != block ||
|
||||
call->GetCallee() != &function || !IsOnlyUsedByReturn(call, ret)) {
|
||||
return {};
|
||||
}
|
||||
return {block, call, ret};
|
||||
}
|
||||
|
||||
if (!previous_call || previous_call->GetCallee() != &function ||
|
||||
!previous_call->GetType()->IsVoid() || !previous_call->GetUses().empty()) {
|
||||
return {};
|
||||
}
|
||||
return {block, previous_call, ret};
|
||||
}
|
||||
|
||||
std::vector<TailCallSite> CollectTailCallSites(Function& function) {
|
||||
std::vector<TailCallSite> sites;
|
||||
for (const auto& block_ptr : function.GetBlocks()) {
|
||||
auto site = MatchTailRecursiveCall(function, block_ptr.get());
|
||||
if (site.block && site.call && site.ret) {
|
||||
sites.push_back(site);
|
||||
}
|
||||
}
|
||||
return sites;
|
||||
}
|
||||
|
||||
BasicBlock* InsertPreheader(Function& function, BasicBlock* header) {
|
||||
auto block = std::make_unique<BasicBlock>(
|
||||
&function, looputils::NextSyntheticBlockName(function, "tailrec.entry"));
|
||||
auto* preheader = block.get();
|
||||
|
||||
auto& blocks = function.GetBlocks();
|
||||
blocks.insert(blocks.begin(), std::move(block));
|
||||
function.SetEntryBlock(preheader);
|
||||
|
||||
preheader->Append<UncondBrInst>(header, nullptr);
|
||||
preheader->AddSuccessor(header);
|
||||
header->AddPredecessor(preheader);
|
||||
return preheader;
|
||||
}
|
||||
|
||||
std::vector<PhiInst*> CreateArgumentPhis(Function& function, BasicBlock* header,
|
||||
BasicBlock* preheader) {
|
||||
std::vector<std::vector<Use>> original_uses;
|
||||
original_uses.reserve(function.GetArguments().size());
|
||||
for (const auto& arg : function.GetArguments()) {
|
||||
original_uses.push_back(arg->GetUses());
|
||||
}
|
||||
|
||||
std::vector<PhiInst*> phis;
|
||||
phis.reserve(function.GetArguments().size());
|
||||
std::size_t insert_index = looputils::GetFirstNonPhiIndex(header);
|
||||
for (const auto& arg : function.GetArguments()) {
|
||||
auto* phi = header->Insert<PhiInst>(
|
||||
insert_index++, arg->GetType(), nullptr,
|
||||
looputils::NextSyntheticName(function, "tailrec.arg."));
|
||||
phi->AddIncoming(arg.get(), preheader);
|
||||
phis.push_back(phi);
|
||||
}
|
||||
|
||||
for (std::size_t i = 0; i < function.GetArguments().size(); ++i) {
|
||||
for (const auto& use : original_uses[i]) {
|
||||
if (auto* user = use.GetUser()) {
|
||||
user->SetOperand(use.GetOperandIndex(), phis[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return phis;
|
||||
}
|
||||
|
||||
void ReplaceTerminatorWithBranch(BasicBlock* block, BasicBlock* dest) {
|
||||
auto& instructions = block->GetInstructions();
|
||||
instructions.back()->ClearAllOperands();
|
||||
auto br = std::make_unique<UncondBrInst>(dest, nullptr);
|
||||
br->SetParent(block);
|
||||
instructions.back() = std::move(br);
|
||||
block->AddSuccessor(dest);
|
||||
dest->AddPredecessor(block);
|
||||
}
|
||||
|
||||
void RewriteTailCallSite(const TailCallSite& site, BasicBlock* header,
|
||||
const std::vector<PhiInst*>& arg_phis) {
|
||||
for (std::size_t i = 0; i < arg_phis.size(); ++i) {
|
||||
arg_phis[i]->AddIncoming(site.call->GetOperand(i + 1), site.block);
|
||||
}
|
||||
|
||||
ReplaceTerminatorWithBranch(site.block, header);
|
||||
site.block->EraseInstruction(site.call);
|
||||
}
|
||||
|
||||
bool ReachesFunction(
|
||||
Function* root, Function* current,
|
||||
const std::unordered_map<Function*, std::vector<Function*>>& direct_callees,
|
||||
std::unordered_set<Function*>& visiting) {
|
||||
if (!root || !current || current->IsExternal()) {
|
||||
return false;
|
||||
}
|
||||
if (!visiting.insert(current).second) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto it = direct_callees.find(current);
|
||||
if (it == direct_callees.end()) {
|
||||
return false;
|
||||
}
|
||||
for (auto* callee : it->second) {
|
||||
if (callee == root) {
|
||||
return true;
|
||||
}
|
||||
if (ReachesFunction(root, callee, direct_callees, visiting)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void RecomputeRecursiveFlags(Module& module) {
|
||||
std::unordered_map<Function*, std::vector<Function*>> direct_callees;
|
||||
for (const auto& function_ptr : module.GetFunctions()) {
|
||||
auto* function = function_ptr.get();
|
||||
if (!function || function->IsExternal()) {
|
||||
continue;
|
||||
}
|
||||
auto& callees = direct_callees[function];
|
||||
for (const auto& block_ptr : function->GetBlocks()) {
|
||||
for (const auto& inst_ptr : block_ptr->GetInstructions()) {
|
||||
auto* call = dyncast<CallInst>(inst_ptr.get());
|
||||
auto* callee = call ? call->GetCallee() : nullptr;
|
||||
if (callee && !callee->IsExternal() &&
|
||||
std::find(callees.begin(), callees.end(), callee) == callees.end()) {
|
||||
callees.push_back(callee);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& function_ptr : module.GetFunctions()) {
|
||||
auto* function = function_ptr.get();
|
||||
if (!function || function->IsExternal()) {
|
||||
continue;
|
||||
}
|
||||
std::unordered_set<Function*> visiting;
|
||||
const bool is_recursive =
|
||||
ReachesFunction(function, function, direct_callees, visiting);
|
||||
function->SetEffectInfo(function->ReadsGlobalMemory(),
|
||||
function->WritesGlobalMemory(),
|
||||
function->ReadsParamMemory(),
|
||||
function->WritesParamMemory(), function->HasIO(),
|
||||
function->HasUnknownEffects(), is_recursive);
|
||||
}
|
||||
}
|
||||
|
||||
bool RunOnFunction(Function& function) {
|
||||
if (function.IsExternal() || !function.GetEntryBlock() || HasEntryPhi(function)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto sites = CollectTailCallSites(function);
|
||||
if (sites.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* header = function.GetEntryBlock();
|
||||
auto* preheader = InsertPreheader(function, header);
|
||||
auto arg_phis = CreateArgumentPhis(function, header, preheader);
|
||||
|
||||
for (const auto& site : sites) {
|
||||
RewriteTailCallSite(site, header, arg_phis);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunTailRecursionElim(Module& module) {
|
||||
bool changed = false;
|
||||
for (const auto& function_ptr : module.GetFunctions()) {
|
||||
if (function_ptr) {
|
||||
changed |= RunOnFunction(*function_ptr);
|
||||
}
|
||||
}
|
||||
if (changed) {
|
||||
RecomputeRecursiveFlags(module);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
@ -0,0 +1,140 @@
|
||||
#include "mir/MIR.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace mir {
|
||||
namespace {
|
||||
|
||||
bool IsHoistCandidate(const MachineFunction& function, int object_index, int use_count) {
|
||||
const auto& object = function.GetStackObject(object_index);
|
||||
if (object.kind != StackObjectKind::Local) {
|
||||
return false;
|
||||
}
|
||||
if (use_count < 2) {
|
||||
return false;
|
||||
}
|
||||
if (object.size >= 4096) {
|
||||
return true;
|
||||
}
|
||||
return object.size >= 256 && use_count >= 4;
|
||||
}
|
||||
|
||||
bool IsPlainFrameLea(const MachineInstr& inst, int object_index) {
|
||||
if (inst.GetOpcode() != MachineInstr::Opcode::Lea || !inst.HasAddress() ||
|
||||
inst.GetOperands().empty() || inst.GetOperands()[0].GetKind() != OperandKind::VReg) {
|
||||
return false;
|
||||
}
|
||||
const auto& address = inst.GetAddress();
|
||||
return address.base_kind == AddrBaseKind::FrameObject &&
|
||||
address.base_index == object_index && address.const_offset == 0 &&
|
||||
address.scaled_vregs.empty();
|
||||
}
|
||||
|
||||
std::size_t FindEntryInsertPos(const MachineBasicBlock& block) {
|
||||
const auto& instructions = block.GetInstructions();
|
||||
std::size_t pos = 0;
|
||||
while (pos < instructions.size() &&
|
||||
instructions[pos].GetOpcode() == MachineInstr::Opcode::Arg) {
|
||||
++pos;
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void RunAddressHoisting(MachineModule& module) {
|
||||
for (auto& function : module.GetFunctions()) {
|
||||
if (!function || function->GetBlocks().empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unordered_map<int, int> use_counts;
|
||||
for (auto& block : function->GetBlocks()) {
|
||||
for (auto& inst : block->GetInstructions()) {
|
||||
if (!inst.HasAddress()) {
|
||||
continue;
|
||||
}
|
||||
const auto& address = inst.GetAddress();
|
||||
if (address.base_kind == AddrBaseKind::FrameObject && address.base_index >= 0) {
|
||||
++use_counts[address.base_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<int, int> base_vregs;
|
||||
for (const auto& [object_index, count] : use_counts) {
|
||||
if (!IsHoistCandidate(*function, object_index, count)) {
|
||||
continue;
|
||||
}
|
||||
base_vregs.emplace(object_index, -1);
|
||||
}
|
||||
if (base_vregs.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto& block : function->GetBlocks()) {
|
||||
for (auto& inst : block->GetInstructions()) {
|
||||
if (!inst.HasAddress()) {
|
||||
continue;
|
||||
}
|
||||
const auto& address = inst.GetAddress();
|
||||
auto it = base_vregs.find(address.base_index);
|
||||
if (it == base_vregs.end()) {
|
||||
continue;
|
||||
}
|
||||
if (it->second >= 0) {
|
||||
continue;
|
||||
}
|
||||
if (IsPlainFrameLea(inst, address.base_index)) {
|
||||
it->second = inst.GetOperands()[0].GetVReg();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto& entry_block = *function->GetBlocks().front();
|
||||
auto& entry_insts = entry_block.GetInstructions();
|
||||
std::size_t insert_pos = FindEntryInsertPos(entry_block);
|
||||
|
||||
for (auto& [object_index, base_vreg] : base_vregs) {
|
||||
if (base_vreg >= 0) {
|
||||
continue;
|
||||
}
|
||||
base_vreg = function->NewVReg(ValueType::Ptr);
|
||||
MachineInstr lea(MachineInstr::Opcode::Lea, {MachineOperand::VReg(base_vreg)});
|
||||
AddressExpr address;
|
||||
address.base_kind = AddrBaseKind::FrameObject;
|
||||
address.base_index = object_index;
|
||||
lea.SetAddress(std::move(address));
|
||||
entry_insts.insert(entry_insts.begin() + static_cast<std::ptrdiff_t>(insert_pos),
|
||||
std::move(lea));
|
||||
++insert_pos;
|
||||
}
|
||||
|
||||
for (auto& block : function->GetBlocks()) {
|
||||
for (auto& inst : block->GetInstructions()) {
|
||||
if (!inst.HasAddress()) {
|
||||
continue;
|
||||
}
|
||||
auto& address = inst.GetAddress();
|
||||
auto it = base_vregs.find(address.base_index);
|
||||
if (it == base_vregs.end()) {
|
||||
continue;
|
||||
}
|
||||
if (IsPlainFrameLea(inst, address.base_index) &&
|
||||
inst.GetOperands()[0].GetKind() == OperandKind::VReg &&
|
||||
inst.GetOperands()[0].GetVReg() == it->second) {
|
||||
continue;
|
||||
}
|
||||
if (address.base_kind != AddrBaseKind::FrameObject || address.base_index < 0) {
|
||||
continue;
|
||||
}
|
||||
address.base_kind = AddrBaseKind::VReg;
|
||||
address.base_index = it->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mir
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,24 +1,25 @@
|
||||
add_library(mir_core STATIC
|
||||
MIRContext.cpp
|
||||
MIRFunction.cpp
|
||||
MIRBasicBlock.cpp
|
||||
MIRInstr.cpp
|
||||
Register.cpp
|
||||
Lowering.cpp
|
||||
RegAlloc.cpp
|
||||
FrameLowering.cpp
|
||||
AsmPrinter.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(mir_core PUBLIC
|
||||
build_options
|
||||
ir
|
||||
)
|
||||
|
||||
add_subdirectory(passes)
|
||||
|
||||
add_library(mir INTERFACE)
|
||||
target_link_libraries(mir INTERFACE
|
||||
mir_core
|
||||
mir_passes
|
||||
)
|
||||
add_library(mir_core STATIC
|
||||
MIRContext.cpp
|
||||
MIRFunction.cpp
|
||||
MIRBasicBlock.cpp
|
||||
MIRInstr.cpp
|
||||
Register.cpp
|
||||
Lowering.cpp
|
||||
AddressHoisting.cpp
|
||||
RegAlloc.cpp
|
||||
FrameLowering.cpp
|
||||
AsmPrinter.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(mir_core PUBLIC
|
||||
build_options
|
||||
ir
|
||||
)
|
||||
|
||||
add_subdirectory(passes)
|
||||
|
||||
add_library(mir INTERFACE)
|
||||
target_link_libraries(mir INTERFACE
|
||||
mir_core
|
||||
mir_passes
|
||||
)
|
||||
|
||||
@ -1,45 +1,40 @@
|
||||
#include "mir/MIR.h"
|
||||
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/Log.h"
|
||||
#include <string>
|
||||
|
||||
namespace mir {
|
||||
namespace {
|
||||
|
||||
int AlignTo(int value, int align) {
|
||||
if (align <= 1) {
|
||||
return value;
|
||||
}
|
||||
return ((value + align - 1) / align) * align;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void RunFrameLowering(MachineFunction& function) {
|
||||
int cursor = 0;
|
||||
for (const auto& slot : function.GetFrameSlots()) {
|
||||
cursor += slot.size;
|
||||
if (-cursor < -256) {
|
||||
throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
|
||||
void RunFrameLowering(MachineModule& module) {
|
||||
for (auto& function : module.GetFunctions()) {
|
||||
for (int reg : function->GetUsedCalleeSavedGPRs()) {
|
||||
function->CreateStackObject(8, 8, StackObjectKind::SavedGPR,
|
||||
"save.x" + std::to_string(reg));
|
||||
}
|
||||
for (int reg : function->GetUsedCalleeSavedFPRs()) {
|
||||
function->CreateStackObject(8, 8, StackObjectKind::SavedFPR,
|
||||
"save.v" + std::to_string(reg));
|
||||
}
|
||||
}
|
||||
|
||||
cursor = 0;
|
||||
for (const auto& slot : function.GetFrameSlots()) {
|
||||
cursor += slot.size;
|
||||
function.GetFrameSlot(slot.index).offset = -cursor;
|
||||
}
|
||||
function.SetFrameSize(AlignTo(cursor, 16));
|
||||
|
||||
auto& insts = function.GetEntry().GetInstructions();
|
||||
std::vector<MachineInstr> lowered;
|
||||
lowered.emplace_back(Opcode::Prologue);
|
||||
for (const auto& inst : insts) {
|
||||
if (inst.GetOpcode() == Opcode::Ret) {
|
||||
lowered.emplace_back(Opcode::Epilogue);
|
||||
int cursor = 0;
|
||||
const int object_count = static_cast<int>(function->GetStackObjects().size());
|
||||
for (int i = 0; i < object_count; ++i) {
|
||||
auto& object = function->GetStackObject(i);
|
||||
cursor = AlignTo(cursor, object.align);
|
||||
cursor += object.size;
|
||||
object.offset = -cursor;
|
||||
}
|
||||
lowered.push_back(inst);
|
||||
function->SetFrameSize(AlignTo(cursor, 16));
|
||||
}
|
||||
insts = std::move(lowered);
|
||||
}
|
||||
|
||||
} // namespace mir
|
||||
} // namespace mir
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,33 +1,106 @@
|
||||
#include "mir/MIR.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
|
||||
#include "utils/Log.h"
|
||||
|
||||
namespace mir {
|
||||
|
||||
MachineFunction::MachineFunction(std::string name)
|
||||
: name_(std::move(name)), entry_("entry") {}
|
||||
MachineFunction::MachineFunction(std::string name, ValueType return_type,
|
||||
std::vector<ValueType> param_types)
|
||||
: name_(std::move(name)),
|
||||
return_type_(return_type),
|
||||
param_types_(std::move(param_types)) {}
|
||||
|
||||
MachineBasicBlock* MachineFunction::CreateBlock(const std::string& name) {
|
||||
auto block = std::make_unique<MachineBasicBlock>(name);
|
||||
auto* ptr = block.get();
|
||||
blocks_.push_back(std::move(block));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
int MachineFunction::NewVReg(ValueType type) {
|
||||
const int id = static_cast<int>(vregs_.size());
|
||||
vregs_.push_back({id, type});
|
||||
allocations_.push_back({});
|
||||
return id;
|
||||
}
|
||||
|
||||
const VRegInfo& MachineFunction::GetVRegInfo(int id) const {
|
||||
if (id < 0 || id >= static_cast<int>(vregs_.size())) {
|
||||
throw std::out_of_range("virtual register index out of range");
|
||||
}
|
||||
return vregs_[static_cast<size_t>(id)];
|
||||
}
|
||||
|
||||
VRegInfo& MachineFunction::GetVRegInfo(int id) {
|
||||
if (id < 0 || id >= static_cast<int>(vregs_.size())) {
|
||||
throw std::out_of_range("virtual register index out of range");
|
||||
}
|
||||
return vregs_[static_cast<size_t>(id)];
|
||||
}
|
||||
|
||||
int MachineFunction::CreateFrameIndex(int size) {
|
||||
int index = static_cast<int>(frame_slots_.size());
|
||||
frame_slots_.push_back(FrameSlot{index, size, 0});
|
||||
int MachineFunction::CreateStackObject(int size, int align, StackObjectKind kind,
|
||||
const std::string& name) {
|
||||
const int index = static_cast<int>(stack_objects_.size());
|
||||
stack_objects_.push_back({index, kind, size, align, 0, name});
|
||||
return index;
|
||||
}
|
||||
|
||||
FrameSlot& MachineFunction::GetFrameSlot(int index) {
|
||||
if (index < 0 || index >= static_cast<int>(frame_slots_.size())) {
|
||||
throw std::runtime_error(FormatError("mir", "非法 FrameIndex"));
|
||||
StackObject& MachineFunction::GetStackObject(int index) {
|
||||
if (index < 0 || index >= static_cast<int>(stack_objects_.size())) {
|
||||
throw std::out_of_range("stack object index out of range");
|
||||
}
|
||||
return stack_objects_[static_cast<size_t>(index)];
|
||||
}
|
||||
|
||||
const StackObject& MachineFunction::GetStackObject(int index) const {
|
||||
if (index < 0 || index >= static_cast<int>(stack_objects_.size())) {
|
||||
throw std::out_of_range("stack object index out of range");
|
||||
}
|
||||
return frame_slots_[index];
|
||||
return stack_objects_[static_cast<size_t>(index)];
|
||||
}
|
||||
|
||||
const FrameSlot& MachineFunction::GetFrameSlot(int index) const {
|
||||
if (index < 0 || index >= static_cast<int>(frame_slots_.size())) {
|
||||
throw std::runtime_error(FormatError("mir", "非法 FrameIndex"));
|
||||
void MachineFunction::SetAllocation(int vreg, Allocation allocation) {
|
||||
if (vreg < 0 || vreg >= static_cast<int>(allocations_.size())) {
|
||||
throw std::out_of_range("allocation index out of range");
|
||||
}
|
||||
return frame_slots_[index];
|
||||
allocations_[static_cast<size_t>(vreg)] = allocation;
|
||||
}
|
||||
|
||||
const Allocation& MachineFunction::GetAllocation(int vreg) const {
|
||||
if (vreg < 0 || vreg >= static_cast<int>(allocations_.size())) {
|
||||
throw std::out_of_range("allocation index out of range");
|
||||
}
|
||||
return allocations_[static_cast<size_t>(vreg)];
|
||||
}
|
||||
|
||||
Allocation& MachineFunction::GetAllocation(int vreg) {
|
||||
if (vreg < 0 || vreg >= static_cast<int>(allocations_.size())) {
|
||||
throw std::out_of_range("allocation index out of range");
|
||||
}
|
||||
return allocations_[static_cast<size_t>(vreg)];
|
||||
}
|
||||
|
||||
void MachineFunction::AddUsedCalleeSavedGPR(int reg_index) {
|
||||
if (std::find(used_callee_saved_gprs_.begin(), used_callee_saved_gprs_.end(),
|
||||
reg_index) == used_callee_saved_gprs_.end()) {
|
||||
used_callee_saved_gprs_.push_back(reg_index);
|
||||
}
|
||||
}
|
||||
|
||||
void MachineFunction::AddUsedCalleeSavedFPR(int reg_index) {
|
||||
if (std::find(used_callee_saved_fprs_.begin(), used_callee_saved_fprs_.end(),
|
||||
reg_index) == used_callee_saved_fprs_.end()) {
|
||||
used_callee_saved_fprs_.push_back(reg_index);
|
||||
}
|
||||
}
|
||||
|
||||
MachineFunction* MachineModule::AddFunction(
|
||||
std::unique_ptr<MachineFunction> function) {
|
||||
auto* ptr = function.get();
|
||||
functions_.push_back(std::move(function));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
} // namespace mir
|
||||
} // namespace mir
|
||||
|
||||
@ -1,23 +1,186 @@
|
||||
#include "mir/MIR.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace mir {
|
||||
|
||||
Operand::Operand(Kind kind, PhysReg reg, int imm)
|
||||
: kind_(kind), reg_(reg), imm_(imm) {}
|
||||
MachineOperand::MachineOperand(OperandKind kind, int vreg, std::int64_t imm,
|
||||
std::string text)
|
||||
: kind_(kind), vreg_(vreg), imm_(imm), text_(std::move(text)) {}
|
||||
|
||||
MachineOperand MachineOperand::VReg(int reg) {
|
||||
return MachineOperand(OperandKind::VReg, reg, 0, "");
|
||||
}
|
||||
|
||||
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); }
|
||||
MachineOperand MachineOperand::Imm(std::int64_t value) {
|
||||
return MachineOperand(OperandKind::Imm, -1, value, "");
|
||||
}
|
||||
|
||||
Operand Operand::Imm(int value) {
|
||||
return Operand(Kind::Imm, PhysReg::W0, value);
|
||||
MachineOperand MachineOperand::Block(std::string name) {
|
||||
return MachineOperand(OperandKind::Block, -1, 0, std::move(name));
|
||||
}
|
||||
|
||||
Operand Operand::FrameIndex(int index) {
|
||||
return Operand(Kind::FrameIndex, PhysReg::W0, index);
|
||||
MachineOperand MachineOperand::Symbol(std::string name) {
|
||||
return MachineOperand(OperandKind::Symbol, -1, 0, std::move(name));
|
||||
}
|
||||
|
||||
MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands)
|
||||
MachineInstr::MachineInstr(Opcode opcode, std::vector<MachineOperand> operands)
|
||||
: opcode_(opcode), operands_(std::move(operands)) {}
|
||||
|
||||
} // namespace mir
|
||||
bool MachineInstr::IsTerminator() const {
|
||||
return opcode_ == Opcode::Br || opcode_ == Opcode::CondBr ||
|
||||
opcode_ == Opcode::Ret || opcode_ == Opcode::Unreachable;
|
||||
}
|
||||
|
||||
std::vector<int> MachineInstr::GetDefs() const {
|
||||
switch (opcode_) {
|
||||
case Opcode::Arg:
|
||||
case Opcode::Copy:
|
||||
case Opcode::Load:
|
||||
case Opcode::Lea:
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::Div:
|
||||
case Opcode::Rem:
|
||||
case Opcode::ModMul:
|
||||
case Opcode::ModPow:
|
||||
case Opcode::DigitExtractPow2:
|
||||
case Opcode::And:
|
||||
case Opcode::Or:
|
||||
case Opcode::Xor:
|
||||
case Opcode::Shl:
|
||||
case Opcode::AShr:
|
||||
case Opcode::LShr:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FSub:
|
||||
case Opcode::FMul:
|
||||
case Opcode::FDiv:
|
||||
case Opcode::FSqrt:
|
||||
case Opcode::FNeg:
|
||||
case Opcode::ICmp:
|
||||
case Opcode::FCmp:
|
||||
case Opcode::ZExt:
|
||||
case Opcode::ItoF:
|
||||
case Opcode::FtoI:
|
||||
if (!operands_.empty() && operands_[0].GetKind() == OperandKind::VReg) {
|
||||
return {operands_[0].GetVReg()};
|
||||
}
|
||||
return {};
|
||||
case Opcode::Call:
|
||||
if (call_return_type_ != ValueType::Void && !operands_.empty() &&
|
||||
operands_[0].GetKind() == OperandKind::VReg) {
|
||||
return {operands_[0].GetVReg()};
|
||||
}
|
||||
return {};
|
||||
case Opcode::Store:
|
||||
case Opcode::Br:
|
||||
case Opcode::CondBr:
|
||||
case Opcode::Ret:
|
||||
case Opcode::Memset:
|
||||
case Opcode::Unreachable:
|
||||
return {};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<int> MachineInstr::GetUses() const {
|
||||
std::vector<int> uses;
|
||||
auto push_vreg = [&](const MachineOperand& operand) {
|
||||
if (operand.GetKind() == OperandKind::VReg) {
|
||||
uses.push_back(operand.GetVReg());
|
||||
}
|
||||
};
|
||||
auto push_addr_uses = [&]() {
|
||||
if (!has_address_) {
|
||||
return;
|
||||
}
|
||||
if (address_.base_kind == AddrBaseKind::VReg && address_.base_index >= 0) {
|
||||
uses.push_back(address_.base_index);
|
||||
}
|
||||
for (const auto& term : address_.scaled_vregs) {
|
||||
uses.push_back(term.first);
|
||||
}
|
||||
};
|
||||
|
||||
switch (opcode_) {
|
||||
case Opcode::Arg:
|
||||
case Opcode::Br:
|
||||
case Opcode::Unreachable:
|
||||
break;
|
||||
case Opcode::Copy:
|
||||
case Opcode::ZExt:
|
||||
case Opcode::ItoF:
|
||||
case Opcode::FtoI:
|
||||
case Opcode::FSqrt:
|
||||
case Opcode::FNeg:
|
||||
if (operands_.size() >= 2) {
|
||||
push_vreg(operands_[1]);
|
||||
}
|
||||
break;
|
||||
case Opcode::Load:
|
||||
case Opcode::Lea:
|
||||
push_addr_uses();
|
||||
break;
|
||||
case Opcode::Store:
|
||||
if (!operands_.empty()) {
|
||||
push_vreg(operands_[0]);
|
||||
}
|
||||
push_addr_uses();
|
||||
break;
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::Div:
|
||||
case Opcode::Rem:
|
||||
case Opcode::ModMul:
|
||||
case Opcode::ModPow:
|
||||
case Opcode::DigitExtractPow2:
|
||||
case Opcode::And:
|
||||
case Opcode::Or:
|
||||
case Opcode::Xor:
|
||||
case Opcode::Shl:
|
||||
case Opcode::AShr:
|
||||
case Opcode::LShr:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FSub:
|
||||
case Opcode::FMul:
|
||||
case Opcode::FDiv:
|
||||
case Opcode::ICmp:
|
||||
case Opcode::FCmp:
|
||||
if (operands_.size() >= 2) {
|
||||
push_vreg(operands_[1]);
|
||||
}
|
||||
if (operands_.size() >= 3) {
|
||||
push_vreg(operands_[2]);
|
||||
}
|
||||
break;
|
||||
case Opcode::CondBr:
|
||||
if (!operands_.empty()) {
|
||||
push_vreg(operands_[0]);
|
||||
}
|
||||
break;
|
||||
case Opcode::Call: {
|
||||
size_t arg_begin = call_return_type_ == ValueType::Void ? 0 : 1;
|
||||
for (size_t i = arg_begin; i < operands_.size(); ++i) {
|
||||
push_vreg(operands_[i]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Opcode::Ret:
|
||||
if (!operands_.empty()) {
|
||||
push_vreg(operands_[0]);
|
||||
}
|
||||
break;
|
||||
case Opcode::Memset:
|
||||
if (!operands_.empty()) {
|
||||
push_vreg(operands_[0]);
|
||||
}
|
||||
if (operands_.size() >= 2) {
|
||||
push_vreg(operands_[1]);
|
||||
}
|
||||
push_addr_uses();
|
||||
break;
|
||||
}
|
||||
return uses;
|
||||
}
|
||||
|
||||
} // namespace mir
|
||||
|
||||
@ -0,0 +1,239 @@
|
||||
#include "mir/MIR.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace mir {
|
||||
namespace {
|
||||
|
||||
using BlockList = std::vector<std::unique_ptr<MachineBasicBlock>>;
|
||||
|
||||
int FindBlockIndex(const MachineFunction& function, const std::string& name) {
|
||||
const auto& blocks = function.GetBlocks();
|
||||
for (size_t i = 0; i < blocks.size(); ++i) {
|
||||
if (blocks[i] && blocks[i]->GetName() == name) {
|
||||
return static_cast<int>(i);
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<int> CollectSuccessors(const MachineFunction& function, int index) {
|
||||
std::vector<int> succs;
|
||||
const auto& blocks = function.GetBlocks();
|
||||
if (index < 0 || index >= static_cast<int>(blocks.size()) || !blocks[index]) {
|
||||
return succs;
|
||||
}
|
||||
const auto& instructions = blocks[index]->GetInstructions();
|
||||
if (instructions.empty()) {
|
||||
return succs;
|
||||
}
|
||||
const auto& term = instructions.back();
|
||||
if (term.GetOpcode() == MachineInstr::Opcode::Br && !term.GetOperands().empty() &&
|
||||
term.GetOperands()[0].GetKind() == OperandKind::Block) {
|
||||
const int succ = FindBlockIndex(function, term.GetOperands()[0].GetText());
|
||||
if (succ >= 0) {
|
||||
succs.push_back(succ);
|
||||
}
|
||||
return succs;
|
||||
}
|
||||
if (term.GetOpcode() == MachineInstr::Opcode::CondBr &&
|
||||
term.GetOperands().size() >= 3) {
|
||||
for (size_t i = 1; i <= 2; ++i) {
|
||||
if (term.GetOperands()[i].GetKind() != OperandKind::Block) {
|
||||
continue;
|
||||
}
|
||||
const int succ = FindBlockIndex(function, term.GetOperands()[i].GetText());
|
||||
if (succ >= 0 &&
|
||||
std::find(succs.begin(), succs.end(), succ) == succs.end()) {
|
||||
succs.push_back(succ);
|
||||
}
|
||||
}
|
||||
}
|
||||
return succs;
|
||||
}
|
||||
|
||||
std::vector<int> BuildPredecessorCount(const MachineFunction& function) {
|
||||
std::vector<int> preds(function.GetBlocks().size(), 0);
|
||||
for (size_t i = 0; i < function.GetBlocks().size(); ++i) {
|
||||
for (int succ : CollectSuccessors(function, static_cast<int>(i))) {
|
||||
++preds[static_cast<size_t>(succ)];
|
||||
}
|
||||
}
|
||||
return preds;
|
||||
}
|
||||
|
||||
bool IsTrivialJumpBlock(const MachineFunction& function, int index) {
|
||||
const auto& blocks = function.GetBlocks();
|
||||
if (index < 0 || index >= static_cast<int>(blocks.size()) || !blocks[index]) {
|
||||
return false;
|
||||
}
|
||||
const auto& instructions = blocks[index]->GetInstructions();
|
||||
return instructions.size() == 1 &&
|
||||
instructions.front().GetOpcode() == MachineInstr::Opcode::Br &&
|
||||
!instructions.front().GetOperands().empty() &&
|
||||
instructions.front().GetOperands()[0].GetKind() == OperandKind::Block;
|
||||
}
|
||||
|
||||
std::string ResolveJumpChain(const MachineFunction& function, const std::string& target) {
|
||||
std::string current = target;
|
||||
std::unordered_set<std::string> visited{current};
|
||||
|
||||
while (true) {
|
||||
const int index = FindBlockIndex(function, current);
|
||||
if (index < 0 || !IsTrivialJumpBlock(function, index)) {
|
||||
return current;
|
||||
}
|
||||
const auto& inst = function.GetBlocks()[static_cast<size_t>(index)]->GetInstructions().front();
|
||||
const std::string& next = inst.GetOperands()[0].GetText();
|
||||
if (!visited.insert(next).second) {
|
||||
return current;
|
||||
}
|
||||
current = next;
|
||||
}
|
||||
}
|
||||
|
||||
bool RewriteBranchTargets(MachineFunction& function) {
|
||||
bool changed = false;
|
||||
for (auto& block : function.GetBlocks()) {
|
||||
if (!block || block->GetInstructions().empty()) {
|
||||
continue;
|
||||
}
|
||||
auto& term = block->GetInstructions().back();
|
||||
auto& operands = term.GetOperands();
|
||||
if (term.GetOpcode() == MachineInstr::Opcode::Br && !operands.empty() &&
|
||||
operands[0].GetKind() == OperandKind::Block) {
|
||||
const std::string resolved = ResolveJumpChain(function, operands[0].GetText());
|
||||
if (resolved != operands[0].GetText()) {
|
||||
operands[0] = MachineOperand::Block(resolved);
|
||||
changed = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (term.GetOpcode() != MachineInstr::Opcode::CondBr || operands.size() < 3) {
|
||||
continue;
|
||||
}
|
||||
for (size_t i = 1; i <= 2; ++i) {
|
||||
if (operands[i].GetKind() != OperandKind::Block) {
|
||||
continue;
|
||||
}
|
||||
const std::string resolved = ResolveJumpChain(function, operands[i].GetText());
|
||||
if (resolved != operands[i].GetText()) {
|
||||
operands[i] = MachineOperand::Block(resolved);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
if (operands[1].GetKind() == OperandKind::Block &&
|
||||
operands[2].GetKind() == OperandKind::Block &&
|
||||
operands[1].GetText() == operands[2].GetText()) {
|
||||
term = MachineInstr(MachineInstr::Opcode::Br, {operands[1]});
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RemoveUnreachableBlocks(MachineFunction& function) {
|
||||
auto& blocks = function.GetBlocks();
|
||||
if (blocks.empty() || !blocks.front()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_set<std::string> reachable;
|
||||
std::vector<std::string> stack{blocks.front()->GetName()};
|
||||
while (!stack.empty()) {
|
||||
std::string name = stack.back();
|
||||
stack.pop_back();
|
||||
if (!reachable.insert(name).second) {
|
||||
continue;
|
||||
}
|
||||
const int index = FindBlockIndex(function, name);
|
||||
if (index < 0) {
|
||||
continue;
|
||||
}
|
||||
for (int succ : CollectSuccessors(function, index)) {
|
||||
stack.push_back(blocks[static_cast<size_t>(succ)]->GetName());
|
||||
}
|
||||
}
|
||||
|
||||
const size_t old_size = blocks.size();
|
||||
blocks.erase(std::remove_if(blocks.begin(), blocks.end(),
|
||||
[&](const std::unique_ptr<MachineBasicBlock>& block) {
|
||||
return block && reachable.count(block->GetName()) == 0;
|
||||
}),
|
||||
blocks.end());
|
||||
return blocks.size() != old_size;
|
||||
}
|
||||
|
||||
bool MergeLinearBlocks(MachineFunction& function) {
|
||||
auto preds = BuildPredecessorCount(function);
|
||||
auto& blocks = function.GetBlocks();
|
||||
|
||||
for (size_t i = 0; i < blocks.size(); ++i) {
|
||||
auto& block = blocks[i];
|
||||
if (!block || block->GetInstructions().empty()) {
|
||||
continue;
|
||||
}
|
||||
auto& insts = block->GetInstructions();
|
||||
auto& term = insts.back();
|
||||
if (term.GetOpcode() != MachineInstr::Opcode::Br || term.GetOperands().empty() ||
|
||||
term.GetOperands()[0].GetKind() != OperandKind::Block) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int succ_index = FindBlockIndex(function, term.GetOperands()[0].GetText());
|
||||
if (succ_index <= 0 || succ_index == static_cast<int>(i) ||
|
||||
preds[static_cast<size_t>(succ_index)] != 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& succ = blocks[static_cast<size_t>(succ_index)];
|
||||
if (!succ || succ->GetInstructions().empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
insts.pop_back();
|
||||
auto& succ_insts = succ->GetInstructions();
|
||||
insts.insert(insts.end(),
|
||||
std::make_move_iterator(succ_insts.begin()),
|
||||
std::make_move_iterator(succ_insts.end()));
|
||||
blocks.erase(blocks.begin() + succ_index);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool RunCFGCleanupOnFunction(MachineFunction& function) {
|
||||
bool changed = false;
|
||||
while (true) {
|
||||
bool local_changed = false;
|
||||
local_changed |= RewriteBranchTargets(function);
|
||||
local_changed |= RemoveUnreachableBlocks(function);
|
||||
if (MergeLinearBlocks(function)) {
|
||||
local_changed = true;
|
||||
}
|
||||
changed |= local_changed;
|
||||
if (!local_changed) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunCFGCleanup(MachineModule& module) {
|
||||
bool changed = false;
|
||||
for (auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunCFGCleanupOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace mir
|
||||
@ -0,0 +1,257 @@
|
||||
#include "mir/MIR.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace mir {
|
||||
namespace {
|
||||
|
||||
struct RematDef {
|
||||
enum class Kind { Invalid, ImmCopy, Lea };
|
||||
|
||||
Kind kind = Kind::Invalid;
|
||||
ValueType type = ValueType::Void;
|
||||
MachineOperand source;
|
||||
AddressExpr address;
|
||||
};
|
||||
|
||||
bool IsCheapRematerializableDef(const MachineInstr& inst, RematDef& def) {
|
||||
const auto defs = inst.GetDefs();
|
||||
if (defs.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (inst.GetOpcode() == MachineInstr::Opcode::Copy) {
|
||||
const auto& operands = inst.GetOperands();
|
||||
if (operands.size() < 2 || operands[1].GetKind() != OperandKind::Imm) {
|
||||
return false;
|
||||
}
|
||||
def.kind = RematDef::Kind::ImmCopy;
|
||||
def.type = inst.GetValueType();
|
||||
def.source = operands[1];
|
||||
return true;
|
||||
}
|
||||
|
||||
if (inst.GetOpcode() != MachineInstr::Opcode::Lea || !inst.HasAddress()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& address = inst.GetAddress();
|
||||
if (address.base_kind == AddrBaseKind::VReg || !address.scaled_vregs.empty()) {
|
||||
return false;
|
||||
}
|
||||
def.kind = RematDef::Kind::Lea;
|
||||
def.type = ValueType::Ptr;
|
||||
def.address = address;
|
||||
return true;
|
||||
}
|
||||
|
||||
MachineInstr BuildRematInstr(int dst_vreg, const RematDef& def) {
|
||||
switch (def.kind) {
|
||||
case RematDef::Kind::ImmCopy: {
|
||||
MachineInstr inst(MachineInstr::Opcode::Copy,
|
||||
{MachineOperand::VReg(dst_vreg), def.source});
|
||||
inst.SetValueType(def.type);
|
||||
return inst;
|
||||
}
|
||||
case RematDef::Kind::Lea: {
|
||||
MachineInstr inst(MachineInstr::Opcode::Lea, {MachineOperand::VReg(dst_vreg)});
|
||||
inst.SetAddress(def.address);
|
||||
inst.SetValueType(ValueType::Ptr);
|
||||
return inst;
|
||||
}
|
||||
case RematDef::Kind::Invalid:
|
||||
break;
|
||||
}
|
||||
return MachineInstr(MachineInstr::Opcode::Unreachable, {});
|
||||
}
|
||||
|
||||
bool RewriteMappedOperand(MachineOperand& operand,
|
||||
const std::unordered_map<int, int>& rename_map) {
|
||||
if (operand.GetKind() != OperandKind::VReg) {
|
||||
return false;
|
||||
}
|
||||
auto it = rename_map.find(operand.GetVReg());
|
||||
if (it == rename_map.end() || it->second == operand.GetVReg()) {
|
||||
return false;
|
||||
}
|
||||
operand = MachineOperand::VReg(it->second);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RewriteMappedAddress(AddressExpr& address,
|
||||
const std::unordered_map<int, int>& rename_map) {
|
||||
bool changed = false;
|
||||
if (address.base_kind == AddrBaseKind::VReg && address.base_index >= 0) {
|
||||
auto it = rename_map.find(address.base_index);
|
||||
if (it != rename_map.end() && it->second != address.base_index) {
|
||||
address.base_index = it->second;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
for (auto& term : address.scaled_vregs) {
|
||||
auto it = rename_map.find(term.first);
|
||||
if (it != rename_map.end() && it->second != term.first) {
|
||||
term.first = it->second;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RewriteUses(MachineInstr& inst, const std::unordered_map<int, int>& rename_map) {
|
||||
bool changed = false;
|
||||
auto& operands = inst.GetOperands();
|
||||
switch (inst.GetOpcode()) {
|
||||
case MachineInstr::Opcode::Copy:
|
||||
case MachineInstr::Opcode::ZExt:
|
||||
case MachineInstr::Opcode::ItoF:
|
||||
case MachineInstr::Opcode::FtoI:
|
||||
case MachineInstr::Opcode::FSqrt:
|
||||
case MachineInstr::Opcode::FNeg:
|
||||
if (operands.size() >= 2) {
|
||||
changed |= RewriteMappedOperand(operands[1], rename_map);
|
||||
}
|
||||
break;
|
||||
case MachineInstr::Opcode::Store:
|
||||
if (!operands.empty()) {
|
||||
changed |= RewriteMappedOperand(operands[0], rename_map);
|
||||
}
|
||||
break;
|
||||
case MachineInstr::Opcode::Add:
|
||||
case MachineInstr::Opcode::Sub:
|
||||
case MachineInstr::Opcode::Mul:
|
||||
case MachineInstr::Opcode::Div:
|
||||
case MachineInstr::Opcode::Rem:
|
||||
case MachineInstr::Opcode::ModMul:
|
||||
case MachineInstr::Opcode::ModPow:
|
||||
case MachineInstr::Opcode::DigitExtractPow2:
|
||||
case MachineInstr::Opcode::And:
|
||||
case MachineInstr::Opcode::Or:
|
||||
case MachineInstr::Opcode::Xor:
|
||||
case MachineInstr::Opcode::Shl:
|
||||
case MachineInstr::Opcode::AShr:
|
||||
case MachineInstr::Opcode::LShr:
|
||||
case MachineInstr::Opcode::FAdd:
|
||||
case MachineInstr::Opcode::FSub:
|
||||
case MachineInstr::Opcode::FMul:
|
||||
case MachineInstr::Opcode::FDiv:
|
||||
case MachineInstr::Opcode::ICmp:
|
||||
case MachineInstr::Opcode::FCmp:
|
||||
if (operands.size() >= 2) {
|
||||
changed |= RewriteMappedOperand(operands[1], rename_map);
|
||||
}
|
||||
if (operands.size() >= 3) {
|
||||
changed |= RewriteMappedOperand(operands[2], rename_map);
|
||||
}
|
||||
break;
|
||||
case MachineInstr::Opcode::CondBr:
|
||||
if (!operands.empty()) {
|
||||
changed |= RewriteMappedOperand(operands[0], rename_map);
|
||||
}
|
||||
break;
|
||||
case MachineInstr::Opcode::Call: {
|
||||
const size_t arg_begin = inst.GetCallReturnType() == ValueType::Void ? 0 : 1;
|
||||
for (size_t i = arg_begin; i < operands.size(); ++i) {
|
||||
changed |= RewriteMappedOperand(operands[i], rename_map);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case MachineInstr::Opcode::Ret:
|
||||
if (!operands.empty()) {
|
||||
changed |= RewriteMappedOperand(operands[0], rename_map);
|
||||
}
|
||||
break;
|
||||
case MachineInstr::Opcode::Memset:
|
||||
if (!operands.empty()) {
|
||||
changed |= RewriteMappedOperand(operands[0], rename_map);
|
||||
}
|
||||
if (operands.size() >= 2) {
|
||||
changed |= RewriteMappedOperand(operands[1], rename_map);
|
||||
}
|
||||
break;
|
||||
case MachineInstr::Opcode::Arg:
|
||||
case MachineInstr::Opcode::Load:
|
||||
case MachineInstr::Opcode::Lea:
|
||||
case MachineInstr::Opcode::Br:
|
||||
case MachineInstr::Opcode::Unreachable:
|
||||
break;
|
||||
}
|
||||
if (inst.HasAddress()) {
|
||||
changed |= RewriteMappedAddress(inst.GetAddress(), rename_map);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RunSpillReductionOnFunction(MachineFunction& function) {
|
||||
bool changed = false;
|
||||
|
||||
for (auto& block_ptr : function.GetBlocks()) {
|
||||
auto& instructions = block_ptr->GetInstructions();
|
||||
std::unordered_map<int, RematDef> available_defs;
|
||||
std::unordered_map<int, RematDef> after_call_defs;
|
||||
std::unordered_map<int, int> rename_map;
|
||||
bool after_call = false;
|
||||
|
||||
for (size_t i = 0; i < instructions.size(); ++i) {
|
||||
if (after_call) {
|
||||
const auto uses = instructions[i].GetUses();
|
||||
for (int use : uses) {
|
||||
if (rename_map.count(use) != 0) {
|
||||
continue;
|
||||
}
|
||||
auto it = after_call_defs.find(use);
|
||||
if (it == after_call_defs.end()) {
|
||||
continue;
|
||||
}
|
||||
const int new_vreg = function.NewVReg(function.GetVRegInfo(use).type);
|
||||
instructions.insert(instructions.begin() + static_cast<long long>(i),
|
||||
BuildRematInstr(new_vreg, it->second));
|
||||
++i;
|
||||
rename_map[use] = new_vreg;
|
||||
available_defs[new_vreg] = it->second;
|
||||
changed = true;
|
||||
}
|
||||
RewriteUses(instructions[i], rename_map);
|
||||
}
|
||||
|
||||
const auto defs = instructions[i].GetDefs();
|
||||
for (int def : defs) {
|
||||
available_defs.erase(def);
|
||||
after_call_defs.erase(def);
|
||||
rename_map.erase(def);
|
||||
}
|
||||
|
||||
RematDef def;
|
||||
if (IsCheapRematerializableDef(instructions[i], def)) {
|
||||
for (int vreg : defs) {
|
||||
available_defs[vreg] = def;
|
||||
}
|
||||
}
|
||||
|
||||
if (instructions[i].GetOpcode() == MachineInstr::Opcode::Call ||
|
||||
instructions[i].GetOpcode() == MachineInstr::Opcode::Memset) {
|
||||
after_call_defs = available_defs;
|
||||
rename_map.clear();
|
||||
after_call = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RunSpillReduction(MachineModule& module) {
|
||||
bool changed = false;
|
||||
for (auto& function : module.GetFunctions()) {
|
||||
if (function) {
|
||||
changed |= RunSpillReductionOnFunction(*function);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace mir
|
||||
@ -1 +0,0 @@
|
||||
21
|
||||
@ -1,9 +0,0 @@
|
||||
int main(){
|
||||
const int a[4][2] = {{1, 2}, {3, 4}, {}, 7};
|
||||
const int N = 3;
|
||||
int b[4][2] = {};
|
||||
int c[4][2] = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||
int d[N + 1][2] = {1, 2, {3}, {5}, a[3][0], 8};
|
||||
int e[4][2][1] = {{d[2][1], {c[2][1]}}, {3, 4}, {5, 6}, {7, 8}};
|
||||
return e[3][1][0] + e[0][0][0] + e[0][1][0] + d[3][0];
|
||||
}
|
||||
@ -1 +0,0 @@
|
||||
9
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue